15
#include "slice_loongarch.h"
19
Slice_loongarch::Slice_loongarch()
22
support_packing = true;
26
int Slice_loongarch::forward(const std::vector<Mat>& bottom_blobs, std::vector<Mat>& top_blobs, const Option& opt) const
28
const Mat& bottom_blob = bottom_blobs[0];
29
int dims = bottom_blob.dims;
30
size_t elemsize = bottom_blob.elemsize;
31
int elempack = bottom_blob.elempack;
32
const int* slices_ptr = slices;
33
const int* indices_ptr = indices;
34
int positive_axis = axis < 0 ? dims + axis : axis;
39
int w = bottom_blob.w * elempack;
41
for (size_t i = 0; i < top_blobs.size(); i++)
46
if (i == top_blobs.size() - 1)
52
int indice = indices_ptr[i];
53
int positive_indice = indice < 0 ? w + indice : indice;
54
slice = positive_indice - q;
59
slice = slices_ptr[i];
62
slice = static_cast<int>((w - q) / (top_blobs.size() - i));
68
if (opt.use_packing_layout)
69
out_elempack = slice % 4 == 0 ? 4 : 1;
71
size_t out_elemsize = elemsize / elempack * out_elempack;
73
Mat& top_blob = top_blobs[i];
74
top_blob.create(slice / out_elempack, out_elemsize, out_elempack, opt.blob_allocator);
78
const float* ptr = (const float*)bottom_blob + q;
79
float* outptr = top_blob;
80
memcpy(outptr, ptr, top_blob.w * top_blob.elemsize);
86
if (dims == 2 && positive_axis == 0)
89
int w = bottom_blob.w;
90
int h = bottom_blob.h * elempack;
93
for (size_t i = 0; i < top_blobs.size(); i++)
98
if (i == top_blobs.size() - 1)
104
int indice = indices_ptr[i];
105
int positive_indice = indice < 0 ? h + indice : indice;
106
slice = positive_indice - q;
111
slice = slices_ptr[i];
114
slice = static_cast<int>((h - q) / (top_blobs.size() - i));
118
int out_elempack = 1;
120
if (opt.use_packing_layout)
121
out_elempack = slice % 4 == 0 ? 4 : 1;
123
size_t out_elemsize = elemsize / elempack * out_elempack;
125
Mat& top_blob = top_blobs[i];
126
top_blob.create(w, slice / out_elempack, out_elemsize, out_elempack, opt.blob_allocator);
127
if (top_blob.empty())
133
size_t out_elemsize = top_blobs[0].elemsize;
134
int out_elempack = top_blobs[0].elempack;
135
for (size_t i = 0; i < top_blobs.size(); i++)
137
out_elemsize = std::min(out_elemsize, top_blobs[i].elemsize);
138
out_elempack = std::min(out_elempack, top_blobs[i].elempack);
141
Mat bottom_blob_unpacked = bottom_blob;
142
if (elempack > out_elempack)
144
convert_packing(bottom_blob, bottom_blob_unpacked, out_elempack, opt);
145
if (bottom_blob_unpacked.empty())
149
const float* ptr = bottom_blob_unpacked;
150
for (size_t i = 0; i < top_blobs.size(); i++)
152
Mat& top_blob = top_blobs[i];
154
if (out_elempack == 1 && top_blob.elempack == 4)
156
for (int j = 0; j < top_blob.h; j++)
158
const float* r0 = ptr;
159
const float* r1 = ptr + w;
160
const float* r2 = ptr + w * 2;
161
const float* r3 = ptr + w * 3;
163
float* outptr0 = top_blob.row(j);
165
for (int j = 0; j < w; j++)
180
int size = w * top_blob.h;
182
float* outptr = top_blob;
183
memcpy(outptr, ptr, size * top_blob.elemsize);
185
ptr += size * top_blob.elempack;
190
if (dims == 2 && positive_axis == 1)
193
int w = bottom_blob.w;
194
int h = bottom_blob.h;
197
for (size_t i = 0; i < top_blobs.size(); i++)
202
if (i == top_blobs.size() - 1)
208
int indice = indices_ptr[i];
209
int positive_indice = indice < 0 ? w + indice : indice;
210
slice = positive_indice - q;
215
slice = slices_ptr[i];
218
slice = static_cast<int>((w - q) / (top_blobs.size() - i));
222
Mat& top_blob = top_blobs[i];
223
top_blob.create(slice, h, elemsize, elempack, opt.blob_allocator);
224
if (top_blob.empty())
230
#pragma omp parallel for num_threads(opt.num_threads)
231
for (int j = 0; j < h; j++)
233
const float* ptr = bottom_blob.row(j);
234
for (size_t i = 0; i < top_blobs.size(); i++)
236
Mat& top_blob = top_blobs[i];
238
float* outptr = top_blob.row(j);
239
memcpy(outptr, ptr, top_blob.w * elemsize);
241
ptr += top_blob.w * elempack;
246
if ((dims == 3 || dims == 4) && positive_axis == 0)
249
int w = bottom_blob.w;
250
int h = bottom_blob.h;
251
int d = bottom_blob.d;
252
int channels = bottom_blob.c * elempack;
255
for (size_t i = 0; i < top_blobs.size(); i++)
260
if (i == top_blobs.size() - 1)
262
slice = channels - q;
266
int indice = indices_ptr[i];
267
int positive_indice = indice < 0 ? channels + indice : indice;
268
slice = positive_indice - q;
273
slice = slices_ptr[i];
276
slice = static_cast<int>((channels - q) / (top_blobs.size() - i));
280
int out_elempack = 1;
282
if (opt.use_packing_layout)
283
out_elempack = slice % 4 == 0 ? 4 : 1;
285
size_t out_elemsize = elemsize / elempack * out_elempack;
287
Mat& top_blob = top_blobs[i];
288
top_blob.create(w, h, d, slice / out_elempack, out_elemsize, out_elempack, opt.blob_allocator);
289
if (top_blob.empty())
292
top_blob.dims = dims;
297
size_t out_elemsize = top_blobs[0].elemsize;
298
int out_elempack = top_blobs[0].elempack;
299
for (size_t i = 0; i < top_blobs.size(); i++)
301
out_elemsize = std::min(out_elemsize, top_blobs[i].elemsize);
302
out_elempack = std::min(out_elempack, top_blobs[i].elempack);
305
Mat bottom_blob_unpacked = bottom_blob;
306
if (elempack > out_elempack)
308
convert_packing(bottom_blob, bottom_blob_unpacked, out_elempack, opt);
309
if (bottom_blob_unpacked.empty())
314
for (size_t i = 0; i < top_blobs.size(); i++)
316
Mat& top_blob = top_blobs[i];
318
if (out_elempack == 1 && top_blob.elempack == 4)
320
int size = top_blob.w * top_blob.h * top_blob.d;
322
for (int q = 0; q < top_blob.c; q++)
324
const float* r0 = bottom_blob_unpacked.channel(p);
325
const float* r1 = bottom_blob_unpacked.channel(p + 1);
326
const float* r2 = bottom_blob_unpacked.channel(p + 2);
327
const float* r3 = bottom_blob_unpacked.channel(p + 3);
329
float* outptr0 = top_blob.channel(q);
331
for (int j = 0; j < size; j++)
346
int size = top_blob.total();
348
const float* ptr = bottom_blob_unpacked.channel(p);
349
float* outptr = top_blob;
350
memcpy(outptr, ptr, size * top_blob.elemsize);
357
if ((dims == 3 && positive_axis == 1) || (dims == 4 && positive_axis == 2))
360
int w = bottom_blob.w;
361
int h = bottom_blob.h;
362
int d = bottom_blob.d;
363
int channels = bottom_blob.c;
366
for (size_t i = 0; i < top_blobs.size(); i++)
371
if (i == top_blobs.size() - 1)
377
int indice = indices_ptr[i];
378
int positive_indice = indice < 0 ? h + indice : indice;
379
slice = positive_indice - q;
384
slice = slices_ptr[i];
387
slice = static_cast<int>((h - q) / (top_blobs.size() - i));
391
Mat& top_blob = top_blobs[i];
392
top_blob.create(w, slice, d, channels, elemsize, elempack, opt.blob_allocator);
393
if (top_blob.empty())
396
top_blob.dims = dims;
401
#pragma omp parallel for num_threads(opt.num_threads)
402
for (int p = 0; p < channels; p++)
404
const float* ptr = bottom_blob.channel(p);
406
for (int j = 0; j < d; j++)
408
for (size_t i = 0; i < top_blobs.size(); i++)
410
Mat& top_blob = top_blobs[i];
412
int size = top_blob.w * top_blob.h;
414
float* outptr = top_blob.channel(p).depth(j);
415
memcpy(outptr, ptr, size * elemsize);
417
ptr += size * elempack;
423
if ((dims == 3 && positive_axis == 2) || (dims == 4 && positive_axis == 3))
426
int w = bottom_blob.w;
427
int h = bottom_blob.h;
428
int d = bottom_blob.d;
429
int channels = bottom_blob.c;
432
for (size_t i = 0; i < top_blobs.size(); i++)
437
if (i == top_blobs.size() - 1)
443
int indice = indices_ptr[i];
444
int positive_indice = indice < 0 ? w + indice : indice;
445
slice = positive_indice - q;
450
slice = slices_ptr[i];
453
slice = static_cast<int>((w - q) / (top_blobs.size() - i));
457
Mat& top_blob = top_blobs[i];
458
top_blob.create(slice, h, d, channels, elemsize, elempack, opt.blob_allocator);
459
if (top_blob.empty())
462
top_blob.dims = dims;
467
#pragma omp parallel for num_threads(opt.num_threads)
468
for (int p = 0; p < channels; p++)
470
const float* ptr = bottom_blob.channel(p);
472
for (int j = 0; j < d; j++)
474
for (int k = 0; k < h; k++)
476
for (size_t i = 0; i < top_blobs.size(); i++)
478
Mat& top_blob = top_blobs[i];
480
float* outptr = top_blob.channel(p).depth(j).row(k);
481
memcpy(outptr, ptr, top_blob.w * elemsize);
483
ptr += top_blob.w * elempack;
490
if (dims == 4 && positive_axis == 1)
492
int w = bottom_blob.w;
493
int h = bottom_blob.h;
494
int d = bottom_blob.d;
495
int channels = bottom_blob.c;
498
for (size_t i = 0; i < top_blobs.size(); i++)
503
if (i == top_blobs.size() - 1)
509
int indice = indices_ptr[i];
510
int positive_indice = indice < 0 ? d + indice : indice;
511
slice = positive_indice - q;
516
slice = slices_ptr[i];
519
slice = static_cast<int>((d - q) / (top_blobs.size() - i));
523
Mat& top_blob = top_blobs[i];
524
top_blob.create(w, h, slice, channels, elemsize, elempack, opt.blob_allocator);
525
if (top_blob.empty())
531
#pragma omp parallel for num_threads(opt.num_threads)
532
for (int p = 0; p < channels; p++)
534
const float* ptr = bottom_blob.channel(p);
536
for (size_t i = 0; i < top_blobs.size(); i++)
538
Mat& top_blob = top_blobs[i];
540
int size = top_blob.w * top_blob.h * top_blob.d;
542
float* outptr = top_blob.channel(p);
543
memcpy(outptr, ptr, size * elemsize);
545
ptr += size * elempack;