1
// Tencent is pleased to support the open source community by making ncnn available.
3
// Copyright (C) 2017 THL A29 Limited, a Tencent company. All rights reserved.
5
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
6
// in compliance with the License. You may obtain a copy of the License at
8
// https://opensource.org/licenses/BSD-3-Clause
10
// Unless required by applicable law or agreed to in writing, software distributed
11
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
12
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
13
// specific language governing permissions and limitations under the License.
22
support_inplace = false;
25
int Squeeze::load_param(const ParamDict& pd)
27
squeeze_w = pd.get(0, 0);
28
squeeze_h = pd.get(1, 0);
29
squeeze_d = pd.get(11, 0);
30
squeeze_c = pd.get(2, 0);
31
axes = pd.get(3, Mat());
36
int Squeeze::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const
38
int w = bottom_blob.w;
39
int h = bottom_blob.h;
40
int d = bottom_blob.d;
41
int channels = bottom_blob.c;
42
int dims = bottom_blob.dims;
44
bool _squeeze_w = false;
45
bool _squeeze_h = false;
46
bool _squeeze_d = false;
47
bool _squeeze_c = false;
51
_squeeze_w = w == 1 && squeeze_w;
52
_squeeze_h = h == 1 && squeeze_h;
53
_squeeze_d = d == 1 && squeeze_d;
54
_squeeze_c = channels == 1 && squeeze_c;
58
const int* axes_ptr = axes;
59
for (int i = 0; i < axes.w; i++)
61
int axis = axes_ptr[i];
65
if (dims == 1 && axis == 0)
69
if (dims == 2 && axis == 0)
73
if (dims == 2 && axis == 1)
77
if (dims == 3 && axis == 0)
79
_squeeze_c = channels == 1;
81
if (dims == 3 && axis == 1)
85
if (dims == 3 && axis == 2)
89
if (dims == 4 && axis == 0)
91
_squeeze_c = channels == 1;
93
if (dims == 4 && axis == 1)
97
if (dims == 4 && axis == 2)
101
if (dims == 4 && axis == 3)
108
top_blob = bottom_blob;
114
top_blob = bottom_blob.reshape(1, opt.blob_allocator);
120
if (_squeeze_w && _squeeze_h)
122
top_blob = bottom_blob.reshape(1, opt.blob_allocator);
126
top_blob = bottom_blob.reshape(h, opt.blob_allocator);
130
top_blob = bottom_blob.reshape(w, opt.blob_allocator);
136
if (_squeeze_w && _squeeze_h && _squeeze_c)
138
top_blob = bottom_blob.reshape(1, opt.blob_allocator);
140
else if (_squeeze_w && _squeeze_h)
142
top_blob = bottom_blob.reshape(channels, opt.blob_allocator);
144
else if (_squeeze_h && _squeeze_c)
146
top_blob = bottom_blob.reshape(w, opt.blob_allocator);
148
else if (_squeeze_w && _squeeze_c)
150
top_blob = bottom_blob.reshape(h, opt.blob_allocator);
154
top_blob = bottom_blob.reshape(h, channels, opt.blob_allocator);
158
top_blob = bottom_blob.reshape(w, channels, opt.blob_allocator);
162
top_blob = bottom_blob.reshape(w, h, opt.blob_allocator);
168
if (_squeeze_w && _squeeze_h && _squeeze_d && _squeeze_c)
170
top_blob = bottom_blob.reshape(1, opt.blob_allocator);
172
else if (_squeeze_w && _squeeze_h && _squeeze_d)
174
top_blob = bottom_blob.reshape(channels, opt.blob_allocator);
176
else if (_squeeze_h && _squeeze_d && _squeeze_c)
178
top_blob = bottom_blob.reshape(w, opt.blob_allocator);
180
else if (_squeeze_w && _squeeze_d && _squeeze_c)
182
top_blob = bottom_blob.reshape(h, opt.blob_allocator);
184
else if (_squeeze_w && _squeeze_h && _squeeze_c)
186
top_blob = bottom_blob.reshape(d, opt.blob_allocator);
188
else if (_squeeze_w && _squeeze_h)
190
top_blob = bottom_blob.reshape(d, channels, opt.blob_allocator);
192
else if (_squeeze_w && _squeeze_d)
194
top_blob = bottom_blob.reshape(h, channels, opt.blob_allocator);
196
else if (_squeeze_h && _squeeze_d)
198
top_blob = bottom_blob.reshape(w, channels, opt.blob_allocator);
200
else if (_squeeze_h && _squeeze_c)
202
top_blob = bottom_blob.reshape(w, d, opt.blob_allocator);
204
else if (_squeeze_w && _squeeze_c)
206
top_blob = bottom_blob.reshape(h, d, opt.blob_allocator);
208
else if (_squeeze_d && _squeeze_c)
210
top_blob = bottom_blob.reshape(w, h, opt.blob_allocator);
214
top_blob = bottom_blob.reshape(h, d, channels, opt.blob_allocator);
218
top_blob = bottom_blob.reshape(w, d, channels, opt.blob_allocator);
222
top_blob = bottom_blob.reshape(w, h, channels, opt.blob_allocator);
226
top_blob = bottom_blob.reshape(w, h, d, opt.blob_allocator);
230
if (top_blob.empty())