1
// yala is pleased to support the open source community by making ncnn available.
4
// Copyright (C) 2022 yala <zhaojunchao@loongson.cn>;<junchao82@qq.com>. All rights reserved.
5
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
6
// in compliance with the License. You may obtain a copy of the License at
8
// https://opensource.org/licenses/BSD-3-Clause
10
// Unless required by applicable law or agreed to in writing, software distributed
11
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
12
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
13
// specific language governing permissions and limitations under the License.
15
#include "prelu_loongarch.h"
19
#endif // __loongarch_sx
21
#include "loongarch_usability.h"
25
PReLU_loongarch::PReLU_loongarch()
28
support_packing = true;
29
#endif // __loongarch_sx
32
int PReLU_loongarch::forward_inplace(Mat& bottom_top_blob, const Option& opt) const
34
int dims = bottom_top_blob.dims;
35
int elempack = bottom_top_blob.elempack;
39
int w = bottom_top_blob.w * elempack;
43
int remain_w_start = nn_w * 4;
45
int remain_w_start = 0;
46
#endif // __loongarch_sx
48
float* ptr = bottom_top_blob;
52
const float* slope = slope_data;
55
#pragma omp parallel for num_threads(opt.num_threads)
56
for (int i = 0; i < nn_w; i++)
58
float* ptr0 = ptr + i * 4;
60
__m128 _p = (__m128)__lsx_vld(ptr0, 0);
61
__m128 _zero = (__m128)__lsx_vreplgr2vr_w(0);
62
__m128 _slope = (__m128)__lsx_vld(slope + i * 4, 0);
63
__m128i _lemask = __lsx_vfcmp_cle_s(_p, _zero);
64
__m128 _ps = __lsx_vfmul_s(_p, _slope);
65
_p = (__m128)__lsx_vbitsel_v((__m128i)_p, (__m128i)_ps, (__m128i)_lemask);
66
__lsx_vst(_p, ptr0, 0);
68
#endif // __loongarch_sx
70
#pragma omp parallel for num_threads(opt.num_threads)
71
for (int i = remain_w_start; i < w; i++)
75
ptr[i] = v * slope[i];
80
const float slope = slope_data[0];
83
#pragma omp parallel for num_threads(opt.num_threads)
84
for (int i = 0; i < nn_w; i++)
86
float* ptr0 = ptr + i * 4;
88
__m128 _p = (__m128)__lsx_vld(ptr0, 0);
89
__m128 _zero = (__m128)__lsx_vreplgr2vr_w(0);
90
__m128 _slope = (__m128)__lsx_vreplfr2vr_s(slope);
91
__m128i _lemask = __lsx_vfcmp_cle_s(_p, _zero);
92
__m128 _ps = __lsx_vfmul_s(_p, _slope);
93
_p = (__m128)__lsx_vbitsel_v((__m128i)_p, (__m128i)_ps, (__m128i)_lemask);
94
__lsx_vst(_p, ptr0, 0);
96
#endif // __loongarch_sx
98
#pragma omp parallel for num_threads(opt.num_threads)
99
for (int i = remain_w_start; i < w; i++)
110
int w = bottom_top_blob.w * elempack;
111
int h = bottom_top_blob.h;
113
#pragma omp parallel for num_threads(opt.num_threads)
114
for (int i = 0; i < h; i++)
116
float* ptr = bottom_top_blob.row(i);
118
const float slope = num_slope > 1 ? slope_data[i] : slope_data[0];
122
__m128 _zero = (__m128)__lsx_vreplgr2vr_w(0);
123
__m128 _slope = (elempack == 4 && num_slope > 1) ? (__m128)__lsx_vld((const float*)slope_data + i * 4, 0) : (__m128)__lsx_vreplfr2vr_s(slope);
125
for (; j + 3 < w; j += 4)
127
__builtin_prefetch(ptr + 16);
128
__m128 _p = (__m128)__lsx_vld(ptr, 0);
129
__m128i _lemask = __lsx_vfcmp_cle_s(_p, _zero);
130
__m128 _ps = __lsx_vfmul_s(_p, _slope);
131
_p = (__m128)__lsx_vbitsel_v((__m128i)_p, (__m128i)_ps, (__m128i)_lemask);
132
__lsx_vst(_p, ptr, 0);
136
#endif // __loongarch_sx
150
int w = bottom_top_blob.w;
151
int h = bottom_top_blob.h;
152
int channels = bottom_top_blob.c;
153
int size = w * h * elempack;
155
const float* slope_data_ptr = slope_data;
157
#pragma omp parallel for num_threads(opt.num_threads)
158
for (int q = 0; q < channels; q++)
160
float* ptr = bottom_top_blob.channel(q);
161
float slope = num_slope > 1 ? slope_data_ptr[q] : slope_data_ptr[0];
165
__m128 _zero = (__m128)__lsx_vreplgr2vr_w(0);
166
__m128 _slope = (elempack == 4 && num_slope > 1) ? (__m128)__lsx_vld((const float*)slope_data + q * 4, 0) : (__m128)__lsx_vreplfr2vr_s(slope);
168
for (; i + 3 < size; i += 4)
170
__builtin_prefetch(ptr + 16);
171
__m128 _p = (__m128)__lsx_vld(ptr, 0);
172
__m128i _lemask = __lsx_vfcmp_cle_s(_p, _zero);
173
__m128 _ps = __lsx_vfmul_s(_p, _slope);
174
_p = (__m128)__lsx_vbitsel_v((__m128i)_p, (__m128i)_ps, (__m128i)_lemask);
175
__lsx_vst(_p, ptr, 0);
179
#endif // __loongarch_sx
180
for (; i < size; i++)