1
// Tencent is pleased to support the open source community by making ncnn available.
3
// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved.
5
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
6
// in compliance with the License. You may obtain a copy of the License at
8
// https://opensource.org/licenses/BSD-3-Clause
10
// Unless required by applicable law or agreed to in writing, software distributed
11
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
12
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
13
// specific language governing permissions and limitations under the License.
15
#include "flatten_mips.h"
19
#include "msa_mathfun.h"
24
Flatten_mips::Flatten_mips()
27
support_packing = true;
31
int Flatten_mips::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const
33
int elembits = bottom_blob.elembits();
36
return forward_int8(bottom_blob, top_blob, opt);
38
int dims = bottom_blob.dims;
42
top_blob = bottom_blob;
46
int w = bottom_blob.w;
47
int h = bottom_blob.h;
48
int d = bottom_blob.d;
49
int channels = bottom_blob.c;
50
size_t elemsize = bottom_blob.elemsize;
51
int elempack = bottom_blob.elempack;
54
int total = size * channels * elempack;
58
if (opt.use_packing_layout)
60
out_elempack = total % 4 == 0 ? 4 : 1;
63
size_t out_elemsize = elemsize / elempack * out_elempack;
65
if (out_elempack == 1)
67
return Flatten::forward(bottom_blob, top_blob, opt);
70
if (dims == 2 && elempack == 1) // out_elempack == 4
72
top_blob = bottom_blob;
74
top_blob.w = total / out_elempack;
76
top_blob.cstep = top_blob.w;
77
top_blob.elemsize = out_elemsize;
78
top_blob.elempack = out_elempack;
82
top_blob.create(total / out_elempack, out_elemsize, out_elempack, opt.blob_allocator);
89
if (elempack == 4) // out_elempack == 4
91
#pragma omp parallel for num_threads(opt.num_threads)
92
for (int i = 0; i < h; i++)
94
const float* ptr = bottom_blob.row(i);
95
float* outptr0 = (float*)top_blob + w * i * 4;
96
float* outptr1 = (float*)top_blob + w * (i * 4 + 1);
97
float* outptr2 = (float*)top_blob + w * (i * 4 + 2);
98
float* outptr3 = (float*)top_blob + w * (i * 4 + 3);
101
for (; j + 3 < w; j += 4)
104
v4f32 _r0 = (v4f32)__msa_ld_w(ptr, 0);
105
v4f32 _r1 = (v4f32)__msa_ld_w(ptr + 4, 0);
106
v4f32 _r2 = (v4f32)__msa_ld_w(ptr + 4 * 2, 0);
107
v4f32 _r3 = (v4f32)__msa_ld_w(ptr + 4 * 3, 0);
109
v4i32 _r01r = __msa_ilvr_w((v4i32)_r1, (v4i32)_r0);
110
v4i32 _r01l = __msa_ilvl_w((v4i32)_r1, (v4i32)_r0);
111
v4i32 _r23r = __msa_ilvr_w((v4i32)_r3, (v4i32)_r2);
112
v4i32 _r23l = __msa_ilvl_w((v4i32)_r3, (v4i32)_r2);
113
v2i64 _r0123_0 = __msa_ilvr_d((v2i64)_r23r, (v2i64)_r01r);
114
v2i64 _r0123_1 = __msa_ilvl_d((v2i64)_r23r, (v2i64)_r01r);
115
v2i64 _r0123_2 = __msa_ilvr_d((v2i64)_r23l, (v2i64)_r01l);
116
v2i64 _r0123_3 = __msa_ilvl_d((v2i64)_r23l, (v2i64)_r01l);
118
__msa_st_w((v4i32)_r0123_0, outptr0, 0);
119
__msa_st_w((v4i32)_r0123_1, outptr1, 0);
120
__msa_st_w((v4i32)_r0123_2, outptr2, 0);
121
__msa_st_w((v4i32)_r0123_3, outptr3, 0);
143
if (dims == 3 || dims == 4)
146
if (elempack == 4) // out_elempack == 4
148
#pragma omp parallel for num_threads(opt.num_threads)
149
for (int q = 0; q < channels; q++)
151
const float* ptr = bottom_blob.channel(q);
152
float* outptr0 = (float*)top_blob + size * q * 4;
153
float* outptr1 = (float*)top_blob + size * (q * 4 + 1);
154
float* outptr2 = (float*)top_blob + size * (q * 4 + 2);
155
float* outptr3 = (float*)top_blob + size * (q * 4 + 3);
158
for (; i + 3 < size; i += 4)
161
v4f32 _r0 = (v4f32)__msa_ld_w(ptr, 0);
162
v4f32 _r1 = (v4f32)__msa_ld_w(ptr + 4, 0);
163
v4f32 _r2 = (v4f32)__msa_ld_w(ptr + 4 * 2, 0);
164
v4f32 _r3 = (v4f32)__msa_ld_w(ptr + 4 * 3, 0);
166
v4i32 _r01r = __msa_ilvr_w((v4i32)_r1, (v4i32)_r0);
167
v4i32 _r01l = __msa_ilvl_w((v4i32)_r1, (v4i32)_r0);
168
v4i32 _r23r = __msa_ilvr_w((v4i32)_r3, (v4i32)_r2);
169
v4i32 _r23l = __msa_ilvl_w((v4i32)_r3, (v4i32)_r2);
170
v2i64 _r0123_0 = __msa_ilvr_d((v2i64)_r23r, (v2i64)_r01r);
171
v2i64 _r0123_1 = __msa_ilvl_d((v2i64)_r23r, (v2i64)_r01r);
172
v2i64 _r0123_2 = __msa_ilvr_d((v2i64)_r23l, (v2i64)_r01l);
173
v2i64 _r0123_3 = __msa_ilvl_d((v2i64)_r23l, (v2i64)_r01l);
175
__msa_st_w((v4i32)_r0123_0, outptr0, 0);
176
__msa_st_w((v4i32)_r0123_1, outptr1, 0);
177
__msa_st_w((v4i32)_r0123_2, outptr2, 0);
178
__msa_st_w((v4i32)_r0123_3, outptr3, 0);
186
for (; i < size; i++)
199
if (elempack == 1) // out_elempack == 4
201
#pragma omp parallel for num_threads(opt.num_threads)
202
for (int q = 0; q < channels; q++)
204
const float* ptr = bottom_blob.channel(q);
205
float* outptr = (float*)top_blob + size * q;
209
for (; i + 3 < size; i += 4)
211
__msa_st_w(__msa_ld_w(ptr, 0), outptr, 0);
216
for (; i < size; i++)
227
int Flatten_mips::forward_int8(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const
229
int dims = bottom_blob.dims;
233
top_blob = bottom_blob;
237
int w = bottom_blob.w;
238
int h = bottom_blob.h;
239
int d = bottom_blob.d;
240
int channels = bottom_blob.c;
241
size_t elemsize = bottom_blob.elemsize;
242
int elempack = bottom_blob.elempack;
243
int size = w * h * d;
245
int total = size * channels * elempack;
247
int out_elempack = 1;
249
if (opt.use_packing_layout)
251
out_elempack = total % 8 == 0 ? 8 : 1;
254
size_t out_elemsize = elemsize / elempack * out_elempack;
256
if (out_elempack == 1)
258
return Flatten::forward(bottom_blob, top_blob, opt);
261
if (dims == 2 && elempack == 1) // out_elempack == 8
263
top_blob = bottom_blob;
265
top_blob.w = total / out_elempack;
267
top_blob.cstep = top_blob.w;
268
top_blob.elemsize = out_elemsize;
269
top_blob.elempack = out_elempack;
273
top_blob.create(total / out_elempack, out_elemsize, out_elempack, opt.blob_allocator);
274
if (top_blob.empty())
280
if (elempack == 8) // out_elempack == 8
282
#pragma omp parallel for num_threads(opt.num_threads)
283
for (int i = 0; i < h; i++)
285
const signed char* ptr = bottom_blob.row<signed char>(i);
286
signed char* outptr0 = (signed char*)top_blob + w * i * 8;
287
signed char* outptr1 = (signed char*)top_blob + w * (i * 8 + 1);
288
signed char* outptr2 = (signed char*)top_blob + w * (i * 8 + 2);
289
signed char* outptr3 = (signed char*)top_blob + w * (i * 8 + 3);
290
signed char* outptr4 = (signed char*)top_blob + w * (i * 8 + 4);
291
signed char* outptr5 = (signed char*)top_blob + w * (i * 8 + 5);
292
signed char* outptr6 = (signed char*)top_blob + w * (i * 8 + 6);
293
signed char* outptr7 = (signed char*)top_blob + w * (i * 8 + 7);
314
if (dims == 3 || dims == 4)
317
if (elempack == 8) // out_elempack == 8
319
#pragma omp parallel for num_threads(opt.num_threads)
320
for (int q = 0; q < channels; q++)
322
const signed char* ptr = bottom_blob.channel(q);
323
signed char* outptr0 = (signed char*)top_blob + size * q * 8;
324
signed char* outptr1 = (signed char*)top_blob + size * (q * 8 + 1);
325
signed char* outptr2 = (signed char*)top_blob + size * (q * 8 + 2);
326
signed char* outptr3 = (signed char*)top_blob + size * (q * 8 + 3);
327
signed char* outptr4 = (signed char*)top_blob + size * (q * 8 + 4);
328
signed char* outptr5 = (signed char*)top_blob + size * (q * 8 + 5);
329
signed char* outptr6 = (signed char*)top_blob + size * (q * 8 + 6);
330
signed char* outptr7 = (signed char*)top_blob + size * (q * 8 + 7);
333
for (; i < size; i++)
350
if (elempack == 1) // out_elempack == 8
352
#pragma omp parallel for num_threads(opt.num_threads)
353
for (int q = 0; q < channels; q++)
355
const signed char* ptr = bottom_blob.channel(q);
356
signed char* outptr = (signed char*)top_blob + size * q;
359
for (; i < size; i++)