ncnn

Форк
0
/
test_multiheadattention.cpp 
174 строки · 5.9 Кб
1
// Tencent is pleased to support the open source community by making ncnn available.
2
//
3
// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved.
4
//
5
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
6
// in compliance with the License. You may obtain a copy of the License at
7
//
8
// https://opensource.org/licenses/BSD-3-Clause
9
//
10
// Unless required by applicable law or agreed to in writing, software distributed
11
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
12
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
13
// specific language governing permissions and limitations under the License.
14

15
#include "testutil.h"
16

17
static int test_multiheadattention(const ncnn::Mat& q, const ncnn::Mat& k, const ncnn::Mat& v, int embed_dim, int num_heads, int attn_mask)
18
{
19
    const int qdim = q.w;
20
    const int kdim = k.w;
21
    const int vdim = v.w;
22

23
    ncnn::ParamDict pd;
24
    pd.set(0, embed_dim);
25
    pd.set(1, num_heads);
26
    pd.set(2, embed_dim * qdim);
27
    pd.set(3, kdim);
28
    pd.set(4, vdim);
29
    pd.set(5, attn_mask);
30

31
    std::vector<ncnn::Mat> weights(8);
32
    weights[0] = RandomMat(embed_dim * qdim);
33
    weights[1] = RandomMat(embed_dim);
34
    weights[2] = RandomMat(embed_dim * kdim);
35
    weights[3] = RandomMat(embed_dim);
36
    weights[4] = RandomMat(embed_dim * vdim);
37
    weights[5] = RandomMat(embed_dim);
38
    weights[6] = RandomMat(qdim * embed_dim);
39
    weights[7] = RandomMat(qdim);
40

41
    std::vector<ncnn::Mat> as(3);
42
    as[0] = q;
43
    as[1] = k;
44
    as[2] = v;
45

46
    if (attn_mask)
47
    {
48
        as.push_back(RandomMat(k.h, q.h));
49
    }
50

51
    float epsilon = 0.005;
52

53
    int ret = test_layer("MultiHeadAttention", pd, weights, as, 1, epsilon);
54
    if (ret != 0)
55
    {
56
        fprintf(stderr, "test_multiheadattention failed q=(%d %d) k=(%d %d) v=(%d %d) embed_dim=%d num_heads=%d kdim=%d vdim=%d attn_mask=%d\n", q.w, q.h, k.w, k.h, v.w, v.h, embed_dim, num_heads, kdim, vdim, attn_mask);
57
    }
58

59
    return ret;
60
}
61

62
static int test_multiheadattention_samekv(const ncnn::Mat& q, const ncnn::Mat& kv, int embed_dim, int num_heads)
63
{
64
    const int qdim = q.w;
65
    const int kvdim = kv.w;
66

67
    ncnn::ParamDict pd;
68
    pd.set(0, embed_dim);
69
    pd.set(1, num_heads);
70
    pd.set(2, embed_dim * qdim);
71
    pd.set(3, kvdim);
72
    pd.set(4, kvdim);
73

74
    std::vector<ncnn::Mat> weights(8);
75
    weights[0] = RandomMat(embed_dim * qdim);
76
    weights[1] = RandomMat(embed_dim);
77
    weights[2] = RandomMat(embed_dim * kvdim);
78
    weights[3] = RandomMat(embed_dim);
79
    weights[4] = RandomMat(embed_dim * kvdim);
80
    weights[5] = RandomMat(embed_dim);
81
    weights[6] = RandomMat(qdim * embed_dim);
82
    weights[7] = RandomMat(qdim);
83

84
    std::vector<ncnn::Mat> as(2);
85
    as[0] = q;
86
    as[1] = kv;
87

88
    float epsilon = 0.005;
89

90
    int ret = test_layer("MultiHeadAttention", pd, weights, as, 1, epsilon);
91
    if (ret != 0)
92
    {
93
        fprintf(stderr, "test_multiheadattention_samekv failed q=(%d %d) kv=(%d %d) embed_dim=%d num_heads=%d kvdim=%d\n", q.w, q.h, kv.w, kv.h, embed_dim, num_heads, kvdim);
94
    }
95

96
    return ret;
97
}
98

99
static int test_multiheadattention_sameqkv(const ncnn::Mat& a, int embed_dim, int num_heads)
100
{
101
    const int qdim = a.w;
102

103
    ncnn::ParamDict pd;
104
    pd.set(0, embed_dim);
105
    pd.set(1, num_heads);
106
    pd.set(2, embed_dim * qdim);
107
    pd.set(3, qdim);
108
    pd.set(4, qdim);
109
    pd.set(6, 0.7f / sqrtf(embed_dim / num_heads));
110

111
    std::vector<ncnn::Mat> weights(8);
112
    weights[0] = RandomMat(embed_dim * qdim);
113
    weights[1] = RandomMat(embed_dim);
114
    weights[2] = RandomMat(embed_dim * qdim);
115
    weights[3] = RandomMat(embed_dim);
116
    weights[4] = RandomMat(embed_dim * qdim);
117
    weights[5] = RandomMat(embed_dim);
118
    weights[6] = RandomMat(qdim * embed_dim);
119
    weights[7] = RandomMat(qdim);
120

121
    std::vector<ncnn::Mat> as(1);
122
    as[0] = a;
123

124
    float epsilon = 0.005;
125

126
    int ret = test_layer("MultiHeadAttention", pd, weights, as, 1, epsilon);
127
    if (ret != 0)
128
    {
129
        fprintf(stderr, "test_multiheadattention_sameqkv failed a=(%d %d) embed_dim=%d num_heads=%d\n", a.w, a.h, embed_dim, num_heads);
130
    }
131

132
    return ret;
133
}
134

135
static int test_multiheadattention_0()
136
{
137
    return 0
138
           || test_multiheadattention(RandomMat(62, 66), RandomMat(32, 66), RandomMat(20, 66), 62, 2, 0)
139
           || test_multiheadattention(RandomMat(26, 64), RandomMat(32, 64), RandomMat(18, 64), 26, 2, 1)
140
           || test_multiheadattention(RandomMat(64, 128), RandomMat(64, 128), RandomMat(64, 128), 64, 4, 0)
141
           || test_multiheadattention(RandomMat(48, 127), RandomMat(64, 127), RandomMat(64, 127), 64, 16, 1)
142
           || test_multiheadattention(RandomMat(16, 128), RandomMat(44, 128), RandomMat(55, 128), 16, 2, 0)
143
           || test_multiheadattention(RandomMat(12, 128), RandomMat(44, 127), RandomMat(55, 127), 16, 4, 1)
144
           || test_multiheadattention(RandomMat(12, 17), RandomMat(28, 127), RandomMat(32, 127), 12, 3, 0)
145
           || test_multiheadattention(RandomMat(12, 17), RandomMat(28, 32), RandomMat(11, 32), 12, 3, 1);
146
}
147

148
static int test_multiheadattention_1()
149
{
150
    return 0
151
           || test_multiheadattention_samekv(RandomMat(64, 128), RandomMat(64, 128), 64, 4)
152
           || test_multiheadattention_samekv(RandomMat(48, 127), RandomMat(64, 127), 64, 16)
153
           || test_multiheadattention_samekv(RandomMat(16, 128), RandomMat(44, 128), 16, 2)
154
           || test_multiheadattention_samekv(RandomMat(12, 128), RandomMat(22, 127), 16, 4)
155
           || test_multiheadattention_samekv(RandomMat(12, 17), RandomMat(28, 127), 12, 3)
156
           || test_multiheadattention_samekv(RandomMat(12, 17), RandomMat(11, 32), 12, 3);
157
}
158

159
static int test_multiheadattention_2()
160
{
161
    return 0
162
           || test_multiheadattention_sameqkv(RandomMat(64, 128), 64, 4)
163
           || test_multiheadattention_sameqkv(RandomMat(48, 127), 64, 8);
164
}
165

166
int main()
167
{
168
    SRAND(7767517);
169

170
    return 0
171
           || test_multiheadattention_0()
172
           || test_multiheadattention_1()
173
           || test_multiheadattention_2();
174
}
175

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.