ncnn

Форк
0
/
reshape.cpp 
354 строки · 9.4 Кб
1
// Tencent is pleased to support the open source community by making ncnn available.
2
//
3
// Copyright (C) 2017 THL A29 Limited, a Tencent company. All rights reserved.
4
//
5
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
6
// in compliance with the License. You may obtain a copy of the License at
7
//
8
// https://opensource.org/licenses/BSD-3-Clause
9
//
10
// Unless required by applicable law or agreed to in writing, software distributed
11
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
12
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
13
// specific language governing permissions and limitations under the License.
14

15
#include "reshape.h"
16

17
namespace ncnn {
18

19
Reshape::Reshape()
20
{
21
    one_blob_only = true;
22
    support_inplace = false;
23
}
24

25
int Reshape::load_param(const ParamDict& pd)
26
{
27
    w = pd.get(0, -233);
28
    h = pd.get(1, -233);
29
    d = pd.get(11, -233);
30
    c = pd.get(2, -233);
31
    permute = pd.get(3, 0);
32

33
    ndim = 4;
34
    if (d == -233)
35
        ndim = 3;
36
    if (c == -233)
37
        ndim = 2;
38
    if (h == -233)
39
        ndim = 1;
40
    if (w == -233)
41
        ndim = 0;
42

43
    return 0;
44
}
45

46
int Reshape::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const
47
{
48
    size_t elemsize = bottom_blob.elemsize;
49
    int total = bottom_blob.w * bottom_blob.h * bottom_blob.d * bottom_blob.c;
50

51
    int dims = bottom_blob.dims;
52

53
    // resolve out shape
54

55
    int outw = w;
56
    int outh = h;
57
    int outd = d;
58
    int outc = c;
59

60
    if (ndim == 1)
61
    {
62
        if (outw == 0)
63
            outw = bottom_blob.w;
64

65
        if (outw == -1)
66
            outw = total;
67

68
        if (dims == 1 && bottom_blob.w == outw)
69
        {
70
            top_blob = bottom_blob;
71
            return 0;
72
        }
73
    }
74
    if (ndim == 2)
75
    {
76
        if (outw == 0)
77
            outw = bottom_blob.w;
78
        if (outh == 0)
79
            outh = bottom_blob.h;
80

81
        if (outw == -1)
82
            outw = total / outh;
83
        if (outh == -1)
84
            outh = total / outw;
85

86
        if (dims == 2 && bottom_blob.h == outh)
87
        {
88
            top_blob = bottom_blob;
89
            return 0;
90
        }
91
    }
92
    if (ndim == 3)
93
    {
94
        if (outw == 0)
95
            outw = bottom_blob.w;
96
        if (outh == 0)
97
            outh = bottom_blob.h;
98
        if (outc == 0)
99
            outc = bottom_blob.c;
100

101
        if (outw == -1)
102
            outw = total / outc / outh;
103
        if (outh == -1)
104
            outh = total / outc / outw;
105
        if (outc == -1)
106
            outc = total / outh / outw;
107

108
        if (dims == 3 && bottom_blob.c == outc)
109
        {
110
            top_blob = bottom_blob;
111
            top_blob.w = outw;
112
            top_blob.h = outh;
113
            return 0;
114
        }
115
    }
116
    if (ndim == 4)
117
    {
118
        if (outw == 0)
119
            outw = bottom_blob.w;
120
        if (outh == 0)
121
            outh = bottom_blob.h;
122
        if (outc == 0)
123
            outc = bottom_blob.c;
124
        if (outd == 0)
125
            outd = bottom_blob.d;
126

127
        if (outw == -1)
128
            outw = total / outc / outd / outh;
129
        if (outh == -1)
130
            outh = total / outc / outd / outw;
131
        if (outd == -1)
132
            outd = total / outc / outh / outw;
133
        if (outc == -1)
134
            outc = total / outd / outh / outw;
135

136
        if (dims == 4 && bottom_blob.c == outc)
137
        {
138
            top_blob = bottom_blob;
139
            top_blob.w = outw;
140
            top_blob.h = outh;
141
            top_blob.d = outd;
142
            return 0;
143
        }
144
    }
145

146
    bool need_permute = permute == 1;
147
    if (dims == 2 && ndim == 2 && bottom_blob.h == outh)
148
        need_permute = false;
149
    if (dims == 3 && ndim == 3 && bottom_blob.c == outc)
150
        need_permute = false;
151
    if (dims == 4 && ndim == 4 && bottom_blob.c == outc)
152
        need_permute = false;
153

154
    if (need_permute)
155
    {
156
        Mat bottom_blob_permuted = bottom_blob;
157

158
        if (dims == 2)
159
        {
160
            // hw -> wh
161
            int _w = bottom_blob.w;
162
            int _h = bottom_blob.h;
163

164
            bottom_blob_permuted.create(_h, _w, elemsize, opt.workspace_allocator);
165
            if (bottom_blob_permuted.empty())
166
                return -100;
167

168
            const float* ptr = bottom_blob;
169
            float* outptr = bottom_blob_permuted;
170

171
            for (int i = 0; i < _w; i++)
172
            {
173
                for (int j = 0; j < _h; j++)
174
                {
175
                    *outptr++ = ptr[j * _w + i];
176
                }
177
            }
178
        }
179
        if (dims == 3)
180
        {
181
            // chw -> hwc
182
            int _w = bottom_blob.w;
183
            int _h = bottom_blob.h;
184
            int channels = bottom_blob.c;
185

186
            bottom_blob_permuted.create(channels, _w, _h, elemsize, opt.workspace_allocator);
187
            if (bottom_blob_permuted.empty())
188
                return -100;
189

190
            #pragma omp parallel for num_threads(opt.num_threads)
191
            for (int q = 0; q < _h; q++)
192
            {
193
                float* outptr = bottom_blob_permuted.channel(q);
194

195
                for (int i = 0; i < _w; i++)
196
                {
197
                    for (int j = 0; j < channels; j++)
198
                    {
199
                        *outptr++ = bottom_blob.channel(j).row(q)[i];
200
                    }
201
                }
202
            }
203
        }
204

205
        if (dims == 4)
206
        {
207
            // cdhw -> dhwc
208
            int _w = bottom_blob.w;
209
            int _h = bottom_blob.h;
210
            int _d = bottom_blob.d;
211
            int channels = bottom_blob.c;
212

213
            bottom_blob_permuted.create(channels, _w, _h, _d, elemsize, opt.workspace_allocator);
214
            if (bottom_blob_permuted.empty())
215
                return -100;
216

217
            #pragma omp parallel for num_threads(opt.num_threads)
218
            for (int z = 0; z < _d; z++)
219
            {
220
                float* outptr = bottom_blob_permuted.channel(z);
221

222
                for (int q = 0; q < _h; q++)
223
                {
224
                    for (int i = 0; i < _w; i++)
225
                    {
226
                        for (int j = 0; j < channels; j++)
227
                        {
228
                            *outptr++ = bottom_blob.channel(j).depth(z).row(q)[i];
229
                        }
230
                    }
231
                }
232
            }
233
        }
234

235
        if (ndim == 1)
236
        {
237
            top_blob = bottom_blob_permuted.reshape(outw, opt.blob_allocator);
238
            if (top_blob.empty())
239
                return -100;
240

241
            return 0;
242
        }
243

244
        // permute on ndhwc/nhwc/nhc
245
        Mat top_blob_permuted;
246
        if (ndim == 2)
247
        {
248
            top_blob_permuted = bottom_blob_permuted.reshape(outh, outw, opt.workspace_allocator);
249
        }
250
        if (ndim == 3)
251
        {
252
            top_blob_permuted = bottom_blob_permuted.reshape(outc, outw, outh, opt.workspace_allocator);
253
        }
254
        if (ndim == 4)
255
        {
256
            top_blob_permuted = bottom_blob_permuted.reshape(outc, outw, outh, outd, opt.workspace_allocator);
257
        }
258
        if (top_blob_permuted.empty())
259
            return -100;
260

261
        if (ndim == 2)
262
        {
263
            // wh -> hw
264
            top_blob.create(outw, outh, elemsize, opt.blob_allocator);
265
            if (top_blob.empty())
266
                return -100;
267

268
            const float* ptr = top_blob_permuted;
269
            float* outptr = top_blob;
270

271
            for (int i = 0; i < outh; i++)
272
            {
273
                for (int j = 0; j < outw; j++)
274
                {
275
                    *outptr++ = ptr[j * outh + i];
276
                }
277
            }
278
        }
279
        if (ndim == 3)
280
        {
281
            // hwc -> chw
282
            top_blob.create(outw, outh, outc, elemsize, opt.blob_allocator);
283
            if (top_blob.empty())
284
                return -100;
285

286
            #pragma omp parallel for num_threads(opt.num_threads)
287
            for (int q = 0; q < outc; q++)
288
            {
289
                float* outptr = top_blob.channel(q);
290

291
                for (int i = 0; i < outh; i++)
292
                {
293
                    const float* ptr = top_blob_permuted.channel(i);
294

295
                    for (int j = 0; j < outw; j++)
296
                    {
297
                        *outptr++ = ptr[j * outc + q];
298
                    }
299
                }
300
            }
301
        }
302
        if (ndim == 4)
303
        {
304
            // dhwc -> cdhw
305
            top_blob.create(outw, outh, outd, outc, elemsize, opt.blob_allocator);
306
            if (top_blob.empty())
307
                return -100;
308

309
            #pragma omp parallel for num_threads(opt.num_threads)
310
            for (int q = 0; q < outc; q++)
311
            {
312
                float* outptr = top_blob.channel(q);
313

314
                for (int k = 0; k < outd; k++)
315
                {
316
                    const float* ptr = top_blob_permuted.channel(k);
317

318
                    for (int i = 0; i < outh; i++)
319
                    {
320
                        for (int j = 0; j < outw; j++)
321
                        {
322
                            *outptr++ = ptr[i * outw * outc + j * outc + q];
323
                        }
324
                    }
325
                }
326
            }
327
        }
328

329
        return 0;
330
    }
331

332
    if (ndim == 1)
333
    {
334
        top_blob = bottom_blob.reshape(outw, opt.blob_allocator);
335
    }
336
    if (ndim == 2)
337
    {
338
        top_blob = bottom_blob.reshape(outw, outh, opt.blob_allocator);
339
    }
340
    if (ndim == 3)
341
    {
342
        top_blob = bottom_blob.reshape(outw, outh, outc, opt.blob_allocator);
343
    }
344
    if (ndim == 4)
345
    {
346
        top_blob = bottom_blob.reshape(outw, outh, outd, outc, opt.blob_allocator);
347
    }
348
    if (top_blob.empty())
349
        return -100;
350

351
    return 0;
352
}
353

354
} // namespace ncnn
355

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.