ncnn

Форк
0
/
flatten_mips.cpp 
370 строк · 12.1 Кб
1
// Tencent is pleased to support the open source community by making ncnn available.
2
//
3
// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved.
4
//
5
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
6
// in compliance with the License. You may obtain a copy of the License at
7
//
8
// https://opensource.org/licenses/BSD-3-Clause
9
//
10
// Unless required by applicable law or agreed to in writing, software distributed
11
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
12
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
13
// specific language governing permissions and limitations under the License.
14

15
#include "flatten_mips.h"
16

17
#if __mips_msa
18
#include <msa.h>
19
#include "msa_mathfun.h"
20
#endif // __mips_msa
21

22
namespace ncnn {
23

24
Flatten_mips::Flatten_mips()
25
{
26
#if __mips_msa
27
    support_packing = true;
28
#endif // __mips_msa
29
}
30

31
int Flatten_mips::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const
32
{
33
    int elembits = bottom_blob.elembits();
34

35
    if (elembits == 8)
36
        return forward_int8(bottom_blob, top_blob, opt);
37

38
    int dims = bottom_blob.dims;
39

40
    if (dims == 1)
41
    {
42
        top_blob = bottom_blob;
43
        return 0;
44
    }
45

46
    int w = bottom_blob.w;
47
    int h = bottom_blob.h;
48
    int d = bottom_blob.d;
49
    int channels = bottom_blob.c;
50
    size_t elemsize = bottom_blob.elemsize;
51
    int elempack = bottom_blob.elempack;
52
    int size = w * h * d;
53

54
    int total = size * channels * elempack;
55

56
    int out_elempack = 1;
57
#if __mips_msa
58
    if (opt.use_packing_layout)
59
    {
60
        out_elempack = total % 4 == 0 ? 4 : 1;
61
    }
62
#endif
63
    size_t out_elemsize = elemsize / elempack * out_elempack;
64

65
    if (out_elempack == 1)
66
    {
67
        return Flatten::forward(bottom_blob, top_blob, opt);
68
    }
69

70
    if (dims == 2 && elempack == 1) // out_elempack == 4
71
    {
72
        top_blob = bottom_blob;
73
        top_blob.dims = 1;
74
        top_blob.w = total / out_elempack;
75
        top_blob.h = 1;
76
        top_blob.cstep = top_blob.w;
77
        top_blob.elemsize = out_elemsize;
78
        top_blob.elempack = out_elempack;
79
        return 0;
80
    }
81

82
    top_blob.create(total / out_elempack, out_elemsize, out_elempack, opt.blob_allocator);
83
    if (top_blob.empty())
84
        return -100;
85

86
    if (dims == 2)
87
    {
88
#if __mips_msa
89
        if (elempack == 4) // out_elempack == 4
90
        {
91
            #pragma omp parallel for num_threads(opt.num_threads)
92
            for (int i = 0; i < h; i++)
93
            {
94
                const float* ptr = bottom_blob.row(i);
95
                float* outptr0 = (float*)top_blob + w * i * 4;
96
                float* outptr1 = (float*)top_blob + w * (i * 4 + 1);
97
                float* outptr2 = (float*)top_blob + w * (i * 4 + 2);
98
                float* outptr3 = (float*)top_blob + w * (i * 4 + 3);
99

100
                int j = 0;
101
                for (; j + 3 < w; j += 4)
102
                {
103
                    // transpose 4x4
104
                    v4f32 _r0 = (v4f32)__msa_ld_w(ptr, 0);
105
                    v4f32 _r1 = (v4f32)__msa_ld_w(ptr + 4, 0);
106
                    v4f32 _r2 = (v4f32)__msa_ld_w(ptr + 4 * 2, 0);
107
                    v4f32 _r3 = (v4f32)__msa_ld_w(ptr + 4 * 3, 0);
108

109
                    v4i32 _r01r = __msa_ilvr_w((v4i32)_r1, (v4i32)_r0);
110
                    v4i32 _r01l = __msa_ilvl_w((v4i32)_r1, (v4i32)_r0);
111
                    v4i32 _r23r = __msa_ilvr_w((v4i32)_r3, (v4i32)_r2);
112
                    v4i32 _r23l = __msa_ilvl_w((v4i32)_r3, (v4i32)_r2);
113
                    v2i64 _r0123_0 = __msa_ilvr_d((v2i64)_r23r, (v2i64)_r01r);
114
                    v2i64 _r0123_1 = __msa_ilvl_d((v2i64)_r23r, (v2i64)_r01r);
115
                    v2i64 _r0123_2 = __msa_ilvr_d((v2i64)_r23l, (v2i64)_r01l);
116
                    v2i64 _r0123_3 = __msa_ilvl_d((v2i64)_r23l, (v2i64)_r01l);
117

118
                    __msa_st_w((v4i32)_r0123_0, outptr0, 0);
119
                    __msa_st_w((v4i32)_r0123_1, outptr1, 0);
120
                    __msa_st_w((v4i32)_r0123_2, outptr2, 0);
121
                    __msa_st_w((v4i32)_r0123_3, outptr3, 0);
122

123
                    ptr += 16;
124
                    outptr0 += 4;
125
                    outptr1 += 4;
126
                    outptr2 += 4;
127
                    outptr3 += 4;
128
                }
129
                for (; j < w; j++)
130
                {
131
                    *outptr0++ = ptr[0];
132
                    *outptr1++ = ptr[1];
133
                    *outptr2++ = ptr[2];
134
                    *outptr3++ = ptr[3];
135

136
                    ptr += 4;
137
                }
138
            }
139
        }
140
#endif // __mips_msa
141
    }
142

143
    if (dims == 3 || dims == 4)
144
    {
145
#if __mips_msa
146
        if (elempack == 4) // out_elempack == 4
147
        {
148
            #pragma omp parallel for num_threads(opt.num_threads)
149
            for (int q = 0; q < channels; q++)
150
            {
151
                const float* ptr = bottom_blob.channel(q);
152
                float* outptr0 = (float*)top_blob + size * q * 4;
153
                float* outptr1 = (float*)top_blob + size * (q * 4 + 1);
154
                float* outptr2 = (float*)top_blob + size * (q * 4 + 2);
155
                float* outptr3 = (float*)top_blob + size * (q * 4 + 3);
156

157
                int i = 0;
158
                for (; i + 3 < size; i += 4)
159
                {
160
                    // transpose 4x4
161
                    v4f32 _r0 = (v4f32)__msa_ld_w(ptr, 0);
162
                    v4f32 _r1 = (v4f32)__msa_ld_w(ptr + 4, 0);
163
                    v4f32 _r2 = (v4f32)__msa_ld_w(ptr + 4 * 2, 0);
164
                    v4f32 _r3 = (v4f32)__msa_ld_w(ptr + 4 * 3, 0);
165

166
                    v4i32 _r01r = __msa_ilvr_w((v4i32)_r1, (v4i32)_r0);
167
                    v4i32 _r01l = __msa_ilvl_w((v4i32)_r1, (v4i32)_r0);
168
                    v4i32 _r23r = __msa_ilvr_w((v4i32)_r3, (v4i32)_r2);
169
                    v4i32 _r23l = __msa_ilvl_w((v4i32)_r3, (v4i32)_r2);
170
                    v2i64 _r0123_0 = __msa_ilvr_d((v2i64)_r23r, (v2i64)_r01r);
171
                    v2i64 _r0123_1 = __msa_ilvl_d((v2i64)_r23r, (v2i64)_r01r);
172
                    v2i64 _r0123_2 = __msa_ilvr_d((v2i64)_r23l, (v2i64)_r01l);
173
                    v2i64 _r0123_3 = __msa_ilvl_d((v2i64)_r23l, (v2i64)_r01l);
174

175
                    __msa_st_w((v4i32)_r0123_0, outptr0, 0);
176
                    __msa_st_w((v4i32)_r0123_1, outptr1, 0);
177
                    __msa_st_w((v4i32)_r0123_2, outptr2, 0);
178
                    __msa_st_w((v4i32)_r0123_3, outptr3, 0);
179

180
                    ptr += 16;
181
                    outptr0 += 4;
182
                    outptr1 += 4;
183
                    outptr2 += 4;
184
                    outptr3 += 4;
185
                }
186
                for (; i < size; i++)
187
                {
188
                    *outptr0++ = ptr[0];
189
                    *outptr1++ = ptr[1];
190
                    *outptr2++ = ptr[2];
191
                    *outptr3++ = ptr[3];
192

193
                    ptr += 4;
194
                }
195
            }
196
        }
197
#endif // __mips_msa
198

199
        if (elempack == 1) // out_elempack == 4
200
        {
201
            #pragma omp parallel for num_threads(opt.num_threads)
202
            for (int q = 0; q < channels; q++)
203
            {
204
                const float* ptr = bottom_blob.channel(q);
205
                float* outptr = (float*)top_blob + size * q;
206

207
                int i = 0;
208
#if __mips_msa
209
                for (; i + 3 < size; i += 4)
210
                {
211
                    __msa_st_w(__msa_ld_w(ptr, 0), outptr, 0);
212
                    ptr += 4;
213
                    outptr += 4;
214
                }
215
#endif // __mips_msa
216
                for (; i < size; i++)
217
                {
218
                    *outptr++ = *ptr++;
219
                }
220
            }
221
        }
222
    }
223

224
    return 0;
225
}
226

227
int Flatten_mips::forward_int8(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const
228
{
229
    int dims = bottom_blob.dims;
230

231
    if (dims == 1)
232
    {
233
        top_blob = bottom_blob;
234
        return 0;
235
    }
236

237
    int w = bottom_blob.w;
238
    int h = bottom_blob.h;
239
    int d = bottom_blob.d;
240
    int channels = bottom_blob.c;
241
    size_t elemsize = bottom_blob.elemsize;
242
    int elempack = bottom_blob.elempack;
243
    int size = w * h * d;
244

245
    int total = size * channels * elempack;
246

247
    int out_elempack = 1;
248
#if __mips_msa
249
    if (opt.use_packing_layout)
250
    {
251
        out_elempack = total % 8 == 0 ? 8 : 1;
252
    }
253
#endif
254
    size_t out_elemsize = elemsize / elempack * out_elempack;
255

256
    if (out_elempack == 1)
257
    {
258
        return Flatten::forward(bottom_blob, top_blob, opt);
259
    }
260

261
    if (dims == 2 && elempack == 1) // out_elempack == 8
262
    {
263
        top_blob = bottom_blob;
264
        top_blob.dims = 1;
265
        top_blob.w = total / out_elempack;
266
        top_blob.h = 1;
267
        top_blob.cstep = top_blob.w;
268
        top_blob.elemsize = out_elemsize;
269
        top_blob.elempack = out_elempack;
270
        return 0;
271
    }
272

273
    top_blob.create(total / out_elempack, out_elemsize, out_elempack, opt.blob_allocator);
274
    if (top_blob.empty())
275
        return -100;
276

277
    if (dims == 2)
278
    {
279
#if __mips_msa
280
        if (elempack == 8) // out_elempack == 8
281
        {
282
            #pragma omp parallel for num_threads(opt.num_threads)
283
            for (int i = 0; i < h; i++)
284
            {
285
                const signed char* ptr = bottom_blob.row<signed char>(i);
286
                signed char* outptr0 = (signed char*)top_blob + w * i * 8;
287
                signed char* outptr1 = (signed char*)top_blob + w * (i * 8 + 1);
288
                signed char* outptr2 = (signed char*)top_blob + w * (i * 8 + 2);
289
                signed char* outptr3 = (signed char*)top_blob + w * (i * 8 + 3);
290
                signed char* outptr4 = (signed char*)top_blob + w * (i * 8 + 4);
291
                signed char* outptr5 = (signed char*)top_blob + w * (i * 8 + 5);
292
                signed char* outptr6 = (signed char*)top_blob + w * (i * 8 + 6);
293
                signed char* outptr7 = (signed char*)top_blob + w * (i * 8 + 7);
294

295
                int j = 0;
296
                for (; j < w; j++)
297
                {
298
                    *outptr0++ = ptr[0];
299
                    *outptr1++ = ptr[1];
300
                    *outptr2++ = ptr[2];
301
                    *outptr3++ = ptr[3];
302
                    *outptr4++ = ptr[4];
303
                    *outptr5++ = ptr[5];
304
                    *outptr6++ = ptr[6];
305
                    *outptr7++ = ptr[7];
306

307
                    ptr += 8;
308
                }
309
            }
310
        }
311
#endif // __mips_msa
312
    }
313

314
    if (dims == 3 || dims == 4)
315
    {
316
#if __mips_msa
317
        if (elempack == 8) // out_elempack == 8
318
        {
319
            #pragma omp parallel for num_threads(opt.num_threads)
320
            for (int q = 0; q < channels; q++)
321
            {
322
                const signed char* ptr = bottom_blob.channel(q);
323
                signed char* outptr0 = (signed char*)top_blob + size * q * 8;
324
                signed char* outptr1 = (signed char*)top_blob + size * (q * 8 + 1);
325
                signed char* outptr2 = (signed char*)top_blob + size * (q * 8 + 2);
326
                signed char* outptr3 = (signed char*)top_blob + size * (q * 8 + 3);
327
                signed char* outptr4 = (signed char*)top_blob + size * (q * 8 + 4);
328
                signed char* outptr5 = (signed char*)top_blob + size * (q * 8 + 5);
329
                signed char* outptr6 = (signed char*)top_blob + size * (q * 8 + 6);
330
                signed char* outptr7 = (signed char*)top_blob + size * (q * 8 + 7);
331

332
                int i = 0;
333
                for (; i < size; i++)
334
                {
335
                    *outptr0++ = ptr[0];
336
                    *outptr1++ = ptr[1];
337
                    *outptr2++ = ptr[2];
338
                    *outptr3++ = ptr[3];
339
                    *outptr4++ = ptr[4];
340
                    *outptr5++ = ptr[5];
341
                    *outptr6++ = ptr[6];
342
                    *outptr7++ = ptr[7];
343

344
                    ptr += 8;
345
                }
346
            }
347
        }
348
#endif // __mips_msa
349

350
        if (elempack == 1) // out_elempack == 8
351
        {
352
            #pragma omp parallel for num_threads(opt.num_threads)
353
            for (int q = 0; q < channels; q++)
354
            {
355
                const signed char* ptr = bottom_blob.channel(q);
356
                signed char* outptr = (signed char*)top_blob + size * q;
357

358
                int i = 0;
359
                for (; i < size; i++)
360
                {
361
                    *outptr++ = *ptr++;
362
                }
363
            }
364
        }
365
    }
366

367
    return 0;
368
}
369

370
} // namespace ncnn
371

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.