22
support_inplace = false;
25
int Reshape::load_param(const ParamDict& pd)
31
permute = pd.get(3, 0);
46
int Reshape::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const
48
size_t elemsize = bottom_blob.elemsize;
49
int total = bottom_blob.w * bottom_blob.h * bottom_blob.d * bottom_blob.c;
51
int dims = bottom_blob.dims;
68
if (dims == 1 && bottom_blob.w == outw)
70
top_blob = bottom_blob;
86
if (dims == 2 && bottom_blob.h == outh)
88
top_blob = bottom_blob;
102
outw = total / outc / outh;
104
outh = total / outc / outw;
106
outc = total / outh / outw;
108
if (dims == 3 && bottom_blob.c == outc)
110
top_blob = bottom_blob;
119
outw = bottom_blob.w;
121
outh = bottom_blob.h;
123
outc = bottom_blob.c;
125
outd = bottom_blob.d;
128
outw = total / outc / outd / outh;
130
outh = total / outc / outd / outw;
132
outd = total / outc / outh / outw;
134
outc = total / outd / outh / outw;
136
if (dims == 4 && bottom_blob.c == outc)
138
top_blob = bottom_blob;
146
bool need_permute = permute == 1;
147
if (dims == 2 && ndim == 2 && bottom_blob.h == outh)
148
need_permute = false;
149
if (dims == 3 && ndim == 3 && bottom_blob.c == outc)
150
need_permute = false;
151
if (dims == 4 && ndim == 4 && bottom_blob.c == outc)
152
need_permute = false;
156
Mat bottom_blob_permuted = bottom_blob;
161
int _w = bottom_blob.w;
162
int _h = bottom_blob.h;
164
bottom_blob_permuted.create(_h, _w, elemsize, opt.workspace_allocator);
165
if (bottom_blob_permuted.empty())
168
const float* ptr = bottom_blob;
169
float* outptr = bottom_blob_permuted;
171
for (int i = 0; i < _w; i++)
173
for (int j = 0; j < _h; j++)
175
*outptr++ = ptr[j * _w + i];
182
int _w = bottom_blob.w;
183
int _h = bottom_blob.h;
184
int channels = bottom_blob.c;
186
bottom_blob_permuted.create(channels, _w, _h, elemsize, opt.workspace_allocator);
187
if (bottom_blob_permuted.empty())
190
#pragma omp parallel for num_threads(opt.num_threads)
191
for (int q = 0; q < _h; q++)
193
float* outptr = bottom_blob_permuted.channel(q);
195
for (int i = 0; i < _w; i++)
197
for (int j = 0; j < channels; j++)
199
*outptr++ = bottom_blob.channel(j).row(q)[i];
208
int _w = bottom_blob.w;
209
int _h = bottom_blob.h;
210
int _d = bottom_blob.d;
211
int channels = bottom_blob.c;
213
bottom_blob_permuted.create(channels, _w, _h, _d, elemsize, opt.workspace_allocator);
214
if (bottom_blob_permuted.empty())
217
#pragma omp parallel for num_threads(opt.num_threads)
218
for (int z = 0; z < _d; z++)
220
float* outptr = bottom_blob_permuted.channel(z);
222
for (int q = 0; q < _h; q++)
224
for (int i = 0; i < _w; i++)
226
for (int j = 0; j < channels; j++)
228
*outptr++ = bottom_blob.channel(j).depth(z).row(q)[i];
237
top_blob = bottom_blob_permuted.reshape(outw, opt.blob_allocator);
238
if (top_blob.empty())
245
Mat top_blob_permuted;
248
top_blob_permuted = bottom_blob_permuted.reshape(outh, outw, opt.workspace_allocator);
252
top_blob_permuted = bottom_blob_permuted.reshape(outc, outw, outh, opt.workspace_allocator);
256
top_blob_permuted = bottom_blob_permuted.reshape(outc, outw, outh, outd, opt.workspace_allocator);
258
if (top_blob_permuted.empty())
264
top_blob.create(outw, outh, elemsize, opt.blob_allocator);
265
if (top_blob.empty())
268
const float* ptr = top_blob_permuted;
269
float* outptr = top_blob;
271
for (int i = 0; i < outh; i++)
273
for (int j = 0; j < outw; j++)
275
*outptr++ = ptr[j * outh + i];
282
top_blob.create(outw, outh, outc, elemsize, opt.blob_allocator);
283
if (top_blob.empty())
286
#pragma omp parallel for num_threads(opt.num_threads)
287
for (int q = 0; q < outc; q++)
289
float* outptr = top_blob.channel(q);
291
for (int i = 0; i < outh; i++)
293
const float* ptr = top_blob_permuted.channel(i);
295
for (int j = 0; j < outw; j++)
297
*outptr++ = ptr[j * outc + q];
305
top_blob.create(outw, outh, outd, outc, elemsize, opt.blob_allocator);
306
if (top_blob.empty())
309
#pragma omp parallel for num_threads(opt.num_threads)
310
for (int q = 0; q < outc; q++)
312
float* outptr = top_blob.channel(q);
314
for (int k = 0; k < outd; k++)
316
const float* ptr = top_blob_permuted.channel(k);
318
for (int i = 0; i < outh; i++)
320
for (int j = 0; j < outw; j++)
322
*outptr++ = ptr[i * outw * outc + j * outc + q];
334
top_blob = bottom_blob.reshape(outw, opt.blob_allocator);
338
top_blob = bottom_blob.reshape(outw, outh, opt.blob_allocator);
342
top_blob = bottom_blob.reshape(outw, outh, outc, opt.blob_allocator);
346
top_blob = bottom_blob.reshape(outw, outh, outd, outc, opt.blob_allocator);
348
if (top_blob.empty())