1
// Xavier Hsinyuan is pleased to support the open source community by making ncnn available.
3
// Copyright (C) 2022 Xavier Hsinyuan <me@lstlx.com>. All rights reserved.
5
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
6
// in compliance with the License. You may obtain a copy of the License at
8
// https://opensource.org/licenses/BSD-3-Clause
10
// Unless required by applicable law or agreed to in writing, software distributed
11
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
12
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
13
// specific language governing permissions and limitations under the License.
15
#include "instancenorm_riscv.h"
18
#include <riscv_vector.h>
19
#endif // __riscv_vector
21
#include "riscv_usability.h"
24
InstanceNorm_riscv::InstanceNorm_riscv()
27
support_packing = true;
29
support_fp16_storage = true;
31
#endif // __riscv_vector
34
int InstanceNorm_riscv::forward_inplace(Mat& bottom_top_blob, const Option& opt) const
36
// x = (x - mean) / (sqrt(var + eps)) * gamma + beta
38
int elembits = bottom_top_blob.elembits();
39
if (opt.use_fp16_storage && elembits == 16)
41
if (opt.use_fp16_arithmetic)
42
return forward_inplace_fp16sa(bottom_top_blob, opt);
44
return forward_inplace_fp16s(bottom_top_blob, opt);
46
int elempack = bottom_top_blob.elempack;
47
#endif // __riscv_vector
48
int w = bottom_top_blob.w;
49
int h = bottom_top_blob.h;
50
int c = bottom_top_blob.c;
53
int dims = bottom_top_blob.dims;
56
#endif // __riscv_vector
59
size = elempack * size;
61
#pragma omp parallel for num_threads(opt.num_threads)
62
for (int q = 0; q < c; q++)
64
float* ptr = bottom_top_blob.channel(q);
70
vfloat32m1_t _sum = vfmv_s_f_f32m1(vundefined_f32m1(), 0.f, vsetvlmax_e32m1());
71
vfloat32m1_t _sqsum = vfmv_s_f_f32m1(vundefined_f32m1(), 0.f, vsetvlmax_e32m1());
77
size_t vl = vsetvl_e32m8(n);
78
vfloat32m8_t _p = vle32_v_f32m8(ptr_sum, vl);
79
_sum = vfredusum_vs_f32m8_f32m1(_sum, _p, /* scalar */ _sum, vl);
80
// _sqsum = vfredosum_vs_f32m8_f32m1(_sqsum, vfmul_vv_f32m8(_p, _p, vl), /* scalar */ _sqsum, vl);
85
sum = vfmv_f_s_f32m1_f32(_sum);
87
for (int i = 0; i < size; i++)
90
//sqsum += ptr[i] * ptr[i];
92
#endif // __riscv_vector
93
float mean = sum / size;
97
float* ptr_sqsum = ptr;
100
size_t vl = vsetvl_e32m8(n);
101
vfloat32m8_t _p = vle32_v_f32m8(ptr_sqsum, vl);
102
_p = vfsub_vf_f32m8(_p, mean, vl);
103
_sqsum = vfredosum_vs_f32m8_f32m1(_sqsum, vfmul_vv_f32m8(_p, _p, vl), /* scalar */ _sqsum, vl);
108
sqsum = vfmv_f_s_f32m1_f32(_sqsum);
111
for (int i = 0; i < size; i++)
116
#endif // __riscv_vector
117
float var = sqsum / size;
118
// the var maybe minus due to accuracy
119
//float var = sqsum / size - mean * mean;
125
float gamma = gamma_data[q];
126
float beta = beta_data[q];
128
a = gamma / (sqrtf(var + eps));
129
b = -mean * a + beta;
133
a = 1.f / (sqrtf(var + eps));
139
float* ptr_store = ptr;
142
size_t vl = vsetvl_e32m8(n);
143
vfloat32m8_t _p = vle32_v_f32m8(ptr_store, vl);
144
_p = vfmul_vf_f32m8(_p, a, vl);
145
_p = vfadd_vf_f32m8(_p, b, vl);
146
vse32_v_f32m8(ptr_store, _p, vl);
152
for (int i = 0; i < size; i++)
154
ptr[i] = ptr[i] * a + b;
156
#endif // __riscv_vector
162
const int packn = csrr_vlenb() / 4;
163
if (elempack == packn)
165
const size_t vl = vsetvl_e32m1(packn);
166
#pragma omp parallel for num_threads(opt.num_threads)
167
for (int q = 0; q < c; q++)
169
float* ptr = bottom_top_blob.channel(q);
170
vfloat32m1_t _sum = vfmv_v_f_f32m1(0.f, vl);
171
vfloat32m1_t _sqsum = vfmv_v_f_f32m1(0.f, vl);
173
for (int i = 0; i < size; i++)
175
vfloat32m1_t _p = vle32_v_f32m1(ptr + vl * i, vl);
176
_sum = vfadd_vv_f32m1(_p, _sum, vl);
177
// _sqsum = vfmadd_vv_f32m1(_p,_p,_sqsum,vl);
179
vfloat32m1_t _mean = vfdiv_vf_f32m1(_sum, size, vl);
180
for (int i = 0; i < size; i++)
182
vfloat32m1_t _p = vle32_v_f32m1(ptr + vl * i, vl);
183
_p = vfsub_vv_f32m1(_p, _mean, vl);
184
_sqsum = vfmadd_vv_f32m1(_p, _p, _sqsum, vl);
186
vfloat32m1_t _var = vfdiv_vf_f32m1(_sqsum, size, vl);
187
// the var maybe minus due to accuracy
188
//float var = sqsum / size - mean * mean;
194
vfloat32m1_t _gamma = vle32_v_f32m1((const float*)gamma_data + q * vl, vl);
195
vfloat32m1_t _beta = vle32_v_f32m1((const float*)beta_data + q * vl, vl);
196
_a = vfdiv_vv_f32m1(_gamma, vfsqrt_v_f32m1(vfadd_vf_f32m1(_var, eps, vl), vl), vl);
197
_b = vfnmsub_vv_f32m1(_a, _mean, _beta, vl);
201
_a = vfrdiv_vf_f32m1(vfsqrt_v_f32m1(vfadd_vf_f32m1(_var, eps, vl), vl), 1.f, vl);
202
_b = vfmul_vv_f32m1(_a, _mean, vl);
203
_b = vfsgnjn_vv_f32m1(_b, _b, vl);
205
for (int i = 0; i < size; i++)
207
vfloat32m1_t _p = vle32_v_f32m1(ptr + i * vl, vl);
208
_p = vfmadd_vv_f32m1(_p, _a, _b, vl);
209
vse32_v_f32m1(ptr + i * vl, _p, vl);
214
#endif // __riscv_vector
218
#if __riscv_vector && __riscv_zfh
219
int InstanceNorm_riscv::forward_inplace_fp16s(Mat& bottom_top_blob, const Option& opt) const
221
// x = (x - mean) / (sqrt(var + eps)) * gamma + beta
223
int elempack = bottom_top_blob.elempack;
225
int w = bottom_top_blob.w;
226
int h = bottom_top_blob.h;
227
int c = bottom_top_blob.c;
230
int dims = bottom_top_blob.dims;
233
size = elempack * size;
234
#pragma omp parallel for num_threads(opt.num_threads)
235
for (int q = 0; q < c; q++)
237
__fp16* ptr = bottom_top_blob.channel(q);
242
vfloat32m1_t _sum = vfmv_s_f_f32m1(vundefined_f32m1(), 0.f, vsetvlmax_e32m1());
243
vfloat32m1_t _sqsum = vfmv_s_f_f32m1(vundefined_f32m1(), 0.f, vsetvlmax_e32m1());
246
__fp16* ptr_sum = ptr;
249
size_t vl = vsetvl_e32m8(n);
250
vfloat32m8_t _p = vfwcvt_f_f_v_f32m8(vle16_v_f16m4(ptr_sum, vl), vl);
251
_sum = vfredusum_vs_f32m8_f32m1(_sum, _p, /* scalar */ _sum, vl);
252
// _sqsum = vfredosum_vs_f32m8_f32m1(_sqsum, vfmul_vv_f32m8(_p, _p, vl), /* scalar */ _sqsum, vl);
257
sum = vfmv_f_s_f32m1_f32(_sum);
258
float mean = sum / size;
261
__fp16* ptr_sqsum = ptr;
264
size_t vl = vsetvl_e32m8(n);
265
vfloat32m8_t _p = vfwcvt_f_f_v_f32m8(vle16_v_f16m4(ptr_sqsum, vl), vl);
266
_p = vfsub_vf_f32m8(_p, mean, vl);
267
_sqsum = vfredosum_vs_f32m8_f32m1(_sqsum, vfmul_vv_f32m8(_p, _p, vl), /* scalar */ _sqsum, vl);
272
sqsum = vfmv_f_s_f32m1_f32(_sqsum);
273
float var = sqsum / size;
274
// the var maybe minus due to accuracy
275
//float var = sqsum / size - mean * mean;
281
float gamma = gamma_data[q];
282
float beta = beta_data[q];
284
a = gamma / (sqrtf(var + eps));
285
b = -mean * a + beta;
289
a = 1.f / (sqrtf(var + eps));
294
__fp16* ptr_store = ptr;
297
size_t vl = vsetvl_e32m8(n);
298
vfloat32m8_t _p = vfwcvt_f_f_v_f32m8(vle16_v_f16m4(ptr_store, vl), vl);
299
_p = vfmul_vf_f32m8(_p, a, vl);
300
_p = vfadd_vf_f32m8(_p, b, vl);
301
vse16_v_f16m4(ptr_store, vfncvt_f_f_w_f16m4(_p, vl), vl);
310
const int packn = csrr_vlenb() / 2;
311
if (elempack == packn)
313
const size_t vl = vsetvl_e16m1(packn);
314
#pragma omp parallel for num_threads(opt.num_threads)
315
for (int q = 0; q < c; q++)
317
__fp16* ptr = bottom_top_blob.channel(q);
318
vfloat32m2_t _sum = vfmv_v_f_f32m2(0.f, vl);
319
vfloat32m2_t _sqsum = vfmv_v_f_f32m2(0.f, vl);
321
for (int i = 0; i < size; i++)
323
vfloat32m2_t _p = vfwcvt_f_f_v_f32m2(vle16_v_f16m1(ptr + vl * i, vl), vl);
324
_sum = vfadd_vv_f32m2(_p, _sum, vl);
325
// _sqsum = vfmadd_vv_f32m2(_p,_p,_sqsum,vl);
327
vfloat32m2_t _mean = vfdiv_vf_f32m2(_sum, size, vl);
328
for (int i = 0; i < size; i++)
330
vfloat32m2_t _p = vfwcvt_f_f_v_f32m2(vle16_v_f16m1(ptr + vl * i, vl), vl);
331
_p = vfsub_vv_f32m2(_p, _mean, vl);
332
_sqsum = vfmadd_vv_f32m2(_p, _p, _sqsum, vl);
334
vfloat32m2_t _var = vfdiv_vf_f32m2(_sqsum, size, vl);
335
// the var maybe minus due to accuracy
336
//float var = sqsum / size - mean * mean;
342
vfloat32m2_t _gamma = vle32_v_f32m2((const float*)gamma_data + q * vl, vl);
343
vfloat32m2_t _beta = vle32_v_f32m2((const float*)beta_data + q * vl, vl);
344
_a = vfdiv_vv_f32m2(_gamma, vfsqrt_v_f32m2(vfadd_vf_f32m2(_var, eps, vl), vl), vl);
345
_b = vfnmsub_vv_f32m2(_a, _mean, _beta, vl);
349
_a = vfrdiv_vf_f32m2(vfsqrt_v_f32m2(vfadd_vf_f32m2(_var, eps, vl), vl), 1.f, vl);
350
_b = vfmul_vv_f32m2(_a, _mean, vl);
351
_b = vfsgnjn_vv_f32m2(_b, _b, vl);
353
for (int i = 0; i < size; i++)
355
vfloat32m2_t _p = vfwcvt_f_f_v_f32m2(vle16_v_f16m1(ptr + i * vl, vl), vl);
356
_p = vfmadd_vv_f32m2(_p, _a, _b, vl);
357
vse16_v_f16m1(ptr + i * vl, vfncvt_f_f_w_f16m1(_p, vl), vl);
365
int InstanceNorm_riscv::forward_inplace_fp16sa(Mat& bottom_top_blob, const Option& opt) const
367
// x = (x - mean) / (sqrt(var + eps)) * gamma + beta
368
int elempack = bottom_top_blob.elempack;
370
int w = bottom_top_blob.w;
371
int h = bottom_top_blob.h;
372
int c = bottom_top_blob.c;
375
int dims = bottom_top_blob.dims;
378
size = elempack * size;
379
#pragma omp parallel for num_threads(opt.num_threads)
380
for (int q = 0; q < c; q++)
382
__fp16* ptr = bottom_top_blob.channel(q);
387
vfloat16m1_t _sum = vfmv_s_f_f16m1(vundefined_f16m1(), 0.f, vsetvlmax_e32m1());
388
vfloat16m1_t _sqsum = vfmv_s_f_f16m1(vundefined_f16m1(), 0.f, vsetvlmax_e32m1());
391
__fp16* ptr_sum = ptr;
394
size_t vl = vsetvl_e16m8(n);
395
vfloat16m8_t _p = vle16_v_f16m8(ptr_sum, vl);
396
_sum = vfredusum_vs_f16m8_f16m1(_sum, _p, /* scalar */ _sum, vl);
397
// _sqsum = vfredosum_vs_f16m8_f16m1(_sqsum, vfmul_vv_f16m8(_p, _p, vl), /* scalar */ _sqsum, vl);
402
sum = vfmv_f_s_f16m1_f16(_sum);
403
__fp16 mean = sum / size;
406
__fp16* ptr_sqsum = ptr;
409
size_t vl = vsetvl_e16m8(n);
410
vfloat16m8_t _p = vle16_v_f16m8(ptr_sqsum, vl);
411
_p = vfsub_vf_f16m8(_p, mean, vl);
412
_sqsum = vfredosum_vs_f16m8_f16m1(_sqsum, vfmul_vv_f16m8(_p, _p, vl), /* scalar */ _sqsum, vl);
417
sqsum = vfmv_f_s_f16m1_f16(_sqsum);
418
__fp16 var = sqsum / size;
419
// the var maybe minus due to accuracy
420
//float var = sqsum / size - mean * mean;
426
float gamma = gamma_data[q];
427
float beta = beta_data[q];
429
a = static_cast<__fp16>(gamma / (sqrt(var + eps)));
430
b = static_cast<__fp16>(-mean * a + beta);
434
a = static_cast<__fp16>(1.f / (sqrt(var + eps)));
435
b = static_cast<__fp16>(-mean * a);
439
__fp16* ptr_store = ptr;
442
size_t vl = vsetvl_e32m8(n);
443
vfloat16m8_t _p = vle16_v_f16m8(ptr_store, vl);
444
_p = vfmul_vf_f16m8(_p, a, vl);
445
_p = vfadd_vf_f16m8(_p, b, vl);
446
vse16_v_f16m8(ptr_store, _p, vl);
455
const int packn = csrr_vlenb() / 2;
456
if (elempack == packn)
458
const size_t vl = vsetvl_e16m1(packn);
459
#pragma omp parallel for num_threads(opt.num_threads)
460
for (int q = 0; q < c; q++)
462
__fp16* ptr = bottom_top_blob.channel(q);
463
vfloat16m1_t _sum = vfmv_v_f_f16m1(0.f, vl);
464
vfloat16m1_t _sqsum = vfmv_v_f_f16m1(0.f, vl);
466
for (int i = 0; i < size; i++)
468
vfloat16m1_t _p = vle16_v_f16m1(ptr + vl * i, vl);
469
_sum = vfadd_vv_f16m1(_p, _sum, vl);
470
// _sqsum = vfmadd_vv_f16m1(_p,_p,_sqsum,vl);
472
vfloat16m1_t _mean = vfdiv_vf_f16m1(_sum, size, vl);
473
for (int i = 0; i < size; i++)
475
vfloat16m1_t _p = vle16_v_f16m1(ptr + vl * i, vl);
476
_p = vfsub_vv_f16m1(_p, _mean, vl);
477
_sqsum = vfmadd_vv_f16m1(_p, _p, _sqsum, vl);
479
vfloat16m1_t _var = vfdiv_vf_f16m1(_sqsum, size, vl);
480
// the var maybe minus due to accuracy
481
//float var = sqsum / size - mean * mean;
487
vfloat16m1_t _gamma = vfncvt_f_f_w_f16m1(vle32_v_f32m2((const float*)gamma_data + q * vl, vl), vl);
488
vfloat16m1_t _beta = vfncvt_f_f_w_f16m1(vle32_v_f32m2((const float*)beta_data + q * vl, vl), vl);
489
_a = vfdiv_vv_f16m1(_gamma, vfsqrt_v_f16m1(vfadd_vf_f16m1(_var, eps, vl), vl), vl);
490
_b = vfnmsub_vv_f16m1(_a, _mean, _beta, vl);
494
_a = vfrdiv_vf_f16m1(vfsqrt_v_f16m1(vfadd_vf_f16m1(_var, eps, vl), vl), 1.f, vl);
495
_b = vfmul_vv_f16m1(_a, _mean, vl);
496
_b = vfsgnjn_vv_f16m1(_b, _b, vl);
498
for (int i = 0; i < size; i++)
500
vfloat16m1_t _p = vle16_v_f16m1(ptr + i * vl, vl);
501
_p = vfmadd_vv_f16m1(_p, _a, _b, vl);
502
vse16_v_f16m1(ptr + i * vl, _p, vl);
510
#endif // __riscv_vector && __riscv_zfh