22
one_blob_only = false;
23
support_inplace = false;
26
int Einsum::load_param(const ParamDict& pd)
28
Mat equation_mat = pd.get(0, Mat());
30
const int equation_len = equation_mat.w;
34
equation.resize(equation_len);
35
char* equation_ptr = (char*)equation.c_str();
37
const int* p = equation_mat;
38
for (int i = 0; i < equation_len; i++)
40
equation_ptr[i] = p[i];
53
char* arrow = strstr(equation_ptr, "->");
56
NCNN_LOGE("invalid equation %s", equation_ptr);
63
char* lhs = equation_ptr;
64
char* rhs = arrow + 2;
67
char* t = strtok(lhs, ",");
70
lhs_tokens.push_back(std::string(t));
71
t = strtok(NULL, ",");
75
rhs_token = std::string(rhs);
79
for (size_t i = 0; i < rhs_token.size(); i++)
81
if (rhs_token[i] < 'i' || rhs_token[i] > 'l')
83
NCNN_LOGE("invalid rhs_token %s", rhs_token.c_str());
88
for (size_t i = 0; i < lhs_tokens.size(); i++)
90
const std::string& lhs_token = lhs_tokens[i];
91
for (size_t j = 0; j < lhs_token.size(); j++)
93
if (lhs_token[j] < 'i' || lhs_token[j] > 'x')
95
NCNN_LOGE("invalid lhs_token %s", lhs_token.c_str());
105
static float get_indexed_value(const Mat& m, const std::string& token, std::vector<int>& indexes)
107
const int dims = m.dims;
111
int x = indexes[token[0] - 'i'];
117
int y = indexes[token[0] - 'i'];
118
int x = indexes[token[1] - 'i'];
124
int c = indexes[token[0] - 'i'];
125
int y = indexes[token[1] - 'i'];
126
int x = indexes[token[2] - 'i'];
127
return m.channel(c).row(y)[x];
132
int c = indexes[token[0] - 'i'];
133
int z = indexes[token[1] - 'i'];
134
int y = indexes[token[2] - 'i'];
135
int x = indexes[token[3] - 'i'];
136
return m.channel(c).depth(z).row(y)[x];
143
static float sum_dim(const std::vector<int>& dim_sizes, int d, const std::vector<Mat>& bottom_blobs, const std::vector<std::string>& tokens, std::vector<int>& indexes)
145
if (d == (int)dim_sizes.size())
148
for (size_t b = 0; b < bottom_blobs.size(); b++)
150
v *= get_indexed_value(bottom_blobs[b], tokens[b], indexes);
158
for (int i = 0; i < dim_sizes[d]; i++)
162
sum += sum_dim(dim_sizes, d + 1, bottom_blobs, tokens, indexes);
168
int Einsum::forward(const std::vector<Mat>& bottom_blobs, std::vector<Mat>& top_blobs, const Option& opt) const
173
size_t elemsize = bottom_blobs[0].elemsize;
175
if (lhs_tokens.empty() && rhs_token == "ii")
182
Mat& top_blob = top_blobs[0];
183
top_blob.create(1, elemsize, opt.blob_allocator);
184
if (top_blob.empty())
187
const Mat& bottom_blob = bottom_blobs[0];
191
for (int i = 0; i < bottom_blob.h; i++)
193
sum += bottom_blob.row(i)[i];
202
std::vector<int> dim_sizes(16, 1);
203
int dim_sizes_count = 0;
205
for (size_t b = 0; b < bottom_blobs.size(); b++)
207
const std::string& lhs_token = lhs_tokens[b];
208
const Mat& bottom_blob = bottom_blobs[b];
209
const int in_dims = bottom_blob.dims;
211
for (int s = 0; s < in_dims; s++)
214
if (in_dims == 1) dim_size = bottom_blob.w;
215
if (in_dims == 2 && s == 0) dim_size = bottom_blob.h;
216
if (in_dims == 2 && s == 1) dim_size = bottom_blob.w;
217
if (in_dims == 3 && s == 0) dim_size = bottom_blob.c;
218
if (in_dims == 3 && s == 1) dim_size = bottom_blob.h;
219
if (in_dims == 3 && s == 2) dim_size = bottom_blob.w;
220
if (in_dims == 4 && s == 0) dim_size = bottom_blob.c;
221
if (in_dims == 4 && s == 1) dim_size = bottom_blob.d;
222
if (in_dims == 4 && s == 2) dim_size = bottom_blob.h;
223
if (in_dims == 4 && s == 3) dim_size = bottom_blob.w;
225
int dim_sizes_index = lhs_token[s] - 'i';
226
dim_sizes[dim_sizes_index] = dim_size;
227
dim_sizes_count = std::max(dim_sizes_count, dim_sizes_index + 1);
231
dim_sizes.resize(dim_sizes_count);
233
const int out_dims = (int)rhs_token.size();
235
std::vector<int> indexes(dim_sizes_count);
239
Mat& top_blob = top_blobs[0];
240
top_blob.create(dim_sizes[0], elemsize, opt.blob_allocator);
241
if (top_blob.empty())
244
for (int i = 0; i < top_blob.w; i++)
248
float sum = sum_dim(dim_sizes, 1, bottom_blobs, lhs_tokens, indexes);
256
Mat& top_blob = top_blobs[0];
257
top_blob.create(dim_sizes[1], dim_sizes[0], elemsize, opt.blob_allocator);
258
if (top_blob.empty())
261
for (int i = 0; i < top_blob.h; i++)
265
for (int j = 0; j < top_blob.w; j++)
269
float sum = sum_dim(dim_sizes, 2, bottom_blobs, lhs_tokens, indexes);
271
top_blob.row(i)[j] = sum;
278
Mat& top_blob = top_blobs[0];
279
top_blob.create(dim_sizes[2], dim_sizes[1], dim_sizes[0], elemsize, opt.blob_allocator);
280
if (top_blob.empty())
283
for (int i = 0; i < top_blob.c; i++)
287
for (int j = 0; j < top_blob.h; j++)
291
for (int k = 0; k < top_blob.w; k++)
295
float sum = sum_dim(dim_sizes, 3, bottom_blobs, lhs_tokens, indexes);
297
top_blob.channel(i).row(j)[k] = sum;
305
Mat& top_blob = top_blobs[0];
306
top_blob.create(dim_sizes[3], dim_sizes[2], dim_sizes[1], dim_sizes[0], elemsize, opt.blob_allocator);
307
if (top_blob.empty())
310
for (int i = 0; i < top_blob.c; i++)
314
for (int j = 0; j < top_blob.d; j++)
318
for (int k = 0; k < top_blob.h; k++)
322
for (int l = 0; l < top_blob.w; l++)
326
float sum = sum_dim(dim_sizes, 4, bottom_blobs, lhs_tokens, indexes);
328
top_blob.channel(i).depth(j).row(k)[l] = sum;