ncnn

Форк
0
/
permute_vulkan.cpp 
1013 строк · 27.7 Кб
1
// Tencent is pleased to support the open source community by making ncnn available.
2
//
3
// Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved.
4
//
5
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
6
// in compliance with the License. You may obtain a copy of the License at
7
//
8
// https://opensource.org/licenses/BSD-3-Clause
9
//
10
// Unless required by applicable law or agreed to in writing, software distributed
11
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
12
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
13
// specific language governing permissions and limitations under the License.
14

15
#include "permute_vulkan.h"
16

17
#include "layer_shader_type.h"
18

19
namespace ncnn {
20

21
Permute_vulkan::Permute_vulkan()
22
{
23
    support_vulkan = true;
24
    support_image_storage = true;
25

26
    pipeline_permute = 0;
27
    pipeline_permute_pack4 = 0;
28
    pipeline_permute_pack1to4 = 0;
29
    pipeline_permute_pack4to1 = 0;
30
    pipeline_permute_pack8 = 0;
31
    pipeline_permute_pack1to8 = 0;
32
    pipeline_permute_pack4to8 = 0;
33
    pipeline_permute_pack8to4 = 0;
34
    pipeline_permute_pack8to1 = 0;
35
}
36

37
int Permute_vulkan::create_pipeline(const Option& _opt)
38
{
39
    Option opt = _opt;
40
    const Mat& shape = bottom_shapes.empty() ? Mat() : bottom_shapes[0];
41
    const Mat& out_shape = top_shapes.empty() ? Mat() : top_shapes[0];
42

43
    int elempack = 1;
44
    if (shape.dims == 1) elempack = opt.use_shader_pack8 && shape.w % 8 == 0 ? 8 : shape.w % 4 == 0 ? 4 : 1;
45
    if (shape.dims == 2) elempack = opt.use_shader_pack8 && shape.h % 8 == 0 ? 8 : shape.h % 4 == 0 ? 4 : 1;
46
    if (shape.dims == 3 || shape.dims == 4) elempack = opt.use_shader_pack8 && shape.c % 8 == 0 ? 8 : shape.c % 4 == 0 ? 4 : 1;
47

48
    int out_elempack = 1;
49
    if (out_shape.dims == 1) out_elempack = opt.use_shader_pack8 && out_shape.w % 8 == 0 ? 8 : out_shape.w % 4 == 0 ? 4 : 1;
50
    if (out_shape.dims == 2) out_elempack = opt.use_shader_pack8 && out_shape.h % 8 == 0 ? 8 : out_shape.h % 4 == 0 ? 4 : 1;
51
    if (out_shape.dims == 3 || out_shape.dims == 4) out_elempack = opt.use_shader_pack8 && out_shape.c % 8 == 0 ? 8 : out_shape.c % 4 == 0 ? 4 : 1;
52

53
    size_t elemsize;
54
    size_t out_elemsize;
55
    if (opt.use_fp16_storage)
56
    {
57
        elemsize = elempack * 2u;
58
        out_elemsize = out_elempack * 2u;
59
    }
60
    else if (opt.use_fp16_packed)
61
    {
62
        elemsize = elempack == 1 ? 4u : elempack * 2u;
63
        out_elemsize = out_elempack == 1 ? 4u : out_elempack * 2u;
64
    }
65
    else
66
    {
67
        elemsize = elempack * 4u;
68
        out_elemsize = out_elempack * 4u;
69
    }
70

71
    Mat shape_packed;
72
    if (shape.dims == 1) shape_packed = Mat(shape.w / elempack, (void*)0, elemsize, elempack);
73
    if (shape.dims == 2) shape_packed = Mat(shape.w, shape.h / elempack, (void*)0, elemsize, elempack);
74
    if (shape.dims == 3) shape_packed = Mat(shape.w, shape.h, shape.c / elempack, (void*)0, elemsize, elempack);
75
    if (shape.dims == 4) shape_packed = Mat(shape.w, shape.h, shape.d, shape.c / elempack, (void*)0, elemsize, elempack);
76

77
    Mat out_shape_packed;
78
    if (out_shape.dims == 1) out_shape_packed = Mat(out_shape.w / out_elempack, (void*)0, out_elemsize, out_elempack);
79
    if (out_shape.dims == 2) out_shape_packed = Mat(out_shape.w, out_shape.h / out_elempack, (void*)0, out_elemsize, out_elempack);
80
    if (out_shape.dims == 3) out_shape_packed = Mat(out_shape.w, out_shape.h, out_shape.c / out_elempack, (void*)0, out_elemsize, out_elempack);
81
    if (out_shape.dims == 4) out_shape_packed = Mat(out_shape.w, out_shape.h, out_shape.d, out_shape.c / out_elempack, (void*)0, out_elemsize, out_elempack);
82

83
    // check blob shape
84
    if (!vkdev->shape_support_image_storage(shape_packed) || !vkdev->shape_support_image_storage(out_shape_packed))
85
    {
86
        support_image_storage = false;
87
        opt.use_image_storage = false;
88
    }
89

90
    std::vector<vk_specialization_type> specializations(2 + 12);
91
    specializations[0].i = order_type;
92
    specializations[1].i = vkdev->info.bug_implicit_fp16_arithmetic();
93
    specializations[2 + 0].i = shape_packed.dims;
94
    specializations[2 + 1].i = shape_packed.w;
95
    specializations[2 + 2].i = shape_packed.h;
96
    specializations[2 + 3].i = shape_packed.d;
97
    specializations[2 + 4].i = shape_packed.c;
98
    specializations[2 + 5].i = shape_packed.cstep;
99
    specializations[2 + 6].i = out_shape_packed.dims;
100
    specializations[2 + 7].i = out_shape_packed.w;
101
    specializations[2 + 8].i = out_shape_packed.h;
102
    specializations[2 + 9].i = out_shape_packed.d;
103
    specializations[2 + 10].i = out_shape_packed.c;
104
    specializations[2 + 11].i = out_shape_packed.cstep;
105

106
    Mat local_size_xyz_bottom; // pack4to1 and pack8to1
107
    if (shape_packed.dims == 2)
108
    {
109
        local_size_xyz_bottom.w = std::min(8, shape_packed.w);
110
        local_size_xyz_bottom.h = std::min(8, shape_packed.h);
111
        local_size_xyz_bottom.c = 1;
112
    }
113
    if (shape_packed.dims == 3)
114
    {
115
        local_size_xyz_bottom.w = std::min(4, shape_packed.w);
116
        local_size_xyz_bottom.h = std::min(4, shape_packed.h);
117
        local_size_xyz_bottom.c = std::min(4, shape_packed.c);
118
    }
119
    if (shape_packed.dims == 4)
120
    {
121
        local_size_xyz_bottom.w = std::min(4, shape_packed.w);
122
        local_size_xyz_bottom.h = std::min(4, shape_packed.h * shape_packed.d);
123
        local_size_xyz_bottom.c = std::min(4, shape_packed.c);
124
    }
125

126
    Mat local_size_xyz;
127
    if (out_shape_packed.dims == 2)
128
    {
129
        local_size_xyz.w = std::min(8, out_shape_packed.w);
130
        local_size_xyz.h = std::min(8, out_shape_packed.h);
131
        local_size_xyz.c = 1;
132
    }
133
    if (out_shape_packed.dims == 3)
134
    {
135
        local_size_xyz.w = std::min(4, out_shape_packed.w);
136
        local_size_xyz.h = std::min(4, out_shape_packed.h);
137
        local_size_xyz.c = std::min(4, out_shape_packed.c);
138
    }
139
    if (out_shape_packed.dims == 4)
140
    {
141
        local_size_xyz.w = std::min(4, out_shape_packed.w);
142
        local_size_xyz.h = std::min(4, out_shape_packed.h * out_shape_packed.d);
143
        local_size_xyz.c = std::min(4, out_shape_packed.c);
144
    }
145

146
    // pack1
147
    if (shape.dims == 0 || (elempack == 1 && out_elempack == 1))
148
    {
149
        pipeline_permute = new Pipeline(vkdev);
150
        pipeline_permute->set_optimal_local_size_xyz(local_size_xyz);
151
        pipeline_permute->create(LayerShaderType::permute, opt, specializations);
152
    }
153

154
    // pack4
155
    if (shape.dims == 0 || (elempack == 4 && out_elempack == 4))
156
    {
157
        pipeline_permute_pack4 = new Pipeline(vkdev);
158
        pipeline_permute_pack4->set_optimal_local_size_xyz(local_size_xyz);
159
        pipeline_permute_pack4->create(LayerShaderType::permute_pack4, opt, specializations);
160
    }
161

162
    // pack1to4
163
    if (shape.dims == 0 || (elempack == 1 && out_elempack == 4))
164
    {
165
        pipeline_permute_pack1to4 = new Pipeline(vkdev);
166
        pipeline_permute_pack1to4->set_optimal_local_size_xyz(local_size_xyz);
167
        pipeline_permute_pack1to4->create(LayerShaderType::permute_pack1to4, opt, specializations);
168
    }
169

170
    // pack4to1
171
    if (shape.dims == 0 || (elempack == 4 && out_elempack == 1))
172
    {
173
        pipeline_permute_pack4to1 = new Pipeline(vkdev);
174
        pipeline_permute_pack4to1->set_optimal_local_size_xyz(local_size_xyz_bottom);
175
        pipeline_permute_pack4to1->create(LayerShaderType::permute_pack4to1, opt, specializations);
176
    }
177

178
    // pack8
179
    if ((opt.use_shader_pack8 && shape.dims == 0) || (elempack == 8 && out_elempack == 8))
180
    {
181
        pipeline_permute_pack8 = new Pipeline(vkdev);
182
        pipeline_permute_pack8->set_optimal_local_size_xyz(local_size_xyz);
183
        pipeline_permute_pack8->create(LayerShaderType::permute_pack8, opt, specializations);
184
    }
185

186
    // pack1to8
187
    if ((opt.use_shader_pack8 && shape.dims == 0) || (elempack == 1 && out_elempack == 8))
188
    {
189
        pipeline_permute_pack1to8 = new Pipeline(vkdev);
190
        pipeline_permute_pack1to8->set_optimal_local_size_xyz(local_size_xyz);
191
        pipeline_permute_pack1to8->create(LayerShaderType::permute_pack1to8, opt, specializations);
192
    }
193

194
    // pack4to8
195
    if ((opt.use_shader_pack8 && shape.dims == 0) || (elempack == 4 && out_elempack == 8))
196
    {
197
        pipeline_permute_pack4to8 = new Pipeline(vkdev);
198
        pipeline_permute_pack4to8->set_optimal_local_size_xyz(local_size_xyz);
199
        pipeline_permute_pack4to8->create(LayerShaderType::permute_pack4to8, opt, specializations);
200
    }
201

202
    // pack8to4
203
    if ((opt.use_shader_pack8 && shape.dims == 0) || (elempack == 8 && out_elempack == 4))
204
    {
205
        pipeline_permute_pack8to4 = new Pipeline(vkdev);
206
        pipeline_permute_pack8to4->set_optimal_local_size_xyz(local_size_xyz);
207
        pipeline_permute_pack8to4->create(LayerShaderType::permute_pack8to4, opt, specializations);
208
    }
209

210
    // pack8to1
211
    if ((opt.use_shader_pack8 && shape.dims == 0) || (elempack == 8 && out_elempack == 1))
212
    {
213
        pipeline_permute_pack8to1 = new Pipeline(vkdev);
214
        pipeline_permute_pack8to1->set_optimal_local_size_xyz(local_size_xyz_bottom);
215
        pipeline_permute_pack8to1->create(LayerShaderType::permute_pack8to1, opt, specializations);
216
    }
217

218
    return 0;
219
}
220

221
int Permute_vulkan::destroy_pipeline(const Option& /*opt*/)
222
{
223
    delete pipeline_permute;
224
    pipeline_permute = 0;
225

226
    delete pipeline_permute_pack4;
227
    pipeline_permute_pack4 = 0;
228

229
    delete pipeline_permute_pack1to4;
230
    pipeline_permute_pack1to4 = 0;
231

232
    delete pipeline_permute_pack4to1;
233
    pipeline_permute_pack4to1 = 0;
234

235
    delete pipeline_permute_pack8;
236
    pipeline_permute_pack8 = 0;
237

238
    delete pipeline_permute_pack1to8;
239
    pipeline_permute_pack1to8 = 0;
240

241
    delete pipeline_permute_pack4to8;
242
    pipeline_permute_pack4to8 = 0;
243

244
    delete pipeline_permute_pack8to4;
245
    pipeline_permute_pack8to4 = 0;
246

247
    delete pipeline_permute_pack8to1;
248
    pipeline_permute_pack8to1 = 0;
249

250
    return 0;
251
}
252

253
int Permute_vulkan::forward(const VkMat& bottom_blob, VkMat& top_blob, VkCompute& cmd, const Option& opt) const
254
{
255
    int w = bottom_blob.w;
256
    int h = bottom_blob.h;
257
    int d = bottom_blob.d;
258
    int channels = bottom_blob.c;
259
    size_t elemsize = bottom_blob.elemsize;
260
    int elempack = bottom_blob.elempack;
261

262
    int dims = bottom_blob.dims;
263

264
    if (dims == 1 || order_type == 0)
265
    {
266
        top_blob = bottom_blob;
267
        return 0;
268
    }
269

270
    int out_elempack;
271
    size_t out_elemsize;
272

273
    if (dims == 2)
274
    {
275
        // order_type
276
        // 0 = w h
277
        // 1 = h w
278

279
        int outw;
280
        int outh;
281

282
        // if (order_type == 1)
283
        {
284
            outw = h * elempack;
285
            outh = w;
286
        }
287

288
        out_elempack = opt.use_shader_pack8 && outh % 8 == 0 ? 8 : outh % 4 == 0 ? 4 : 1;
289
        out_elemsize = elemsize / elempack * out_elempack;
290

291
        if (opt.use_fp16_packed && !opt.use_fp16_storage)
292
        {
293
            if (out_elempack == 8) out_elemsize = 8 * 2u;
294
            if (out_elempack == 4) out_elemsize = 4 * 2u;
295
            if (out_elempack == 1) out_elemsize = 4u;
296
        }
297

298
        top_blob.create(outw, outh / out_elempack, out_elemsize, out_elempack, opt.blob_vkallocator);
299
        if (top_blob.empty())
300
            return -100;
301
    }
302
    else if (dims == 3)
303
    {
304
        // order_type
305
        // 0 = w h c
306
        // 1 = h w c
307
        // 2 = w c h
308
        // 3 = c w h
309
        // 4 = h c w
310
        // 5 = c h w
311

312
        const int c = channels * elempack;
313

314
        int outw;
315
        int outh;
316
        int outc;
317

318
        if (order_type == 1)
319
        {
320
            outw = h;
321
            outh = w;
322
            outc = c;
323
        }
324
        else if (order_type == 2)
325
        {
326
            outw = w;
327
            outh = c;
328
            outc = h;
329
        }
330
        else if (order_type == 3)
331
        {
332
            outw = c;
333
            outh = w;
334
            outc = h;
335
        }
336
        else if (order_type == 4)
337
        {
338
            outw = h;
339
            outh = c;
340
            outc = w;
341
        }
342
        else // if (order_type == 5)
343
        {
344
            outw = c;
345
            outh = h;
346
            outc = w;
347
        }
348

349
        out_elempack = opt.use_shader_pack8 && outc % 8 == 0 ? 8 : outc % 4 == 0 ? 4 : 1;
350
        out_elemsize = elemsize / elempack * out_elempack;
351

352
        if (opt.use_fp16_packed && !opt.use_fp16_storage)
353
        {
354
            if (out_elempack == 8) out_elemsize = 8 * 2u;
355
            if (out_elempack == 4) out_elemsize = 4 * 2u;
356
            if (out_elempack == 1) out_elemsize = 4u;
357
        }
358

359
        top_blob.create(outw, outh, outc / out_elempack, out_elemsize, out_elempack, opt.blob_vkallocator);
360
        if (top_blob.empty())
361
            return -100;
362
    }
363
    else // if (dims == 4)
364
    {
365
        // order_type
366
        // 0 = w h d c
367
        // 1 = h w d c
368
        // 2 = w d h c
369
        // 3 = d w h c
370
        // 4 = h d w c
371
        // 5 = d h w c
372
        // 6 = w h c d
373
        // 7 = h w c d
374
        // 8 = w c h d
375
        // 9 = c w h d
376
        //10 = h c w d
377
        //11 = c h w d
378
        //12 = w d c h
379
        //13 = d w c h
380
        //14 = w c d h
381
        //15 = c w d h
382
        //16 = d c w h
383
        //17 = c d w h
384
        //18 = h d c w
385
        //19 = d h c w
386
        //20 = h c d w
387
        //21 = c h d w
388
        //22 = d c h w
389
        //23 = c d h w
390

391
        const int c = channels * elempack;
392

393
        int outw;
394
        int outh;
395
        int outd;
396
        int outc;
397

398
        if (order_type == 1)
399
        {
400
            outw = h;
401
            outh = w;
402
            outd = d;
403
            outc = c;
404
        }
405
        else if (order_type == 2)
406
        {
407
            outw = w;
408
            outh = d;
409
            outd = h;
410
            outc = c;
411
        }
412
        else if (order_type == 3)
413
        {
414
            outw = d;
415
            outh = w;
416
            outd = h;
417
            outc = c;
418
        }
419
        else if (order_type == 4)
420
        {
421
            outw = h;
422
            outh = d;
423
            outd = w;
424
            outc = c;
425
        }
426
        else if (order_type == 5)
427
        {
428
            outw = d;
429
            outh = h;
430
            outd = w;
431
            outc = c;
432
        }
433
        else if (order_type == 6)
434
        {
435
            outw = w;
436
            outh = h;
437
            outd = c;
438
            outc = d;
439
        }
440
        else if (order_type == 7)
441
        {
442
            outw = h;
443
            outh = w;
444
            outd = c;
445
            outc = d;
446
        }
447
        else if (order_type == 8)
448
        {
449
            outw = w;
450
            outh = c;
451
            outd = h;
452
            outc = d;
453
        }
454
        else if (order_type == 9)
455
        {
456
            outw = c;
457
            outh = w;
458
            outd = h;
459
            outc = d;
460
        }
461
        else if (order_type == 10)
462
        {
463
            outw = h;
464
            outh = c;
465
            outd = w;
466
            outc = d;
467
        }
468
        else if (order_type == 11)
469
        {
470
            outw = c;
471
            outh = h;
472
            outd = w;
473
            outc = d;
474
        }
475
        else if (order_type == 12)
476
        {
477
            outw = w;
478
            outh = d;
479
            outd = c;
480
            outc = h;
481
        }
482
        else if (order_type == 13)
483
        {
484
            outw = d;
485
            outh = w;
486
            outd = c;
487
            outc = h;
488
        }
489
        else if (order_type == 14)
490
        {
491
            outw = w;
492
            outh = c;
493
            outd = d;
494
            outc = h;
495
        }
496
        else if (order_type == 15)
497
        {
498
            outw = c;
499
            outh = w;
500
            outd = d;
501
            outc = h;
502
        }
503
        else if (order_type == 16)
504
        {
505
            outw = d;
506
            outh = c;
507
            outd = w;
508
            outc = h;
509
        }
510
        else if (order_type == 17)
511
        {
512
            outw = c;
513
            outh = d;
514
            outd = w;
515
            outc = h;
516
        }
517
        else if (order_type == 18)
518
        {
519
            outw = h;
520
            outh = d;
521
            outd = c;
522
            outc = w;
523
        }
524
        else if (order_type == 19)
525
        {
526
            outw = d;
527
            outh = h;
528
            outd = c;
529
            outc = w;
530
        }
531
        else if (order_type == 20)
532
        {
533
            outw = h;
534
            outh = c;
535
            outd = d;
536
            outc = w;
537
        }
538
        else if (order_type == 21)
539
        {
540
            outw = c;
541
            outh = h;
542
            outd = d;
543
            outc = w;
544
        }
545
        else if (order_type == 22)
546
        {
547
            outw = d;
548
            outh = c;
549
            outd = h;
550
            outc = w;
551
        }
552
        else // if (order_type == 23)
553
        {
554
            outw = c;
555
            outh = d;
556
            outd = h;
557
            outc = w;
558
        }
559

560
        out_elempack = opt.use_shader_pack8 && outc % 8 == 0 ? 8 : outc % 4 == 0 ? 4 : 1;
561
        out_elemsize = elemsize / elempack * out_elempack;
562

563
        if (opt.use_fp16_packed && !opt.use_fp16_storage)
564
        {
565
            if (out_elempack == 8) out_elemsize = 8 * 2u;
566
            if (out_elempack == 4) out_elemsize = 4 * 2u;
567
            if (out_elempack == 1) out_elemsize = 4u;
568
        }
569

570
        top_blob.create(outw, outh, outd, outc / out_elempack, out_elemsize, out_elempack, opt.blob_vkallocator);
571
        if (top_blob.empty())
572
            return -100;
573
    }
574

575
    std::vector<VkMat> bindings(2);
576
    bindings[0] = bottom_blob;
577
    bindings[1] = top_blob;
578

579
    std::vector<vk_constant_type> constants(12);
580
    constants[0].i = bottom_blob.dims;
581
    constants[1].i = bottom_blob.w;
582
    constants[2].i = bottom_blob.h;
583
    constants[3].i = bottom_blob.d;
584
    constants[4].i = bottom_blob.c;
585
    constants[5].i = bottom_blob.cstep;
586
    constants[6].i = top_blob.dims;
587
    constants[7].i = top_blob.w;
588
    constants[8].i = top_blob.h;
589
    constants[9].i = top_blob.d;
590
    constants[10].i = top_blob.c;
591
    constants[11].i = top_blob.cstep;
592

593
    if (elempack == 1 && out_elempack == 1)
594
    {
595
        cmd.record_pipeline(pipeline_permute, bindings, constants, top_blob);
596
    }
597
    else if (elempack == 4 && out_elempack == 4)
598
    {
599
        cmd.record_pipeline(pipeline_permute_pack4, bindings, constants, top_blob);
600
    }
601
    else if (elempack == 1 && out_elempack == 4)
602
    {
603
        cmd.record_pipeline(pipeline_permute_pack1to4, bindings, constants, top_blob);
604
    }
605
    else if (elempack == 4 && out_elempack == 1)
606
    {
607
        cmd.record_pipeline(pipeline_permute_pack4to1, bindings, constants, bottom_blob);
608
    }
609
    else if (elempack == 8 && out_elempack == 8)
610
    {
611
        cmd.record_pipeline(pipeline_permute_pack8, bindings, constants, top_blob);
612
    }
613
    else if (elempack == 1 && out_elempack == 8)
614
    {
615
        cmd.record_pipeline(pipeline_permute_pack1to8, bindings, constants, top_blob);
616
    }
617
    else if (elempack == 4 && out_elempack == 8)
618
    {
619
        cmd.record_pipeline(pipeline_permute_pack4to8, bindings, constants, top_blob);
620
    }
621
    else if (elempack == 8 && out_elempack == 4)
622
    {
623
        cmd.record_pipeline(pipeline_permute_pack8to4, bindings, constants, top_blob);
624
    }
625
    else if (elempack == 8 && out_elempack == 1)
626
    {
627
        cmd.record_pipeline(pipeline_permute_pack8to1, bindings, constants, bottom_blob);
628
    }
629

630
    return 0;
631
}
632

633
int Permute_vulkan::forward(const VkImageMat& bottom_blob, VkImageMat& top_blob, VkCompute& cmd, const Option& opt) const
634
{
635
    int w = bottom_blob.w;
636
    int h = bottom_blob.h;
637
    int d = bottom_blob.d;
638
    int channels = bottom_blob.c;
639
    size_t elemsize = bottom_blob.elemsize;
640
    int elempack = bottom_blob.elempack;
641

642
    int dims = bottom_blob.dims;
643

644
    if (dims == 1 || order_type == 0)
645
    {
646
        top_blob = bottom_blob;
647
        return 0;
648
    }
649

650
    int out_elempack;
651
    size_t out_elemsize;
652

653
    if (dims == 2)
654
    {
655
        // order_type
656
        // 0 = w h
657
        // 1 = h w
658

659
        int outw;
660
        int outh;
661

662
        // if (order_type == 1)
663
        {
664
            outw = h * elempack;
665
            outh = w;
666
        }
667

668
        out_elempack = opt.use_shader_pack8 && outh % 8 == 0 ? 8 : outh % 4 == 0 ? 4 : 1;
669
        out_elemsize = elemsize / elempack * out_elempack;
670

671
        if (opt.use_fp16_packed && !opt.use_fp16_storage)
672
        {
673
            if (out_elempack == 8) out_elemsize = 8 * 2u;
674
            if (out_elempack == 4) out_elemsize = 4 * 2u;
675
            if (out_elempack == 1) out_elemsize = 4u;
676
        }
677

678
        top_blob.create(outw, outh / out_elempack, out_elemsize, out_elempack, opt.blob_vkallocator);
679
        if (top_blob.empty())
680
            return -100;
681
    }
682
    else if (dims == 3)
683
    {
684
        // order_type
685
        // 0 = w h c
686
        // 1 = h w c
687
        // 2 = w c h
688
        // 3 = c w h
689
        // 4 = h c w
690
        // 5 = c h w
691

692
        const int c = channels * elempack;
693

694
        int outw;
695
        int outh;
696
        int outc;
697

698
        if (order_type == 1)
699
        {
700
            outw = h;
701
            outh = w;
702
            outc = c;
703
        }
704
        else if (order_type == 2)
705
        {
706
            outw = w;
707
            outh = c;
708
            outc = h;
709
        }
710
        else if (order_type == 3)
711
        {
712
            outw = c;
713
            outh = w;
714
            outc = h;
715
        }
716
        else if (order_type == 4)
717
        {
718
            outw = h;
719
            outh = c;
720
            outc = w;
721
        }
722
        else // if (order_type == 5)
723
        {
724
            outw = c;
725
            outh = h;
726
            outc = w;
727
        }
728

729
        out_elempack = opt.use_shader_pack8 && outc % 8 == 0 ? 8 : outc % 4 == 0 ? 4 : 1;
730
        out_elemsize = elemsize / elempack * out_elempack;
731

732
        if (opt.use_fp16_packed && !opt.use_fp16_storage)
733
        {
734
            if (out_elempack == 8) out_elemsize = 8 * 2u;
735
            if (out_elempack == 4) out_elemsize = 4 * 2u;
736
            if (out_elempack == 1) out_elemsize = 4u;
737
        }
738

739
        top_blob.create(outw, outh, outc / out_elempack, out_elemsize, out_elempack, opt.blob_vkallocator);
740
        if (top_blob.empty())
741
            return -100;
742
    }
743
    else // if (dims == 4)
744
    {
745
        // order_type
746
        // 0 = w h d c
747
        // 1 = h w d c
748
        // 2 = w d h c
749
        // 3 = d w h c
750
        // 4 = h d w c
751
        // 5 = d h w c
752
        // 6 = w h c d
753
        // 7 = h w c d
754
        // 8 = w c h d
755
        // 9 = c w h d
756
        //10 = h c w d
757
        //11 = c h w d
758
        //12 = w d c h
759
        //13 = d w c h
760
        //14 = w c d h
761
        //15 = c w d h
762
        //16 = d c w h
763
        //17 = c d w h
764
        //18 = h d c w
765
        //19 = d h c w
766
        //20 = h c d w
767
        //21 = c h d w
768
        //22 = d c h w
769
        //23 = c d h w
770

771
        const int c = channels * elempack;
772

773
        int outw;
774
        int outh;
775
        int outd;
776
        int outc;
777

778
        if (order_type == 1)
779
        {
780
            outw = h;
781
            outh = w;
782
            outd = d;
783
            outc = c;
784
        }
785
        else if (order_type == 2)
786
        {
787
            outw = w;
788
            outh = d;
789
            outd = h;
790
            outc = c;
791
        }
792
        else if (order_type == 3)
793
        {
794
            outw = d;
795
            outh = w;
796
            outd = h;
797
            outc = c;
798
        }
799
        else if (order_type == 4)
800
        {
801
            outw = h;
802
            outh = d;
803
            outd = w;
804
            outc = c;
805
        }
806
        else if (order_type == 5)
807
        {
808
            outw = d;
809
            outh = h;
810
            outd = w;
811
            outc = c;
812
        }
813
        else if (order_type == 6)
814
        {
815
            outw = w;
816
            outh = h;
817
            outd = c;
818
            outc = d;
819
        }
820
        else if (order_type == 7)
821
        {
822
            outw = h;
823
            outh = w;
824
            outd = c;
825
            outc = d;
826
        }
827
        else if (order_type == 8)
828
        {
829
            outw = w;
830
            outh = c;
831
            outd = h;
832
            outc = d;
833
        }
834
        else if (order_type == 9)
835
        {
836
            outw = c;
837
            outh = w;
838
            outd = h;
839
            outc = d;
840
        }
841
        else if (order_type == 10)
842
        {
843
            outw = h;
844
            outh = c;
845
            outd = w;
846
            outc = d;
847
        }
848
        else if (order_type == 11)
849
        {
850
            outw = c;
851
            outh = h;
852
            outd = w;
853
            outc = d;
854
        }
855
        else if (order_type == 12)
856
        {
857
            outw = w;
858
            outh = d;
859
            outd = c;
860
            outc = h;
861
        }
862
        else if (order_type == 13)
863
        {
864
            outw = d;
865
            outh = w;
866
            outd = c;
867
            outc = h;
868
        }
869
        else if (order_type == 14)
870
        {
871
            outw = w;
872
            outh = c;
873
            outd = d;
874
            outc = h;
875
        }
876
        else if (order_type == 15)
877
        {
878
            outw = c;
879
            outh = w;
880
            outd = d;
881
            outc = h;
882
        }
883
        else if (order_type == 16)
884
        {
885
            outw = d;
886
            outh = c;
887
            outd = w;
888
            outc = h;
889
        }
890
        else if (order_type == 17)
891
        {
892
            outw = c;
893
            outh = d;
894
            outd = w;
895
            outc = h;
896
        }
897
        else if (order_type == 18)
898
        {
899
            outw = h;
900
            outh = d;
901
            outd = c;
902
            outc = w;
903
        }
904
        else if (order_type == 19)
905
        {
906
            outw = d;
907
            outh = h;
908
            outd = c;
909
            outc = w;
910
        }
911
        else if (order_type == 20)
912
        {
913
            outw = h;
914
            outh = c;
915
            outd = d;
916
            outc = w;
917
        }
918
        else if (order_type == 21)
919
        {
920
            outw = c;
921
            outh = h;
922
            outd = d;
923
            outc = w;
924
        }
925
        else if (order_type == 22)
926
        {
927
            outw = d;
928
            outh = c;
929
            outd = h;
930
            outc = w;
931
        }
932
        else // if (order_type == 23)
933
        {
934
            outw = c;
935
            outh = d;
936
            outd = h;
937
            outc = w;
938
        }
939

940
        out_elempack = opt.use_shader_pack8 && outc % 8 == 0 ? 8 : outc % 4 == 0 ? 4 : 1;
941
        out_elemsize = elemsize / elempack * out_elempack;
942

943
        if (opt.use_fp16_packed && !opt.use_fp16_storage)
944
        {
945
            if (out_elempack == 8) out_elemsize = 8 * 2u;
946
            if (out_elempack == 4) out_elemsize = 4 * 2u;
947
            if (out_elempack == 1) out_elemsize = 4u;
948
        }
949

950
        top_blob.create(outw, outh, outd, outc / out_elempack, out_elemsize, out_elempack, opt.blob_vkallocator);
951
        if (top_blob.empty())
952
            return -100;
953
    }
954

955
    std::vector<VkImageMat> bindings(2);
956
    bindings[0] = bottom_blob;
957
    bindings[1] = top_blob;
958

959
    std::vector<vk_constant_type> constants(12);
960
    constants[0].i = bottom_blob.dims;
961
    constants[1].i = bottom_blob.w;
962
    constants[2].i = bottom_blob.h;
963
    constants[3].i = bottom_blob.d;
964
    constants[4].i = bottom_blob.c;
965
    constants[5].i = 0; //bottom_blob.cstep;
966
    constants[6].i = top_blob.dims;
967
    constants[7].i = top_blob.w;
968
    constants[8].i = top_blob.h;
969
    constants[9].i = top_blob.d;
970
    constants[10].i = top_blob.c;
971
    constants[11].i = 0; //top_blob.cstep;
972

973
    if (elempack == 1 && out_elempack == 1)
974
    {
975
        cmd.record_pipeline(pipeline_permute, bindings, constants, top_blob);
976
    }
977
    else if (elempack == 4 && out_elempack == 4)
978
    {
979
        cmd.record_pipeline(pipeline_permute_pack4, bindings, constants, top_blob);
980
    }
981
    else if (elempack == 1 && out_elempack == 4)
982
    {
983
        cmd.record_pipeline(pipeline_permute_pack1to4, bindings, constants, top_blob);
984
    }
985
    else if (elempack == 4 && out_elempack == 1)
986
    {
987
        cmd.record_pipeline(pipeline_permute_pack4to1, bindings, constants, bottom_blob);
988
    }
989
    else if (elempack == 8 && out_elempack == 8)
990
    {
991
        cmd.record_pipeline(pipeline_permute_pack8, bindings, constants, top_blob);
992
    }
993
    else if (elempack == 1 && out_elempack == 8)
994
    {
995
        cmd.record_pipeline(pipeline_permute_pack1to8, bindings, constants, top_blob);
996
    }
997
    else if (elempack == 4 && out_elempack == 8)
998
    {
999
        cmd.record_pipeline(pipeline_permute_pack4to8, bindings, constants, top_blob);
1000
    }
1001
    else if (elempack == 8 && out_elempack == 4)
1002
    {
1003
        cmd.record_pipeline(pipeline_permute_pack8to4, bindings, constants, top_blob);
1004
    }
1005
    else if (elempack == 8 && out_elempack == 1)
1006
    {
1007
        cmd.record_pipeline(pipeline_permute_pack8to1, bindings, constants, bottom_blob);
1008
    }
1009

1010
    return 0;
1011
}
1012

1013
} // namespace ncnn
1014

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.