ncnn

Форк
0
/
prelu_loongarch.cpp 
193 строки · 6.0 Кб
1
// yala is pleased to support the open source community by making ncnn available.
2
//
3
//
4
// Copyright (C) 2022 yala <zhaojunchao@loongson.cn>;<junchao82@qq.com>. All rights reserved.
5
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
6
// in compliance with the License. You may obtain a copy of the License at
7
//
8
// https://opensource.org/licenses/BSD-3-Clause
9
//
10
// Unless required by applicable law or agreed to in writing, software distributed
11
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
12
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
13
// specific language governing permissions and limitations under the License.
14

15
#include "prelu_loongarch.h"
16

17
#if __loongarch_sx
18
#include <lsxintrin.h>
19
#endif // __loongarch_sx
20

21
#include "loongarch_usability.h"
22

23
namespace ncnn {
24

25
PReLU_loongarch::PReLU_loongarch()
26
{
27
#if __loongarch_sx
28
    support_packing = true;
29
#endif // __loongarch_sx
30
}
31

32
int PReLU_loongarch::forward_inplace(Mat& bottom_top_blob, const Option& opt) const
33
{
34
    int dims = bottom_top_blob.dims;
35
    int elempack = bottom_top_blob.elempack;
36

37
    if (dims == 1)
38
    {
39
        int w = bottom_top_blob.w * elempack;
40

41
#if __loongarch_sx
42
        int nn_w = w / 4;
43
        int remain_w_start = nn_w * 4;
44
#else
45
        int remain_w_start = 0;
46
#endif // __loongarch_sx
47

48
        float* ptr = bottom_top_blob;
49

50
        if (num_slope > 1)
51
        {
52
            const float* slope = slope_data;
53

54
#if __loongarch_sx
55
            #pragma omp parallel for num_threads(opt.num_threads)
56
            for (int i = 0; i < nn_w; i++)
57
            {
58
                float* ptr0 = ptr + i * 4;
59

60
                __m128 _p = (__m128)__lsx_vld(ptr0, 0);
61
                __m128 _zero = (__m128)__lsx_vreplgr2vr_w(0);
62
                __m128 _slope = (__m128)__lsx_vld(slope + i * 4, 0);
63
                __m128i _lemask = __lsx_vfcmp_cle_s(_p, _zero);
64
                __m128 _ps = __lsx_vfmul_s(_p, _slope);
65
                _p = (__m128)__lsx_vbitsel_v((__m128i)_p, (__m128i)_ps, (__m128i)_lemask);
66
                __lsx_vst(_p, ptr0, 0);
67
            }
68
#endif // __loongarch_sx
69

70
            #pragma omp parallel for num_threads(opt.num_threads)
71
            for (int i = remain_w_start; i < w; i++)
72
            {
73
                float v = ptr[i];
74
                if (v < 0.f)
75
                    ptr[i] = v * slope[i];
76
            }
77
        }
78
        else
79
        {
80
            const float slope = slope_data[0];
81

82
#if __loongarch_sx
83
            #pragma omp parallel for num_threads(opt.num_threads)
84
            for (int i = 0; i < nn_w; i++)
85
            {
86
                float* ptr0 = ptr + i * 4;
87

88
                __m128 _p = (__m128)__lsx_vld(ptr0, 0);
89
                __m128 _zero = (__m128)__lsx_vreplgr2vr_w(0);
90
                __m128 _slope = (__m128)__lsx_vreplfr2vr_s(slope);
91
                __m128i _lemask = __lsx_vfcmp_cle_s(_p, _zero);
92
                __m128 _ps = __lsx_vfmul_s(_p, _slope);
93
                _p = (__m128)__lsx_vbitsel_v((__m128i)_p, (__m128i)_ps, (__m128i)_lemask);
94
                __lsx_vst(_p, ptr0, 0);
95
            }
96
#endif // __loongarch_sx
97

98
            #pragma omp parallel for num_threads(opt.num_threads)
99
            for (int i = remain_w_start; i < w; i++)
100
            {
101
                float v = ptr[i];
102
                if (v < 0.f)
103
                    ptr[i] = v * slope;
104
            }
105
        }
106
    }
107

108
    if (dims == 2)
109
    {
110
        int w = bottom_top_blob.w * elempack;
111
        int h = bottom_top_blob.h;
112

113
        #pragma omp parallel for num_threads(opt.num_threads)
114
        for (int i = 0; i < h; i++)
115
        {
116
            float* ptr = bottom_top_blob.row(i);
117

118
            const float slope = num_slope > 1 ? slope_data[i] : slope_data[0];
119

120
            int j = 0;
121
#if __loongarch_sx
122
            __m128 _zero = (__m128)__lsx_vreplgr2vr_w(0);
123
            __m128 _slope = (elempack == 4 && num_slope > 1) ? (__m128)__lsx_vld((const float*)slope_data + i * 4, 0) : (__m128)__lsx_vreplfr2vr_s(slope);
124

125
            for (; j + 3 < w; j += 4)
126
            {
127
                __builtin_prefetch(ptr + 16);
128
                __m128 _p = (__m128)__lsx_vld(ptr, 0);
129
                __m128i _lemask = __lsx_vfcmp_cle_s(_p, _zero);
130
                __m128 _ps = __lsx_vfmul_s(_p, _slope);
131
                _p = (__m128)__lsx_vbitsel_v((__m128i)_p, (__m128i)_ps, (__m128i)_lemask);
132
                __lsx_vst(_p, ptr, 0);
133

134
                ptr += 4;
135
            }
136
#endif // __loongarch_sx
137
            for (; j < w; j++)
138
            {
139
                float v = *ptr;
140
                if (v < 0.f)
141
                    *ptr = v * slope;
142

143
                ptr++;
144
            }
145
        }
146
    }
147

148
    if (dims == 3)
149
    {
150
        int w = bottom_top_blob.w;
151
        int h = bottom_top_blob.h;
152
        int channels = bottom_top_blob.c;
153
        int size = w * h * elempack;
154

155
        const float* slope_data_ptr = slope_data;
156

157
        #pragma omp parallel for num_threads(opt.num_threads)
158
        for (int q = 0; q < channels; q++)
159
        {
160
            float* ptr = bottom_top_blob.channel(q);
161
            float slope = num_slope > 1 ? slope_data_ptr[q] : slope_data_ptr[0];
162

163
            int i = 0;
164
#if __loongarch_sx
165
            __m128 _zero = (__m128)__lsx_vreplgr2vr_w(0);
166
            __m128 _slope = (elempack == 4 && num_slope > 1) ? (__m128)__lsx_vld((const float*)slope_data + q * 4, 0) : (__m128)__lsx_vreplfr2vr_s(slope);
167

168
            for (; i + 3 < size; i += 4)
169
            {
170
                __builtin_prefetch(ptr + 16);
171
                __m128 _p = (__m128)__lsx_vld(ptr, 0);
172
                __m128i _lemask = __lsx_vfcmp_cle_s(_p, _zero);
173
                __m128 _ps = __lsx_vfmul_s(_p, _slope);
174
                _p = (__m128)__lsx_vbitsel_v((__m128i)_p, (__m128i)_ps, (__m128i)_lemask);
175
                __lsx_vst(_p, ptr, 0);
176

177
                ptr += 4;
178
            }
179
#endif // __loongarch_sx
180
            for (; i < size; i++)
181
            {
182
                if (*ptr < 0)
183
                    *ptr *= slope;
184

185
                ptr++;
186
            }
187
        }
188
    }
189

190
    return 0;
191
}
192

193
} // namespace ncnn
194

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.