1
// Tencent is pleased to support the open source community by making ncnn available.
3
// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved.
5
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
6
// in compliance with the License. You may obtain a copy of the License at
8
// https://opensource.org/licenses/BSD-3-Clause
10
// Unless required by applicable law or agreed to in writing, software distributed
11
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
12
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
13
// specific language governing permissions and limitations under the License.
17
static int test_multiheadattention(const ncnn::Mat& q, const ncnn::Mat& k, const ncnn::Mat& v, int embed_dim, int num_heads, int attn_mask)
26
pd.set(2, embed_dim * qdim);
31
std::vector<ncnn::Mat> weights(8);
32
weights[0] = RandomMat(embed_dim * qdim);
33
weights[1] = RandomMat(embed_dim);
34
weights[2] = RandomMat(embed_dim * kdim);
35
weights[3] = RandomMat(embed_dim);
36
weights[4] = RandomMat(embed_dim * vdim);
37
weights[5] = RandomMat(embed_dim);
38
weights[6] = RandomMat(qdim * embed_dim);
39
weights[7] = RandomMat(qdim);
41
std::vector<ncnn::Mat> as(3);
48
as.push_back(RandomMat(k.h, q.h));
51
float epsilon = 0.005;
53
int ret = test_layer("MultiHeadAttention", pd, weights, as, 1, epsilon);
56
fprintf(stderr, "test_multiheadattention failed q=(%d %d) k=(%d %d) v=(%d %d) embed_dim=%d num_heads=%d kdim=%d vdim=%d attn_mask=%d\n", q.w, q.h, k.w, k.h, v.w, v.h, embed_dim, num_heads, kdim, vdim, attn_mask);
62
static int test_multiheadattention_samekv(const ncnn::Mat& q, const ncnn::Mat& kv, int embed_dim, int num_heads)
65
const int kvdim = kv.w;
70
pd.set(2, embed_dim * qdim);
74
std::vector<ncnn::Mat> weights(8);
75
weights[0] = RandomMat(embed_dim * qdim);
76
weights[1] = RandomMat(embed_dim);
77
weights[2] = RandomMat(embed_dim * kvdim);
78
weights[3] = RandomMat(embed_dim);
79
weights[4] = RandomMat(embed_dim * kvdim);
80
weights[5] = RandomMat(embed_dim);
81
weights[6] = RandomMat(qdim * embed_dim);
82
weights[7] = RandomMat(qdim);
84
std::vector<ncnn::Mat> as(2);
88
float epsilon = 0.005;
90
int ret = test_layer("MultiHeadAttention", pd, weights, as, 1, epsilon);
93
fprintf(stderr, "test_multiheadattention_samekv failed q=(%d %d) kv=(%d %d) embed_dim=%d num_heads=%d kvdim=%d\n", q.w, q.h, kv.w, kv.h, embed_dim, num_heads, kvdim);
99
static int test_multiheadattention_sameqkv(const ncnn::Mat& a, int embed_dim, int num_heads)
101
const int qdim = a.w;
104
pd.set(0, embed_dim);
105
pd.set(1, num_heads);
106
pd.set(2, embed_dim * qdim);
109
pd.set(6, 0.7f / sqrtf(embed_dim / num_heads));
111
std::vector<ncnn::Mat> weights(8);
112
weights[0] = RandomMat(embed_dim * qdim);
113
weights[1] = RandomMat(embed_dim);
114
weights[2] = RandomMat(embed_dim * qdim);
115
weights[3] = RandomMat(embed_dim);
116
weights[4] = RandomMat(embed_dim * qdim);
117
weights[5] = RandomMat(embed_dim);
118
weights[6] = RandomMat(qdim * embed_dim);
119
weights[7] = RandomMat(qdim);
121
std::vector<ncnn::Mat> as(1);
124
float epsilon = 0.005;
126
int ret = test_layer("MultiHeadAttention", pd, weights, as, 1, epsilon);
129
fprintf(stderr, "test_multiheadattention_sameqkv failed a=(%d %d) embed_dim=%d num_heads=%d\n", a.w, a.h, embed_dim, num_heads);
135
static int test_multiheadattention_0()
138
|| test_multiheadattention(RandomMat(62, 66), RandomMat(32, 66), RandomMat(20, 66), 62, 2, 0)
139
|| test_multiheadattention(RandomMat(26, 64), RandomMat(32, 64), RandomMat(18, 64), 26, 2, 1)
140
|| test_multiheadattention(RandomMat(64, 128), RandomMat(64, 128), RandomMat(64, 128), 64, 4, 0)
141
|| test_multiheadattention(RandomMat(48, 127), RandomMat(64, 127), RandomMat(64, 127), 64, 16, 1)
142
|| test_multiheadattention(RandomMat(16, 128), RandomMat(44, 128), RandomMat(55, 128), 16, 2, 0)
143
|| test_multiheadattention(RandomMat(12, 128), RandomMat(44, 127), RandomMat(55, 127), 16, 4, 1)
144
|| test_multiheadattention(RandomMat(12, 17), RandomMat(28, 127), RandomMat(32, 127), 12, 3, 0)
145
|| test_multiheadattention(RandomMat(12, 17), RandomMat(28, 32), RandomMat(11, 32), 12, 3, 1);
148
static int test_multiheadattention_1()
151
|| test_multiheadattention_samekv(RandomMat(64, 128), RandomMat(64, 128), 64, 4)
152
|| test_multiheadattention_samekv(RandomMat(48, 127), RandomMat(64, 127), 64, 16)
153
|| test_multiheadattention_samekv(RandomMat(16, 128), RandomMat(44, 128), 16, 2)
154
|| test_multiheadattention_samekv(RandomMat(12, 128), RandomMat(22, 127), 16, 4)
155
|| test_multiheadattention_samekv(RandomMat(12, 17), RandomMat(28, 127), 12, 3)
156
|| test_multiheadattention_samekv(RandomMat(12, 17), RandomMat(11, 32), 12, 3);
159
static int test_multiheadattention_2()
162
|| test_multiheadattention_sameqkv(RandomMat(64, 128), 64, 4)
163
|| test_multiheadattention_sameqkv(RandomMat(48, 127), 64, 8);
171
|| test_multiheadattention_0()
172
|| test_multiheadattention_1()
173
|| test_multiheadattention_2();