ncnn

Форк
0
/
innerproduct_mips.cpp 
1631 строка · 58.3 Кб
1
// Tencent is pleased to support the open source community by making ncnn available.
2
//
3
// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved.
4
//
5
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
6
// in compliance with the License. You may obtain a copy of the License at
7
//
8
// https://opensource.org/licenses/BSD-3-Clause
9
//
10
// Unless required by applicable law or agreed to in writing, software distributed
11
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
12
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
13
// specific language governing permissions and limitations under the License.
14

15
#include "innerproduct_mips.h"
16

17
#include "layer_type.h"
18

19
#if __mips_msa
20
#include <msa.h>
21
#include "msa_mathfun.h"
22
#endif // __mips_msa
23

24
#include "mips_activation.h"
25

26
namespace ncnn {
27

28
InnerProduct_mips::InnerProduct_mips()
29
{
30
#if __mips_msa
31
    support_packing = true;
32
#endif // __mips_msa
33

34
    flatten = 0;
35
}
36

37
int InnerProduct_mips::create_pipeline(const Option& opt)
38
{
39
    {
40
        flatten = ncnn::create_layer_cpu(ncnn::LayerType::Flatten);
41

42
        ncnn::ParamDict pd;
43

44
        flatten->load_param(pd);
45

46
        flatten->create_pipeline(opt);
47
    }
48

49
#if NCNN_INT8
50
    if (opt.use_int8_inference && weight_data.elemsize == (size_t)1u)
51
    {
52
        return create_pipeline_int8_mips(opt);
53
    }
54
#endif
55

56
#if __mips_msa
57
    if (opt.use_fp16_storage)
58
    {
59
        return create_pipeline_fp16s(opt);
60
    }
61
#endif
62

63
    const int num_input = weight_data_size / num_output;
64

65
    int out_elempack = 1;
66

67
#if __mips_msa
68
    if (opt.use_packing_layout)
69
    {
70
        out_elempack = num_output % 4 == 0 ? 4 : 1;
71
    }
72
#endif // __mips_msa
73

74
    if (out_elempack == 4)
75
    {
76
        // src = inch-outch
77
        // dst = 4-inch-outch/4
78
        {
79
            Mat weight_data_r2 = weight_data.reshape(num_input, num_output);
80

81
            weight_data_tm.create(num_input, num_output / 4, (size_t)4u * 4, 4);
82

83
            for (int q = 0; q + 3 < num_output; q += 4)
84
            {
85
                float* g0 = weight_data_tm.row(q / 4);
86

87
                for (int p = 0; p < num_input; p++)
88
                {
89
                    for (int j = 0; j < 4; j++)
90
                    {
91
                        *g0++ = weight_data_r2.row(q + j)[p];
92
                    }
93
                }
94
            }
95
        }
96
    }
97
    else
98
    {
99
        weight_data_tm = weight_data;
100
    }
101

102
    if (opt.lightmode)
103
        weight_data.release();
104

105
    return 0;
106
}
107

108
int InnerProduct_mips::destroy_pipeline(const Option& opt)
109
{
110
    if (flatten)
111
    {
112
        flatten->destroy_pipeline(opt);
113
        delete flatten;
114
        flatten = 0;
115
    }
116

117
    return 0;
118
}
119

120
int InnerProduct_mips::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const
121
{
122
#if NCNN_INT8
123
    if (opt.use_int8_inference && int8_scale_term)
124
    {
125
        return forward_int8_mips(bottom_blob, top_blob, opt);
126
    }
127
#endif
128

129
#if __mips_msa
130
    if (opt.use_fp16_storage)
131
    {
132
        return forward_fp16s(bottom_blob, top_blob, opt);
133
    }
134
#endif
135

136
    const int num_input = weight_data_size / num_output;
137

138
    if (bottom_blob.dims == 2 && bottom_blob.w == num_input)
139
    {
140
        // gemm
141
        int h = bottom_blob.h;
142
        size_t elemsize = bottom_blob.elemsize;
143
        int elempack = bottom_blob.elempack;
144

145
        top_blob.create(num_output, h, elemsize, elempack, opt.blob_allocator);
146
        if (top_blob.empty())
147
            return -100;
148

149
        int num_output_elempack = 1;
150
#if __mips_msa
151
        if (opt.use_packing_layout)
152
        {
153
            num_output_elempack = num_output % 4 == 0 ? 4 : 1;
154
        }
155
#endif
156

157
        #pragma omp parallel for num_threads(opt.num_threads)
158
        for (int j = 0; j < h; j++)
159
        {
160
#if __mips_msa
161
            if (elempack == 4 && num_output_elempack == 4)
162
            {
163
                float* outptr = top_blob.row(j);
164

165
                for (int p = 0; p < num_output / num_output_elempack; p++)
166
                {
167
                    const float* kptr = weight_data_tm.row(p);
168
                    const float* m = bottom_blob.row(j);
169

170
                    v4f32 _sum0 = (v4f32)__msa_fill_w(0);
171
                    v4f32 _sum1 = (v4f32)__msa_fill_w(0);
172
                    v4f32 _sum2 = (v4f32)__msa_fill_w(0);
173
                    v4f32 _sum3 = (v4f32)__msa_fill_w(0);
174

175
                    if (bias_term)
176
                    {
177
                        _sum0 = __msa_fill_w_f32(bias_data[p * 4 + 0]);
178
                        _sum1 = __msa_fill_w_f32(bias_data[p * 4 + 1]);
179
                        _sum2 = __msa_fill_w_f32(bias_data[p * 4 + 2]);
180
                        _sum3 = __msa_fill_w_f32(bias_data[p * 4 + 3]);
181
                    }
182

183
                    int i = 0;
184
                    for (; i < num_input; i++)
185
                    {
186
                        __builtin_prefetch(m + 16);
187
                        __builtin_prefetch(kptr + 16);
188
                        v4f32 _val = (v4f32)__msa_ld_w(m, 0);
189
                        v4i32 _w = __msa_ld_w(kptr, 0);
190
                        _sum0 = __msa_fmadd_w(_sum0, _val, (v4f32)__msa_splati_w(_w, 0));
191
                        _sum1 = __msa_fmadd_w(_sum1, _val, (v4f32)__msa_splati_w(_w, 1));
192
                        _sum2 = __msa_fmadd_w(_sum2, _val, (v4f32)__msa_splati_w(_w, 2));
193
                        _sum3 = __msa_fmadd_w(_sum3, _val, (v4f32)__msa_splati_w(_w, 3));
194

195
                        m += 4;
196
                        kptr += 4;
197
                    }
198

199
                    _sum0 = activation_ps(_sum0, activation_type, activation_params);
200
                    _sum1 = activation_ps(_sum1, activation_type, activation_params);
201
                    _sum2 = activation_ps(_sum2, activation_type, activation_params);
202
                    _sum3 = activation_ps(_sum3, activation_type, activation_params);
203

204
                    __msa_st_w((v4i32)_sum0, outptr, 0);
205
                    __msa_st_w((v4i32)_sum1, outptr + 4, 0);
206
                    __msa_st_w((v4i32)_sum2, outptr + 8, 0);
207
                    __msa_st_w((v4i32)_sum3, outptr + 12, 0);
208
                    outptr += 16;
209
                }
210
            }
211

212
            if (elempack == 1 && num_output_elempack == 4)
213
            {
214
                float* outptr = top_blob.row(j);
215

216
                for (int p = 0; p < num_output / num_output_elempack; p++)
217
                {
218
                    const float* kptr = weight_data_tm.row(p);
219
                    const float* m = bottom_blob.row(j);
220

221
                    v4f32 _sum0 = (v4f32)__msa_fill_w(0);
222
                    v4f32 _sum1 = (v4f32)__msa_fill_w(0);
223
                    v4f32 _sum2 = (v4f32)__msa_fill_w(0);
224
                    v4f32 _sum3 = (v4f32)__msa_fill_w(0);
225

226
                    if (bias_term)
227
                    {
228
                        _sum0 = (v4f32)__msa_ld_w((const float*)bias_data + p * 4, 0);
229
                    }
230

231
                    int i = 0;
232
                    for (; i + 3 < num_input; i += 4)
233
                    {
234
                        __builtin_prefetch(m + 16);
235
                        __builtin_prefetch(kptr + 64);
236
                        v4i32 _val = __msa_ld_w(m, 0);
237
                        v4f32 _w0 = (v4f32)__msa_ld_w(kptr, 0);
238
                        v4f32 _w1 = (v4f32)__msa_ld_w(kptr + 4, 0);
239
                        v4f32 _w2 = (v4f32)__msa_ld_w(kptr + 8, 0);
240
                        v4f32 _w3 = (v4f32)__msa_ld_w(kptr + 12, 0);
241
                        _sum0 = __msa_fmadd_w(_sum0, (v4f32)__msa_splati_w(_val, 0), _w0);
242
                        _sum1 = __msa_fmadd_w(_sum1, (v4f32)__msa_splati_w(_val, 1), _w1);
243
                        _sum2 = __msa_fmadd_w(_sum2, (v4f32)__msa_splati_w(_val, 2), _w2);
244
                        _sum3 = __msa_fmadd_w(_sum3, (v4f32)__msa_splati_w(_val, 3), _w3);
245

246
                        m += 4;
247
                        kptr += 16;
248
                    }
249
                    for (; i < num_input; i++)
250
                    {
251
                        v4f32 _val = __msa_fill_w_f32(m[0]);
252
                        v4f32 _w = (v4f32)__msa_ld_w(kptr, 0);
253
                        _sum0 = __msa_fmadd_w(_sum0, _val, _w);
254

255
                        m += 1;
256
                        kptr += 4;
257
                    }
258

259
                    _sum0 = __msa_fadd_w(_sum0, _sum1);
260
                    _sum2 = __msa_fadd_w(_sum2, _sum3);
261
                    _sum0 = __msa_fadd_w(_sum0, _sum2);
262

263
                    _sum0 = activation_ps(_sum0, activation_type, activation_params);
264

265
                    __msa_st_w((v4i32)_sum0, outptr, 0);
266
                    outptr += 4;
267
                }
268
            }
269

270
            if (elempack == 4 && num_output_elempack == 1)
271
            {
272
                float* outptr = top_blob.row(j);
273

274
                for (int p = 0; p < num_output; p++)
275
                {
276
                    const float* kptr = (const float*)weight_data_tm + num_input * p;
277
                    const float* m = bottom_blob.row(j);
278

279
                    v4f32 _sum = (v4f32)__msa_fill_w(0);
280

281
                    if (bias_term)
282
                    {
283
                        _sum = __msa_fill_w_f32(bias_data[p]);
284
                    }
285

286
                    for (int i = 0; i < num_input; i++)
287
                    {
288
                        __builtin_prefetch(m + 16);
289
                        __builtin_prefetch(kptr + 4);
290
                        v4f32 _val = (v4f32)__msa_ld_w(m, 0);
291
                        v4f32 _k = __msa_fill_w_f32(kptr[0]);
292
                        _sum = __msa_fmadd_w(_sum, _val, _k);
293

294
                        m += 4;
295
                        kptr += 1;
296
                    }
297

298
                    _sum = activation_ps(_sum, activation_type, activation_params);
299

300
                    __msa_st_w((v4i32)_sum, outptr, 0);
301
                    outptr += 4;
302
                }
303
            }
304
#endif // __mips_msa
305

306
            if (elempack == 1 && num_output_elempack == 1)
307
            {
308
                float* outptr = top_blob.row(j);
309

310
                for (int p = 0; p < num_output; p++)
311
                {
312
                    const float* kptr = (const float*)weight_data_tm + num_input * p;
313
                    const float* m = bottom_blob.row(j);
314

315
                    float sum = 0.f;
316

317
                    if (bias_term)
318
                    {
319
                        sum = bias_data[p];
320
                    }
321

322
                    int i = 0;
323
#if __mips_msa
324
                    v4f32 _sum = (v4f32)__msa_fill_w(0);
325
                    for (; i + 3 < num_input; i += 4)
326
                    {
327
                        __builtin_prefetch(m + 16);
328
                        __builtin_prefetch(kptr + 16);
329
                        v4f32 _m = (v4f32)__msa_ld_w(m, 0);
330
                        v4f32 _w = (v4f32)__msa_ld_w(kptr, 0);
331
                        _sum = __msa_fmadd_w(_sum, _m, _w);
332

333
                        m += 4;
334
                        kptr += 4;
335
                    }
336
                    sum += __msa_reduce_fadd_w(_sum);
337
#endif // __mips_msa
338
                    for (; i < num_input; i++)
339
                    {
340
                        sum += *m * *kptr;
341

342
                        m += 1;
343
                        kptr += 1;
344
                    }
345

346
                    sum = activation_ss(sum, activation_type, activation_params);
347

348
                    outptr[0] = sum;
349
                    outptr += 1;
350
                }
351
            }
352
        }
353

354
        return 0;
355
    }
356

357
    // flatten
358
    Mat bottom_blob_flattened = bottom_blob;
359
    if (bottom_blob.dims != 1)
360
    {
361
        Option opt_flatten = opt;
362
        opt_flatten.blob_allocator = opt.workspace_allocator;
363

364
        flatten->forward(bottom_blob, bottom_blob_flattened, opt_flatten);
365
    }
366

367
    size_t elemsize = bottom_blob_flattened.elemsize;
368
    int elempack = bottom_blob_flattened.elempack;
369

370
    int out_elempack = 1;
371
#if __mips_msa
372
    if (opt.use_packing_layout)
373
    {
374
        out_elempack = num_output % 4 == 0 ? 4 : 1;
375
    }
376
#endif // __mips_msa
377
    size_t out_elemsize = elemsize / elempack * out_elempack;
378

379
    top_blob.create(num_output / out_elempack, out_elemsize, out_elempack, opt.blob_allocator);
380
    if (top_blob.empty())
381
        return -100;
382

383
#if __mips_msa
384
    if (out_elempack == 4)
385
    {
386
        #pragma omp parallel for num_threads(opt.num_threads)
387
        for (int p = 0; p < num_output / out_elempack; p++)
388
        {
389
            v4f32 _sum0 = (v4f32)__msa_fill_w(0);
390
            v4f32 _sum1 = (v4f32)__msa_fill_w(0);
391
            v4f32 _sum2 = (v4f32)__msa_fill_w(0);
392
            v4f32 _sum3 = (v4f32)__msa_fill_w(0);
393

394
            if (bias_term)
395
            {
396
                _sum0 = (v4f32)__msa_ld_w((const float*)bias_data + p * 4, 0);
397
            }
398

399
            const float* kptr = weight_data_tm.row(p);
400

401
            const float* sptr = bottom_blob_flattened;
402

403
            int i = 0;
404
            for (; i + 3 < num_input; i += 4)
405
            {
406
                __builtin_prefetch(sptr + 16);
407
                __builtin_prefetch(kptr + 64);
408
                v4i32 _val = __msa_ld_w(sptr, 0);
409
                v4f32 _w0 = (v4f32)__msa_ld_w(kptr, 0);
410
                v4f32 _w1 = (v4f32)__msa_ld_w(kptr + 4, 0);
411
                v4f32 _w2 = (v4f32)__msa_ld_w(kptr + 8, 0);
412
                v4f32 _w3 = (v4f32)__msa_ld_w(kptr + 12, 0);
413
                _sum0 = __msa_fmadd_w(_sum0, (v4f32)__msa_splati_w(_val, 0), _w0);
414
                _sum1 = __msa_fmadd_w(_sum1, (v4f32)__msa_splati_w(_val, 1), _w1);
415
                _sum2 = __msa_fmadd_w(_sum2, (v4f32)__msa_splati_w(_val, 2), _w2);
416
                _sum3 = __msa_fmadd_w(_sum3, (v4f32)__msa_splati_w(_val, 3), _w3);
417

418
                sptr += 4;
419
                kptr += 16;
420
            }
421
            for (; i < num_input; i++)
422
            {
423
                v4f32 _val = __msa_fill_w_f32(sptr[0]);
424
                v4f32 _w = (v4f32)__msa_ld_w(kptr, 0);
425
                _sum0 = __msa_fmadd_w(_sum0, _val, _w);
426

427
                sptr += 1;
428
                kptr += 4;
429
            }
430

431
            _sum0 = __msa_fadd_w(_sum0, _sum1);
432
            _sum2 = __msa_fadd_w(_sum2, _sum3);
433
            _sum0 = __msa_fadd_w(_sum0, _sum2);
434

435
            _sum0 = activation_ps(_sum0, activation_type, activation_params);
436

437
            float* outptr = top_blob;
438
            __msa_st_w((v4i32)_sum0, outptr + p * 4, 0);
439
        }
440
    }
441
#endif // __mips_msa
442

443
    if (out_elempack == 1)
444
    {
445
        int nn_num_output = num_output / 4;
446
        int remain_num_output_start = nn_num_output * 4;
447

448
        #pragma omp parallel for num_threads(opt.num_threads)
449
        for (int pp = 0; pp < nn_num_output; pp++)
450
        {
451
            int p = pp * 4;
452

453
            float sum0 = 0.f;
454
            float sum1 = 0.f;
455
            float sum2 = 0.f;
456
            float sum3 = 0.f;
457

458
            if (bias_term)
459
            {
460
                sum0 = bias_data[p];
461
                sum1 = bias_data[p + 1];
462
                sum2 = bias_data[p + 2];
463
                sum3 = bias_data[p + 3];
464
            }
465

466
            const float* w0 = (const float*)weight_data_tm + num_input * p;
467
            const float* w1 = (const float*)weight_data_tm + num_input * (p + 1);
468
            const float* w2 = (const float*)weight_data_tm + num_input * (p + 2);
469
            const float* w3 = (const float*)weight_data_tm + num_input * (p + 3);
470

471
            const float* m = bottom_blob_flattened;
472

473
            int i = 0;
474
#if __mips_msa
475
            v4f32 _sum0 = (v4f32)__msa_fill_w(0);
476
            v4f32 _sum1 = (v4f32)__msa_fill_w(0);
477
            v4f32 _sum2 = (v4f32)__msa_fill_w(0);
478
            v4f32 _sum3 = (v4f32)__msa_fill_w(0);
479
            for (; i + 3 < num_input; i += 4)
480
            {
481
                __builtin_prefetch(m + 16);
482
                __builtin_prefetch(w0 + 16);
483
                __builtin_prefetch(w1 + 16);
484
                __builtin_prefetch(w2 + 16);
485
                __builtin_prefetch(w3 + 16);
486
                v4f32 _m = (v4f32)__msa_ld_w(m, 0);
487
                v4f32 _w0 = (v4f32)__msa_ld_w(w0, 0);
488
                v4f32 _w1 = (v4f32)__msa_ld_w(w1, 0);
489
                v4f32 _w2 = (v4f32)__msa_ld_w(w2, 0);
490
                v4f32 _w3 = (v4f32)__msa_ld_w(w3, 0);
491
                _sum0 = __msa_fmadd_w(_sum0, _m, _w0);
492
                _sum1 = __msa_fmadd_w(_sum1, _m, _w1);
493
                _sum2 = __msa_fmadd_w(_sum2, _m, _w2);
494
                _sum3 = __msa_fmadd_w(_sum3, _m, _w3);
495

496
                m += 4;
497
                w0 += 4;
498
                w1 += 4;
499
                w2 += 4;
500
                w3 += 4;
501
            }
502
#endif // __mips_msa
503
            for (; i < num_input; i++)
504
            {
505
                sum0 += *m * *w0;
506
                sum1 += *m * *w1;
507
                sum2 += *m * *w2;
508
                sum3 += *m * *w3;
509

510
                m++;
511
                w0++;
512
                w1++;
513
                w2++;
514
                w3++;
515
            }
516

517
#if __mips_msa
518
            sum0 += __msa_reduce_fadd_w(_sum0);
519
            sum1 += __msa_reduce_fadd_w(_sum1);
520
            sum2 += __msa_reduce_fadd_w(_sum2);
521
            sum3 += __msa_reduce_fadd_w(_sum3);
522
#endif // __mips_msa
523

524
            sum0 = activation_ss(sum0, activation_type, activation_params);
525
            sum1 = activation_ss(sum1, activation_type, activation_params);
526
            sum2 = activation_ss(sum2, activation_type, activation_params);
527
            sum3 = activation_ss(sum3, activation_type, activation_params);
528

529
            top_blob[p] = sum0;
530
            top_blob[p + 1] = sum1;
531
            top_blob[p + 2] = sum2;
532
            top_blob[p + 3] = sum3;
533
        }
534

535
        // num_output
536
        #pragma omp parallel for num_threads(opt.num_threads)
537
        for (int p = remain_num_output_start; p < num_output; p++)
538
        {
539
            float sum = 0.f;
540

541
            if (bias_term)
542
                sum = bias_data[p];
543

544
            const float* w = (const float*)weight_data_tm + num_input * p;
545

546
            const float* m = bottom_blob_flattened;
547

548
            int i = 0;
549
#if __mips_msa
550
            v4f32 _sum0 = (v4f32)__msa_fill_w(0);
551
            for (; i + 3 < num_input; i += 4)
552
            {
553
                __builtin_prefetch(m + 16);
554
                __builtin_prefetch(w + 16);
555
                v4f32 _m = (v4f32)__msa_ld_w(m, 0);
556
                v4f32 _w = (v4f32)__msa_ld_w(w, 0);
557
                _sum0 = __msa_fmadd_w(_sum0, _m, _w);
558

559
                m += 4;
560
                w += 4;
561
            }
562
            sum += __msa_reduce_fadd_w(_sum0);
563
#endif // __mips_msa
564
            for (; i < num_input; i++)
565
            {
566
                sum += *m * *w;
567

568
                m++;
569
                w++;
570
            }
571

572
            sum = activation_ss(sum, activation_type, activation_params);
573

574
            top_blob[p] = sum;
575
        }
576
    }
577

578
    return 0;
579
}
580

581
#if __mips_msa
582
int InnerProduct_mips::create_pipeline_fp16s(const Option& opt)
583
{
584
    const int num_input = weight_data_size / num_output;
585

586
    int out_elempack = 1;
587
    if (opt.use_packing_layout)
588
    {
589
        out_elempack = num_output % 4 == 0 ? 4 : 1;
590
    }
591

592
    // src = inch-outch
593
    // dst = pb-inch-outch/pb
594
    if (out_elempack == 4)
595
    {
596
        Mat weight_data_r2 = weight_data.reshape(num_input, num_output);
597

598
        weight_data_tm.create(num_input, num_output / 4, (size_t)8u, 4);
599

600
        for (int q = 0; q + 3 < num_output; q += 4)
601
        {
602
            unsigned short* g0 = weight_data_tm.row<unsigned short>(q / 4);
603

604
            const float* k0 = weight_data_r2.row(q);
605
            const float* k1 = weight_data_r2.row(q + 1);
606
            const float* k2 = weight_data_r2.row(q + 2);
607
            const float* k3 = weight_data_r2.row(q + 3);
608

609
            int p = 0;
610
            for (; p + 3 < num_input; p += 4)
611
            {
612
                // transpose 4x4
613
                v4f32 _r0 = (v4f32)__msa_ld_w(k0, 0);
614
                v4f32 _r1 = (v4f32)__msa_ld_w(k1, 0);
615
                v4f32 _r2 = (v4f32)__msa_ld_w(k2, 0);
616
                v4f32 _r3 = (v4f32)__msa_ld_w(k3, 0);
617

618
                v4i32 _r01r = __msa_ilvr_w((v4i32)_r1, (v4i32)_r0);
619
                v4i32 _r01l = __msa_ilvl_w((v4i32)_r1, (v4i32)_r0);
620
                v4i32 _r23r = __msa_ilvr_w((v4i32)_r3, (v4i32)_r2);
621
                v4i32 _r23l = __msa_ilvl_w((v4i32)_r3, (v4i32)_r2);
622
                v2i64 _r0123_0 = __msa_ilvr_d((v2i64)_r23r, (v2i64)_r01r);
623
                v2i64 _r0123_1 = __msa_ilvl_d((v2i64)_r23r, (v2i64)_r01r);
624
                v2i64 _r0123_2 = __msa_ilvr_d((v2i64)_r23l, (v2i64)_r01l);
625
                v2i64 _r0123_3 = __msa_ilvl_d((v2i64)_r23l, (v2i64)_r01l);
626

627
                v8i16 _p0 = __msa_fexdo_h((v4f32)_r0123_1, (v4f32)_r0123_0);
628
                v8i16 _p1 = __msa_fexdo_h((v4f32)_r0123_3, (v4f32)_r0123_2);
629

630
                __msa_st_h(_p0, g0, 0);
631
                __msa_st_h(_p1, g0 + 8, 0);
632

633
                k0 += 4;
634
                k1 += 4;
635
                k2 += 4;
636
                k3 += 4;
637
                g0 += 16;
638
            }
639
            for (; p < num_input; p++)
640
            {
641
                g0[0] = float32_to_float16(*k0++);
642
                g0[1] = float32_to_float16(*k1++);
643
                g0[2] = float32_to_float16(*k2++);
644
                g0[3] = float32_to_float16(*k3++);
645
                g0 += 4;
646
            }
647
        }
648
    }
649

650
    if (out_elempack == 1)
651
    {
652
        Mat weight_data_r2 = weight_data.reshape(num_input, num_output);
653
        ncnn::cast_float32_to_float16(weight_data_r2, weight_data_tm, opt);
654
    }
655

656
    if (opt.lightmode)
657
        weight_data.release();
658

659
    return 0;
660
}
661

662
int InnerProduct_mips::forward_fp16s(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const
663
{
664
    const int num_input = weight_data_size / num_output;
665

666
    if (bottom_blob.dims == 2 && bottom_blob.w == num_input)
667
    {
668
        // gemm
669
        int h = bottom_blob.h;
670
        size_t elemsize = bottom_blob.elemsize;
671
        int elempack = bottom_blob.elempack;
672

673
        top_blob.create(num_output, h, elemsize, elempack, opt.blob_allocator);
674
        if (top_blob.empty())
675
            return -100;
676

677
        int num_output_elempack = 1;
678
        if (opt.use_packing_layout)
679
        {
680
            num_output_elempack = num_output % 4 == 0 ? 4 : 1;
681
        }
682

683
        #pragma omp parallel for num_threads(opt.num_threads)
684
        for (int j = 0; j < h; j++)
685
        {
686
            if (elempack == 4 && num_output_elempack == 4)
687
            {
688
                float* outptr = top_blob.row(j);
689

690
                for (int p = 0; p < num_output / num_output_elempack; p++)
691
                {
692
                    const unsigned short* kptr = weight_data_tm.row<const unsigned short>(p);
693
                    const float* m = bottom_blob.row(j);
694

695
                    v4f32 _sum0 = (v4f32)__msa_fill_w(0);
696
                    v4f32 _sum1 = (v4f32)__msa_fill_w(0);
697
                    v4f32 _sum2 = (v4f32)__msa_fill_w(0);
698
                    v4f32 _sum3 = (v4f32)__msa_fill_w(0);
699

700
                    if (bias_term)
701
                    {
702
                        _sum0 = __msa_fill_w_f32(bias_data[p * 4 + 0]);
703
                        _sum1 = __msa_fill_w_f32(bias_data[p * 4 + 1]);
704
                        _sum2 = __msa_fill_w_f32(bias_data[p * 4 + 2]);
705
                        _sum3 = __msa_fill_w_f32(bias_data[p * 4 + 3]);
706
                    }
707

708
                    int i = 0;
709
                    for (; i < num_input; i++)
710
                    {
711
                        __builtin_prefetch(m + 16);
712
                        __builtin_prefetch(kptr + 16);
713
                        v4f32 _val = (v4f32)__msa_ld_w(m, 0);
714
                        v4i32 _w = (v4i32)__msa_fexupr_w(__msa_ld_h(kptr, 0));
715
                        _sum0 = __msa_fmadd_w(_sum0, _val, (v4f32)__msa_splati_w(_w, 0));
716
                        _sum1 = __msa_fmadd_w(_sum1, _val, (v4f32)__msa_splati_w(_w, 1));
717
                        _sum2 = __msa_fmadd_w(_sum2, _val, (v4f32)__msa_splati_w(_w, 2));
718
                        _sum3 = __msa_fmadd_w(_sum3, _val, (v4f32)__msa_splati_w(_w, 3));
719

720
                        m += 4;
721
                        kptr += 4;
722
                    }
723

724
                    _sum0 = activation_ps(_sum0, activation_type, activation_params);
725
                    _sum1 = activation_ps(_sum1, activation_type, activation_params);
726
                    _sum2 = activation_ps(_sum2, activation_type, activation_params);
727
                    _sum3 = activation_ps(_sum3, activation_type, activation_params);
728

729
                    __msa_st_w((v4i32)_sum0, outptr, 0);
730
                    __msa_st_w((v4i32)_sum1, outptr + 4, 0);
731
                    __msa_st_w((v4i32)_sum2, outptr + 8, 0);
732
                    __msa_st_w((v4i32)_sum3, outptr + 12, 0);
733
                    outptr += 16;
734
                }
735
            }
736

737
            if (elempack == 1 && num_output_elempack == 4)
738
            {
739
                float* outptr = top_blob.row(j);
740

741
                for (int p = 0; p < num_output / num_output_elempack; p++)
742
                {
743
                    const unsigned short* kptr = weight_data_tm.row<const unsigned short>(p);
744
                    const float* m = bottom_blob.row(j);
745

746
                    v4f32 _sum0 = (v4f32)__msa_fill_w(0);
747
                    v4f32 _sum1 = (v4f32)__msa_fill_w(0);
748
                    v4f32 _sum2 = (v4f32)__msa_fill_w(0);
749
                    v4f32 _sum3 = (v4f32)__msa_fill_w(0);
750

751
                    if (bias_term)
752
                    {
753
                        _sum0 = (v4f32)__msa_ld_w((const float*)bias_data + p * 4, 0);
754
                    }
755

756
                    int i = 0;
757
                    for (; i + 3 < num_input; i += 4)
758
                    {
759
                        __builtin_prefetch(m + 16);
760
                        __builtin_prefetch(kptr + 64);
761
                        v4i32 _val = __msa_ld_w(m, 0);
762
                        v8i16 _w01 = __msa_ld_h(kptr, 0);
763
                        v8i16 _w23 = __msa_ld_h(kptr + 8, 0);
764
                        v4f32 _w0 = __msa_fexupr_w(_w01);
765
                        v4f32 _w1 = __msa_fexupl_w(_w01);
766
                        v4f32 _w2 = __msa_fexupr_w(_w23);
767
                        v4f32 _w3 = __msa_fexupl_w(_w23);
768
                        _sum0 = __msa_fmadd_w(_sum0, (v4f32)__msa_splati_w(_val, 0), _w0);
769
                        _sum1 = __msa_fmadd_w(_sum1, (v4f32)__msa_splati_w(_val, 1), _w1);
770
                        _sum2 = __msa_fmadd_w(_sum2, (v4f32)__msa_splati_w(_val, 2), _w2);
771
                        _sum3 = __msa_fmadd_w(_sum3, (v4f32)__msa_splati_w(_val, 3), _w3);
772

773
                        m += 4;
774
                        kptr += 16;
775
                    }
776
                    for (; i < num_input; i++)
777
                    {
778
                        v4f32 _val = __msa_fill_w_f32(m[0]);
779
                        v4f32 _w = __msa_fexupr_w(__msa_ld_h(kptr, 0));
780
                        _sum0 = __msa_fmadd_w(_sum0, _val, _w);
781

782
                        m += 1;
783
                        kptr += 4;
784
                    }
785

786
                    _sum0 = __msa_fadd_w(_sum0, _sum1);
787
                    _sum2 = __msa_fadd_w(_sum2, _sum3);
788
                    _sum0 = __msa_fadd_w(_sum0, _sum2);
789

790
                    _sum0 = activation_ps(_sum0, activation_type, activation_params);
791

792
                    __msa_st_w((v4i32)_sum0, outptr, 0);
793
                    outptr += 4;
794
                }
795
            }
796

797
            if (elempack == 4 && num_output_elempack == 1)
798
            {
799
                float* outptr = top_blob.row(j);
800

801
                for (int p = 0; p < num_output; p++)
802
                {
803
                    const unsigned short* kptr = weight_data_tm.row<const unsigned short>(p);
804
                    const float* m = bottom_blob.row(j);
805

806
                    v4f32 _sum = (v4f32)__msa_fill_w(0);
807

808
                    if (bias_term)
809
                    {
810
                        _sum = __msa_fill_w_f32(bias_data[p]);
811
                    }
812

813
                    for (int i = 0; i < num_input; i++)
814
                    {
815
                        __builtin_prefetch(m + 16);
816
                        __builtin_prefetch(kptr + 4);
817
                        v4f32 _val = (v4f32)__msa_ld_w(m, 0);
818
                        v4f32 _k = __msa_fill_w_f32(float16_to_float32(kptr[0]));
819
                        _sum = __msa_fmadd_w(_sum, _val, _k);
820

821
                        m += 4;
822
                        kptr += 1;
823
                    }
824

825
                    _sum = activation_ps(_sum, activation_type, activation_params);
826

827
                    __msa_st_w((v4i32)_sum, outptr, 0);
828
                    outptr += 4;
829
                }
830
            }
831

832
            if (elempack == 1 && num_output_elempack == 1)
833
            {
834
                float* outptr = top_blob.row(j);
835

836
                for (int p = 0; p < num_output; p++)
837
                {
838
                    const unsigned short* kptr = weight_data_tm.row<const unsigned short>(p);
839
                    const float* m = bottom_blob.row(j);
840

841
                    float sum = 0.f;
842

843
                    if (bias_term)
844
                    {
845
                        sum = bias_data[p];
846
                    }
847

848
                    int i = 0;
849
                    v4f32 _sum = (v4f32)__msa_fill_w(0);
850
                    for (; i + 3 < num_input; i += 4)
851
                    {
852
                        __builtin_prefetch(m + 16);
853
                        __builtin_prefetch(kptr + 16);
854
                        v4f32 _m = (v4f32)__msa_ld_w(m, 0);
855
                        v4f32 _w = __msa_fexupr_w(__msa_ld_h(kptr, 0));
856
                        _sum = __msa_fmadd_w(_sum, _m, _w);
857

858
                        m += 4;
859
                        kptr += 4;
860
                    }
861
                    sum += __msa_reduce_fadd_w(_sum);
862
                    for (; i < num_input; i++)
863
                    {
864
                        sum += *m * float16_to_float32(*kptr);
865

866
                        m += 1;
867
                        kptr += 1;
868
                    }
869

870
                    sum = activation_ss(sum, activation_type, activation_params);
871

872
                    outptr[0] = sum;
873
                    outptr += 1;
874
                }
875
            }
876
        }
877

878
        return 0;
879
    }
880

881
    // flatten
882
    Mat bottom_blob_flattened = bottom_blob;
883
    if (bottom_blob.dims != 1)
884
    {
885
        Option opt_flatten = opt;
886
        opt_flatten.blob_allocator = opt.workspace_allocator;
887

888
        flatten->forward(bottom_blob, bottom_blob_flattened, opt_flatten);
889
    }
890

891
    size_t elemsize = bottom_blob_flattened.elemsize;
892
    int elempack = bottom_blob_flattened.elempack;
893

894
    int out_elempack = 1;
895
    if (opt.use_packing_layout)
896
    {
897
        out_elempack = num_output % 4 == 0 ? 4 : 1;
898
    }
899
    size_t out_elemsize = elemsize / elempack * out_elempack;
900

901
    top_blob.create(num_output / out_elempack, out_elemsize, out_elempack, opt.blob_allocator);
902
    if (top_blob.empty())
903
        return -100;
904

905
    if (out_elempack == 4)
906
    {
907
        #pragma omp parallel for num_threads(opt.num_threads)
908
        for (int p = 0; p < num_output / out_elempack; p++)
909
        {
910
            v4f32 _sum0 = (v4f32)__msa_fill_w(0);
911
            v4f32 _sum1 = (v4f32)__msa_fill_w(0);
912
            v4f32 _sum2 = (v4f32)__msa_fill_w(0);
913
            v4f32 _sum3 = (v4f32)__msa_fill_w(0);
914

915
            if (bias_term)
916
            {
917
                _sum0 = (v4f32)__msa_ld_w((const float*)bias_data + p * 4, 0);
918
            }
919

920
            const unsigned short* kptr = weight_data_tm.row<const unsigned short>(p);
921

922
            const float* sptr = bottom_blob_flattened;
923

924
            int i = 0;
925
            for (; i + 3 < num_input; i += 4)
926
            {
927
                __builtin_prefetch(sptr + 16);
928
                __builtin_prefetch(kptr + 64);
929
                v4i32 _val = __msa_ld_w(sptr, 0);
930
                v8i16 _w01 = __msa_ld_h(kptr, 0);
931
                v8i16 _w23 = __msa_ld_h(kptr + 8, 0);
932
                v4f32 _w0 = __msa_fexupr_w(_w01);
933
                v4f32 _w1 = __msa_fexupl_w(_w01);
934
                v4f32 _w2 = __msa_fexupr_w(_w23);
935
                v4f32 _w3 = __msa_fexupl_w(_w23);
936
                _sum0 = __msa_fmadd_w(_sum0, (v4f32)__msa_splati_w(_val, 0), _w0);
937
                _sum1 = __msa_fmadd_w(_sum1, (v4f32)__msa_splati_w(_val, 1), _w1);
938
                _sum2 = __msa_fmadd_w(_sum2, (v4f32)__msa_splati_w(_val, 2), _w2);
939
                _sum3 = __msa_fmadd_w(_sum3, (v4f32)__msa_splati_w(_val, 3), _w3);
940

941
                sptr += 4;
942
                kptr += 16;
943
            }
944
            for (; i < num_input; i++)
945
            {
946
                v4f32 _val = __msa_fill_w_f32(sptr[0]);
947
                v4f32 _w = __msa_fexupr_w(__msa_ld_h(kptr, 0));
948
                _sum0 = __msa_fmadd_w(_sum0, _val, _w);
949

950
                sptr += 1;
951
                kptr += 4;
952
            }
953

954
            _sum0 = __msa_fadd_w(_sum0, _sum1);
955
            _sum2 = __msa_fadd_w(_sum2, _sum3);
956
            _sum0 = __msa_fadd_w(_sum0, _sum2);
957

958
            _sum0 = activation_ps(_sum0, activation_type, activation_params);
959

960
            float* outptr = top_blob;
961
            __msa_st_w((v4i32)_sum0, outptr + p * 4, 0);
962
        }
963
    }
964

965
    if (out_elempack == 1)
966
    {
967
        int nn_num_output = num_output / 4;
968
        int remain_num_output_start = nn_num_output * 4;
969

970
        #pragma omp parallel for num_threads(opt.num_threads)
971
        for (int pp = 0; pp < nn_num_output; pp++)
972
        {
973
            int p = pp * 4;
974

975
            float sum0 = 0.f;
976
            float sum1 = 0.f;
977
            float sum2 = 0.f;
978
            float sum3 = 0.f;
979

980
            if (bias_term)
981
            {
982
                sum0 = bias_data[p];
983
                sum1 = bias_data[p + 1];
984
                sum2 = bias_data[p + 2];
985
                sum3 = bias_data[p + 3];
986
            }
987

988
            const unsigned short* w0 = weight_data_tm.row<const unsigned short>(p);
989
            const unsigned short* w1 = weight_data_tm.row<const unsigned short>(p + 1);
990
            const unsigned short* w2 = weight_data_tm.row<const unsigned short>(p + 2);
991
            const unsigned short* w3 = weight_data_tm.row<const unsigned short>(p + 3);
992

993
            const float* m = bottom_blob_flattened;
994

995
            int i = 0;
996
            v4f32 _sum0 = (v4f32)__msa_fill_w(0);
997
            v4f32 _sum1 = (v4f32)__msa_fill_w(0);
998
            v4f32 _sum2 = (v4f32)__msa_fill_w(0);
999
            v4f32 _sum3 = (v4f32)__msa_fill_w(0);
1000
            for (; i + 3 < num_input; i += 4)
1001
            {
1002
                __builtin_prefetch(m + 16);
1003
                __builtin_prefetch(w0 + 16);
1004
                __builtin_prefetch(w1 + 16);
1005
                __builtin_prefetch(w2 + 16);
1006
                __builtin_prefetch(w3 + 16);
1007
                v4f32 _m = (v4f32)__msa_ld_w(m, 0);
1008
                v4f32 _w0 = __msa_fexupr_w(__msa_ld_h(w0, 0));
1009
                v4f32 _w1 = __msa_fexupr_w(__msa_ld_h(w1, 0));
1010
                v4f32 _w2 = __msa_fexupr_w(__msa_ld_h(w2, 0));
1011
                v4f32 _w3 = __msa_fexupr_w(__msa_ld_h(w3, 0));
1012
                _sum0 = __msa_fmadd_w(_sum0, _m, _w0);
1013
                _sum1 = __msa_fmadd_w(_sum1, _m, _w1);
1014
                _sum2 = __msa_fmadd_w(_sum2, _m, _w2);
1015
                _sum3 = __msa_fmadd_w(_sum3, _m, _w3);
1016

1017
                m += 4;
1018
                w0 += 4;
1019
                w1 += 4;
1020
                w2 += 4;
1021
                w3 += 4;
1022
            }
1023
            for (; i < num_input; i++)
1024
            {
1025
                sum0 += *m * float16_to_float32(*w0);
1026
                sum1 += *m * float16_to_float32(*w1);
1027
                sum2 += *m * float16_to_float32(*w2);
1028
                sum3 += *m * float16_to_float32(*w3);
1029

1030
                m++;
1031
                w0++;
1032
                w1++;
1033
                w2++;
1034
                w3++;
1035
            }
1036

1037
            sum0 += __msa_reduce_fadd_w(_sum0);
1038
            sum1 += __msa_reduce_fadd_w(_sum1);
1039
            sum2 += __msa_reduce_fadd_w(_sum2);
1040
            sum3 += __msa_reduce_fadd_w(_sum3);
1041

1042
            sum0 = activation_ss(sum0, activation_type, activation_params);
1043
            sum1 = activation_ss(sum1, activation_type, activation_params);
1044
            sum2 = activation_ss(sum2, activation_type, activation_params);
1045
            sum3 = activation_ss(sum3, activation_type, activation_params);
1046

1047
            top_blob[p] = sum0;
1048
            top_blob[p + 1] = sum1;
1049
            top_blob[p + 2] = sum2;
1050
            top_blob[p + 3] = sum3;
1051
        }
1052

1053
        // num_output
1054
        #pragma omp parallel for num_threads(opt.num_threads)
1055
        for (int p = remain_num_output_start; p < num_output; p++)
1056
        {
1057
            float sum = 0.f;
1058

1059
            if (bias_term)
1060
                sum = bias_data[p];
1061

1062
            const unsigned short* w = weight_data_tm.row<const unsigned short>(p);
1063

1064
            const float* m = bottom_blob_flattened;
1065

1066
            int i = 0;
1067
            v4f32 _sum0 = (v4f32)__msa_fill_w(0);
1068
            for (; i + 3 < num_input; i += 4)
1069
            {
1070
                __builtin_prefetch(m + 16);
1071
                __builtin_prefetch(w + 16);
1072
                v4f32 _m = (v4f32)__msa_ld_w(m, 0);
1073
                v4f32 _w = __msa_fexupr_w(__msa_ld_h(w, 0));
1074
                _sum0 = __msa_fmadd_w(_sum0, _m, _w);
1075

1076
                m += 4;
1077
                w += 4;
1078
            }
1079
            sum += __msa_reduce_fadd_w(_sum0);
1080
            for (; i < num_input; i++)
1081
            {
1082
                sum += *m * float16_to_float32(*w);
1083

1084
                m++;
1085
                w++;
1086
            }
1087

1088
            sum = activation_ss(sum, activation_type, activation_params);
1089

1090
            top_blob[p] = sum;
1091
        }
1092
    }
1093

1094
    return 0;
1095
}
1096
#endif // __mips_msa
1097

1098
#if NCNN_INT8
1099
int InnerProduct_mips::create_pipeline_int8_mips(const Option& opt)
1100
{
1101
    const int num_input = weight_data_size / num_output;
1102

1103
    int out_elempack = 1;
1104
#if __mips_msa
1105
    if (opt.use_packing_layout)
1106
    {
1107
        out_elempack = num_output % 8 == 0 ? 8 : 1;
1108
    }
1109
#endif // __mips_msa
1110

1111
    // src = inch-outch
1112
    // dst = pb-inch-outch/pb
1113
    {
1114
        Mat weight_data_r2 = weight_data.reshape(num_input, num_output);
1115

1116
        weight_data_tm.create(num_input, num_output / out_elempack, (size_t)out_elempack, out_elempack);
1117

1118
        for (int q = 0; q + (out_elempack - 1) < num_output; q += out_elempack)
1119
        {
1120
            signed char* g0 = weight_data_tm.row<signed char>(q / out_elempack);
1121

1122
            for (int p = 0; p < num_input; p++)
1123
            {
1124
                for (int j = 0; j < out_elempack; j++)
1125
                {
1126
                    *g0++ = weight_data_r2.row<signed char>(q + j)[p];
1127
                }
1128
            }
1129
        }
1130
    }
1131

1132
    scale_in_data.create(num_output);
1133
    for (int p = 0; p < num_output; p++)
1134
    {
1135
        // dequantize
1136
        float scale_in;
1137
        if (weight_data_int8_scales[p] == 0)
1138
            scale_in = 0;
1139
        else
1140
            scale_in = 1.f / (bottom_blob_int8_scales[0] * weight_data_int8_scales[p]);
1141

1142
        scale_in_data[p] = scale_in;
1143
    }
1144

1145
    if (opt.lightmode)
1146
        weight_data.release();
1147

1148
    return 0;
1149
}
1150

1151
int InnerProduct_mips::forward_int8_mips(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const
1152
{
1153
    const int num_input = weight_data_size / num_output;
1154

1155
    int elembits = bottom_blob.elembits();
1156

1157
    Mat bottom_blob_int8 = bottom_blob;
1158
    if (elembits != 8)
1159
    {
1160
        Option opt_q = opt;
1161
        opt_q.blob_allocator = opt.workspace_allocator;
1162
        quantize_to_int8(bottom_blob, bottom_blob_int8, bottom_blob_int8_scales, opt_q);
1163
    }
1164

1165
    if (bottom_blob_int8.dims == 2 && bottom_blob_int8.w == num_input)
1166
    {
1167
        // gemm
1168
        Mat bottom_blob_int8_unpacked;
1169
        Option opt_unpack = opt;
1170
        opt_unpack.blob_allocator = opt.workspace_allocator;
1171
        convert_packing(bottom_blob_int8, bottom_blob_int8_unpacked, 1, opt_unpack);
1172

1173
        int h = bottom_blob_int8_unpacked.h;
1174

1175
        int out_elempack = 1;
1176
#if __mips_msa
1177
        if (opt.use_packing_layout)
1178
        {
1179
            out_elempack = h % 4 == 0 ? 4 : 1;
1180
        }
1181
#endif
1182

1183
        int outh = h / out_elempack;
1184

1185
        top_blob.create(num_output, outh, (size_t)(4u * out_elempack), out_elempack, opt.blob_allocator);
1186
        if (top_blob.empty())
1187
            return -100;
1188

1189
        int num_output_elempack = 1;
1190
#if __mips_msa
1191
        if (opt.use_packing_layout)
1192
        {
1193
            num_output_elempack = num_output % 8 == 0 ? 8 : 1;
1194
        }
1195
#endif
1196

1197
#if __mips_msa
1198
        if (num_output_elempack == 8 && out_elempack == 4)
1199
        {
1200
            #pragma omp parallel for num_threads(opt.num_threads)
1201
            for (int j = 0; j < outh; j++)
1202
            {
1203
                float* outptr = top_blob.row(j);
1204

1205
                for (int p = 0; p < num_output / num_output_elempack; p++)
1206
                {
1207
                    const signed char* kptr = weight_data_tm.row<const signed char>(p);
1208
                    const signed char* m0 = bottom_blob_int8_unpacked.row<const signed char>(j * 4);
1209
                    const signed char* m1 = bottom_blob_int8_unpacked.row<const signed char>(j * 4 + 1);
1210
                    const signed char* m2 = bottom_blob_int8_unpacked.row<const signed char>(j * 4 + 2);
1211
                    const signed char* m3 = bottom_blob_int8_unpacked.row<const signed char>(j * 4 + 3);
1212

1213
                    v4i32 _sum00 = __msa_fill_w(0);
1214
                    v4i32 _sum01 = __msa_fill_w(0);
1215
                    v4i32 _sum10 = __msa_fill_w(0);
1216
                    v4i32 _sum11 = __msa_fill_w(0);
1217
                    v4i32 _sum20 = __msa_fill_w(0);
1218
                    v4i32 _sum21 = __msa_fill_w(0);
1219
                    v4i32 _sum30 = __msa_fill_w(0);
1220
                    v4i32 _sum31 = __msa_fill_w(0);
1221

1222
                    int i = 0;
1223
                    for (; i < num_input; i++)
1224
                    {
1225
                        __builtin_prefetch(m0 + 4);
1226
                        __builtin_prefetch(m1 + 4);
1227
                        __builtin_prefetch(m2 + 4);
1228
                        __builtin_prefetch(m3 + 4);
1229
                        __builtin_prefetch(kptr + 32);
1230
                        v8i16 _val0 = __msa_fill_h((short)m0[0]);
1231
                        v8i16 _val1 = __msa_fill_h((short)m1[0]);
1232
                        v8i16 _val2 = __msa_fill_h((short)m2[0]);
1233
                        v8i16 _val3 = __msa_fill_h((short)m3[0]);
1234

1235
                        v16i8 _w = __msa_ld_b(kptr, 0);
1236
                        v8i16 _w16 = (v8i16)__msa_ilvr_b(__msa_clti_s_b(_w, 0), _w);
1237

1238
                        v8i16 _s0 = __msa_mulv_h(_val0, _w16);
1239
                        v8i16 _s1 = __msa_mulv_h(_val1, _w16);
1240
                        v8i16 _s2 = __msa_mulv_h(_val2, _w16);
1241
                        v8i16 _s3 = __msa_mulv_h(_val3, _w16);
1242
                        v8i16 _exts0 = __msa_clti_s_h(_s0, 0);
1243
                        v8i16 _exts1 = __msa_clti_s_h(_s1, 0);
1244
                        v8i16 _exts2 = __msa_clti_s_h(_s2, 0);
1245
                        v8i16 _exts3 = __msa_clti_s_h(_s3, 0);
1246
                        v4i32 _s0l = (v4i32)__msa_ilvr_h(_exts0, _s0);
1247
                        v4i32 _s0h = (v4i32)__msa_ilvl_h(_exts0, _s0);
1248
                        v4i32 _s1l = (v4i32)__msa_ilvr_h(_exts1, _s1);
1249
                        v4i32 _s1h = (v4i32)__msa_ilvl_h(_exts1, _s1);
1250
                        v4i32 _s2l = (v4i32)__msa_ilvr_h(_exts2, _s2);
1251
                        v4i32 _s2h = (v4i32)__msa_ilvl_h(_exts2, _s2);
1252
                        v4i32 _s3l = (v4i32)__msa_ilvr_h(_exts3, _s3);
1253
                        v4i32 _s3h = (v4i32)__msa_ilvl_h(_exts3, _s3);
1254

1255
                        _sum00 = __msa_addv_w(_sum00, _s0l);
1256
                        _sum01 = __msa_addv_w(_sum01, _s0h);
1257
                        _sum10 = __msa_addv_w(_sum10, _s1l);
1258
                        _sum11 = __msa_addv_w(_sum11, _s1h);
1259
                        _sum20 = __msa_addv_w(_sum20, _s2l);
1260
                        _sum21 = __msa_addv_w(_sum21, _s2h);
1261
                        _sum30 = __msa_addv_w(_sum30, _s3l);
1262
                        _sum31 = __msa_addv_w(_sum31, _s3h);
1263

1264
                        m0++;
1265
                        m1++;
1266
                        m2++;
1267
                        m3++;
1268
                        kptr += 8;
1269
                    }
1270

1271
                    // dequantize and relu
1272
                    v4f32 _scale_in0 = (v4f32)__msa_ld_w((const float*)scale_in_data + p * 8, 0);
1273
                    v4f32 _scale_in1 = (v4f32)__msa_ld_w((const float*)scale_in_data + p * 8 + 4, 0);
1274

1275
                    v4f32 _sumfp32_00 = (v4f32)__msa_ffint_s_w(_sum00);
1276
                    v4f32 _sumfp32_01 = (v4f32)__msa_ffint_s_w(_sum01);
1277
                    v4f32 _sumfp32_10 = (v4f32)__msa_ffint_s_w(_sum10);
1278
                    v4f32 _sumfp32_11 = (v4f32)__msa_ffint_s_w(_sum11);
1279
                    v4f32 _sumfp32_20 = (v4f32)__msa_ffint_s_w(_sum20);
1280
                    v4f32 _sumfp32_21 = (v4f32)__msa_ffint_s_w(_sum21);
1281
                    v4f32 _sumfp32_30 = (v4f32)__msa_ffint_s_w(_sum30);
1282
                    v4f32 _sumfp32_31 = (v4f32)__msa_ffint_s_w(_sum31);
1283
                    if (bias_term)
1284
                    {
1285
                        v4f32 _bias0 = (v4f32)__msa_ld_w((const float*)bias_data + p * 8, 0);
1286
                        v4f32 _bias1 = (v4f32)__msa_ld_w((const float*)bias_data + p * 8 + 4, 0);
1287
                        _sumfp32_00 = __msa_fmadd_w(_bias0, _sumfp32_00, _scale_in0);
1288
                        _sumfp32_01 = __msa_fmadd_w(_bias1, _sumfp32_01, _scale_in1);
1289
                        _sumfp32_10 = __msa_fmadd_w(_bias0, _sumfp32_10, _scale_in0);
1290
                        _sumfp32_11 = __msa_fmadd_w(_bias1, _sumfp32_11, _scale_in1);
1291
                        _sumfp32_20 = __msa_fmadd_w(_bias0, _sumfp32_20, _scale_in0);
1292
                        _sumfp32_21 = __msa_fmadd_w(_bias1, _sumfp32_21, _scale_in1);
1293
                        _sumfp32_30 = __msa_fmadd_w(_bias0, _sumfp32_30, _scale_in0);
1294
                        _sumfp32_31 = __msa_fmadd_w(_bias1, _sumfp32_31, _scale_in1);
1295
                    }
1296
                    else
1297
                    {
1298
                        _sumfp32_00 = __msa_fmul_w(_sumfp32_00, _scale_in0);
1299
                        _sumfp32_01 = __msa_fmul_w(_sumfp32_01, _scale_in1);
1300
                        _sumfp32_10 = __msa_fmul_w(_sumfp32_10, _scale_in0);
1301
                        _sumfp32_11 = __msa_fmul_w(_sumfp32_11, _scale_in1);
1302
                        _sumfp32_20 = __msa_fmul_w(_sumfp32_20, _scale_in0);
1303
                        _sumfp32_21 = __msa_fmul_w(_sumfp32_21, _scale_in1);
1304
                        _sumfp32_30 = __msa_fmul_w(_sumfp32_30, _scale_in0);
1305
                        _sumfp32_31 = __msa_fmul_w(_sumfp32_31, _scale_in1);
1306
                    }
1307

1308
                    _sumfp32_00 = activation_ps(_sumfp32_00, activation_type, activation_params);
1309
                    _sumfp32_01 = activation_ps(_sumfp32_01, activation_type, activation_params);
1310
                    _sumfp32_10 = activation_ps(_sumfp32_10, activation_type, activation_params);
1311
                    _sumfp32_11 = activation_ps(_sumfp32_11, activation_type, activation_params);
1312
                    _sumfp32_20 = activation_ps(_sumfp32_20, activation_type, activation_params);
1313
                    _sumfp32_21 = activation_ps(_sumfp32_21, activation_type, activation_params);
1314
                    _sumfp32_30 = activation_ps(_sumfp32_30, activation_type, activation_params);
1315
                    _sumfp32_31 = activation_ps(_sumfp32_31, activation_type, activation_params);
1316

1317
                    // transpose 4x8
1318
                    v4i32 _r01r = __msa_ilvr_w((v4i32)_sumfp32_10, (v4i32)_sumfp32_00);
1319
                    v4i32 _r01l = __msa_ilvl_w((v4i32)_sumfp32_10, (v4i32)_sumfp32_00);
1320
                    v4i32 _r23r = __msa_ilvr_w((v4i32)_sumfp32_30, (v4i32)_sumfp32_20);
1321
                    v4i32 _r23l = __msa_ilvl_w((v4i32)_sumfp32_30, (v4i32)_sumfp32_20);
1322
                    v4i32 _r45r = __msa_ilvr_w((v4i32)_sumfp32_11, (v4i32)_sumfp32_01);
1323
                    v4i32 _r45l = __msa_ilvl_w((v4i32)_sumfp32_11, (v4i32)_sumfp32_01);
1324
                    v4i32 _r67r = __msa_ilvr_w((v4i32)_sumfp32_31, (v4i32)_sumfp32_21);
1325
                    v4i32 _r67l = __msa_ilvl_w((v4i32)_sumfp32_31, (v4i32)_sumfp32_21);
1326
                    _sumfp32_00 = (v4f32)__msa_ilvr_d((v2i64)_r23r, (v2i64)_r01r);
1327
                    _sumfp32_10 = (v4f32)__msa_ilvl_d((v2i64)_r23r, (v2i64)_r01r);
1328
                    _sumfp32_20 = (v4f32)__msa_ilvr_d((v2i64)_r23l, (v2i64)_r01l);
1329
                    _sumfp32_30 = (v4f32)__msa_ilvl_d((v2i64)_r23l, (v2i64)_r01l);
1330
                    _sumfp32_01 = (v4f32)__msa_ilvr_d((v2i64)_r67r, (v2i64)_r45r);
1331
                    _sumfp32_11 = (v4f32)__msa_ilvl_d((v2i64)_r67r, (v2i64)_r45r);
1332
                    _sumfp32_21 = (v4f32)__msa_ilvr_d((v2i64)_r67l, (v2i64)_r45l);
1333
                    _sumfp32_31 = (v4f32)__msa_ilvl_d((v2i64)_r67l, (v2i64)_r45l);
1334

1335
                    __msa_st_w((v4i32)_sumfp32_00, outptr, 0);
1336
                    __msa_st_w((v4i32)_sumfp32_10, outptr + 4, 0);
1337
                    __msa_st_w((v4i32)_sumfp32_20, outptr + 8, 0);
1338
                    __msa_st_w((v4i32)_sumfp32_30, outptr + 12, 0);
1339
                    __msa_st_w((v4i32)_sumfp32_01, outptr + 16, 0);
1340
                    __msa_st_w((v4i32)_sumfp32_11, outptr + 20, 0);
1341
                    __msa_st_w((v4i32)_sumfp32_21, outptr + 24, 0);
1342
                    __msa_st_w((v4i32)_sumfp32_31, outptr + 28, 0);
1343

1344
                    outptr += 32;
1345
                }
1346
            }
1347
        }
1348

1349
        if (num_output_elempack == 1 && out_elempack == 4)
1350
        {
1351
            #pragma omp parallel for num_threads(opt.num_threads)
1352
            for (int j = 0; j < outh; j++)
1353
            {
1354
                float* outptr = top_blob.row(j);
1355

1356
                for (int p = 0; p < num_output; p++)
1357
                {
1358
                    const signed char* kptr = weight_data_tm.row<const signed char>(p);
1359
                    const signed char* m0 = bottom_blob_int8_unpacked.row<const signed char>(j * 4);
1360
                    const signed char* m1 = bottom_blob_int8_unpacked.row<const signed char>(j * 4 + 1);
1361
                    const signed char* m2 = bottom_blob_int8_unpacked.row<const signed char>(j * 4 + 2);
1362
                    const signed char* m3 = bottom_blob_int8_unpacked.row<const signed char>(j * 4 + 3);
1363

1364
                    int sum0 = 0;
1365
                    int sum1 = 0;
1366
                    int sum2 = 0;
1367
                    int sum3 = 0;
1368

1369
                    int i = 0;
1370
                    for (; i < num_input; i++)
1371
                    {
1372
                        sum0 += *m0++ * kptr[0];
1373
                        sum1 += *m1++ * kptr[0];
1374
                        sum2 += *m2++ * kptr[0];
1375
                        sum3 += *m3++ * kptr[0];
1376
                        kptr += 1;
1377
                    }
1378

1379
                    // dequantize and relu
1380
                    float sumfp32_0 = sum0 * scale_in_data[p];
1381
                    float sumfp32_1 = sum1 * scale_in_data[p];
1382
                    float sumfp32_2 = sum2 * scale_in_data[p];
1383
                    float sumfp32_3 = sum3 * scale_in_data[p];
1384

1385
                    if (bias_term)
1386
                    {
1387
                        sumfp32_0 += bias_data[p];
1388
                        sumfp32_1 += bias_data[p];
1389
                        sumfp32_2 += bias_data[p];
1390
                        sumfp32_3 += bias_data[p];
1391
                    }
1392

1393
                    outptr[0] = activation_ss(sumfp32_0, activation_type, activation_params);
1394
                    outptr[1] = activation_ss(sumfp32_1, activation_type, activation_params);
1395
                    outptr[2] = activation_ss(sumfp32_2, activation_type, activation_params);
1396
                    outptr[3] = activation_ss(sumfp32_3, activation_type, activation_params);
1397
                    outptr += 4;
1398
                }
1399
            }
1400
        }
1401

1402
        if (num_output_elempack == 8 && out_elempack == 1)
1403
        {
1404
            #pragma omp parallel for num_threads(opt.num_threads)
1405
            for (int j = 0; j < outh; j++)
1406
            {
1407
                float* outptr = top_blob.row(j);
1408

1409
                for (int p = 0; p < num_output / num_output_elempack; p++)
1410
                {
1411
                    const signed char* kptr = weight_data_tm.row<const signed char>(p);
1412
                    const signed char* m = bottom_blob_int8_unpacked.row<const signed char>(j);
1413

1414
                    v4i32 _sum0 = __msa_fill_w(0);
1415
                    v4i32 _sum1 = __msa_fill_w(0);
1416

1417
                    int i = 0;
1418
                    for (; i < num_input; i++)
1419
                    {
1420
                        __builtin_prefetch(m + 4);
1421
                        __builtin_prefetch(kptr + 32);
1422
                        v8i16 _val = __msa_fill_h((short)m[0]);
1423

1424
                        v16i8 _w = __msa_ld_b(kptr, 0);
1425
                        v8i16 _w16 = (v8i16)__msa_ilvr_b(__msa_clti_s_b(_w, 0), _w);
1426

1427
                        v8i16 _s0 = __msa_mulv_h(_val, _w16);
1428
                        v8i16 _exts0 = __msa_clti_s_h(_s0, 0);
1429
                        v4i32 _s0l = (v4i32)__msa_ilvr_h(_exts0, _s0);
1430
                        v4i32 _s0h = (v4i32)__msa_ilvl_h(_exts0, _s0);
1431

1432
                        _sum0 = __msa_addv_w(_sum0, _s0l);
1433
                        _sum1 = __msa_addv_w(_sum1, _s0h);
1434

1435
                        m++;
1436
                        kptr += 8;
1437
                    }
1438

1439
                    // dequantize and relu
1440
                    v4f32 _scale_in0 = (v4f32)__msa_ld_w((const float*)scale_in_data + p * 8, 0);
1441
                    v4f32 _scale_in1 = (v4f32)__msa_ld_w((const float*)scale_in_data + p * 8 + 4, 0);
1442

1443
                    v4f32 _sumfp32_0 = (v4f32)__msa_ffint_s_w(_sum0);
1444
                    v4f32 _sumfp32_1 = (v4f32)__msa_ffint_s_w(_sum1);
1445

1446
                    if (bias_term)
1447
                    {
1448
                        v4f32 _bias0 = (v4f32)__msa_ld_w((const float*)bias_data + p * 8, 0);
1449
                        v4f32 _bias1 = (v4f32)__msa_ld_w((const float*)bias_data + p * 8 + 4, 0);
1450
                        _sumfp32_0 = __msa_fmadd_w(_bias0, _sumfp32_0, _scale_in0);
1451
                        _sumfp32_1 = __msa_fmadd_w(_bias1, _sumfp32_1, _scale_in1);
1452
                    }
1453
                    else
1454
                    {
1455
                        _sumfp32_0 = __msa_fmul_w(_sumfp32_0, _scale_in0);
1456
                        _sumfp32_1 = __msa_fmul_w(_sumfp32_1, _scale_in1);
1457
                    }
1458

1459
                    _sumfp32_0 = activation_ps(_sumfp32_0, activation_type, activation_params);
1460
                    _sumfp32_1 = activation_ps(_sumfp32_1, activation_type, activation_params);
1461

1462
                    __msa_st_w((v4i32)_sumfp32_0, outptr, 0);
1463
                    __msa_st_w((v4i32)_sumfp32_1, outptr + 4, 0);
1464
                    outptr += 8;
1465
                }
1466
            }
1467
        }
1468
#endif // __mips_msa
1469

1470
        if (num_output_elempack == 1 && out_elempack == 1)
1471
        {
1472
            #pragma omp parallel for num_threads(opt.num_threads)
1473
            for (int j = 0; j < outh; j++)
1474
            {
1475
                float* outptr = top_blob.row(j);
1476

1477
                for (int p = 0; p < num_output; p++)
1478
                {
1479
                    const signed char* kptr = weight_data_tm.row<const signed char>(p);
1480
                    const signed char* m = bottom_blob_int8_unpacked.row<const signed char>(j);
1481

1482
                    int sum = 0;
1483

1484
                    int i = 0;
1485
                    for (; i < num_input; i++)
1486
                    {
1487
                        sum += *m++ * *kptr++;
1488
                    }
1489

1490
                    // dequantize and relu
1491
                    float sumfp32 = sum * scale_in_data[p];
1492

1493
                    if (bias_term)
1494
                        sumfp32 += bias_data[p];
1495

1496
                    outptr[0] = activation_ss(sumfp32, activation_type, activation_params);
1497
                    outptr += 1;
1498
                }
1499
            }
1500
        }
1501

1502
        return 0;
1503
    }
1504

1505
    Mat bottom_blob_int8_flattened = bottom_blob_int8;
1506
    if (bottom_blob_int8.dims != 1)
1507
    {
1508
        Option opt_flatten = opt;
1509
        opt_flatten.blob_allocator = opt.workspace_allocator;
1510
        flatten->forward(bottom_blob_int8, bottom_blob_int8_flattened, opt_flatten);
1511
    }
1512

1513
    //     int elempack = bottom_blob_int8_flattened.elempack;
1514

1515
    int out_elempack = 1;
1516
#if __mips_msa
1517
    if (opt.use_packing_layout)
1518
    {
1519
        out_elempack = num_output % 8 == 0 ? 8 : 1;
1520
    }
1521
#endif // __mips_msa
1522
    //     size_t out_elemsize = elemsize / elempack * out_elempack;
1523

1524
    top_blob.create(num_output / out_elempack, (size_t)(4u * out_elempack), out_elempack, opt.blob_allocator);
1525
    if (top_blob.empty())
1526
        return -100;
1527

1528
#if __mips_msa
1529
    if (out_elempack == 8)
1530
    {
1531
        #pragma omp parallel for num_threads(opt.num_threads)
1532
        for (int p = 0; p < num_output / out_elempack; p++)
1533
        {
1534
            v4i32 _sum0 = __msa_fill_w(0);
1535
            v4i32 _sum1 = __msa_fill_w(0);
1536

1537
            const signed char* kptr = weight_data_tm.row<const signed char>(p);
1538
            const signed char* sptr = bottom_blob_int8_flattened;
1539

1540
            int i = 0;
1541
            for (; i < num_input; i++)
1542
            {
1543
                __builtin_prefetch(sptr + 4);
1544
                __builtin_prefetch(kptr + 32);
1545
                v8i16 _val = __msa_fill_h((short)sptr[0]);
1546

1547
                v16i8 _w = __msa_ld_b(kptr, 0);
1548
                v8i16 _w16 = (v8i16)__msa_ilvr_b(__msa_clti_s_b(_w, 0), _w);
1549

1550
                v8i16 _s0 = __msa_mulv_h(_val, _w16);
1551
                v8i16 _exts0 = __msa_clti_s_h(_s0, 0);
1552
                v4i32 _s0l = (v4i32)__msa_ilvr_h(_exts0, _s0);
1553
                v4i32 _s0h = (v4i32)__msa_ilvl_h(_exts0, _s0);
1554

1555
                _sum0 = __msa_addv_w(_sum0, _s0l);
1556
                _sum1 = __msa_addv_w(_sum1, _s0h);
1557

1558
                sptr += 1;
1559
                kptr += 8;
1560
            }
1561

1562
            // dequantize and relu
1563
            v4f32 _scale_in0 = (v4f32)__msa_ld_w((const float*)scale_in_data + p * 8, 0);
1564
            v4f32 _scale_in1 = (v4f32)__msa_ld_w((const float*)scale_in_data + p * 8 + 4, 0);
1565

1566
            v4f32 _sumfp32_0 = (v4f32)__msa_ffint_s_w(_sum0);
1567
            v4f32 _sumfp32_1 = (v4f32)__msa_ffint_s_w(_sum1);
1568

1569
            if (bias_term)
1570
            {
1571
                v4f32 _bias0 = (v4f32)__msa_ld_w((const float*)bias_data + p * 8, 0);
1572
                v4f32 _bias1 = (v4f32)__msa_ld_w((const float*)bias_data + p * 8 + 4, 0);
1573
                _sumfp32_0 = __msa_fmadd_w(_bias0, _sumfp32_0, _scale_in0);
1574
                _sumfp32_1 = __msa_fmadd_w(_bias1, _sumfp32_1, _scale_in1);
1575
            }
1576
            else
1577
            {
1578
                _sumfp32_0 = __msa_fmul_w(_sumfp32_0, _scale_in0);
1579
                _sumfp32_1 = __msa_fmul_w(_sumfp32_1, _scale_in1);
1580
            }
1581

1582
            _sumfp32_0 = activation_ps(_sumfp32_0, activation_type, activation_params);
1583
            _sumfp32_1 = activation_ps(_sumfp32_1, activation_type, activation_params);
1584

1585
            float* outptr = (float*)top_blob + p * 8;
1586
            __msa_st_w((v4i32)_sumfp32_0, outptr, 0);
1587
            __msa_st_w((v4i32)_sumfp32_1, outptr + 4, 0);
1588
        }
1589
    }
1590
#endif // __mips_msa
1591

1592
    if (out_elempack == 1)
1593
    {
1594
        #pragma omp parallel for num_threads(opt.num_threads)
1595
        for (int p = 0; p < num_output / out_elempack; p++)
1596
        {
1597
            int sum = 0;
1598

1599
            const signed char* kptr = weight_data_tm.row<const signed char>(p);
1600
            const signed char* sptr = bottom_blob_int8_flattened;
1601

1602
            int i = 0;
1603
            for (; i < num_input; i++)
1604
            {
1605
                signed char val = sptr[0];
1606

1607
                signed char w = kptr[0];
1608

1609
                sum += val * w;
1610

1611
                sptr += 1;
1612
                kptr += 1;
1613
            }
1614

1615
            // dequantize and relu
1616
            float sumfp32 = sum * scale_in_data[p];
1617

1618
            if (bias_term)
1619
                sumfp32 += bias_data[p];
1620

1621
            sumfp32 = activation_ss(sumfp32, activation_type, activation_params);
1622

1623
            top_blob[p] = sumfp32;
1624
        }
1625
    }
1626

1627
    return 0;
1628
}
1629
#endif // NCNN_INT8
1630

1631
} // namespace ncnn
1632

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.