1
// Tencent is pleased to support the open source community by making ncnn available.
3
// Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved.
5
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
6
// in compliance with the License. You may obtain a copy of the License at
8
// https://opensource.org/licenses/BSD-3-Clause
10
// Unless required by applicable law or agreed to in writing, software distributed
11
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
12
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
13
// specific language governing permissions and limitations under the License.
15
#include "multiheadattention_vulkan.h"
17
#include "layer_shader_type.h"
18
#include "layer_type.h"
22
MultiHeadAttention_vulkan::MultiHeadAttention_vulkan()
24
support_vulkan = true;
25
support_image_storage = true;
35
pipeline_multiheadattention_qk_cross = 0;
36
pipeline_multiheadattention_qk_cross_pack4 = 0;
37
pipeline_multiheadattention_qk_cross_pack1to4 = 0;
38
pipeline_multiheadattention_qk_cross_pack4to1 = 0;
40
pipeline_multiheadattention_qkv_cross = 0;
41
pipeline_multiheadattention_qkv_cross_pack4 = 0;
42
pipeline_multiheadattention_qkv_cross_pack1to4 = 0;
43
pipeline_multiheadattention_qkv_cross_pack4to1 = 0;
46
int MultiHeadAttention_vulkan::create_pipeline(const Option& opt)
48
const int embed_dim_per_head = embed_dim / num_heads;
49
const int qdim = weight_data_size / embed_dim;
51
q_gemm = ncnn::create_layer_vulkan(ncnn::LayerType::Gemm);
52
q_gemm->vkdev = vkdev;
56
pd.set(2, 0); // transA
57
pd.set(3, 1); // transB
58
pd.set(4, 1); // constantA
59
pd.set(5, 0); // constantB
60
pd.set(6, 1); // constantC
61
pd.set(7, embed_dim); // M
64
pd.set(10, 1); // constant_broadcast_type_C
65
pd.set(11, 0); // output_N1M
66
// pd.set(12, 1); // output_elempack
67
pd.set(14, 0); // output_transpose
68
q_gemm->load_param(pd);
70
weights[0] = q_weight_data;
71
weights[1] = q_bias_data;
72
q_gemm->load_model(ModelBinFromMatArray(weights));
73
q_gemm->create_pipeline(opt);
77
q_weight_data.release();
78
q_bias_data.release();
83
k_gemm = ncnn::create_layer_vulkan(ncnn::LayerType::Gemm);
84
k_gemm->vkdev = vkdev;
86
pd.set(2, 0); // transA
87
pd.set(3, 1); // transB
88
pd.set(4, 1); // constantA
89
pd.set(5, 0); // constantB
90
pd.set(6, 1); // constantC
91
pd.set(7, embed_dim); // M
94
pd.set(10, 1); // constant_broadcast_type_C
95
pd.set(11, 0); // output_N1M
96
// pd.set(12, 1); // output_elempack
97
pd.set(14, 0); // output_transpose
98
k_gemm->load_param(pd);
100
weights[0] = k_weight_data;
101
weights[1] = k_bias_data;
102
k_gemm->load_model(ModelBinFromMatArray(weights));
103
k_gemm->create_pipeline(opt);
107
k_weight_data.release();
108
k_bias_data.release();
113
v_gemm = ncnn::create_layer_vulkan(ncnn::LayerType::Gemm);
114
v_gemm->vkdev = vkdev;
116
pd.set(2, 0); // transA
117
pd.set(3, 1); // transB
118
pd.set(4, 1); // constantA
119
pd.set(5, 0); // constantB
120
pd.set(6, 1); // constantC
121
pd.set(7, embed_dim); // M
123
pd.set(9, vdim); // K
124
pd.set(10, 1); // constant_broadcast_type_C
125
pd.set(11, 0); // output_N1M
126
// pd.set(12, 1); // output_elempack
127
pd.set(14, 0); // output_transpose
128
v_gemm->load_param(pd);
130
weights[0] = v_weight_data;
131
weights[1] = v_bias_data;
132
v_gemm->load_model(ModelBinFromMatArray(weights));
133
v_gemm->create_pipeline(opt);
137
v_weight_data.release();
138
v_bias_data.release();
143
std::vector<vk_specialization_type> specializations(6);
144
specializations[0].i = attn_mask;
145
specializations[1].i = 0; //constantM;
146
specializations[2].i = 0; //constantN;
147
specializations[3].i = 0; //embed_dim_per_head;//constantK;
148
specializations[4].i = num_heads;
149
specializations[5].i = 0; //attn_mask.dims;
152
pipeline_multiheadattention_qk_cross = new Pipeline(vkdev);
153
pipeline_multiheadattention_qk_cross->set_local_size_xyz(8, 8, 1);
154
pipeline_multiheadattention_qk_cross->create(LayerShaderType::multiheadattention_qk_cross, opt, specializations);
157
pipeline_multiheadattention_qk_cross_pack4 = new Pipeline(vkdev);
158
pipeline_multiheadattention_qk_cross_pack4->set_local_size_xyz(8, 8, 1);
159
pipeline_multiheadattention_qk_cross_pack4->create(LayerShaderType::multiheadattention_qk_cross_pack4, opt, specializations);
162
pipeline_multiheadattention_qk_cross_pack1to4 = new Pipeline(vkdev);
163
pipeline_multiheadattention_qk_cross_pack1to4->set_local_size_xyz(8, 8, 1);
164
pipeline_multiheadattention_qk_cross_pack1to4->create(LayerShaderType::multiheadattention_qk_cross_pack1to4, opt, specializations);
167
pipeline_multiheadattention_qk_cross_pack4to1 = new Pipeline(vkdev);
168
pipeline_multiheadattention_qk_cross_pack4to1->set_local_size_xyz(8, 8, 1);
169
pipeline_multiheadattention_qk_cross_pack4to1->create(LayerShaderType::multiheadattention_qk_cross_pack4to1, opt, specializations);
173
std::vector<vk_specialization_type> specializations(4);
174
specializations[0].i = 0; //constantM;
175
specializations[1].i = 0; //embed_dim_per_head;//constantN;
176
specializations[2].i = 0; //constantK;
177
specializations[3].i = num_heads;
180
pipeline_multiheadattention_qkv_cross = new Pipeline(vkdev);
181
pipeline_multiheadattention_qkv_cross->set_local_size_xyz(8, 8, 1);
182
pipeline_multiheadattention_qkv_cross->create(LayerShaderType::multiheadattention_qkv_cross, opt, specializations);
185
pipeline_multiheadattention_qkv_cross_pack4 = new Pipeline(vkdev);
186
pipeline_multiheadattention_qkv_cross_pack4->set_local_size_xyz(8, 8, 1);
187
pipeline_multiheadattention_qkv_cross_pack4->create(LayerShaderType::multiheadattention_qkv_cross_pack4, opt, specializations);
190
pipeline_multiheadattention_qkv_cross_pack1to4 = new Pipeline(vkdev);
191
pipeline_multiheadattention_qkv_cross_pack1to4->set_local_size_xyz(8, 8, 1);
192
pipeline_multiheadattention_qkv_cross_pack1to4->create(LayerShaderType::multiheadattention_qkv_cross_pack1to4, opt, specializations);
195
pipeline_multiheadattention_qkv_cross_pack4to1 = new Pipeline(vkdev);
196
pipeline_multiheadattention_qkv_cross_pack4to1->set_local_size_xyz(8, 8, 1);
197
pipeline_multiheadattention_qkv_cross_pack4to1->create(LayerShaderType::multiheadattention_qkv_cross_pack4to1, opt, specializations);
202
qk_softmax = ncnn::create_layer_vulkan(ncnn::LayerType::Softmax);
203
qk_softmax->vkdev = vkdev;
207
qk_softmax->load_param(pd);
208
qk_softmax->load_model(ModelBinFromMatArray(0));
209
qk_softmax->create_pipeline(opt);
213
o_gemm = ncnn::create_layer_vulkan(ncnn::LayerType::Gemm);
214
o_gemm->vkdev = vkdev;
216
pd.set(2, 1); // transA
217
pd.set(3, 1); // transB
218
pd.set(4, 0); // constantA
219
pd.set(5, 1); // constantB
220
pd.set(6, 1); // constantC
221
pd.set(7, 0); // M = outch
222
pd.set(8, qdim); // N = size
223
pd.set(9, embed_dim); // K = maxk*inch
224
pd.set(10, 4); // constant_broadcast_type_C
225
pd.set(11, 0); // output_N1M
226
o_gemm->load_param(pd);
228
weights[0] = out_weight_data;
229
weights[1] = out_bias_data;
230
o_gemm->load_model(ModelBinFromMatArray(weights));
231
o_gemm->create_pipeline(opt);
235
out_weight_data.release();
236
out_bias_data.release();
243
int MultiHeadAttention_vulkan::destroy_pipeline(const Option& opt)
247
q_gemm->destroy_pipeline(opt);
254
k_gemm->destroy_pipeline(opt);
261
v_gemm->destroy_pipeline(opt);
266
delete pipeline_multiheadattention_qk_cross;
267
pipeline_multiheadattention_qk_cross = 0;
269
delete pipeline_multiheadattention_qk_cross_pack4;
270
pipeline_multiheadattention_qk_cross_pack4 = 0;
272
delete pipeline_multiheadattention_qk_cross_pack1to4;
273
pipeline_multiheadattention_qk_cross_pack1to4 = 0;
275
delete pipeline_multiheadattention_qk_cross_pack4to1;
276
pipeline_multiheadattention_qk_cross_pack4to1 = 0;
278
delete pipeline_multiheadattention_qkv_cross;
279
pipeline_multiheadattention_qkv_cross = 0;
281
delete pipeline_multiheadattention_qkv_cross_pack4;
282
pipeline_multiheadattention_qkv_cross_pack4 = 0;
284
delete pipeline_multiheadattention_qkv_cross_pack1to4;
285
pipeline_multiheadattention_qkv_cross_pack1to4 = 0;
287
delete pipeline_multiheadattention_qkv_cross_pack4to1;
288
pipeline_multiheadattention_qkv_cross_pack4to1 = 0;
292
qk_softmax->destroy_pipeline(opt);
299
o_gemm->destroy_pipeline(opt);
307
int MultiHeadAttention_vulkan::upload_model(VkTransfer& cmd, const Option& opt)
311
q_gemm->upload_model(cmd, opt);
316
k_gemm->upload_model(cmd, opt);
321
v_gemm->upload_model(cmd, opt);
326
o_gemm->upload_model(cmd, opt);
332
int MultiHeadAttention_vulkan::forward(const std::vector<VkMat>& bottom_blobs, std::vector<VkMat>& top_blobs, VkCompute& cmd, const Option& opt) const
334
const VkMat& q_blob = bottom_blobs[0];
335
const VkMat& k_blob = (bottom_blobs.size() == 1 || (bottom_blobs.size() == 2 && attn_mask)) ? q_blob : bottom_blobs[1];
336
const VkMat& v_blob = (bottom_blobs.size() == 1 || (bottom_blobs.size() == 2 && attn_mask)) ? q_blob : (bottom_blobs.size() == 2 || (bottom_blobs.size() == 3 && attn_mask)) ? k_blob : bottom_blobs[2];
337
VkMat attn_mask_blob = attn_mask ? bottom_blobs[bottom_blobs.size() - 1] : VkMat();
339
const int embed_dim_per_head = embed_dim / num_heads;
340
const int src_seqlen = q_blob.h * q_blob.elempack;
341
const int dst_seqlen = k_blob.h * k_blob.elempack;
344
q_gemm->forward(q_blob, q_affine, cmd, opt);
347
k_gemm->forward(k_blob, k_affine, cmd, opt);
353
int K = q_affine.h * q_affine.elempack / num_heads;
356
// int K_elempack = opt.use_shader_pack8 && K % 8 == 0 ? 8 : K % 4 == 0 ? 4 : 1;
357
// int M_elempack = opt.use_shader_pack8 && M % 8 == 0 ? 8 : M % 4 == 0 ? 4 : 1;
358
// int MB_elempack = opt.use_shader_pack8 && (M * B) % 8 == 0 ? 8 : (M * B) % 4 == 0 ? 4 : 1;
359
int K_elempack = K % 4 == 0 ? 4 : 1;
360
int M_elempack = M % 4 == 0 ? 4 : 1;
361
int MB_elempack = (M * B) % 4 == 0 ? 4 : 1;
362
size_t M_elemsize = q_affine.elemsize / q_affine.elempack * M_elempack;
364
if (opt.use_fp16_packed && !opt.use_fp16_storage)
366
if (M_elempack == 8) M_elemsize = 8 * 2u;
367
if (M_elempack == 4) M_elemsize = 4 * 2u;
368
if (M_elempack == 1) M_elemsize = 4u;
371
if (K_elempack < q_affine.elempack)
374
vkdev->convert_packing(q_affine, tmp, K_elempack, cmd, opt);
377
if (K_elempack < k_affine.elempack)
380
vkdev->convert_packing(k_affine, tmp, K_elempack, cmd, opt);
383
if (M_elempack < attn_mask_blob.elempack)
386
vkdev->convert_packing(attn_mask_blob, tmp, M_elempack, cmd, opt);
387
attn_mask_blob = tmp;
390
qk_cross.create(N, M / M_elempack * B, M_elemsize, M_elempack, opt.blob_vkallocator);
391
if (qk_cross.empty())
394
std::vector<VkMat> bindings(4);
395
bindings[0] = q_affine;
396
bindings[1] = k_affine;
397
bindings[2] = qk_cross;
398
bindings[3] = attn_mask_blob;
400
std::vector<vk_constant_type> constants(5);
401
constants[0].i = M / M_elempack;
403
constants[2].i = K / K_elempack;
405
constants[4].i = attn_mask_blob.dims;
409
dispatcher.h = M / M_elempack;
412
const Pipeline* pipeline = 0;
413
if (K_elempack == 1 && M_elempack == 1)
415
pipeline = pipeline_multiheadattention_qk_cross;
417
if (K_elempack == 1 && M_elempack == 4)
419
pipeline = pipeline_multiheadattention_qk_cross_pack1to4;
421
if (K_elempack == 4 && M_elempack == 1)
423
pipeline = pipeline_multiheadattention_qk_cross_pack4to1;
425
if (K_elempack == 4 && M_elempack == 4)
427
pipeline = pipeline_multiheadattention_qk_cross_pack4;
430
cmd.record_pipeline(pipeline, bindings, constants, dispatcher);
432
if (MB_elempack > M_elempack)
435
vkdev->convert_packing(qk_cross, tmp, MB_elempack, cmd, opt);
443
qk_softmax->forward_inplace(qk_cross, cmd, opt);
445
if (vkdev->info.vendor_id() == 0x10de)
447
// FIXME softmax produces nan result on nvidia (about 20% chance)
448
// memory barrier seems to be not enough here
449
// device copy-to and copy-back is better than queue submit anyway --- nihui
451
// cmd.submit_and_wait();
454
VkImageMat qk_cross2;
455
cmd.record_buffer_to_image(qk_cross, qk_cross2, opt);
456
cmd.record_image_to_buffer(qk_cross2, qk_cross, opt);
460
v_gemm->forward(v_blob, v_affine, cmd, opt);
464
int M = qk_cross.h * qk_cross.elempack / num_heads;
465
int N = v_affine.h * v_affine.elempack / num_heads;
469
// int M_elempack = opt.use_shader_pack8 && M % 8 == 0 ? 8 : M % 4 == 0 ? 4 : 1;
470
// int N_elempack = opt.use_shader_pack8 && N % 8 == 0 ? 8 : N % 4 == 0 ? 4 : 1;
471
// int NB_elempack = opt.use_shader_pack8 && (N * B) % 8 == 0 ? 8 : (N * B) % 4 == 0 ? 4 : 1;
472
int M_elempack = M % 4 == 0 ? 4 : 1;
473
int N_elempack = N % 4 == 0 ? 4 : 1;
474
int NB_elempack = (N * B) % 4 == 0 ? 4 : 1;
475
size_t N_elemsize = v_affine.elemsize / v_affine.elempack * N_elempack;
477
if (opt.use_fp16_packed && !opt.use_fp16_storage)
479
if (N_elempack == 8) N_elemsize = 8 * 2u;
480
if (N_elempack == 4) N_elemsize = 4 * 2u;
481
if (N_elempack == 1) N_elemsize = 4u;
484
if (M_elempack < qk_cross.elempack)
487
vkdev->convert_packing(qk_cross, tmp, M_elempack, cmd, opt);
491
if (N_elempack < v_affine.elempack)
494
vkdev->convert_packing(v_affine, tmp, N_elempack, cmd, opt);
498
qkv_cross.create(M, N / N_elempack * B, N_elemsize, N_elempack, opt.blob_vkallocator);
499
if (qkv_cross.empty())
502
std::vector<VkMat> bindings(3);
503
bindings[0] = qk_cross;
504
bindings[1] = v_affine;
505
bindings[2] = qkv_cross;
507
std::vector<vk_constant_type> constants(4);
508
constants[0].i = M / M_elempack;
509
constants[1].i = N / N_elempack;
514
dispatcher.w = N / N_elempack;
515
dispatcher.h = M / M_elempack;
518
const Pipeline* pipeline = 0;
519
if (M_elempack == 1 && N_elempack == 1)
521
pipeline = pipeline_multiheadattention_qkv_cross;
523
if (M_elempack == 1 && N_elempack == 4)
525
pipeline = pipeline_multiheadattention_qkv_cross_pack1to4;
527
if (M_elempack == 4 && N_elempack == 1)
529
pipeline = pipeline_multiheadattention_qkv_cross_pack4to1;
531
if (M_elempack == 4 && N_elempack == 4)
533
pipeline = pipeline_multiheadattention_qkv_cross_pack4;
536
cmd.record_pipeline(pipeline, bindings, constants, dispatcher);
538
if (NB_elempack > N_elempack)
541
vkdev->convert_packing(qkv_cross, tmp, NB_elempack, cmd, opt);
549
o_gemm->forward(qkv_cross, top_blobs[0], cmd, opt);
554
int MultiHeadAttention_vulkan::forward(const std::vector<VkImageMat>& bottom_blobs, std::vector<VkImageMat>& top_blobs, VkCompute& cmd, const Option& opt) const
556
const VkImageMat& q_blob = bottom_blobs[0];
557
const VkImageMat& k_blob = (bottom_blobs.size() == 1 || (bottom_blobs.size() == 2 && attn_mask)) ? q_blob : bottom_blobs[1];
558
const VkImageMat& v_blob = (bottom_blobs.size() == 1 || (bottom_blobs.size() == 2 && attn_mask)) ? q_blob : (bottom_blobs.size() == 2 || (bottom_blobs.size() == 3 && attn_mask)) ? k_blob : bottom_blobs[2];
559
VkImageMat attn_mask_blob = attn_mask ? bottom_blobs[bottom_blobs.size() - 1] : VkImageMat();
561
const int embed_dim_per_head = embed_dim / num_heads;
562
const int src_seqlen = q_blob.h * q_blob.elempack;
563
const int dst_seqlen = k_blob.h * k_blob.elempack;
566
q_gemm->forward(q_blob, q_affine, cmd, opt);
569
k_gemm->forward(k_blob, k_affine, cmd, opt);
575
int K = q_affine.h * q_affine.elempack / num_heads;
578
// int K_elempack = opt.use_shader_pack8 && K % 8 == 0 ? 8 : K % 4 == 0 ? 4 : 1;
579
// int M_elempack = opt.use_shader_pack8 && M % 8 == 0 ? 8 : M % 4 == 0 ? 4 : 1;
580
// int MB_elempack = opt.use_shader_pack8 && (M * B) % 8 == 0 ? 8 : (M * B) % 4 == 0 ? 4 : 1;
581
int K_elempack = K % 4 == 0 ? 4 : 1;
582
int M_elempack = M % 4 == 0 ? 4 : 1;
583
int MB_elempack = (M * B) % 4 == 0 ? 4 : 1;
584
size_t M_elemsize = q_affine.elemsize / q_affine.elempack * M_elempack;
586
if (opt.use_fp16_packed && !opt.use_fp16_storage)
588
if (M_elempack == 8) M_elemsize = 8 * 2u;
589
if (M_elempack == 4) M_elemsize = 4 * 2u;
590
if (M_elempack == 1) M_elemsize = 4u;
593
if (K_elempack < q_affine.elempack)
596
vkdev->convert_packing(q_affine, tmp, K_elempack, cmd, opt);
599
if (K_elempack < k_affine.elempack)
602
vkdev->convert_packing(k_affine, tmp, K_elempack, cmd, opt);
605
if (M_elempack < attn_mask_blob.elempack)
608
vkdev->convert_packing(attn_mask_blob, tmp, M_elempack, cmd, opt);
609
attn_mask_blob = tmp;
612
qk_cross.create(N, M / M_elempack * B, M_elemsize, M_elempack, opt.blob_vkallocator);
613
if (qk_cross.empty())
616
std::vector<VkImageMat> bindings(4);
617
bindings[0] = q_affine;
618
bindings[1] = k_affine;
619
bindings[2] = qk_cross;
620
bindings[3] = attn_mask_blob;
622
std::vector<vk_constant_type> constants(5);
623
constants[0].i = M / M_elempack;
625
constants[2].i = K / K_elempack;
627
constants[4].i = attn_mask_blob.dims;
629
VkImageMat dispatcher;
631
dispatcher.h = M / M_elempack;
634
const Pipeline* pipeline = 0;
635
if (K_elempack == 1 && M_elempack == 1)
637
pipeline = pipeline_multiheadattention_qk_cross;
639
if (K_elempack == 1 && M_elempack == 4)
641
pipeline = pipeline_multiheadattention_qk_cross_pack1to4;
643
if (K_elempack == 4 && M_elempack == 1)
645
pipeline = pipeline_multiheadattention_qk_cross_pack4to1;
647
if (K_elempack == 4 && M_elempack == 4)
649
pipeline = pipeline_multiheadattention_qk_cross_pack4;
652
cmd.record_pipeline(pipeline, bindings, constants, dispatcher);
654
if (MB_elempack > M_elempack)
657
vkdev->convert_packing(qk_cross, tmp, MB_elempack, cmd, opt);
665
qk_softmax->forward_inplace(qk_cross, cmd, opt);
668
v_gemm->forward(v_blob, v_affine, cmd, opt);
670
VkImageMat qkv_cross;
672
int M = qk_cross.h * qk_cross.elempack / num_heads;
673
int N = v_affine.h * v_affine.elempack / num_heads;
677
// int M_elempack = opt.use_shader_pack8 && M % 8 == 0 ? 8 : M % 4 == 0 ? 4 : 1;
678
// int N_elempack = opt.use_shader_pack8 && N % 8 == 0 ? 8 : N % 4 == 0 ? 4 : 1;
679
// int NB_elempack = opt.use_shader_pack8 && (N * B) % 8 == 0 ? 8 : (N * B) % 4 == 0 ? 4 : 1;
680
int M_elempack = M % 4 == 0 ? 4 : 1;
681
int N_elempack = N % 4 == 0 ? 4 : 1;
682
int NB_elempack = (N * B) % 4 == 0 ? 4 : 1;
683
size_t N_elemsize = v_affine.elemsize / v_affine.elempack * N_elempack;
685
if (opt.use_fp16_packed && !opt.use_fp16_storage)
687
if (N_elempack == 8) N_elemsize = 8 * 2u;
688
if (N_elempack == 4) N_elemsize = 4 * 2u;
689
if (N_elempack == 1) N_elemsize = 4u;
692
if (M_elempack < qk_cross.elempack)
695
vkdev->convert_packing(qk_cross, tmp, M_elempack, cmd, opt);
699
if (N_elempack < v_affine.elempack)
702
vkdev->convert_packing(v_affine, tmp, N_elempack, cmd, opt);
706
qkv_cross.create(M, N / N_elempack * B, N_elemsize, N_elempack, opt.blob_vkallocator);
707
if (qkv_cross.empty())
710
std::vector<VkImageMat> bindings(3);
711
bindings[0] = qk_cross;
712
bindings[1] = v_affine;
713
bindings[2] = qkv_cross;
715
std::vector<vk_constant_type> constants(4);
716
constants[0].i = M / M_elempack;
717
constants[1].i = N / N_elempack;
721
VkImageMat dispatcher;
722
dispatcher.w = N / N_elempack;
723
dispatcher.h = M / M_elempack;
726
const Pipeline* pipeline = 0;
727
if (M_elempack == 1 && N_elempack == 1)
729
pipeline = pipeline_multiheadattention_qkv_cross;
731
if (M_elempack == 1 && N_elempack == 4)
733
pipeline = pipeline_multiheadattention_qkv_cross_pack1to4;
735
if (M_elempack == 4 && N_elempack == 1)
737
pipeline = pipeline_multiheadattention_qkv_cross_pack4to1;
739
if (M_elempack == 4 && N_elempack == 4)
741
pipeline = pipeline_multiheadattention_qkv_cross_pack4;
744
cmd.record_pipeline(pipeline, bindings, constants, dispatcher);
746
if (NB_elempack > N_elempack)
749
vkdev->convert_packing(qkv_cross, tmp, NB_elempack, cmd, opt);
757
o_gemm->forward(qkv_cross, top_blobs[0], cmd, opt);