ncnn

Форк
0
/
instancenorm_riscv.cpp 
512 строк · 17.8 Кб
1
// Xavier Hsinyuan is pleased to support the open source community by making ncnn available.
2
//
3
// Copyright (C) 2022 Xavier Hsinyuan <me@lstlx.com>. All rights reserved.
4
//
5
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
6
// in compliance with the License. You may obtain a copy of the License at
7
//
8
// https://opensource.org/licenses/BSD-3-Clause
9
//
10
// Unless required by applicable law or agreed to in writing, software distributed
11
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
12
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
13
// specific language governing permissions and limitations under the License.
14

15
#include "instancenorm_riscv.h"
16

17
#if __riscv_vector
18
#include <riscv_vector.h>
19
#endif // __riscv_vector
20

21
#include "riscv_usability.h"
22

23
namespace ncnn {
24
InstanceNorm_riscv::InstanceNorm_riscv()
25
{
26
#if __riscv_vector
27
    support_packing = true;
28
#if __riscv_zfh
29
    support_fp16_storage = true;
30
#endif
31
#endif // __riscv_vector
32
}
33

34
int InstanceNorm_riscv::forward_inplace(Mat& bottom_top_blob, const Option& opt) const
35
{
36
// x = (x - mean) / (sqrt(var + eps)) * gamma + beta
37
#if __riscv_vector
38
    int elembits = bottom_top_blob.elembits();
39
    if (opt.use_fp16_storage && elembits == 16)
40
    {
41
        if (opt.use_fp16_arithmetic)
42
            return forward_inplace_fp16sa(bottom_top_blob, opt);
43
        else
44
            return forward_inplace_fp16s(bottom_top_blob, opt);
45
    }
46
    int elempack = bottom_top_blob.elempack;
47
#endif // __riscv_vector
48
    int w = bottom_top_blob.w;
49
    int h = bottom_top_blob.h;
50
    int c = bottom_top_blob.c;
51
    int size = w * h;
52

53
    int dims = bottom_top_blob.dims;
54
#if __riscv_vector
55
    if (elempack == 1)
56
#endif // __riscv_vector
57
    {
58
#if __riscv_vector
59
        size = elempack * size;
60
#endif
61
        #pragma omp parallel for num_threads(opt.num_threads)
62
        for (int q = 0; q < c; q++)
63
        {
64
            float* ptr = bottom_top_blob.channel(q);
65

66
            // mean and var
67
            float sum = 0.f;
68
            float sqsum = 0.f;
69
#if __riscv_vector
70
            vfloat32m1_t _sum = vfmv_s_f_f32m1(vundefined_f32m1(), 0.f, vsetvlmax_e32m1());
71
            vfloat32m1_t _sqsum = vfmv_s_f_f32m1(vundefined_f32m1(), 0.f, vsetvlmax_e32m1());
72
            {
73
                int n = size;
74
                float* ptr_sum = ptr;
75
                while (n > 0)
76
                {
77
                    size_t vl = vsetvl_e32m8(n);
78
                    vfloat32m8_t _p = vle32_v_f32m8(ptr_sum, vl);
79
                    _sum = vfredusum_vs_f32m8_f32m1(_sum, _p, /* scalar */ _sum, vl);
80
                    // _sqsum = vfredosum_vs_f32m8_f32m1(_sqsum, vfmul_vv_f32m8(_p, _p, vl), /* scalar */ _sqsum, vl);
81
                    ptr_sum += vl;
82
                    n -= vl;
83
                }
84
            }
85
            sum = vfmv_f_s_f32m1_f32(_sum);
86
#else
87
            for (int i = 0; i < size; i++)
88
            {
89
                sum += ptr[i];
90
                //sqsum += ptr[i] * ptr[i];
91
            }
92
#endif // __riscv_vector
93
            float mean = sum / size;
94
#if __riscv_vecotr
95
            {
96
                int n = size;
97
                float* ptr_sqsum = ptr;
98
                while (n > 0)
99
                {
100
                    size_t vl = vsetvl_e32m8(n);
101
                    vfloat32m8_t _p = vle32_v_f32m8(ptr_sqsum, vl);
102
                    _p = vfsub_vf_f32m8(_p, mean, vl);
103
                    _sqsum = vfredosum_vs_f32m8_f32m1(_sqsum, vfmul_vv_f32m8(_p, _p, vl), /* scalar */ _sqsum, vl);
104
                    n -= vl;
105
                    ptr_sqsum += vl;
106
                }
107
            }
108
            sqsum = vfmv_f_s_f32m1_f32(_sqsum);
109
#else
110
            float tmp = 0.f;
111
            for (int i = 0; i < size; i++)
112
            {
113
                tmp = ptr[i] - mean;
114
                sqsum += tmp * tmp;
115
            }
116
#endif // __riscv_vector
117
            float var = sqsum / size;
118
            // the var maybe minus due to accuracy
119
            //float var = sqsum / size - mean * mean;
120

121
            float a;
122
            float b;
123
            if (affine)
124
            {
125
                float gamma = gamma_data[q];
126
                float beta = beta_data[q];
127

128
                a = gamma / (sqrtf(var + eps));
129
                b = -mean * a + beta;
130
            }
131
            else
132
            {
133
                a = 1.f / (sqrtf(var + eps));
134
                b = -mean * a;
135
            }
136
#if __riscv_vector
137
            {
138
                int n = size;
139
                float* ptr_store = ptr;
140
                while (n > 0)
141
                {
142
                    size_t vl = vsetvl_e32m8(n);
143
                    vfloat32m8_t _p = vle32_v_f32m8(ptr_store, vl);
144
                    _p = vfmul_vf_f32m8(_p, a, vl);
145
                    _p = vfadd_vf_f32m8(_p, b, vl);
146
                    vse32_v_f32m8(ptr_store, _p, vl);
147
                    n -= vl;
148
                    ptr_store += vl;
149
                }
150
            }
151
#else
152
            for (int i = 0; i < size; i++)
153
            {
154
                ptr[i] = ptr[i] * a + b;
155
            }
156
#endif // __riscv_vector
157
        }
158
        return 0;
159
    }
160

161
#if __riscv_vector
162
    const int packn = csrr_vlenb() / 4;
163
    if (elempack == packn)
164
    {
165
        const size_t vl = vsetvl_e32m1(packn);
166
        #pragma omp parallel for num_threads(opt.num_threads)
167
        for (int q = 0; q < c; q++)
168
        {
169
            float* ptr = bottom_top_blob.channel(q);
170
            vfloat32m1_t _sum = vfmv_v_f_f32m1(0.f, vl);
171
            vfloat32m1_t _sqsum = vfmv_v_f_f32m1(0.f, vl);
172

173
            for (int i = 0; i < size; i++)
174
            {
175
                vfloat32m1_t _p = vle32_v_f32m1(ptr + vl * i, vl);
176
                _sum = vfadd_vv_f32m1(_p, _sum, vl);
177
                // _sqsum = vfmadd_vv_f32m1(_p,_p,_sqsum,vl);
178
            }
179
            vfloat32m1_t _mean = vfdiv_vf_f32m1(_sum, size, vl);
180
            for (int i = 0; i < size; i++)
181
            {
182
                vfloat32m1_t _p = vle32_v_f32m1(ptr + vl * i, vl);
183
                _p = vfsub_vv_f32m1(_p, _mean, vl);
184
                _sqsum = vfmadd_vv_f32m1(_p, _p, _sqsum, vl);
185
            }
186
            vfloat32m1_t _var = vfdiv_vf_f32m1(_sqsum, size, vl);
187
            // the var maybe minus due to accuracy
188
            //float var = sqsum / size - mean * mean;
189

190
            vfloat32m1_t _a;
191
            vfloat32m1_t _b;
192
            if (affine)
193
            {
194
                vfloat32m1_t _gamma = vle32_v_f32m1((const float*)gamma_data + q * vl, vl);
195
                vfloat32m1_t _beta = vle32_v_f32m1((const float*)beta_data + q * vl, vl);
196
                _a = vfdiv_vv_f32m1(_gamma, vfsqrt_v_f32m1(vfadd_vf_f32m1(_var, eps, vl), vl), vl);
197
                _b = vfnmsub_vv_f32m1(_a, _mean, _beta, vl);
198
            }
199
            else
200
            {
201
                _a = vfrdiv_vf_f32m1(vfsqrt_v_f32m1(vfadd_vf_f32m1(_var, eps, vl), vl), 1.f, vl);
202
                _b = vfmul_vv_f32m1(_a, _mean, vl);
203
                _b = vfsgnjn_vv_f32m1(_b, _b, vl);
204
            }
205
            for (int i = 0; i < size; i++)
206
            {
207
                vfloat32m1_t _p = vle32_v_f32m1(ptr + i * vl, vl);
208
                _p = vfmadd_vv_f32m1(_p, _a, _b, vl);
209
                vse32_v_f32m1(ptr + i * vl, _p, vl);
210
            }
211
        }
212
        return 0;
213
    }
214
#endif // __riscv_vector
215
    return 0;
216
}
217

218
#if __riscv_vector && __riscv_zfh
219
int InstanceNorm_riscv::forward_inplace_fp16s(Mat& bottom_top_blob, const Option& opt) const
220
{
221
    // x = (x - mean) / (sqrt(var + eps)) * gamma + beta
222

223
    int elempack = bottom_top_blob.elempack;
224

225
    int w = bottom_top_blob.w;
226
    int h = bottom_top_blob.h;
227
    int c = bottom_top_blob.c;
228
    int size = w * h;
229

230
    int dims = bottom_top_blob.dims;
231
    if (elempack == 1)
232
    {
233
        size = elempack * size;
234
        #pragma omp parallel for num_threads(opt.num_threads)
235
        for (int q = 0; q < c; q++)
236
        {
237
            __fp16* ptr = bottom_top_blob.channel(q);
238

239
            // mean and var
240
            float sum = 0.f;
241
            float sqsum = 0.f;
242
            vfloat32m1_t _sum = vfmv_s_f_f32m1(vundefined_f32m1(), 0.f, vsetvlmax_e32m1());
243
            vfloat32m1_t _sqsum = vfmv_s_f_f32m1(vundefined_f32m1(), 0.f, vsetvlmax_e32m1());
244
            {
245
                int n = size;
246
                __fp16* ptr_sum = ptr;
247
                while (n > 0)
248
                {
249
                    size_t vl = vsetvl_e32m8(n);
250
                    vfloat32m8_t _p = vfwcvt_f_f_v_f32m8(vle16_v_f16m4(ptr_sum, vl), vl);
251
                    _sum = vfredusum_vs_f32m8_f32m1(_sum, _p, /* scalar */ _sum, vl);
252
                    // _sqsum = vfredosum_vs_f32m8_f32m1(_sqsum, vfmul_vv_f32m8(_p, _p, vl), /* scalar */ _sqsum, vl);
253
                    ptr_sum += vl;
254
                    n -= vl;
255
                }
256
            }
257
            sum = vfmv_f_s_f32m1_f32(_sum);
258
            float mean = sum / size;
259
            {
260
                int n = size;
261
                __fp16* ptr_sqsum = ptr;
262
                while (n > 0)
263
                {
264
                    size_t vl = vsetvl_e32m8(n);
265
                    vfloat32m8_t _p = vfwcvt_f_f_v_f32m8(vle16_v_f16m4(ptr_sqsum, vl), vl);
266
                    _p = vfsub_vf_f32m8(_p, mean, vl);
267
                    _sqsum = vfredosum_vs_f32m8_f32m1(_sqsum, vfmul_vv_f32m8(_p, _p, vl), /* scalar */ _sqsum, vl);
268
                    n -= vl;
269
                    ptr_sqsum += vl;
270
                }
271
            }
272
            sqsum = vfmv_f_s_f32m1_f32(_sqsum);
273
            float var = sqsum / size;
274
            // the var maybe minus due to accuracy
275
            //float var = sqsum / size - mean * mean;
276

277
            float a;
278
            float b;
279
            if (affine)
280
            {
281
                float gamma = gamma_data[q];
282
                float beta = beta_data[q];
283

284
                a = gamma / (sqrtf(var + eps));
285
                b = -mean * a + beta;
286
            }
287
            else
288
            {
289
                a = 1.f / (sqrtf(var + eps));
290
                b = -mean * a;
291
            }
292
            {
293
                int n = size;
294
                __fp16* ptr_store = ptr;
295
                while (n > 0)
296
                {
297
                    size_t vl = vsetvl_e32m8(n);
298
                    vfloat32m8_t _p = vfwcvt_f_f_v_f32m8(vle16_v_f16m4(ptr_store, vl), vl);
299
                    _p = vfmul_vf_f32m8(_p, a, vl);
300
                    _p = vfadd_vf_f32m8(_p, b, vl);
301
                    vse16_v_f16m4(ptr_store, vfncvt_f_f_w_f16m4(_p, vl), vl);
302
                    n -= vl;
303
                    ptr_store += vl;
304
                }
305
            }
306
        }
307
        return 0;
308
    }
309

310
    const int packn = csrr_vlenb() / 2;
311
    if (elempack == packn)
312
    {
313
        const size_t vl = vsetvl_e16m1(packn);
314
        #pragma omp parallel for num_threads(opt.num_threads)
315
        for (int q = 0; q < c; q++)
316
        {
317
            __fp16* ptr = bottom_top_blob.channel(q);
318
            vfloat32m2_t _sum = vfmv_v_f_f32m2(0.f, vl);
319
            vfloat32m2_t _sqsum = vfmv_v_f_f32m2(0.f, vl);
320

321
            for (int i = 0; i < size; i++)
322
            {
323
                vfloat32m2_t _p = vfwcvt_f_f_v_f32m2(vle16_v_f16m1(ptr + vl * i, vl), vl);
324
                _sum = vfadd_vv_f32m2(_p, _sum, vl);
325
                // _sqsum = vfmadd_vv_f32m2(_p,_p,_sqsum,vl);
326
            }
327
            vfloat32m2_t _mean = vfdiv_vf_f32m2(_sum, size, vl);
328
            for (int i = 0; i < size; i++)
329
            {
330
                vfloat32m2_t _p = vfwcvt_f_f_v_f32m2(vle16_v_f16m1(ptr + vl * i, vl), vl);
331
                _p = vfsub_vv_f32m2(_p, _mean, vl);
332
                _sqsum = vfmadd_vv_f32m2(_p, _p, _sqsum, vl);
333
            }
334
            vfloat32m2_t _var = vfdiv_vf_f32m2(_sqsum, size, vl);
335
            // the var maybe minus due to accuracy
336
            //float var = sqsum / size - mean * mean;
337

338
            vfloat32m2_t _a;
339
            vfloat32m2_t _b;
340
            if (affine)
341
            {
342
                vfloat32m2_t _gamma = vle32_v_f32m2((const float*)gamma_data + q * vl, vl);
343
                vfloat32m2_t _beta = vle32_v_f32m2((const float*)beta_data + q * vl, vl);
344
                _a = vfdiv_vv_f32m2(_gamma, vfsqrt_v_f32m2(vfadd_vf_f32m2(_var, eps, vl), vl), vl);
345
                _b = vfnmsub_vv_f32m2(_a, _mean, _beta, vl);
346
            }
347
            else
348
            {
349
                _a = vfrdiv_vf_f32m2(vfsqrt_v_f32m2(vfadd_vf_f32m2(_var, eps, vl), vl), 1.f, vl);
350
                _b = vfmul_vv_f32m2(_a, _mean, vl);
351
                _b = vfsgnjn_vv_f32m2(_b, _b, vl);
352
            }
353
            for (int i = 0; i < size; i++)
354
            {
355
                vfloat32m2_t _p = vfwcvt_f_f_v_f32m2(vle16_v_f16m1(ptr + i * vl, vl), vl);
356
                _p = vfmadd_vv_f32m2(_p, _a, _b, vl);
357
                vse16_v_f16m1(ptr + i * vl, vfncvt_f_f_w_f16m1(_p, vl), vl);
358
            }
359
        }
360
        return 0;
361
    }
362
    return 0;
363
}
364

365
int InstanceNorm_riscv::forward_inplace_fp16sa(Mat& bottom_top_blob, const Option& opt) const
366
{
367
    // x = (x - mean) / (sqrt(var + eps)) * gamma + beta
368
    int elempack = bottom_top_blob.elempack;
369

370
    int w = bottom_top_blob.w;
371
    int h = bottom_top_blob.h;
372
    int c = bottom_top_blob.c;
373
    int size = w * h;
374

375
    int dims = bottom_top_blob.dims;
376
    if (elempack == 1)
377
    {
378
        size = elempack * size;
379
        #pragma omp parallel for num_threads(opt.num_threads)
380
        for (int q = 0; q < c; q++)
381
        {
382
            __fp16* ptr = bottom_top_blob.channel(q);
383

384
            // mean and var
385
            __fp16 sum = 0.f;
386
            __fp16 sqsum = 0.f;
387
            vfloat16m1_t _sum = vfmv_s_f_f16m1(vundefined_f16m1(), 0.f, vsetvlmax_e32m1());
388
            vfloat16m1_t _sqsum = vfmv_s_f_f16m1(vundefined_f16m1(), 0.f, vsetvlmax_e32m1());
389
            {
390
                int n = size;
391
                __fp16* ptr_sum = ptr;
392
                while (n > 0)
393
                {
394
                    size_t vl = vsetvl_e16m8(n);
395
                    vfloat16m8_t _p = vle16_v_f16m8(ptr_sum, vl);
396
                    _sum = vfredusum_vs_f16m8_f16m1(_sum, _p, /* scalar */ _sum, vl);
397
                    // _sqsum = vfredosum_vs_f16m8_f16m1(_sqsum, vfmul_vv_f16m8(_p, _p, vl), /* scalar */ _sqsum, vl);
398
                    ptr_sum += vl;
399
                    n -= vl;
400
                }
401
            }
402
            sum = vfmv_f_s_f16m1_f16(_sum);
403
            __fp16 mean = sum / size;
404
            {
405
                int n = size;
406
                __fp16* ptr_sqsum = ptr;
407
                while (n > 0)
408
                {
409
                    size_t vl = vsetvl_e16m8(n);
410
                    vfloat16m8_t _p = vle16_v_f16m8(ptr_sqsum, vl);
411
                    _p = vfsub_vf_f16m8(_p, mean, vl);
412
                    _sqsum = vfredosum_vs_f16m8_f16m1(_sqsum, vfmul_vv_f16m8(_p, _p, vl), /* scalar */ _sqsum, vl);
413
                    n -= vl;
414
                    ptr_sqsum += vl;
415
                }
416
            }
417
            sqsum = vfmv_f_s_f16m1_f16(_sqsum);
418
            __fp16 var = sqsum / size;
419
            // the var maybe minus due to accuracy
420
            //float var = sqsum / size - mean * mean;
421

422
            __fp16 a;
423
            __fp16 b;
424
            if (affine)
425
            {
426
                float gamma = gamma_data[q];
427
                float beta = beta_data[q];
428

429
                a = static_cast<__fp16>(gamma / (sqrt(var + eps)));
430
                b = static_cast<__fp16>(-mean * a + beta);
431
            }
432
            else
433
            {
434
                a = static_cast<__fp16>(1.f / (sqrt(var + eps)));
435
                b = static_cast<__fp16>(-mean * a);
436
            }
437
            {
438
                int n = size;
439
                __fp16* ptr_store = ptr;
440
                while (n > 0)
441
                {
442
                    size_t vl = vsetvl_e32m8(n);
443
                    vfloat16m8_t _p = vle16_v_f16m8(ptr_store, vl);
444
                    _p = vfmul_vf_f16m8(_p, a, vl);
445
                    _p = vfadd_vf_f16m8(_p, b, vl);
446
                    vse16_v_f16m8(ptr_store, _p, vl);
447
                    n -= vl;
448
                    ptr_store += vl;
449
                }
450
            }
451
        }
452
        return 0;
453
    }
454

455
    const int packn = csrr_vlenb() / 2;
456
    if (elempack == packn)
457
    {
458
        const size_t vl = vsetvl_e16m1(packn);
459
        #pragma omp parallel for num_threads(opt.num_threads)
460
        for (int q = 0; q < c; q++)
461
        {
462
            __fp16* ptr = bottom_top_blob.channel(q);
463
            vfloat16m1_t _sum = vfmv_v_f_f16m1(0.f, vl);
464
            vfloat16m1_t _sqsum = vfmv_v_f_f16m1(0.f, vl);
465

466
            for (int i = 0; i < size; i++)
467
            {
468
                vfloat16m1_t _p = vle16_v_f16m1(ptr + vl * i, vl);
469
                _sum = vfadd_vv_f16m1(_p, _sum, vl);
470
                // _sqsum = vfmadd_vv_f16m1(_p,_p,_sqsum,vl);
471
            }
472
            vfloat16m1_t _mean = vfdiv_vf_f16m1(_sum, size, vl);
473
            for (int i = 0; i < size; i++)
474
            {
475
                vfloat16m1_t _p = vle16_v_f16m1(ptr + vl * i, vl);
476
                _p = vfsub_vv_f16m1(_p, _mean, vl);
477
                _sqsum = vfmadd_vv_f16m1(_p, _p, _sqsum, vl);
478
            }
479
            vfloat16m1_t _var = vfdiv_vf_f16m1(_sqsum, size, vl);
480
            // the var maybe minus due to accuracy
481
            //float var = sqsum / size - mean * mean;
482

483
            vfloat16m1_t _a;
484
            vfloat16m1_t _b;
485
            if (affine)
486
            {
487
                vfloat16m1_t _gamma = vfncvt_f_f_w_f16m1(vle32_v_f32m2((const float*)gamma_data + q * vl, vl), vl);
488
                vfloat16m1_t _beta = vfncvt_f_f_w_f16m1(vle32_v_f32m2((const float*)beta_data + q * vl, vl), vl);
489
                _a = vfdiv_vv_f16m1(_gamma, vfsqrt_v_f16m1(vfadd_vf_f16m1(_var, eps, vl), vl), vl);
490
                _b = vfnmsub_vv_f16m1(_a, _mean, _beta, vl);
491
            }
492
            else
493
            {
494
                _a = vfrdiv_vf_f16m1(vfsqrt_v_f16m1(vfadd_vf_f16m1(_var, eps, vl), vl), 1.f, vl);
495
                _b = vfmul_vv_f16m1(_a, _mean, vl);
496
                _b = vfsgnjn_vv_f16m1(_b, _b, vl);
497
            }
498
            for (int i = 0; i < size; i++)
499
            {
500
                vfloat16m1_t _p = vle16_v_f16m1(ptr + i * vl, vl);
501
                _p = vfmadd_vv_f16m1(_p, _a, _b, vl);
502
                vse16_v_f16m1(ptr + i * vl, _p, vl);
503
            }
504
        }
505
        return 0;
506
    }
507
    return 0;
508
}
509

510
#endif // __riscv_vector && __riscv_zfh
511

512
} // namespace ncnn

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.