ncnn

Форк
0
/
quantize_mips.cpp 
492 строки · 18.8 Кб
1
// Tencent is pleased to support the open source community by making ncnn available.
2
//
3
// Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved.
4
//
5
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
6
// in compliance with the License. You may obtain a copy of the License at
7
//
8
// https://opensource.org/licenses/BSD-3-Clause
9
//
10
// Unless required by applicable law or agreed to in writing, software distributed
11
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
12
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
13
// specific language governing permissions and limitations under the License.
14

15
#include "quantize_mips.h"
16

17
#if __mips_msa
18
#include <msa.h>
19
#endif // __mips_msa
20

21
#include "mips_usability.h"
22

23
namespace ncnn {
24

25
Quantize_mips::Quantize_mips()
26
{
27
#if __mips_msa
28
    support_packing = true;
29
#endif
30
}
31

32
int Quantize_mips::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const
33
{
34
    int dims = bottom_blob.dims;
35
    int elempack = bottom_blob.elempack;
36

37
#if __mips_msa
38
    if (elempack == 4)
39
    {
40
        if (dims == 1)
41
        {
42
            int w = bottom_blob.w;
43
            int out_elempack = opt.use_packing_layout && w * elempack % 8 == 0 ? 8 : 1;
44
            int outw = w * elempack / out_elempack;
45

46
            top_blob.create(outw, (size_t)out_elempack, out_elempack, opt.blob_allocator);
47
            if (top_blob.empty())
48
                return -100;
49

50
            if (scale_data_size == 1)
51
            {
52
                const float scale = scale_data[0];
53

54
                #pragma omp parallel for num_threads(opt.num_threads)
55
                for (int i = 0; i < w; i++)
56
                {
57
                    const float* ptr0 = (const float*)bottom_blob + i * 4;
58
                    signed char* outptr = (signed char*)top_blob + i * 4;
59

60
                    outptr[0] = float2int8(ptr0[0] * scale);
61
                    outptr[1] = float2int8(ptr0[1] * scale);
62
                    outptr[2] = float2int8(ptr0[2] * scale);
63
                    outptr[3] = float2int8(ptr0[3] * scale);
64
                }
65
            }
66
            else
67
            {
68
                #pragma omp parallel for num_threads(opt.num_threads)
69
                for (int i = 0; i < w; i++)
70
                {
71
                    const float* ptr0 = (const float*)bottom_blob + i * 4;
72
                    signed char* outptr = (signed char*)top_blob + i * 4;
73

74
                    outptr[0] = float2int8(ptr0[0] * scale_data[i * 4]);
75
                    outptr[1] = float2int8(ptr0[1] * scale_data[i * 4 + 1]);
76
                    outptr[2] = float2int8(ptr0[2] * scale_data[i * 4 + 2]);
77
                    outptr[3] = float2int8(ptr0[3] * scale_data[i * 4 + 3]);
78
                }
79
            }
80
        }
81

82
        if (dims == 2)
83
        {
84
            int w = bottom_blob.w;
85
            int h = bottom_blob.h;
86
            int out_elempack = opt.use_packing_layout && h * elempack % 8 == 0 ? 8 : 1;
87
            int outh = h * elempack / out_elempack;
88

89
            top_blob.create(w, outh, (size_t)out_elempack, out_elempack, opt.blob_allocator);
90
            if (top_blob.empty())
91
                return -100;
92

93
            if (out_elempack == 8)
94
            {
95
                if (scale_data_size == 1)
96
                {
97
                    v4f32 _scale = (v4f32)__msa_fill_w_f32(scale_data[0]);
98

99
                    #pragma omp parallel for num_threads(opt.num_threads)
100
                    for (int i = 0; i < outh; i++)
101
                    {
102
                        const float* ptr0 = bottom_blob.row(i * 2);
103
                        const float* ptr1 = bottom_blob.row(i * 2 + 1);
104
                        signed char* outptr = top_blob.row<signed char>(i);
105

106
                        for (int j = 0; j < w; j++)
107
                        {
108
                            __builtin_prefetch(ptr0 + 16);
109
                            __builtin_prefetch(ptr1 + 16);
110
                            v4f32 _vlow = (v4f32)__msa_ld_w(ptr0, 0);
111
                            v4f32 _vhigh = (v4f32)__msa_ld_w(ptr1, 0);
112
                            _vlow = __msa_fmul_w(_vlow, _scale);
113
                            _vhigh = __msa_fmul_w(_vhigh, _scale);
114
                            *((int64_t*)outptr) = float2int8(_vlow, _vhigh);
115

116
                            ptr0 += 4;
117
                            ptr1 += 4;
118
                            outptr += 8;
119
                        }
120
                    }
121
                }
122
                else
123
                {
124
                    #pragma omp parallel for num_threads(opt.num_threads)
125
                    for (int i = 0; i < outh; i++)
126
                    {
127
                        const float* ptr0 = bottom_blob.row(i * 2);
128
                        const float* ptr1 = bottom_blob.row(i * 2 + 1);
129
                        signed char* outptr = top_blob.row<signed char>(i);
130

131
                        v4f32 _scale0 = (v4f32)__msa_ld_w((const float*)scale_data + i * 8, 0);
132
                        v4f32 _scale1 = (v4f32)__msa_ld_w((const float*)scale_data + i * 8 + 4, 0);
133

134
                        for (int j = 0; j < w; j++)
135
                        {
136
                            __builtin_prefetch(ptr0 + 16);
137
                            __builtin_prefetch(ptr1 + 16);
138
                            v4f32 _vlow = (v4f32)__msa_ld_w(ptr0, 0);
139
                            v4f32 _vhigh = (v4f32)__msa_ld_w(ptr1, 0);
140
                            _vlow = __msa_fmul_w(_vlow, _scale0);
141
                            _vhigh = __msa_fmul_w(_vhigh, _scale1);
142
                            *((int64_t*)outptr) = float2int8(_vlow, _vhigh);
143

144
                            ptr0 += 4;
145
                            ptr1 += 4;
146
                            outptr += 8;
147
                        }
148
                    }
149
                }
150
            }
151
            if (out_elempack == 1)
152
            {
153
                if (scale_data_size == 1)
154
                {
155
                    const float scale = scale_data[0];
156

157
                    #pragma omp parallel for num_threads(opt.num_threads)
158
                    for (int i = 0; i < h; i++)
159
                    {
160
                        const float* ptr0 = bottom_blob.row(i);
161
                        signed char* outptr0 = top_blob.row<signed char>(i * 4);
162
                        signed char* outptr1 = top_blob.row<signed char>(i * 4 + 1);
163
                        signed char* outptr2 = top_blob.row<signed char>(i * 4 + 2);
164
                        signed char* outptr3 = top_blob.row<signed char>(i * 4 + 3);
165

166
                        for (int j = 0; j < w; j++)
167
                        {
168
                            outptr0[0] = float2int8(ptr0[0] * scale);
169
                            outptr1[0] = float2int8(ptr0[1] * scale);
170
                            outptr2[0] = float2int8(ptr0[2] * scale);
171
                            outptr3[0] = float2int8(ptr0[3] * scale);
172

173
                            ptr0 += 4;
174
                            outptr0 += 1;
175
                            outptr1 += 1;
176
                            outptr2 += 1;
177
                            outptr3 += 1;
178
                        }
179
                    }
180
                }
181
                else
182
                {
183
                    #pragma omp parallel for num_threads(opt.num_threads)
184
                    for (int i = 0; i < h; i++)
185
                    {
186
                        const float* ptr0 = bottom_blob.row(i);
187
                        signed char* outptr0 = top_blob.row<signed char>(i * 4);
188
                        signed char* outptr1 = top_blob.row<signed char>(i * 4 + 1);
189
                        signed char* outptr2 = top_blob.row<signed char>(i * 4 + 2);
190
                        signed char* outptr3 = top_blob.row<signed char>(i * 4 + 3);
191

192
                        const float s0 = scale_data[i * 4];
193
                        const float s1 = scale_data[i * 4 + 1];
194
                        const float s2 = scale_data[i * 4 + 2];
195
                        const float s3 = scale_data[i * 4 + 3];
196

197
                        for (int j = 0; j < w; j++)
198
                        {
199
                            outptr0[0] = float2int8(ptr0[0] * s0);
200
                            outptr1[0] = float2int8(ptr0[1] * s1);
201
                            outptr2[0] = float2int8(ptr0[2] * s2);
202
                            outptr3[0] = float2int8(ptr0[3] * s3);
203

204
                            ptr0 += 4;
205
                            outptr0 += 1;
206
                            outptr1 += 1;
207
                            outptr2 += 1;
208
                            outptr3 += 1;
209
                        }
210
                    }
211
                }
212
            }
213
        }
214

215
        if (dims == 3)
216
        {
217
            int w = bottom_blob.w;
218
            int h = bottom_blob.h;
219
            int channels = bottom_blob.c;
220
            int size = w * h;
221
            int out_elempack = opt.use_packing_layout && channels * elempack % 8 == 0 ? 8 : 1;
222
            int outc = channels * elempack / out_elempack;
223

224
            top_blob.create(w, h, outc, (size_t)out_elempack, out_elempack, opt.blob_allocator);
225
            if (top_blob.empty())
226
                return -100;
227

228
            if (out_elempack == 8)
229
            {
230
                if (scale_data_size == 1)
231
                {
232
                    v4f32 _scale = (v4f32)__msa_fill_w_f32(scale_data[0]);
233

234
                    #pragma omp parallel for num_threads(opt.num_threads)
235
                    for (int q = 0; q < outc; q++)
236
                    {
237
                        const float* ptr0 = bottom_blob.channel(q * 2);
238
                        const float* ptr1 = bottom_blob.channel(q * 2 + 1);
239
                        signed char* outptr = top_blob.channel(q);
240

241
                        int i = 0;
242
                        for (; i + 1 < size; i += 2)
243
                        {
244
                            __builtin_prefetch(ptr0 + 32);
245
                            __builtin_prefetch(ptr1 + 32);
246
                            v4f32 _v0 = (v4f32)__msa_ld_w(ptr0, 0);
247
                            v4f32 _v1 = (v4f32)__msa_ld_w(ptr0 + 4, 0);
248
                            v4f32 _v2 = (v4f32)__msa_ld_w(ptr1, 0);
249
                            v4f32 _v3 = (v4f32)__msa_ld_w(ptr1 + 4, 0);
250
                            _v0 = __msa_fmul_w(_v0, _scale);
251
                            _v1 = __msa_fmul_w(_v1, _scale);
252
                            _v2 = __msa_fmul_w(_v2, _scale);
253
                            _v3 = __msa_fmul_w(_v3, _scale);
254
                            *((int64_t*)outptr) = float2int8(_v0, _v2);
255
                            *((int64_t*)(outptr + 8)) = float2int8(_v1, _v3);
256

257
                            ptr0 += 8;
258
                            ptr1 += 8;
259
                            outptr += 16;
260
                        }
261
                        for (; i < size; i++)
262
                        {
263
                            __builtin_prefetch(ptr0 + 16);
264
                            __builtin_prefetch(ptr1 + 16);
265
                            v4f32 _vlow = (v4f32)__msa_ld_w(ptr0, 0);
266
                            v4f32 _vhigh = (v4f32)__msa_ld_w(ptr1, 0);
267
                            _vlow = __msa_fmul_w(_vlow, _scale);
268
                            _vhigh = __msa_fmul_w(_vhigh, _scale);
269
                            *((int64_t*)outptr) = float2int8(_vlow, _vhigh);
270

271
                            ptr0 += 4;
272
                            ptr1 += 4;
273
                            outptr += 8;
274
                        }
275
                    }
276
                }
277
                else
278
                {
279
                    #pragma omp parallel for num_threads(opt.num_threads)
280
                    for (int q = 0; q < outc; q++)
281
                    {
282
                        const float* ptr0 = bottom_blob.channel(q * 2);
283
                        const float* ptr1 = bottom_blob.channel(q * 2 + 1);
284
                        signed char* outptr = top_blob.channel(q);
285

286
                        v4f32 _scale0 = (v4f32)__msa_ld_w((const float*)scale_data + q * 8, 0);
287
                        v4f32 _scale1 = (v4f32)__msa_ld_w((const float*)scale_data + q * 8 + 4, 0);
288

289
                        int i = 0;
290
                        for (; i < size; i++)
291
                        {
292
                            __builtin_prefetch(ptr0 + 16);
293
                            __builtin_prefetch(ptr1 + 16);
294
                            v4f32 _vlow = (v4f32)__msa_ld_w(ptr0, 0);
295
                            v4f32 _vhigh = (v4f32)__msa_ld_w(ptr1, 0);
296
                            _vlow = __msa_fmul_w(_vlow, _scale0);
297
                            _vhigh = __msa_fmul_w(_vhigh, _scale1);
298
                            *((int64_t*)outptr) = float2int8(_vlow, _vhigh);
299

300
                            ptr0 += 4;
301
                            ptr1 += 4;
302
                            outptr += 8;
303
                        }
304
                    }
305
                }
306
            }
307
            if (out_elempack == 1)
308
            {
309
                if (scale_data_size == 1)
310
                {
311
                    const float scale = scale_data[0];
312

313
                    #pragma omp parallel for num_threads(opt.num_threads)
314
                    for (int q = 0; q < channels; q++)
315
                    {
316
                        const float* ptr0 = bottom_blob.channel(q);
317
                        signed char* outptr0 = top_blob.channel(q * 4);
318
                        signed char* outptr1 = top_blob.channel(q * 4 + 1);
319
                        signed char* outptr2 = top_blob.channel(q * 4 + 2);
320
                        signed char* outptr3 = top_blob.channel(q * 4 + 3);
321

322
                        for (int i = 0; i < size; i++)
323
                        {
324
                            outptr0[0] = float2int8(ptr0[0] * scale);
325
                            outptr1[0] = float2int8(ptr0[1] * scale);
326
                            outptr2[0] = float2int8(ptr0[2] * scale);
327
                            outptr3[0] = float2int8(ptr0[3] * scale);
328

329
                            ptr0 += 4;
330
                            outptr0 += 1;
331
                            outptr1 += 1;
332
                            outptr2 += 1;
333
                            outptr3 += 1;
334
                        }
335
                    }
336
                }
337
                else
338
                {
339
                    #pragma omp parallel for num_threads(opt.num_threads)
340
                    for (int q = 0; q < channels; q++)
341
                    {
342
                        const float* ptr0 = bottom_blob.channel(q);
343
                        signed char* outptr0 = top_blob.channel(q * 4);
344
                        signed char* outptr1 = top_blob.channel(q * 4 + 1);
345
                        signed char* outptr2 = top_blob.channel(q * 4 + 2);
346
                        signed char* outptr3 = top_blob.channel(q * 4 + 3);
347

348
                        const float s0 = scale_data[q * 4];
349
                        const float s1 = scale_data[q * 4 + 1];
350
                        const float s2 = scale_data[q * 4 + 2];
351
                        const float s3 = scale_data[q * 4 + 3];
352

353
                        for (int i = 0; i < size; i++)
354
                        {
355
                            outptr0[0] = float2int8(ptr0[0] * s0);
356
                            outptr1[0] = float2int8(ptr0[1] * s1);
357
                            outptr2[0] = float2int8(ptr0[2] * s2);
358
                            outptr3[0] = float2int8(ptr0[3] * s3);
359

360
                            ptr0 += 4;
361
                            outptr0 += 1;
362
                            outptr1 += 1;
363
                            outptr2 += 1;
364
                            outptr3 += 1;
365
                        }
366
                    }
367
                }
368
            }
369
        }
370

371
        return 0;
372
    }
373
#endif // __mips_msa
374

375
    if (dims == 1)
376
    {
377
        int w = bottom_blob.w;
378

379
        top_blob.create(w, (size_t)1u, opt.blob_allocator);
380
        if (top_blob.empty())
381
            return -100;
382

383
        const float* ptr = bottom_blob;
384
        signed char* outptr = top_blob;
385

386
        if (scale_data_size == 1)
387
        {
388
            const float scale = scale_data[0];
389

390
            #pragma omp parallel for num_threads(opt.num_threads)
391
            for (int i = 0; i < w; i++)
392
            {
393
                outptr[i] = float2int8(ptr[i] * scale);
394
            }
395
        }
396
        else
397
        {
398
            #pragma omp parallel for num_threads(opt.num_threads)
399
            for (int i = 0; i < w; i++)
400
            {
401
                outptr[i] = float2int8(ptr[i] * scale_data[i]);
402
            }
403
        }
404
    }
405

406
    if (dims == 2)
407
    {
408
        int w = bottom_blob.w;
409
        int h = bottom_blob.h;
410

411
        top_blob.create(w, h, (size_t)1u, opt.blob_allocator);
412
        if (top_blob.empty())
413
            return -100;
414

415
        #pragma omp parallel for num_threads(opt.num_threads)
416
        for (int i = 0; i < h; i++)
417
        {
418
            const float* ptr0 = bottom_blob.row(i);
419
            signed char* outptr0 = top_blob.row<signed char>(i);
420

421
            const float scale = scale_data_size == 1 ? scale_data[0] : scale_data[i];
422

423
            for (int j = 0; j < w; j++)
424
            {
425
                *outptr0++ = float2int8(*ptr0++ * scale);
426
            }
427
        }
428
    }
429

430
    if (dims == 3)
431
    {
432
        int w = bottom_blob.w;
433
        int h = bottom_blob.h;
434
        int channels = bottom_blob.c;
435
        int size = w * h;
436

437
        top_blob.create(w, h, channels, (size_t)1u, opt.blob_allocator);
438
        if (top_blob.empty())
439
            return -100;
440

441
        #pragma omp parallel for num_threads(opt.num_threads)
442
        for (int q = 0; q < channels; q++)
443
        {
444
            const float* ptr = bottom_blob.channel(q);
445
            signed char* outptr = top_blob.channel(q);
446

447
            const float scale = scale_data_size == 1 ? scale_data[0] : scale_data[q];
448

449
            int i = 0;
450
#if __mips_msa
451
            v4f32 _scale = (v4f32)__msa_fill_w_f32(scale);
452
            for (; i + 15 < size; i += 16)
453
            {
454
                __builtin_prefetch(ptr + 64);
455
                v4f32 _v0 = (v4f32)__msa_ld_w(ptr, 0);
456
                v4f32 _v1 = (v4f32)__msa_ld_w(ptr + 4, 0);
457
                v4f32 _v2 = (v4f32)__msa_ld_w(ptr + 8, 0);
458
                v4f32 _v3 = (v4f32)__msa_ld_w(ptr + 12, 0);
459
                _v0 = __msa_fmul_w(_v0, _scale);
460
                _v1 = __msa_fmul_w(_v1, _scale);
461
                _v2 = __msa_fmul_w(_v2, _scale);
462
                _v3 = __msa_fmul_w(_v3, _scale);
463
                *((int64_t*)outptr) = float2int8(_v0, _v1);
464
                *((int64_t*)(outptr + 8)) = float2int8(_v2, _v3);
465

466
                ptr += 16;
467
                outptr += 16;
468
            }
469
            for (; i + 7 < size; i += 8)
470
            {
471
                __builtin_prefetch(ptr + 32);
472
                v4f32 _v0 = (v4f32)__msa_ld_w(ptr, 0);
473
                v4f32 _v1 = (v4f32)__msa_ld_w(ptr + 4, 0);
474
                _v0 = __msa_fmul_w(_v0, _scale);
475
                _v1 = __msa_fmul_w(_v1, _scale);
476
                *((int64_t*)outptr) = float2int8(_v0, _v1);
477

478
                ptr += 8;
479
                outptr += 8;
480
            }
481
#endif // __mips_msa
482
            for (; i < size; i++)
483
            {
484
                *outptr++ = float2int8(*ptr++ * scale);
485
            }
486
        }
487
    }
488

489
    return 0;
490
}
491

492
} // namespace ncnn
493

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.