ncnn

Форк
0
/
multiheadattention.cpp 
319 строк · 9.3 Кб
1
// Tencent is pleased to support the open source community by making ncnn available.
2
//
3
// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved.
4
//
5
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
6
// in compliance with the License. You may obtain a copy of the License at
7
//
8
// https://opensource.org/licenses/BSD-3-Clause
9
//
10
// Unless required by applicable law or agreed to in writing, software distributed
11
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
12
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
13
// specific language governing permissions and limitations under the License.
14

15
#include "multiheadattention.h"
16

17
#include <float.h>
18

19
namespace ncnn {
20

21
MultiHeadAttention::MultiHeadAttention()
22
{
23
}
24

25
int MultiHeadAttention::load_param(const ParamDict& pd)
26
{
27
    embed_dim = pd.get(0, 0);
28
    num_heads = pd.get(1, 1);
29
    weight_data_size = pd.get(2, 0);
30
    kdim = pd.get(3, embed_dim);
31
    vdim = pd.get(4, embed_dim);
32
    attn_mask = pd.get(5, 0);
33
    scale = pd.get(6, 1.f / sqrtf(embed_dim / num_heads));
34

35
    return 0;
36
}
37

38
int MultiHeadAttention::load_model(const ModelBin& mb)
39
{
40
    const int qdim = weight_data_size / embed_dim;
41

42
    q_weight_data = mb.load(embed_dim * qdim, 0);
43
    if (q_weight_data.empty())
44
        return -100;
45

46
    q_bias_data = mb.load(embed_dim, 1);
47
    if (q_bias_data.empty())
48
        return -100;
49

50
    k_weight_data = mb.load(embed_dim * kdim, 0);
51
    if (k_weight_data.empty())
52
        return -100;
53

54
    k_bias_data = mb.load(embed_dim, 1);
55
    if (k_bias_data.empty())
56
        return -100;
57

58
    v_weight_data = mb.load(embed_dim * vdim, 0);
59
    if (v_weight_data.empty())
60
        return -100;
61

62
    v_bias_data = mb.load(embed_dim, 1);
63
    if (v_bias_data.empty())
64
        return -100;
65

66
    out_weight_data = mb.load(qdim * embed_dim, 0);
67
    if (out_weight_data.empty())
68
        return -100;
69

70
    out_bias_data = mb.load(qdim, 1);
71
    if (out_bias_data.empty())
72
        return -100;
73

74
    return 0;
75
}
76

77
// refers to https://pytorch.org/docs/stable/generated/torch.nn.MultiheadAttention.html
78
int MultiHeadAttention::forward(const std::vector<Mat>& bottom_blobs, std::vector<Mat>& top_blobs, const Option& opt) const
79
{
80
    const Mat& q_blob = bottom_blobs[0];
81
    const Mat& k_blob = (bottom_blobs.size() == 1 || (bottom_blobs.size() == 2 && attn_mask)) ? q_blob : bottom_blobs[1];
82
    const Mat& v_blob = (bottom_blobs.size() == 1 || (bottom_blobs.size() == 2 && attn_mask)) ? q_blob : (bottom_blobs.size() == 2 || (bottom_blobs.size() == 3 && attn_mask)) ? k_blob : bottom_blobs[2];
83
    const Mat& attn_mask_blob = attn_mask ? bottom_blobs[bottom_blobs.size() - 1] : Mat();
84

85
    const int src_seqlen = q_blob.h;
86
    const int dst_seqlen = k_blob.h;
87
    const int embed_dim_per_head = embed_dim / num_heads;
88
    const int qdim = weight_data_size / embed_dim;
89

90
    // assert k_blob.h == v_blob.h
91

92
    Mat& top_blob = top_blobs[0];
93
    top_blob.create(qdim, src_seqlen, 4u, opt.blob_allocator);
94
    if (top_blob.empty())
95
        return -100;
96

97
    Mat xq(embed_dim_per_head, src_seqlen, num_heads, 4u, opt.workspace_allocator);
98
    if (xq.empty())
99
        return -100;
100
    Mat xk(embed_dim_per_head, dst_seqlen, num_heads, 4u, opt.workspace_allocator);
101
    if (xk.empty())
102
        return -100;
103
    Mat xv(dst_seqlen, embed_dim_per_head, num_heads, 4u, opt.workspace_allocator);
104
    if (xv.empty())
105
        return -100;
106

107
    Mat xqk(dst_seqlen, src_seqlen, num_heads, 4u, opt.workspace_allocator);
108
    if (xqk.empty())
109
        return -100;
110

111
    Mat xqkv(embed_dim_per_head, num_heads, src_seqlen, 4u, opt.workspace_allocator);
112
    if (xqkv.empty())
113
        return -100;
114

115
    #pragma omp parallel for num_threads(opt.num_threads)
116
    for (int q = 0; q < num_heads; q++)
117
    {
118
        // xq = affine(q) * scale
119
        {
120
            Mat outm = xq.channel(q);
121

122
            for (int i = 0; i < src_seqlen; i++)
123
            {
124
                float* outptr = outm.row(i);
125

126
                for (int j = 0; j < embed_dim_per_head; j++)
127
                {
128
                    const float* ptr = q_blob.row(i);
129
                    const float* kptr = (const float*)q_weight_data + qdim * (q * embed_dim_per_head + j);
130

131
                    float sum = q_bias_data[q * embed_dim_per_head + j];
132
                    for (int k = 0; k < qdim; k++)
133
                    {
134
                        sum += *ptr++ * *kptr++;
135
                    }
136

137
                    outptr[j] = sum * scale;
138
                }
139
            }
140
        }
141

142
        // xk = affine(k)
143
        {
144
            Mat outm = xk.channel(q);
145

146
            for (int i = 0; i < dst_seqlen; i++)
147
            {
148
                float* outptr = outm.row(i);
149

150
                for (int j = 0; j < embed_dim_per_head; j++)
151
                {
152
                    const float* ptr = k_blob.row(i);
153
                    const float* kptr = (const float*)k_weight_data + kdim * (q * embed_dim_per_head + j);
154

155
                    float sum = k_bias_data[q * embed_dim_per_head + j];
156
                    for (int k = 0; k < kdim; k++)
157
                    {
158
                        sum += *ptr++ * *kptr++;
159
                    }
160

161
                    outptr[j] = sum;
162
                }
163
            }
164
        }
165

166
        // xv = affine(v)
167
        {
168
            Mat outm = xv.channel(q);
169

170
            for (int i = 0; i < embed_dim_per_head; i++)
171
            {
172
                for (int j = 0; j < dst_seqlen; j++)
173
                {
174
                    const float* ptr = v_blob.row(j);
175
                    const float* kptr = (const float*)v_weight_data + vdim * (q * embed_dim_per_head + i);
176

177
                    float sum = v_bias_data[q * embed_dim_per_head + i];
178
                    for (int k = 0; k < vdim; k++)
179
                    {
180
                        sum += *ptr++ * *kptr++;
181
                    }
182

183
                    float* outptr = outm.row(i);
184

185
                    outptr[j] = sum;
186
                }
187
            }
188
        }
189

190
        // xqk = xq * xk
191
        // xq  (embed_dim_per_head, src_seqlen)
192
        // xk  (embed_dim_per_head, dst_seqlen)
193
        {
194
            const Mat xqm = xq.channel(q);
195
            const Mat xkm = xk.channel(q);
196

197
            Mat outm = xqk.channel(q);
198

199
            for (int i = 0; i < src_seqlen; i++)
200
            {
201
                float* outptr = outm.row(i);
202

203
                for (int j = 0; j < dst_seqlen; j++)
204
                {
205
                    const float* qptr = xqm.row(i);
206
                    const float* kptr = xkm.row(j);
207

208
                    float sum = 0.f;
209
                    for (int k = 0; k < embed_dim_per_head; k++)
210
                    {
211
                        sum += *qptr++ * *kptr++;
212
                    }
213

214
                    outptr[j] = sum;
215
                }
216
            }
217
        }
218

219
        // xqk = xqk + mask
220
        if (attn_mask)
221
        {
222
            const Mat& maskm = attn_mask_blob.dims == 3 ? attn_mask_blob.channel(q) : attn_mask_blob;
223
            Mat outm = xqk.channel(q);
224

225
            for (int i = 0; i < src_seqlen; i++)
226
            {
227
                const float* mptr = maskm.row(i);
228
                float* outptr = outm.row(i);
229

230
                for (int j = 0; j < dst_seqlen; j++)
231
                {
232
                    outptr[j] += mptr[j];
233
                }
234
            }
235
        }
236

237
        // softmax(xqk)
238
        {
239
            Mat outm = xqk.channel(q);
240

241
            for (int i = 0; i < src_seqlen; i++)
242
            {
243
                float* ptr = outm.row(i);
244

245
                float max = -FLT_MAX;
246
                for (int j = 0; j < dst_seqlen; j++)
247
                {
248
                    max = std::max(max, ptr[j]);
249
                }
250

251
                float sum = 0.f;
252
                for (int j = 0; j < dst_seqlen; j++)
253
                {
254
                    ptr[j] = (float)(expf(ptr[j] - max));
255
                    sum += ptr[j];
256
                }
257

258
                for (int j = 0; j < dst_seqlen; j++)
259
                {
260
                    ptr[j] /= sum;
261
                }
262
            }
263
        }
264

265
        // xqkv = xqk * xv
266
        // xqk (dst_seqlen, src_seqlen)
267
        // xv  (dst_seqlen, embed_dim_per_head)
268
        // out (embed_dim_per_head, num_heads, src_seqlen)
269
        {
270
            const Mat xqkm = xqk.channel(q);
271
            const Mat xvm = xv.channel(q);
272

273
            for (int i = 0; i < src_seqlen; i++)
274
            {
275
                float* outptr = xqkv.channel(i).row(q);
276

277
                for (int j = 0; j < embed_dim_per_head; j++)
278
                {
279
                    const float* qkptr = xqkm.row(i);
280
                    const float* vptr = xvm.row(j);
281

282
                    float sum = 0.f;
283
                    for (int k = 0; k < dst_seqlen; k++)
284
                    {
285
                        sum += *qkptr++ * *vptr++;
286
                    }
287

288
                    outptr[j] = sum;
289
                }
290
            }
291
        }
292
    }
293

294
    // out = affine(xqkv)
295
    // xqkv  (embed_dim, src_seqlen)
296
    #pragma omp parallel for num_threads(opt.num_threads)
297
    for (int i = 0; i < src_seqlen; i++)
298
    {
299
        float* outptr = top_blob.row(i);
300

301
        for (int j = 0; j < qdim; j++)
302
        {
303
            const float* ptr = xqkv.channel(i);
304
            const float* kptr = (const float*)out_weight_data + embed_dim * j;
305

306
            float sum = out_bias_data[j];
307
            for (int k = 0; k < embed_dim; k++)
308
            {
309
                sum += *ptr++ * *kptr++;
310
            }
311

312
            outptr[j] = sum;
313
        }
314
    }
315

316
    return 0;
317
}
318

319
} // namespace ncnn
320

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.