ncnn

Форк
0
/
prelu_riscv.cpp 
617 строк · 18.8 Кб
1
// Xavier Hsinyuan is pleased to support the open source community by making ncnn available.
2
//
3
// Copyright (C) 2021 Xavier Hsinyuan <me@lstlx.com>. All rights reserved.
4
//
5
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
6
// in compliance with the License. You may obtain a copy of the License at
7
//
8
// https://opensource.org/licenses/BSD-3-Clause
9
//
10
// Unless required by applicable law or agreed to in writing, software distributed
11
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
12
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
13
// specific language governing permissions and limitations under the License.
14

15
#include "prelu_riscv.h"
16

17
#if __riscv_vector
18
#include <riscv_vector.h>
19
#endif // __riscv_vector
20

21
namespace ncnn {
22

23
PReLU_riscv::PReLU_riscv()
24
{
25
#if __riscv_vector
26
    support_packing = true;
27
#if __riscv_zfh
28
    support_fp16_storage = true;
29
#endif
30
#endif
31
}
32

33
int PReLU_riscv::forward_inplace(Mat& bottom_top_blob, const Option& opt) const
34
{
35
    int elembits = bottom_top_blob.elembits();
36

37
#if __riscv_vector && __riscv_zfh
38
    if (opt.use_fp16_storage && elembits == 16)
39
    {
40
        if (opt.use_fp16_arithmetic)
41
            return forward_inplace_fp16sa(bottom_top_blob, opt);
42
        else
43
            return forward_inplace_fp16s(bottom_top_blob, opt);
44
    }
45
#endif
46

47
    int w = bottom_top_blob.w;
48
    int h = bottom_top_blob.h;
49
    int channels = bottom_top_blob.c;
50
    int size = w * h;
51
    int elempack = bottom_top_blob.elempack;
52
    int dims = bottom_top_blob.dims;
53
#if __riscv_vector
54
    if (dims == 1)
55
    {
56
        int w = bottom_top_blob.w;
57
        float* ptr = bottom_top_blob;
58
        const float* ptr_slope = slope_data;
59
        if (num_slope > 1)
60
        {
61
            int n = w * elempack;
62

63
            // #pragma omp parallel for num_threads(opt.num_threads)
64
            while (n > 0)
65
            {
66
                size_t vl = vsetvl_e32m8(n);
67
                vfloat32m8_t _p = vle32_v_f32m8(ptr, vl);
68
                vfloat32m8_t _slope = vle32_v_f32m8(ptr_slope, vl);
69
                vbool4_t _lower = vmflt_vf_f32m8_b4(_p, .0f, vl);
70

71
                _p = vfmul_vv_f32m8_m(_lower, _p, /*op1*/ _p, _slope, vl);
72
                vse32_v_f32m8(ptr, _p, vl);
73

74
                ptr += vl;
75
                ptr_slope += vl;
76
                n -= vl;
77
            }
78
        }
79
        else
80
        {
81
            float slope = slope_data[0];
82

83
            int n = w * elempack;
84
            // #pragma omp parallel for num_threads(opt.num_threads)
85
            while (n > 0)
86
            {
87
                size_t vl = vsetvl_e32m8(n);
88
                vfloat32m8_t _p = vle32_v_f32m8(ptr, vl);
89
                vbool4_t _lower = vmflt_vf_f32m8_b4(_p, .0f, vl);
90

91
                _p = vfmul_vf_f32m8_m(_lower, _p, /*op1*/ _p, slope, vl);
92
                vse32_v_f32m8(ptr, _p, vl);
93

94
                ptr += vl;
95
                n -= vl;
96
            }
97
        }
98
    }
99

100
    if (dims == 2)
101
    {
102
        int w = bottom_top_blob.w;
103
        int h = bottom_top_blob.h;
104

105
        #pragma omp parallel for num_threads(opt.num_threads)
106
        for (int i = 0; i < h; i++)
107
        {
108
            float* ptr = bottom_top_blob.row(i);
109
            if (num_slope > 1)
110
            {
111
                for (int j = 0; j < w; j++)
112
                {
113
                    const float* ptr_slope = (const float*)slope_data + i * elempack;
114
                    int n = elempack;
115

116
                    while (n > 0)
117
                    {
118
                        size_t vl = vsetvl_e32m8(n);
119
                        vfloat32m8_t _p = vle32_v_f32m8(ptr, vl);
120
                        vfloat32m8_t _slope = vle32_v_f32m8(ptr_slope, vl);
121

122
                        vbool4_t _lower = vmflt_vf_f32m8_b4(_p, .0f, vl);
123
                        _p = vfmul_vv_f32m8_m(_lower, _p, /*op1*/ _p, _slope, vl);
124
                        vse32_v_f32m8(ptr, _p, vl);
125

126
                        ptr += vl;
127
                        ptr_slope += vl;
128
                        n -= vl;
129
                    }
130
                }
131
            }
132
            else
133
            {
134
                float slope = slope_data[0];
135
                int n = w * elempack;
136
                while (n > 0)
137
                {
138
                    size_t vl = vsetvl_e32m8(n);
139
                    vfloat32m8_t _p = vle32_v_f32m8(ptr, vl);
140
                    vbool4_t _lower = vmflt_vf_f32m8_b4(_p, .0f, vl);
141

142
                    _p = vfmul_vf_f32m8_m(_lower, _p, /*op1*/ _p, slope, vl);
143
                    vse32_v_f32m8(ptr, _p, vl);
144

145
                    ptr += vl;
146
                    n -= vl;
147
                }
148
            }
149
        }
150
    }
151

152
    if (dims == 3)
153
    {
154
        int w = bottom_top_blob.w;
155
        int h = bottom_top_blob.h;
156
        int channels = bottom_top_blob.c;
157
        int size = w * h;
158

159
        #pragma omp parallel for num_threads(opt.num_threads)
160
        for (int q = 0; q < channels; q++)
161
        {
162
            float* ptr = bottom_top_blob.channel(q);
163
            int n = size * elempack;
164

165
            if (num_slope > 1 && elempack != 1)
166
            {
167
                while (n > 0)
168
                {
169
                    int n1 = elempack;
170
                    const float* slope_ptr = (const float*)slope_data + q * elempack;
171
                    while (n1 > 0)
172
                    {
173
                        size_t vl = vsetvl_e32m8(n1);
174
                        vfloat32m8_t _p = vle32_v_f32m8(ptr, vl);
175
                        vfloat32m8_t _slope = vle32_v_f32m8(slope_ptr, vl);
176

177
                        vbool4_t _lower = vmflt_vf_f32m8_b4(_p, .0f, vl);
178
                        _p = vfmul_vv_f32m8_m(_lower, _p, /*op1*/ _p, _slope, vl);
179
                        vse32_v_f32m8(ptr, _p, vl);
180

181
                        ptr += vl;
182
                        slope_ptr += vl;
183
                        n1 -= vl;
184
                    }
185
                    n -= elempack;
186
                }
187
            }
188
            else
189
            {
190
                // num_slope == 1 or elempack ==1
191
                float slope = num_slope > 1 ? slope_data[q] : slope_data[0];
192
                while (n > 0)
193
                {
194
                    size_t vl = vsetvl_e32m8(n);
195
                    vfloat32m8_t _p = vle32_v_f32m8(ptr, vl);
196

197
                    vbool4_t _lower = vmflt_vf_f32m8_b4(_p, .0f, vl);
198
                    _p = vfmul_vf_f32m8_m(_lower, _p, /*op1*/ _p, slope, vl);
199
                    vse32_v_f32m8(ptr, _p, vl);
200

201
                    ptr += vl;
202
                    n -= vl;
203
                }
204
            }
205
        }
206
    }
207

208
#else
209
    if (dims == 1)
210
    {
211
        int w = bottom_top_blob.w;
212

213
        float* ptr = bottom_top_blob;
214

215
        if (num_slope > 1)
216
        {
217
            #pragma omp parallel for num_threads(opt.num_threads)
218
            for (int i = 0; i < w; i++)
219
            {
220
                if (ptr[i] < 0)
221
                    ptr[i] *= slope_data[i];
222
            }
223
        }
224
        else
225
        {
226
            float slope = slope_data[0];
227

228
            #pragma omp parallel for num_threads(opt.num_threads)
229
            for (int i = 0; i < w; i++)
230
            {
231
                if (ptr[i] < 0)
232
                    ptr[i] *= slope;
233
            }
234
        }
235
    }
236

237
    if (dims == 2)
238
    {
239
        int w = bottom_top_blob.w;
240
        int h = bottom_top_blob.h;
241

242
        #pragma omp parallel for num_threads(opt.num_threads)
243
        for (int i = 0; i < h; i++)
244
        {
245
            float* ptr = bottom_top_blob.row(i);
246
            float slope = num_slope > 1 ? slope_data[i] : slope_data[0];
247

248
            for (int j = 0; j < w; j++)
249
            {
250
                if (ptr[j] < 0)
251
                    ptr[j] *= slope;
252
            }
253
        }
254
    }
255

256
    if (dims == 3)
257
    {
258
        int w = bottom_top_blob.w;
259
        int h = bottom_top_blob.h;
260
        int channels = bottom_top_blob.c;
261
        int size = w * h;
262

263
        #pragma omp parallel for num_threads(opt.num_threads)
264
        for (int q = 0; q < channels; q++)
265
        {
266
            float* ptr = bottom_top_blob.channel(q);
267
            float slope = num_slope > 1 ? slope_data[q] : slope_data[0];
268

269
            for (int i = 0; i < size; i++)
270
            {
271
                if (ptr[i] < 0)
272
                    ptr[i] *= slope;
273
            }
274
        }
275
    }
276

277
#endif
278

279
    return 0;
280
}
281

282
#if __riscv_vector && __riscv_zfh
283
//fp16s(a)
284
//hint: slope always store as fp32
285

286
int PReLU_riscv::forward_inplace_fp16s(Mat& bottom_top_blob, const Option& opt) const
287
{
288
    int w = bottom_top_blob.w;
289
    int h = bottom_top_blob.h;
290
    int size = w * h;
291
    int elempack = bottom_top_blob.elempack;
292
    int dims = bottom_top_blob.dims;
293

294
    if (dims == 1)
295
    {
296
        int w = bottom_top_blob.w;
297
        __fp16* ptr = bottom_top_blob;
298
        const float* ptr_slope = slope_data;
299
        if (num_slope > 1)
300
        {
301
            int n = w * elempack;
302

303
            // #pragma omp parallel for num_threads(opt.num_threads)
304
            while (n > 0)
305
            {
306
                size_t vl = vsetvl_e16m4(n);
307

308
                vfloat32m8_t _p = vfwcvt_f_f_v_f32m8(vle16_v_f16m4(ptr, vl), vl);
309
                vfloat32m8_t _slope = vle32_v_f32m8(ptr_slope, vl);
310
                vbool4_t _lower = vmflt_vf_f32m8_b4(_p, .0f, vl);
311
                _p = vfmul_vv_f32m8_m(_lower, _p, /*op1*/ _p, _slope, vl);
312

313
                vse16_v_f16m4(ptr, vfncvt_f_f_w_f16m4(_p, vl), vl);
314
                ptr += vl;
315
                ptr_slope += vl;
316
                n -= vl;
317
            }
318
        }
319
        else
320
        {
321
            float slope = slope_data[0];
322

323
            int n = w * elempack;
324
            // #pragma omp parallel for num_threads(opt.num_threads)
325
            while (n > 0)
326
            {
327
                size_t vl = vsetvl_e16m4(n);
328
                vfloat32m8_t _p = vfwcvt_f_f_v_f32m8(vle16_v_f16m4(ptr, vl), vl);
329
                vbool4_t _lower = vmflt_vf_f32m8_b4(_p, .0f, vl);
330

331
                _p = vfmul_vf_f32m8_m(_lower, _p, /*op1*/ _p, slope, vl);
332
                vse16_v_f16m4(ptr, vfncvt_f_f_w_f16m4(_p, vl), vl);
333

334
                ptr += vl;
335
                n -= vl;
336
            }
337
        }
338
    }
339

340
    if (dims == 2)
341
    {
342
        int w = bottom_top_blob.w;
343
        int h = bottom_top_blob.h;
344

345
        #pragma omp parallel for num_threads(opt.num_threads)
346
        for (int i = 0; i < h; i++)
347
        {
348
            __fp16* ptr = bottom_top_blob.row<__fp16>(i);
349
            if (num_slope > 1)
350
            {
351
                for (int j = 0; j < w; j++)
352
                {
353
                    const float* ptr_slope = (const float*)slope_data + i * elempack;
354
                    int n = elempack;
355

356
                    while (n > 0)
357
                    {
358
                        size_t vl = vsetvl_e16m4(n);
359
                        vfloat32m8_t _p = vfwcvt_f_f_v_f32m8(vle16_v_f16m4(ptr, vl), vl);
360
                        vfloat32m8_t _slope = vle32_v_f32m8(ptr_slope, vl);
361

362
                        vbool4_t _lower = vmflt_vf_f32m8_b4(_p, .0f, vl);
363
                        _p = vfmul_vv_f32m8_m(_lower, _p, /*op1*/ _p, _slope, vl);
364
                        vse16_v_f16m4(ptr, vfncvt_f_f_w_f16m4(_p, vl), vl);
365

366
                        ptr += vl;
367
                        ptr_slope += vl;
368
                        n -= vl;
369
                    }
370
                }
371
            }
372
            else
373
            {
374
                float slope = slope_data[0];
375
                int n = w * elempack;
376
                while (n > 0)
377
                {
378
                    size_t vl = vsetvl_e16m4(n);
379
                    vfloat32m8_t _p = vfwcvt_f_f_v_f32m8(vle16_v_f16m4(ptr, vl), vl);
380
                    vbool4_t _lower = vmflt_vf_f32m8_b4(_p, .0f, vl);
381

382
                    _p = vfmul_vf_f32m8_m(_lower, _p, /*op1*/ _p, slope, vl);
383
                    vse16_v_f16m4(ptr, vfncvt_f_f_w_f16m4(_p, vl), vl);
384

385
                    ptr += vl;
386
                    n -= vl;
387
                }
388
            }
389
        }
390
    }
391

392
    if (dims == 3)
393
    {
394
        int w = bottom_top_blob.w;
395
        int h = bottom_top_blob.h;
396
        int channels = bottom_top_blob.c;
397
        int size = w * h;
398

399
        #pragma omp parallel for num_threads(opt.num_threads)
400
        for (int q = 0; q < channels; q++)
401
        {
402
            __fp16* ptr = bottom_top_blob.channel(q);
403
            int n = size * elempack;
404

405
            if (num_slope > 1 && elempack != 1)
406
            {
407
                while (n > 0)
408
                {
409
                    int n1 = elempack;
410
                    const float* slope_ptr = (const float*)slope_data + q * elempack;
411
                    while (n1 > 0)
412
                    {
413
                        size_t vl = vsetvl_e16m4(n1);
414
                        vfloat32m8_t _p = vfwcvt_f_f_v_f32m8(vle16_v_f16m4(ptr, vl), vl);
415
                        vfloat32m8_t _slope = vle32_v_f32m8(slope_ptr, vl);
416

417
                        vbool4_t _lower = vmflt_vf_f32m8_b4(_p, .0f, vl);
418
                        _p = vfmul_vv_f32m8_m(_lower, _p, /*op1*/ _p, _slope, vl);
419
                        vse16_v_f16m4(ptr, vfncvt_f_f_w_f16m4(_p, vl), vl);
420

421
                        ptr += vl;
422
                        slope_ptr += vl;
423
                        n1 -= vl;
424
                    }
425
                    n -= elempack;
426
                }
427
            }
428
            else
429
            {
430
                // num_slope == 1 or elempack ==1
431
                float slope = num_slope > 1 ? slope_data[q] : slope_data[0];
432
                while (n > 0)
433
                {
434
                    size_t vl = vsetvl_e16m4(n);
435
                    vfloat32m8_t _p = vfwcvt_f_f_v_f32m8(vle16_v_f16m4(ptr, vl), vl);
436

437
                    vbool4_t _lower = vmflt_vf_f32m8_b4(_p, .0f, vl);
438
                    _p = vfmul_vf_f32m8_m(_lower, _p, /*op1*/ _p, slope, vl);
439
                    vse16_v_f16m4(ptr, vfncvt_f_f_w_f16m4(_p, vl), vl);
440

441
                    ptr += vl;
442
                    n -= vl;
443
                }
444
            }
445
        }
446
    }
447

448
    return 0;
449
}
450

451
int PReLU_riscv::forward_inplace_fp16sa(Mat& bottom_top_blob, const Option& opt) const
452
{
453
    int w = bottom_top_blob.w;
454
    int h = bottom_top_blob.h;
455
    int size = w * h;
456
    int elempack = bottom_top_blob.elempack;
457
    int dims = bottom_top_blob.dims;
458

459
    if (dims == 1)
460
    {
461
        int w = bottom_top_blob.w;
462
        __fp16* ptr = bottom_top_blob;
463
        const float* ptr_slope = slope_data;
464
        if (num_slope > 1)
465
        {
466
            int n = w * elempack;
467

468
            // #pragma omp parallel for num_threads(opt.num_threads)
469
            while (n > 0)
470
            {
471
                size_t vl = vsetvl_e16m4(n);
472
                vfloat16m4_t _p = vle16_v_f16m4(ptr, vl);
473
                vfloat16m4_t _slope = vfncvt_f_f_w_f16m4(vle32_v_f32m8(ptr_slope, vl), vl);
474
                vbool4_t _lower = vmflt_vf_f16m4_b4(_p, .0f, vl);
475

476
                _p = vfmul_vv_f16m4_m(_lower, _p, /*op1*/ _p, _slope, vl);
477
                vse16_v_f16m4(ptr, _p, vl);
478

479
                ptr += vl;
480
                ptr_slope += vl;
481
                n -= vl;
482
            }
483
        }
484
        else
485
        {
486
            __fp16 slope = slope_data[0];
487

488
            int n = w * elempack;
489
            // #pragma omp parallel for num_threads(opt.num_threads)
490
            while (n > 0)
491
            {
492
                size_t vl = vsetvl_e16m8(n);
493
                vfloat16m8_t _p = vle16_v_f16m8(ptr, vl);
494
                vbool2_t _lower = vmflt_vf_f16m8_b2(_p, .0f, vl);
495

496
                _p = vfmul_vf_f16m8_m(_lower, _p, /*op1*/ _p, slope, vl);
497
                vse16_v_f16m8(ptr, _p, vl);
498

499
                ptr += vl;
500
                n -= vl;
501
            }
502
        }
503
    }
504

505
    if (dims == 2)
506
    {
507
        int w = bottom_top_blob.w;
508
        int h = bottom_top_blob.h;
509

510
        #pragma omp parallel for num_threads(opt.num_threads)
511
        for (int i = 0; i < h; i++)
512
        {
513
            __fp16* ptr = bottom_top_blob.row<__fp16>(i);
514
            if (num_slope > 1)
515
            {
516
                for (int j = 0; j < w; j++)
517
                {
518
                    const float* ptr_slope = (const float*)slope_data + i * elempack;
519
                    int n = elempack;
520

521
                    while (n > 0)
522
                    {
523
                        size_t vl = vsetvl_e16m4(n);
524
                        vfloat16m4_t _p = vle16_v_f16m4(ptr, vl);
525
                        vfloat16m4_t _slope = vfncvt_f_f_w_f16m4(vle32_v_f32m8(ptr_slope, vl), vl);
526

527
                        vbool4_t _lower = vmflt_vf_f16m4_b4(_p, .0f, vl);
528
                        _p = vfmul_vv_f16m4_m(_lower, _p, /*op1*/ _p, _slope, vl);
529
                        vse16_v_f16m4(ptr, _p, vl);
530

531
                        ptr += vl;
532
                        ptr_slope += vl;
533
                        n -= vl;
534
                    }
535
                }
536
            }
537
            else
538
            {
539
                __fp16 slope = slope_data[0];
540
                int n = w * elempack;
541
                while (n > 0)
542
                {
543
                    size_t vl = vsetvl_e16m8(n);
544
                    vfloat16m8_t _p = vle16_v_f16m8(ptr, vl);
545
                    vbool2_t _lower = vmflt_vf_f16m8_b2(_p, .0f, vl);
546

547
                    _p = vfmul_vf_f16m8_m(_lower, _p, /*op1*/ _p, slope, vl);
548
                    vse16_v_f16m8(ptr, _p, vl);
549

550
                    ptr += vl;
551
                    n -= vl;
552
                }
553
            }
554
        }
555
    }
556

557
    if (dims == 3)
558
    {
559
        int w = bottom_top_blob.w;
560
        int h = bottom_top_blob.h;
561
        int channels = bottom_top_blob.c;
562
        int size = w * h;
563

564
        #pragma omp parallel for num_threads(opt.num_threads)
565
        for (int q = 0; q < channels; q++)
566
        {
567
            __fp16* ptr = bottom_top_blob.channel(q);
568
            int n = size * elempack;
569

570
            if (num_slope > 1 && elempack != 1)
571
            {
572
                while (n > 0)
573
                {
574
                    int n1 = elempack;
575
                    const float* slope_ptr = (const float*)slope_data + q * elempack;
576
                    while (n1 > 0)
577
                    {
578
                        size_t vl = vsetvl_e16m4(n1);
579
                        vfloat16m4_t _p = vle16_v_f16m4(ptr, vl);
580
                        vfloat16m4_t _slope = vfncvt_f_f_w_f16m4(vle32_v_f32m8(slope_ptr, vl), vl);
581

582
                        vbool4_t _lower = vmflt_vf_f16m4_b4(_p, .0f, vl);
583
                        _p = vfmul_vv_f16m4_m(_lower, _p, /*op1*/ _p, _slope, vl);
584
                        vse16_v_f16m4(ptr, _p, vl);
585

586
                        ptr += vl;
587
                        slope_ptr += vl;
588
                        n1 -= vl;
589
                    }
590
                    n -= elempack;
591
                }
592
            }
593
            else
594
            {
595
                // num_slope == 1 or elempack ==1
596
                float slope = num_slope > 1 ? slope_data[q] : slope_data[0];
597
                while (n > 0)
598
                {
599
                    size_t vl = vsetvl_e16m8(n);
600
                    vfloat16m8_t _p = vle16_v_f16m8(ptr, vl);
601

602
                    vbool2_t _lower = vmflt_vf_f16m8_b2(_p, .0f, vl);
603
                    _p = vfmul_vf_f16m8_m(_lower, _p, /*op1*/ _p, (__fp16)slope, vl);
604
                    vse16_v_f16m8(ptr, _p, vl);
605

606
                    ptr += vl;
607
                    n -= vl;
608
                }
609
            }
610
        }
611
    }
612

613
    return 0;
614
}
615

616
#endif
617
} // namespace ncnn
618

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.