ncnn

Форк
0
/
einsum.cpp 
338 строк · 8.7 Кб
1
// Tencent is pleased to support the open source community by making ncnn available.
2
//
3
// Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved.
4
//
5
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
6
// in compliance with the License. You may obtain a copy of the License at
7
//
8
// https://opensource.org/licenses/BSD-3-Clause
9
//
10
// Unless required by applicable law or agreed to in writing, software distributed
11
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
12
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
13
// specific language governing permissions and limitations under the License.
14

15
#include "einsum.h"
16
#include <string.h>
17

18
namespace ncnn {
19

20
Einsum::Einsum()
21
{
22
    one_blob_only = false;
23
    support_inplace = false;
24
}
25

26
int Einsum::load_param(const ParamDict& pd)
27
{
28
    Mat equation_mat = pd.get(0, Mat());
29

30
    const int equation_len = equation_mat.w;
31

32
    // restore to lexical equation string
33
    std::string equation;
34
    equation.resize(equation_len);
35
    char* equation_ptr = (char*)equation.c_str();
36
    {
37
        const int* p = equation_mat;
38
        for (int i = 0; i < equation_len; i++)
39
        {
40
            equation_ptr[i] = p[i];
41
        }
42
    }
43

44
    if (equation == "ii")
45
    {
46
        // trace
47
        rhs_token = "ii";
48

49
        return 0;
50
    }
51

52
    // split into tokens
53
    char* arrow = strstr(equation_ptr, "->");
54
    if (!arrow)
55
    {
56
        NCNN_LOGE("invalid equation %s", equation_ptr);
57
        return -1;
58
    }
59

60
    arrow[0] = '\0';
61
    arrow[1] = '\0';
62

63
    char* lhs = equation_ptr;
64
    char* rhs = arrow + 2;
65

66
    {
67
        char* t = strtok(lhs, ",");
68
        while (t)
69
        {
70
            lhs_tokens.push_back(std::string(t));
71
            t = strtok(NULL, ",");
72
        }
73
    }
74

75
    rhs_token = std::string(rhs);
76

77
    // check token always in ijkl
78
    {
79
        for (size_t i = 0; i < rhs_token.size(); i++)
80
        {
81
            if (rhs_token[i] < 'i' || rhs_token[i] > 'l')
82
            {
83
                NCNN_LOGE("invalid rhs_token %s", rhs_token.c_str());
84
                return -1;
85
            }
86
        }
87

88
        for (size_t i = 0; i < lhs_tokens.size(); i++)
89
        {
90
            const std::string& lhs_token = lhs_tokens[i];
91
            for (size_t j = 0; j < lhs_token.size(); j++)
92
            {
93
                if (lhs_token[j] < 'i' || lhs_token[j] > 'x')
94
                {
95
                    NCNN_LOGE("invalid lhs_token %s", lhs_token.c_str());
96
                    return -1;
97
                }
98
            }
99
        }
100
    }
101

102
    return 0;
103
}
104

105
static float get_indexed_value(const Mat& m, const std::string& token, std::vector<int>& indexes)
106
{
107
    const int dims = m.dims;
108

109
    if (dims == 1)
110
    {
111
        int x = indexes[token[0] - 'i'];
112
        return m[x];
113
    }
114

115
    if (dims == 2)
116
    {
117
        int y = indexes[token[0] - 'i'];
118
        int x = indexes[token[1] - 'i'];
119
        return m.row(y)[x];
120
    }
121

122
    if (dims == 3)
123
    {
124
        int c = indexes[token[0] - 'i'];
125
        int y = indexes[token[1] - 'i'];
126
        int x = indexes[token[2] - 'i'];
127
        return m.channel(c).row(y)[x];
128
    }
129

130
    if (dims == 4)
131
    {
132
        int c = indexes[token[0] - 'i'];
133
        int z = indexes[token[1] - 'i'];
134
        int y = indexes[token[2] - 'i'];
135
        int x = indexes[token[3] - 'i'];
136
        return m.channel(c).depth(z).row(y)[x];
137
    }
138

139
    // should never reach here
140
    return 0;
141
}
142

143
static float sum_dim(const std::vector<int>& dim_sizes, int d, const std::vector<Mat>& bottom_blobs, const std::vector<std::string>& tokens, std::vector<int>& indexes)
144
{
145
    if (d == (int)dim_sizes.size())
146
    {
147
        float v = 1.f;
148
        for (size_t b = 0; b < bottom_blobs.size(); b++)
149
        {
150
            v *= get_indexed_value(bottom_blobs[b], tokens[b], indexes);
151
        }
152

153
        return v;
154
    }
155

156
    float sum = 0.f;
157

158
    for (int i = 0; i < dim_sizes[d]; i++)
159
    {
160
        indexes[d] = i;
161

162
        sum += sum_dim(dim_sizes, d + 1, bottom_blobs, tokens, indexes);
163
    }
164

165
    return sum;
166
}
167

168
int Einsum::forward(const std::vector<Mat>& bottom_blobs, std::vector<Mat>& top_blobs, const Option& opt) const
169
{
170
    // assert bottom_blobs.size() == lhs_tokens.size()
171
    // assert top_blobs.size() == 1
172

173
    size_t elemsize = bottom_blobs[0].elemsize;
174

175
    if (lhs_tokens.empty() && rhs_token == "ii")
176
    {
177
        // assert bottom_blobs.size() == 1
178
        // assert bottom_blob.dims == 2
179
        // assert bottom_blob.w == bottom_blob.h
180

181
        // trace
182
        Mat& top_blob = top_blobs[0];
183
        top_blob.create(1, elemsize, opt.blob_allocator);
184
        if (top_blob.empty())
185
            return -100;
186

187
        const Mat& bottom_blob = bottom_blobs[0];
188

189
        float sum = 0.f;
190

191
        for (int i = 0; i < bottom_blob.h; i++)
192
        {
193
            sum += bottom_blob.row(i)[i];
194
        }
195

196
        top_blob[0] = sum;
197

198
        return 0;
199
    }
200

201
    // resolve dimension sizes
202
    std::vector<int> dim_sizes(16, 1); // map ijklmnopqrstuvwx -> dim_size
203
    int dim_sizes_count = 0;
204

205
    for (size_t b = 0; b < bottom_blobs.size(); b++)
206
    {
207
        const std::string& lhs_token = lhs_tokens[b];
208
        const Mat& bottom_blob = bottom_blobs[b];
209
        const int in_dims = bottom_blob.dims;
210

211
        for (int s = 0; s < in_dims; s++)
212
        {
213
            int dim_size = 1;
214
            if (in_dims == 1) dim_size = bottom_blob.w;
215
            if (in_dims == 2 && s == 0) dim_size = bottom_blob.h;
216
            if (in_dims == 2 && s == 1) dim_size = bottom_blob.w;
217
            if (in_dims == 3 && s == 0) dim_size = bottom_blob.c;
218
            if (in_dims == 3 && s == 1) dim_size = bottom_blob.h;
219
            if (in_dims == 3 && s == 2) dim_size = bottom_blob.w;
220
            if (in_dims == 4 && s == 0) dim_size = bottom_blob.c;
221
            if (in_dims == 4 && s == 1) dim_size = bottom_blob.d;
222
            if (in_dims == 4 && s == 2) dim_size = bottom_blob.h;
223
            if (in_dims == 4 && s == 3) dim_size = bottom_blob.w;
224

225
            int dim_sizes_index = lhs_token[s] - 'i';
226
            dim_sizes[dim_sizes_index] = dim_size;
227
            dim_sizes_count = std::max(dim_sizes_count, dim_sizes_index + 1);
228
        }
229
    }
230

231
    dim_sizes.resize(dim_sizes_count);
232

233
    const int out_dims = (int)rhs_token.size();
234

235
    std::vector<int> indexes(dim_sizes_count);
236

237
    if (out_dims == 1)
238
    {
239
        Mat& top_blob = top_blobs[0];
240
        top_blob.create(dim_sizes[0], elemsize, opt.blob_allocator);
241
        if (top_blob.empty())
242
            return -100;
243

244
        for (int i = 0; i < top_blob.w; i++)
245
        {
246
            indexes[0] = i;
247

248
            float sum = sum_dim(dim_sizes, 1, bottom_blobs, lhs_tokens, indexes);
249

250
            top_blob[i] = sum;
251
        }
252
    }
253

254
    if (out_dims == 2)
255
    {
256
        Mat& top_blob = top_blobs[0];
257
        top_blob.create(dim_sizes[1], dim_sizes[0], elemsize, opt.blob_allocator);
258
        if (top_blob.empty())
259
            return -100;
260

261
        for (int i = 0; i < top_blob.h; i++)
262
        {
263
            indexes[0] = i;
264

265
            for (int j = 0; j < top_blob.w; j++)
266
            {
267
                indexes[1] = j;
268

269
                float sum = sum_dim(dim_sizes, 2, bottom_blobs, lhs_tokens, indexes);
270

271
                top_blob.row(i)[j] = sum;
272
            }
273
        }
274
    }
275

276
    if (out_dims == 3)
277
    {
278
        Mat& top_blob = top_blobs[0];
279
        top_blob.create(dim_sizes[2], dim_sizes[1], dim_sizes[0], elemsize, opt.blob_allocator);
280
        if (top_blob.empty())
281
            return -100;
282

283
        for (int i = 0; i < top_blob.c; i++)
284
        {
285
            indexes[0] = i;
286

287
            for (int j = 0; j < top_blob.h; j++)
288
            {
289
                indexes[1] = j;
290

291
                for (int k = 0; k < top_blob.w; k++)
292
                {
293
                    indexes[2] = k;
294

295
                    float sum = sum_dim(dim_sizes, 3, bottom_blobs, lhs_tokens, indexes);
296

297
                    top_blob.channel(i).row(j)[k] = sum;
298
                }
299
            }
300
        }
301
    }
302

303
    if (out_dims == 4)
304
    {
305
        Mat& top_blob = top_blobs[0];
306
        top_blob.create(dim_sizes[3], dim_sizes[2], dim_sizes[1], dim_sizes[0], elemsize, opt.blob_allocator);
307
        if (top_blob.empty())
308
            return -100;
309

310
        for (int i = 0; i < top_blob.c; i++)
311
        {
312
            indexes[0] = i;
313

314
            for (int j = 0; j < top_blob.d; j++)
315
            {
316
                indexes[1] = j;
317

318
                for (int k = 0; k < top_blob.h; k++)
319
                {
320
                    indexes[2] = k;
321

322
                    for (int l = 0; l < top_blob.w; l++)
323
                    {
324
                        indexes[3] = l;
325

326
                        float sum = sum_dim(dim_sizes, 4, bottom_blobs, lhs_tokens, indexes);
327

328
                        top_blob.channel(i).depth(j).row(k)[l] = sum;
329
                    }
330
                }
331
            }
332
        }
333
    }
334

335
    return 0;
336
}
337

338
} // namespace ncnn
339

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.