ncnn

Форк
0
/
multiheadattention_vulkan.cpp 
762 строки · 25.1 Кб
1
// Tencent is pleased to support the open source community by making ncnn available.
2
//
3
// Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved.
4
//
5
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
6
// in compliance with the License. You may obtain a copy of the License at
7
//
8
// https://opensource.org/licenses/BSD-3-Clause
9
//
10
// Unless required by applicable law or agreed to in writing, software distributed
11
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
12
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
13
// specific language governing permissions and limitations under the License.
14

15
#include "multiheadattention_vulkan.h"
16

17
#include "layer_shader_type.h"
18
#include "layer_type.h"
19

20
namespace ncnn {
21

22
MultiHeadAttention_vulkan::MultiHeadAttention_vulkan()
23
{
24
    support_vulkan = true;
25
    support_image_storage = true;
26

27
    q_gemm = 0;
28
    k_gemm = 0;
29
    v_gemm = 0;
30

31
    qk_softmax = 0;
32

33
    o_gemm = 0;
34

35
    pipeline_multiheadattention_qk_cross = 0;
36
    pipeline_multiheadattention_qk_cross_pack4 = 0;
37
    pipeline_multiheadattention_qk_cross_pack1to4 = 0;
38
    pipeline_multiheadattention_qk_cross_pack4to1 = 0;
39

40
    pipeline_multiheadattention_qkv_cross = 0;
41
    pipeline_multiheadattention_qkv_cross_pack4 = 0;
42
    pipeline_multiheadattention_qkv_cross_pack1to4 = 0;
43
    pipeline_multiheadattention_qkv_cross_pack4to1 = 0;
44
}
45

46
int MultiHeadAttention_vulkan::create_pipeline(const Option& opt)
47
{
48
    const int embed_dim_per_head = embed_dim / num_heads;
49
    const int qdim = weight_data_size / embed_dim;
50
    {
51
        q_gemm = ncnn::create_layer_vulkan(ncnn::LayerType::Gemm);
52
        q_gemm->vkdev = vkdev;
53
        ncnn::ParamDict pd;
54
        pd.set(0, scale);
55
        pd.set(1, 1.f);
56
        pd.set(2, 0);         // transA
57
        pd.set(3, 1);         // transB
58
        pd.set(4, 1);         // constantA
59
        pd.set(5, 0);         // constantB
60
        pd.set(6, 1);         // constantC
61
        pd.set(7, embed_dim); // M
62
        pd.set(8, 0);         // N
63
        pd.set(9, qdim);      // K
64
        pd.set(10, 1);        // constant_broadcast_type_C
65
        pd.set(11, 0);        // output_N1M
66
        // pd.set(12, 1);        // output_elempack
67
        pd.set(14, 0); // output_transpose
68
        q_gemm->load_param(pd);
69
        Mat weights[2];
70
        weights[0] = q_weight_data;
71
        weights[1] = q_bias_data;
72
        q_gemm->load_model(ModelBinFromMatArray(weights));
73
        q_gemm->create_pipeline(opt);
74

75
        if (opt.lightmode)
76
        {
77
            q_weight_data.release();
78
            q_bias_data.release();
79
        }
80
    }
81

82
    {
83
        k_gemm = ncnn::create_layer_vulkan(ncnn::LayerType::Gemm);
84
        k_gemm->vkdev = vkdev;
85
        ncnn::ParamDict pd;
86
        pd.set(2, 0);         // transA
87
        pd.set(3, 1);         // transB
88
        pd.set(4, 1);         // constantA
89
        pd.set(5, 0);         // constantB
90
        pd.set(6, 1);         // constantC
91
        pd.set(7, embed_dim); // M
92
        pd.set(8, 0);         // N
93
        pd.set(9, kdim);      // K
94
        pd.set(10, 1);        // constant_broadcast_type_C
95
        pd.set(11, 0);        // output_N1M
96
        // pd.set(12, 1);        // output_elempack
97
        pd.set(14, 0); // output_transpose
98
        k_gemm->load_param(pd);
99
        Mat weights[2];
100
        weights[0] = k_weight_data;
101
        weights[1] = k_bias_data;
102
        k_gemm->load_model(ModelBinFromMatArray(weights));
103
        k_gemm->create_pipeline(opt);
104

105
        if (opt.lightmode)
106
        {
107
            k_weight_data.release();
108
            k_bias_data.release();
109
        }
110
    }
111

112
    {
113
        v_gemm = ncnn::create_layer_vulkan(ncnn::LayerType::Gemm);
114
        v_gemm->vkdev = vkdev;
115
        ncnn::ParamDict pd;
116
        pd.set(2, 0);         // transA
117
        pd.set(3, 1);         // transB
118
        pd.set(4, 1);         // constantA
119
        pd.set(5, 0);         // constantB
120
        pd.set(6, 1);         // constantC
121
        pd.set(7, embed_dim); // M
122
        pd.set(8, 0);         // N
123
        pd.set(9, vdim);      // K
124
        pd.set(10, 1);        // constant_broadcast_type_C
125
        pd.set(11, 0);        // output_N1M
126
        // pd.set(12, 1);        // output_elempack
127
        pd.set(14, 0); // output_transpose
128
        v_gemm->load_param(pd);
129
        Mat weights[2];
130
        weights[0] = v_weight_data;
131
        weights[1] = v_bias_data;
132
        v_gemm->load_model(ModelBinFromMatArray(weights));
133
        v_gemm->create_pipeline(opt);
134

135
        if (opt.lightmode)
136
        {
137
            v_weight_data.release();
138
            v_bias_data.release();
139
        }
140
    }
141

142
    {
143
        std::vector<vk_specialization_type> specializations(6);
144
        specializations[0].i = attn_mask;
145
        specializations[1].i = 0; //constantM;
146
        specializations[2].i = 0; //constantN;
147
        specializations[3].i = 0; //embed_dim_per_head;//constantK;
148
        specializations[4].i = num_heads;
149
        specializations[5].i = 0; //attn_mask.dims;
150

151
        {
152
            pipeline_multiheadattention_qk_cross = new Pipeline(vkdev);
153
            pipeline_multiheadattention_qk_cross->set_local_size_xyz(8, 8, 1);
154
            pipeline_multiheadattention_qk_cross->create(LayerShaderType::multiheadattention_qk_cross, opt, specializations);
155
        }
156
        {
157
            pipeline_multiheadattention_qk_cross_pack4 = new Pipeline(vkdev);
158
            pipeline_multiheadattention_qk_cross_pack4->set_local_size_xyz(8, 8, 1);
159
            pipeline_multiheadattention_qk_cross_pack4->create(LayerShaderType::multiheadattention_qk_cross_pack4, opt, specializations);
160
        }
161
        {
162
            pipeline_multiheadattention_qk_cross_pack1to4 = new Pipeline(vkdev);
163
            pipeline_multiheadattention_qk_cross_pack1to4->set_local_size_xyz(8, 8, 1);
164
            pipeline_multiheadattention_qk_cross_pack1to4->create(LayerShaderType::multiheadattention_qk_cross_pack1to4, opt, specializations);
165
        }
166
        {
167
            pipeline_multiheadattention_qk_cross_pack4to1 = new Pipeline(vkdev);
168
            pipeline_multiheadattention_qk_cross_pack4to1->set_local_size_xyz(8, 8, 1);
169
            pipeline_multiheadattention_qk_cross_pack4to1->create(LayerShaderType::multiheadattention_qk_cross_pack4to1, opt, specializations);
170
        }
171
    }
172
    {
173
        std::vector<vk_specialization_type> specializations(4);
174
        specializations[0].i = 0; //constantM;
175
        specializations[1].i = 0; //embed_dim_per_head;//constantN;
176
        specializations[2].i = 0; //constantK;
177
        specializations[3].i = num_heads;
178

179
        {
180
            pipeline_multiheadattention_qkv_cross = new Pipeline(vkdev);
181
            pipeline_multiheadattention_qkv_cross->set_local_size_xyz(8, 8, 1);
182
            pipeline_multiheadattention_qkv_cross->create(LayerShaderType::multiheadattention_qkv_cross, opt, specializations);
183
        }
184
        {
185
            pipeline_multiheadattention_qkv_cross_pack4 = new Pipeline(vkdev);
186
            pipeline_multiheadattention_qkv_cross_pack4->set_local_size_xyz(8, 8, 1);
187
            pipeline_multiheadattention_qkv_cross_pack4->create(LayerShaderType::multiheadattention_qkv_cross_pack4, opt, specializations);
188
        }
189
        {
190
            pipeline_multiheadattention_qkv_cross_pack1to4 = new Pipeline(vkdev);
191
            pipeline_multiheadattention_qkv_cross_pack1to4->set_local_size_xyz(8, 8, 1);
192
            pipeline_multiheadattention_qkv_cross_pack1to4->create(LayerShaderType::multiheadattention_qkv_cross_pack1to4, opt, specializations);
193
        }
194
        {
195
            pipeline_multiheadattention_qkv_cross_pack4to1 = new Pipeline(vkdev);
196
            pipeline_multiheadattention_qkv_cross_pack4to1->set_local_size_xyz(8, 8, 1);
197
            pipeline_multiheadattention_qkv_cross_pack4to1->create(LayerShaderType::multiheadattention_qkv_cross_pack4to1, opt, specializations);
198
        }
199
    }
200

201
    {
202
        qk_softmax = ncnn::create_layer_vulkan(ncnn::LayerType::Softmax);
203
        qk_softmax->vkdev = vkdev;
204
        ncnn::ParamDict pd;
205
        pd.set(0, -1);
206
        pd.set(1, 1);
207
        qk_softmax->load_param(pd);
208
        qk_softmax->load_model(ModelBinFromMatArray(0));
209
        qk_softmax->create_pipeline(opt);
210
    }
211

212
    {
213
        o_gemm = ncnn::create_layer_vulkan(ncnn::LayerType::Gemm);
214
        o_gemm->vkdev = vkdev;
215
        ncnn::ParamDict pd;
216
        pd.set(2, 1);         // transA
217
        pd.set(3, 1);         // transB
218
        pd.set(4, 0);         // constantA
219
        pd.set(5, 1);         // constantB
220
        pd.set(6, 1);         // constantC
221
        pd.set(7, 0);         // M = outch
222
        pd.set(8, qdim);      // N = size
223
        pd.set(9, embed_dim); // K = maxk*inch
224
        pd.set(10, 4);        // constant_broadcast_type_C
225
        pd.set(11, 0);        // output_N1M
226
        o_gemm->load_param(pd);
227
        Mat weights[2];
228
        weights[0] = out_weight_data;
229
        weights[1] = out_bias_data;
230
        o_gemm->load_model(ModelBinFromMatArray(weights));
231
        o_gemm->create_pipeline(opt);
232

233
        if (opt.lightmode)
234
        {
235
            out_weight_data.release();
236
            out_bias_data.release();
237
        }
238
    }
239

240
    return 0;
241
}
242

243
int MultiHeadAttention_vulkan::destroy_pipeline(const Option& opt)
244
{
245
    if (q_gemm)
246
    {
247
        q_gemm->destroy_pipeline(opt);
248
        delete q_gemm;
249
        q_gemm = 0;
250
    }
251

252
    if (k_gemm)
253
    {
254
        k_gemm->destroy_pipeline(opt);
255
        delete k_gemm;
256
        k_gemm = 0;
257
    }
258

259
    if (v_gemm)
260
    {
261
        v_gemm->destroy_pipeline(opt);
262
        delete v_gemm;
263
        v_gemm = 0;
264
    }
265

266
    delete pipeline_multiheadattention_qk_cross;
267
    pipeline_multiheadattention_qk_cross = 0;
268

269
    delete pipeline_multiheadattention_qk_cross_pack4;
270
    pipeline_multiheadattention_qk_cross_pack4 = 0;
271

272
    delete pipeline_multiheadattention_qk_cross_pack1to4;
273
    pipeline_multiheadattention_qk_cross_pack1to4 = 0;
274

275
    delete pipeline_multiheadattention_qk_cross_pack4to1;
276
    pipeline_multiheadattention_qk_cross_pack4to1 = 0;
277

278
    delete pipeline_multiheadattention_qkv_cross;
279
    pipeline_multiheadattention_qkv_cross = 0;
280

281
    delete pipeline_multiheadattention_qkv_cross_pack4;
282
    pipeline_multiheadattention_qkv_cross_pack4 = 0;
283

284
    delete pipeline_multiheadattention_qkv_cross_pack1to4;
285
    pipeline_multiheadattention_qkv_cross_pack1to4 = 0;
286

287
    delete pipeline_multiheadattention_qkv_cross_pack4to1;
288
    pipeline_multiheadattention_qkv_cross_pack4to1 = 0;
289

290
    if (qk_softmax)
291
    {
292
        qk_softmax->destroy_pipeline(opt);
293
        delete qk_softmax;
294
        qk_softmax = 0;
295
    }
296

297
    if (o_gemm)
298
    {
299
        o_gemm->destroy_pipeline(opt);
300
        delete o_gemm;
301
        o_gemm = 0;
302
    }
303

304
    return 0;
305
}
306

307
int MultiHeadAttention_vulkan::upload_model(VkTransfer& cmd, const Option& opt)
308
{
309
    if (q_gemm)
310
    {
311
        q_gemm->upload_model(cmd, opt);
312
    }
313

314
    if (k_gemm)
315
    {
316
        k_gemm->upload_model(cmd, opt);
317
    }
318

319
    if (v_gemm)
320
    {
321
        v_gemm->upload_model(cmd, opt);
322
    }
323

324
    if (o_gemm)
325
    {
326
        o_gemm->upload_model(cmd, opt);
327
    }
328

329
    return 0;
330
}
331

332
int MultiHeadAttention_vulkan::forward(const std::vector<VkMat>& bottom_blobs, std::vector<VkMat>& top_blobs, VkCompute& cmd, const Option& opt) const
333
{
334
    const VkMat& q_blob = bottom_blobs[0];
335
    const VkMat& k_blob = (bottom_blobs.size() == 1 || (bottom_blobs.size() == 2 && attn_mask)) ? q_blob : bottom_blobs[1];
336
    const VkMat& v_blob = (bottom_blobs.size() == 1 || (bottom_blobs.size() == 2 && attn_mask)) ? q_blob : (bottom_blobs.size() == 2 || (bottom_blobs.size() == 3 && attn_mask)) ? k_blob : bottom_blobs[2];
337
    VkMat attn_mask_blob = attn_mask ? bottom_blobs[bottom_blobs.size() - 1] : VkMat();
338

339
    const int embed_dim_per_head = embed_dim / num_heads;
340
    const int src_seqlen = q_blob.h * q_blob.elempack;
341
    const int dst_seqlen = k_blob.h * k_blob.elempack;
342

343
    VkMat q_affine;
344
    q_gemm->forward(q_blob, q_affine, cmd, opt);
345

346
    VkMat k_affine;
347
    k_gemm->forward(k_blob, k_affine, cmd, opt);
348

349
    VkMat qk_cross;
350
    {
351
        int M = q_affine.w;
352
        int N = k_affine.w;
353
        int K = q_affine.h * q_affine.elempack / num_heads;
354
        int B = num_heads;
355

356
        // int K_elempack = opt.use_shader_pack8 && K % 8 == 0 ? 8 : K % 4 == 0 ? 4 : 1;
357
        // int M_elempack = opt.use_shader_pack8 && M % 8 == 0 ? 8 : M % 4 == 0 ? 4 : 1;
358
        // int MB_elempack = opt.use_shader_pack8 && (M * B) % 8 == 0 ? 8 : (M * B) % 4 == 0 ? 4 : 1;
359
        int K_elempack = K % 4 == 0 ? 4 : 1;
360
        int M_elempack = M % 4 == 0 ? 4 : 1;
361
        int MB_elempack = (M * B) % 4 == 0 ? 4 : 1;
362
        size_t M_elemsize = q_affine.elemsize / q_affine.elempack * M_elempack;
363

364
        if (opt.use_fp16_packed && !opt.use_fp16_storage)
365
        {
366
            if (M_elempack == 8) M_elemsize = 8 * 2u;
367
            if (M_elempack == 4) M_elemsize = 4 * 2u;
368
            if (M_elempack == 1) M_elemsize = 4u;
369
        }
370

371
        if (K_elempack < q_affine.elempack)
372
        {
373
            VkMat tmp;
374
            vkdev->convert_packing(q_affine, tmp, K_elempack, cmd, opt);
375
            q_affine = tmp;
376
        }
377
        if (K_elempack < k_affine.elempack)
378
        {
379
            VkMat tmp;
380
            vkdev->convert_packing(k_affine, tmp, K_elempack, cmd, opt);
381
            k_affine = tmp;
382
        }
383
        if (M_elempack < attn_mask_blob.elempack)
384
        {
385
            VkMat tmp;
386
            vkdev->convert_packing(attn_mask_blob, tmp, M_elempack, cmd, opt);
387
            attn_mask_blob = tmp;
388
        }
389

390
        qk_cross.create(N, M / M_elempack * B, M_elemsize, M_elempack, opt.blob_vkallocator);
391
        if (qk_cross.empty())
392
            return -100;
393

394
        std::vector<VkMat> bindings(4);
395
        bindings[0] = q_affine;
396
        bindings[1] = k_affine;
397
        bindings[2] = qk_cross;
398
        bindings[3] = attn_mask_blob;
399

400
        std::vector<vk_constant_type> constants(5);
401
        constants[0].i = M / M_elempack;
402
        constants[1].i = N;
403
        constants[2].i = K / K_elempack;
404
        constants[3].i = B;
405
        constants[4].i = attn_mask_blob.dims;
406

407
        VkMat dispatcher;
408
        dispatcher.w = N;
409
        dispatcher.h = M / M_elempack;
410
        dispatcher.c = B;
411

412
        const Pipeline* pipeline = 0;
413
        if (K_elempack == 1 && M_elempack == 1)
414
        {
415
            pipeline = pipeline_multiheadattention_qk_cross;
416
        }
417
        if (K_elempack == 1 && M_elempack == 4)
418
        {
419
            pipeline = pipeline_multiheadattention_qk_cross_pack1to4;
420
        }
421
        if (K_elempack == 4 && M_elempack == 1)
422
        {
423
            pipeline = pipeline_multiheadattention_qk_cross_pack4to1;
424
        }
425
        if (K_elempack == 4 && M_elempack == 4)
426
        {
427
            pipeline = pipeline_multiheadattention_qk_cross_pack4;
428
        }
429

430
        cmd.record_pipeline(pipeline, bindings, constants, dispatcher);
431

432
        if (MB_elempack > M_elempack)
433
        {
434
            VkMat tmp;
435
            vkdev->convert_packing(qk_cross, tmp, MB_elempack, cmd, opt);
436
            qk_cross = tmp;
437
        }
438
    }
439

440
    q_affine.release();
441
    k_affine.release();
442

443
    qk_softmax->forward_inplace(qk_cross, cmd, opt);
444

445
    if (vkdev->info.vendor_id() == 0x10de)
446
    {
447
        // FIXME softmax produces nan result on nvidia (about 20% chance)
448
        // memory barrier seems to be not enough here
449
        // device copy-to and copy-back is better than queue submit anyway  --- nihui
450

451
        // cmd.submit_and_wait();
452
        // cmd.reset();
453

454
        VkImageMat qk_cross2;
455
        cmd.record_buffer_to_image(qk_cross, qk_cross2, opt);
456
        cmd.record_image_to_buffer(qk_cross2, qk_cross, opt);
457
    }
458

459
    VkMat v_affine;
460
    v_gemm->forward(v_blob, v_affine, cmd, opt);
461

462
    VkMat qkv_cross;
463
    {
464
        int M = qk_cross.h * qk_cross.elempack / num_heads;
465
        int N = v_affine.h * v_affine.elempack / num_heads;
466
        int K = v_affine.w;
467
        int B = num_heads;
468

469
        // int M_elempack = opt.use_shader_pack8 && M % 8 == 0 ? 8 : M % 4 == 0 ? 4 : 1;
470
        // int N_elempack = opt.use_shader_pack8 && N % 8 == 0 ? 8 : N % 4 == 0 ? 4 : 1;
471
        // int NB_elempack = opt.use_shader_pack8 && (N * B) % 8 == 0 ? 8 : (N * B) % 4 == 0 ? 4 : 1;
472
        int M_elempack = M % 4 == 0 ? 4 : 1;
473
        int N_elempack = N % 4 == 0 ? 4 : 1;
474
        int NB_elempack = (N * B) % 4 == 0 ? 4 : 1;
475
        size_t N_elemsize = v_affine.elemsize / v_affine.elempack * N_elempack;
476

477
        if (opt.use_fp16_packed && !opt.use_fp16_storage)
478
        {
479
            if (N_elempack == 8) N_elemsize = 8 * 2u;
480
            if (N_elempack == 4) N_elemsize = 4 * 2u;
481
            if (N_elempack == 1) N_elemsize = 4u;
482
        }
483

484
        if (M_elempack < qk_cross.elempack)
485
        {
486
            VkMat tmp;
487
            vkdev->convert_packing(qk_cross, tmp, M_elempack, cmd, opt);
488
            qk_cross = tmp;
489
        }
490

491
        if (N_elempack < v_affine.elempack)
492
        {
493
            VkMat tmp;
494
            vkdev->convert_packing(v_affine, tmp, N_elempack, cmd, opt);
495
            v_affine = tmp;
496
        }
497

498
        qkv_cross.create(M, N / N_elempack * B, N_elemsize, N_elempack, opt.blob_vkallocator);
499
        if (qkv_cross.empty())
500
            return -100;
501

502
        std::vector<VkMat> bindings(3);
503
        bindings[0] = qk_cross;
504
        bindings[1] = v_affine;
505
        bindings[2] = qkv_cross;
506

507
        std::vector<vk_constant_type> constants(4);
508
        constants[0].i = M / M_elempack;
509
        constants[1].i = N / N_elempack;
510
        constants[2].i = K;
511
        constants[3].i = B;
512

513
        VkMat dispatcher;
514
        dispatcher.w = N / N_elempack;
515
        dispatcher.h = M / M_elempack;
516
        dispatcher.c = B;
517

518
        const Pipeline* pipeline = 0;
519
        if (M_elempack == 1 && N_elempack == 1)
520
        {
521
            pipeline = pipeline_multiheadattention_qkv_cross;
522
        }
523
        if (M_elempack == 1 && N_elempack == 4)
524
        {
525
            pipeline = pipeline_multiheadattention_qkv_cross_pack1to4;
526
        }
527
        if (M_elempack == 4 && N_elempack == 1)
528
        {
529
            pipeline = pipeline_multiheadattention_qkv_cross_pack4to1;
530
        }
531
        if (M_elempack == 4 && N_elempack == 4)
532
        {
533
            pipeline = pipeline_multiheadattention_qkv_cross_pack4;
534
        }
535

536
        cmd.record_pipeline(pipeline, bindings, constants, dispatcher);
537

538
        if (NB_elempack > N_elempack)
539
        {
540
            VkMat tmp;
541
            vkdev->convert_packing(qkv_cross, tmp, NB_elempack, cmd, opt);
542
            qkv_cross = tmp;
543
        }
544
    }
545

546
    qk_cross.release();
547
    v_affine.release();
548

549
    o_gemm->forward(qkv_cross, top_blobs[0], cmd, opt);
550

551
    return 0;
552
}
553

554
int MultiHeadAttention_vulkan::forward(const std::vector<VkImageMat>& bottom_blobs, std::vector<VkImageMat>& top_blobs, VkCompute& cmd, const Option& opt) const
555
{
556
    const VkImageMat& q_blob = bottom_blobs[0];
557
    const VkImageMat& k_blob = (bottom_blobs.size() == 1 || (bottom_blobs.size() == 2 && attn_mask)) ? q_blob : bottom_blobs[1];
558
    const VkImageMat& v_blob = (bottom_blobs.size() == 1 || (bottom_blobs.size() == 2 && attn_mask)) ? q_blob : (bottom_blobs.size() == 2 || (bottom_blobs.size() == 3 && attn_mask)) ? k_blob : bottom_blobs[2];
559
    VkImageMat attn_mask_blob = attn_mask ? bottom_blobs[bottom_blobs.size() - 1] : VkImageMat();
560

561
    const int embed_dim_per_head = embed_dim / num_heads;
562
    const int src_seqlen = q_blob.h * q_blob.elempack;
563
    const int dst_seqlen = k_blob.h * k_blob.elempack;
564

565
    VkImageMat q_affine;
566
    q_gemm->forward(q_blob, q_affine, cmd, opt);
567

568
    VkImageMat k_affine;
569
    k_gemm->forward(k_blob, k_affine, cmd, opt);
570

571
    VkImageMat qk_cross;
572
    {
573
        int M = q_affine.w;
574
        int N = k_affine.w;
575
        int K = q_affine.h * q_affine.elempack / num_heads;
576
        int B = num_heads;
577

578
        // int K_elempack = opt.use_shader_pack8 && K % 8 == 0 ? 8 : K % 4 == 0 ? 4 : 1;
579
        // int M_elempack = opt.use_shader_pack8 && M % 8 == 0 ? 8 : M % 4 == 0 ? 4 : 1;
580
        // int MB_elempack = opt.use_shader_pack8 && (M * B) % 8 == 0 ? 8 : (M * B) % 4 == 0 ? 4 : 1;
581
        int K_elempack = K % 4 == 0 ? 4 : 1;
582
        int M_elempack = M % 4 == 0 ? 4 : 1;
583
        int MB_elempack = (M * B) % 4 == 0 ? 4 : 1;
584
        size_t M_elemsize = q_affine.elemsize / q_affine.elempack * M_elempack;
585

586
        if (opt.use_fp16_packed && !opt.use_fp16_storage)
587
        {
588
            if (M_elempack == 8) M_elemsize = 8 * 2u;
589
            if (M_elempack == 4) M_elemsize = 4 * 2u;
590
            if (M_elempack == 1) M_elemsize = 4u;
591
        }
592

593
        if (K_elempack < q_affine.elempack)
594
        {
595
            VkImageMat tmp;
596
            vkdev->convert_packing(q_affine, tmp, K_elempack, cmd, opt);
597
            q_affine = tmp;
598
        }
599
        if (K_elempack < k_affine.elempack)
600
        {
601
            VkImageMat tmp;
602
            vkdev->convert_packing(k_affine, tmp, K_elempack, cmd, opt);
603
            k_affine = tmp;
604
        }
605
        if (M_elempack < attn_mask_blob.elempack)
606
        {
607
            VkImageMat tmp;
608
            vkdev->convert_packing(attn_mask_blob, tmp, M_elempack, cmd, opt);
609
            attn_mask_blob = tmp;
610
        }
611

612
        qk_cross.create(N, M / M_elempack * B, M_elemsize, M_elempack, opt.blob_vkallocator);
613
        if (qk_cross.empty())
614
            return -100;
615

616
        std::vector<VkImageMat> bindings(4);
617
        bindings[0] = q_affine;
618
        bindings[1] = k_affine;
619
        bindings[2] = qk_cross;
620
        bindings[3] = attn_mask_blob;
621

622
        std::vector<vk_constant_type> constants(5);
623
        constants[0].i = M / M_elempack;
624
        constants[1].i = N;
625
        constants[2].i = K / K_elempack;
626
        constants[3].i = B;
627
        constants[4].i = attn_mask_blob.dims;
628

629
        VkImageMat dispatcher;
630
        dispatcher.w = N;
631
        dispatcher.h = M / M_elempack;
632
        dispatcher.c = B;
633

634
        const Pipeline* pipeline = 0;
635
        if (K_elempack == 1 && M_elempack == 1)
636
        {
637
            pipeline = pipeline_multiheadattention_qk_cross;
638
        }
639
        if (K_elempack == 1 && M_elempack == 4)
640
        {
641
            pipeline = pipeline_multiheadattention_qk_cross_pack1to4;
642
        }
643
        if (K_elempack == 4 && M_elempack == 1)
644
        {
645
            pipeline = pipeline_multiheadattention_qk_cross_pack4to1;
646
        }
647
        if (K_elempack == 4 && M_elempack == 4)
648
        {
649
            pipeline = pipeline_multiheadattention_qk_cross_pack4;
650
        }
651

652
        cmd.record_pipeline(pipeline, bindings, constants, dispatcher);
653

654
        if (MB_elempack > M_elempack)
655
        {
656
            VkImageMat tmp;
657
            vkdev->convert_packing(qk_cross, tmp, MB_elempack, cmd, opt);
658
            qk_cross = tmp;
659
        }
660
    }
661

662
    q_affine.release();
663
    k_affine.release();
664

665
    qk_softmax->forward_inplace(qk_cross, cmd, opt);
666

667
    VkImageMat v_affine;
668
    v_gemm->forward(v_blob, v_affine, cmd, opt);
669

670
    VkImageMat qkv_cross;
671
    {
672
        int M = qk_cross.h * qk_cross.elempack / num_heads;
673
        int N = v_affine.h * v_affine.elempack / num_heads;
674
        int K = v_affine.w;
675
        int B = num_heads;
676

677
        // int M_elempack = opt.use_shader_pack8 && M % 8 == 0 ? 8 : M % 4 == 0 ? 4 : 1;
678
        // int N_elempack = opt.use_shader_pack8 && N % 8 == 0 ? 8 : N % 4 == 0 ? 4 : 1;
679
        // int NB_elempack = opt.use_shader_pack8 && (N * B) % 8 == 0 ? 8 : (N * B) % 4 == 0 ? 4 : 1;
680
        int M_elempack = M % 4 == 0 ? 4 : 1;
681
        int N_elempack = N % 4 == 0 ? 4 : 1;
682
        int NB_elempack = (N * B) % 4 == 0 ? 4 : 1;
683
        size_t N_elemsize = v_affine.elemsize / v_affine.elempack * N_elempack;
684

685
        if (opt.use_fp16_packed && !opt.use_fp16_storage)
686
        {
687
            if (N_elempack == 8) N_elemsize = 8 * 2u;
688
            if (N_elempack == 4) N_elemsize = 4 * 2u;
689
            if (N_elempack == 1) N_elemsize = 4u;
690
        }
691

692
        if (M_elempack < qk_cross.elempack)
693
        {
694
            VkImageMat tmp;
695
            vkdev->convert_packing(qk_cross, tmp, M_elempack, cmd, opt);
696
            qk_cross = tmp;
697
        }
698

699
        if (N_elempack < v_affine.elempack)
700
        {
701
            VkImageMat tmp;
702
            vkdev->convert_packing(v_affine, tmp, N_elempack, cmd, opt);
703
            v_affine = tmp;
704
        }
705

706
        qkv_cross.create(M, N / N_elempack * B, N_elemsize, N_elempack, opt.blob_vkallocator);
707
        if (qkv_cross.empty())
708
            return -100;
709

710
        std::vector<VkImageMat> bindings(3);
711
        bindings[0] = qk_cross;
712
        bindings[1] = v_affine;
713
        bindings[2] = qkv_cross;
714

715
        std::vector<vk_constant_type> constants(4);
716
        constants[0].i = M / M_elempack;
717
        constants[1].i = N / N_elempack;
718
        constants[2].i = K;
719
        constants[3].i = B;
720

721
        VkImageMat dispatcher;
722
        dispatcher.w = N / N_elempack;
723
        dispatcher.h = M / M_elempack;
724
        dispatcher.c = B;
725

726
        const Pipeline* pipeline = 0;
727
        if (M_elempack == 1 && N_elempack == 1)
728
        {
729
            pipeline = pipeline_multiheadattention_qkv_cross;
730
        }
731
        if (M_elempack == 1 && N_elempack == 4)
732
        {
733
            pipeline = pipeline_multiheadattention_qkv_cross_pack1to4;
734
        }
735
        if (M_elempack == 4 && N_elempack == 1)
736
        {
737
            pipeline = pipeline_multiheadattention_qkv_cross_pack4to1;
738
        }
739
        if (M_elempack == 4 && N_elempack == 4)
740
        {
741
            pipeline = pipeline_multiheadattention_qkv_cross_pack4;
742
        }
743

744
        cmd.record_pipeline(pipeline, bindings, constants, dispatcher);
745

746
        if (NB_elempack > N_elempack)
747
        {
748
            VkImageMat tmp;
749
            vkdev->convert_packing(qkv_cross, tmp, NB_elempack, cmd, opt);
750
            qkv_cross = tmp;
751
        }
752
    }
753

754
    qk_cross.release();
755
    v_affine.release();
756

757
    o_gemm->forward(qkv_cross, top_blobs[0], cmd, opt);
758

759
    return 0;
760
}
761

762
} // namespace ncnn
763

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.