1
// Tencent is pleased to support the open source community by making ncnn available.
3
// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved.
5
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
6
// in compliance with the License. You may obtain a copy of the License at
8
// https://opensource.org/licenses/BSD-3-Clause
10
// Unless required by applicable law or agreed to in writing, software distributed
11
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
12
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
13
// specific language governing permissions and limitations under the License.
15
#include "multiheadattention.h"
21
MultiHeadAttention::MultiHeadAttention()
25
int MultiHeadAttention::load_param(const ParamDict& pd)
27
embed_dim = pd.get(0, 0);
28
num_heads = pd.get(1, 1);
29
weight_data_size = pd.get(2, 0);
30
kdim = pd.get(3, embed_dim);
31
vdim = pd.get(4, embed_dim);
32
attn_mask = pd.get(5, 0);
33
scale = pd.get(6, 1.f / sqrtf(embed_dim / num_heads));
38
int MultiHeadAttention::load_model(const ModelBin& mb)
40
const int qdim = weight_data_size / embed_dim;
42
q_weight_data = mb.load(embed_dim * qdim, 0);
43
if (q_weight_data.empty())
46
q_bias_data = mb.load(embed_dim, 1);
47
if (q_bias_data.empty())
50
k_weight_data = mb.load(embed_dim * kdim, 0);
51
if (k_weight_data.empty())
54
k_bias_data = mb.load(embed_dim, 1);
55
if (k_bias_data.empty())
58
v_weight_data = mb.load(embed_dim * vdim, 0);
59
if (v_weight_data.empty())
62
v_bias_data = mb.load(embed_dim, 1);
63
if (v_bias_data.empty())
66
out_weight_data = mb.load(qdim * embed_dim, 0);
67
if (out_weight_data.empty())
70
out_bias_data = mb.load(qdim, 1);
71
if (out_bias_data.empty())
77
// refers to https://pytorch.org/docs/stable/generated/torch.nn.MultiheadAttention.html
78
int MultiHeadAttention::forward(const std::vector<Mat>& bottom_blobs, std::vector<Mat>& top_blobs, const Option& opt) const
80
const Mat& q_blob = bottom_blobs[0];
81
const Mat& k_blob = (bottom_blobs.size() == 1 || (bottom_blobs.size() == 2 && attn_mask)) ? q_blob : bottom_blobs[1];
82
const Mat& v_blob = (bottom_blobs.size() == 1 || (bottom_blobs.size() == 2 && attn_mask)) ? q_blob : (bottom_blobs.size() == 2 || (bottom_blobs.size() == 3 && attn_mask)) ? k_blob : bottom_blobs[2];
83
const Mat& attn_mask_blob = attn_mask ? bottom_blobs[bottom_blobs.size() - 1] : Mat();
85
const int src_seqlen = q_blob.h;
86
const int dst_seqlen = k_blob.h;
87
const int embed_dim_per_head = embed_dim / num_heads;
88
const int qdim = weight_data_size / embed_dim;
90
// assert k_blob.h == v_blob.h
92
Mat& top_blob = top_blobs[0];
93
top_blob.create(qdim, src_seqlen, 4u, opt.blob_allocator);
97
Mat xq(embed_dim_per_head, src_seqlen, num_heads, 4u, opt.workspace_allocator);
100
Mat xk(embed_dim_per_head, dst_seqlen, num_heads, 4u, opt.workspace_allocator);
103
Mat xv(dst_seqlen, embed_dim_per_head, num_heads, 4u, opt.workspace_allocator);
107
Mat xqk(dst_seqlen, src_seqlen, num_heads, 4u, opt.workspace_allocator);
111
Mat xqkv(embed_dim_per_head, num_heads, src_seqlen, 4u, opt.workspace_allocator);
115
#pragma omp parallel for num_threads(opt.num_threads)
116
for (int q = 0; q < num_heads; q++)
118
// xq = affine(q) * scale
120
Mat outm = xq.channel(q);
122
for (int i = 0; i < src_seqlen; i++)
124
float* outptr = outm.row(i);
126
for (int j = 0; j < embed_dim_per_head; j++)
128
const float* ptr = q_blob.row(i);
129
const float* kptr = (const float*)q_weight_data + qdim * (q * embed_dim_per_head + j);
131
float sum = q_bias_data[q * embed_dim_per_head + j];
132
for (int k = 0; k < qdim; k++)
134
sum += *ptr++ * *kptr++;
137
outptr[j] = sum * scale;
144
Mat outm = xk.channel(q);
146
for (int i = 0; i < dst_seqlen; i++)
148
float* outptr = outm.row(i);
150
for (int j = 0; j < embed_dim_per_head; j++)
152
const float* ptr = k_blob.row(i);
153
const float* kptr = (const float*)k_weight_data + kdim * (q * embed_dim_per_head + j);
155
float sum = k_bias_data[q * embed_dim_per_head + j];
156
for (int k = 0; k < kdim; k++)
158
sum += *ptr++ * *kptr++;
168
Mat outm = xv.channel(q);
170
for (int i = 0; i < embed_dim_per_head; i++)
172
for (int j = 0; j < dst_seqlen; j++)
174
const float* ptr = v_blob.row(j);
175
const float* kptr = (const float*)v_weight_data + vdim * (q * embed_dim_per_head + i);
177
float sum = v_bias_data[q * embed_dim_per_head + i];
178
for (int k = 0; k < vdim; k++)
180
sum += *ptr++ * *kptr++;
183
float* outptr = outm.row(i);
191
// xq (embed_dim_per_head, src_seqlen)
192
// xk (embed_dim_per_head, dst_seqlen)
194
const Mat xqm = xq.channel(q);
195
const Mat xkm = xk.channel(q);
197
Mat outm = xqk.channel(q);
199
for (int i = 0; i < src_seqlen; i++)
201
float* outptr = outm.row(i);
203
for (int j = 0; j < dst_seqlen; j++)
205
const float* qptr = xqm.row(i);
206
const float* kptr = xkm.row(j);
209
for (int k = 0; k < embed_dim_per_head; k++)
211
sum += *qptr++ * *kptr++;
222
const Mat& maskm = attn_mask_blob.dims == 3 ? attn_mask_blob.channel(q) : attn_mask_blob;
223
Mat outm = xqk.channel(q);
225
for (int i = 0; i < src_seqlen; i++)
227
const float* mptr = maskm.row(i);
228
float* outptr = outm.row(i);
230
for (int j = 0; j < dst_seqlen; j++)
232
outptr[j] += mptr[j];
239
Mat outm = xqk.channel(q);
241
for (int i = 0; i < src_seqlen; i++)
243
float* ptr = outm.row(i);
245
float max = -FLT_MAX;
246
for (int j = 0; j < dst_seqlen; j++)
248
max = std::max(max, ptr[j]);
252
for (int j = 0; j < dst_seqlen; j++)
254
ptr[j] = (float)(expf(ptr[j] - max));
258
for (int j = 0; j < dst_seqlen; j++)
266
// xqk (dst_seqlen, src_seqlen)
267
// xv (dst_seqlen, embed_dim_per_head)
268
// out (embed_dim_per_head, num_heads, src_seqlen)
270
const Mat xqkm = xqk.channel(q);
271
const Mat xvm = xv.channel(q);
273
for (int i = 0; i < src_seqlen; i++)
275
float* outptr = xqkv.channel(i).row(q);
277
for (int j = 0; j < embed_dim_per_head; j++)
279
const float* qkptr = xqkm.row(i);
280
const float* vptr = xvm.row(j);
283
for (int k = 0; k < dst_seqlen; k++)
285
sum += *qkptr++ * *vptr++;
294
// out = affine(xqkv)
295
// xqkv (embed_dim, src_seqlen)
296
#pragma omp parallel for num_threads(opt.num_threads)
297
for (int i = 0; i < src_seqlen; i++)
299
float* outptr = top_blob.row(i);
301
for (int j = 0; j < qdim; j++)
303
const float* ptr = xqkv.channel(i);
304
const float* kptr = (const float*)out_weight_data + embed_dim * j;
306
float sum = out_bias_data[j];
307
for (int k = 0; k < embed_dim; k++)
309
sum += *ptr++ * *kptr++;