ncnn

Форк
0
/
test_einsum.cpp 
180 строк · 4.3 Кб
1
// Tencent is pleased to support the open source community by making ncnn available.
2
//
3
// Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved.
4
//
5
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
6
// in compliance with the License. You may obtain a copy of the License at
7
//
8
// https://opensource.org/licenses/BSD-3-Clause
9
//
10
// Unless required by applicable law or agreed to in writing, software distributed
11
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
12
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
13
// specific language governing permissions and limitations under the License.
14

15
#include "testutil.h"
16

17
static int test_einsum(const std::vector<ncnn::Mat>& a, const std::string& equation)
18
{
19
    ncnn::Mat equation_mat(equation.size());
20
    for (size_t i = 0; i < equation.size(); i++)
21
    {
22
        ((int*)equation_mat)[i] = equation[i];
23
    }
24

25
    ncnn::ParamDict pd;
26
    pd.set(0, equation_mat);
27

28
    std::vector<ncnn::Mat> weights(0);
29

30
    int ret = test_layer("Einsum", pd, weights, a);
31
    if (ret != 0)
32
    {
33
        fprintf(stderr, "test_einsum failed a[0].dims=%d a[0]=(%d %d %d) equation=%s\n", a[0].dims, a[0].w, a[0].h, a[0].c, equation.c_str());
34
    }
35

36
    return ret;
37
}
38

39
static int test_einsum_0()
40
{
41
    std::vector<ncnn::Mat> a(1);
42
    a[0] = RandomMat(32, 32);
43

44
    return test_einsum(a, "ii");
45
}
46

47
static int test_einsum_1()
48
{
49
    std::vector<ncnn::Mat> a(1);
50
    a[0] = RandomMat(27, 32);
51

52
    return test_einsum(a, "ij->i") || test_einsum(a, "ji->i");
53
}
54

55
static int test_einsum_2()
56
{
57
    std::vector<ncnn::Mat> a(1);
58
    a[0] = RandomMat(17, 14, 32);
59

60
    return 0
61
           || test_einsum(a, "ijk->i")
62
           || test_einsum(a, "jik->i")
63
           || test_einsum(a, "jki->i")
64
           || test_einsum(a, "ikj->ij")
65
           || test_einsum(a, "kij->ij")
66
           || test_einsum(a, "ijk->ij");
67
}
68

69
static int test_einsum_3()
70
{
71
    std::vector<ncnn::Mat> a(1);
72
    a[0] = RandomMat(17, 14, 9, 32);
73

74
    return 0
75
           || test_einsum(a, "jkli->i")
76
           || test_einsum(a, "jkil->i")
77
           || test_einsum(a, "jikl->i")
78
           || test_einsum(a, "ijkl->i")
79
           || test_einsum(a, "iklj->ij")
80
           || test_einsum(a, "klij->ij")
81
           || test_einsum(a, "kijl->ij")
82
           || test_einsum(a, "ijkl->ij")
83
           || test_einsum(a, "ijlk->ijk")
84
           || test_einsum(a, "lijk->ijk")
85
           || test_einsum(a, "ijkl->ijk");
86
}
87

88
static int test_einsum_4()
89
{
90
    std::vector<ncnn::Mat> a(2);
91
    a[0] = RandomMat(12, 28);
92
    a[1] = RandomMat(12);
93

94
    return test_einsum(a, "ij,j->i");
95
}
96

97
static int test_einsum_5()
98
{
99
    std::vector<ncnn::Mat> a(2);
100
    a[0] = RandomMat(14);
101
    a[1] = RandomMat(14, 7, 16);
102

103
    return test_einsum(a, "k,ijk->ij");
104
}
105

106
static int test_einsum_6()
107
{
108
    std::vector<ncnn::Mat> a(2);
109
    a[0] = RandomMat(27);
110
    a[1] = RandomMat(32);
111

112
    return test_einsum(a, "i,j->ij");
113
}
114

115
static int test_einsum_7()
116
{
117
    std::vector<ncnn::Mat> a(4);
118
    a[0] = RandomMat(7);
119
    a[1] = RandomMat(2);
120
    a[2] = RandomMat(11);
121
    a[3] = RandomMat(16);
122

123
    return test_einsum(a, "i,j,k,l->ijkl");
124
}
125

126
static int test_einsum_8()
127
{
128
    std::vector<ncnn::Mat> a(2);
129
    a[0] = RandomMat(5, 2, 3);
130
    a[1] = RandomMat(4, 5, 3);
131

132
    return test_einsum(a, "ijl,ilk->ijk");
133
}
134

135
static int test_einsum_9()
136
{
137
    std::vector<ncnn::Mat> a(2);
138
    a[0] = RandomMat(4, 5, 3);
139
    a[1] = RandomMat(5, 2, 3);
140

141
    return test_einsum(a, "ilk,ijl->ijk");
142
}
143

144
static int test_einsum_10()
145
{
146
    std::vector<ncnn::Mat> a(3);
147
    a[0] = RandomMat(15, 12);
148
    a[1] = RandomMat(24, 15, 13);
149
    a[2] = RandomMat(24, 12);
150

151
    return test_einsum(a, "ik,jkl,il->ij");
152
}
153

154
static int test_einsum_11()
155
{
156
    std::vector<ncnn::Mat> a(2);
157
    a[0] = RandomMat(7, 5, 3, 2);
158
    a[1] = RandomMat(5, 17, 3, 11);
159

160
    return test_einsum(a, "imnj,kmln->ijkl");
161
}
162

163
int main()
164
{
165
    SRAND(7767517);
166

167
    return 0
168
           || test_einsum_0()
169
           || test_einsum_1()
170
           || test_einsum_2()
171
           || test_einsum_3()
172
           || test_einsum_4()
173
           || test_einsum_5()
174
           || test_einsum_6()
175
           || test_einsum_7()
176
           || test_einsum_8()
177
           || test_einsum_9()
178
           || test_einsum_10()
179
           || test_einsum_11();
180
}
181

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.