ncnn
1// Tencent is pleased to support the open source community by making ncnn available.
2//
3// Copyright (C) 2020 THL A29 Limited, a Tencent company. All rights reserved.
4//
5// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
6// in compliance with the License. You may obtain a copy of the License at
7//
8// https://opensource.org/licenses/BSD-3-Clause
9//
10// Unless required by applicable law or agreed to in writing, software distributed
11// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
12// CONDITIONS OF ANY KIND, either express or implied. See the License for the
13// specific language governing permissions and limitations under the License.
14
15#include "pixelshuffle.h"16
17namespace ncnn {18
19PixelShuffle::PixelShuffle()20{
21one_blob_only = true;22support_inplace = false;23}
24
25int PixelShuffle::load_param(const ParamDict& pd)26{
27upscale_factor = pd.get(0, 1);28mode = pd.get(1, 0);29
30return 0;31}
32
33int PixelShuffle::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const34{
35int w = bottom_blob.w;36int h = bottom_blob.h;37int channels = bottom_blob.c;38size_t elemsize = bottom_blob.elemsize;39
40int outw = w * upscale_factor;41int outh = h * upscale_factor;42int outc = channels / (upscale_factor * upscale_factor);43
44top_blob.create(outw, outh, outc, elemsize, opt.blob_allocator);45if (top_blob.empty())46return -100;47
48#pragma omp parallel for num_threads(opt.num_threads)49for (int p = 0; p < outc; p++)50{51Mat m = top_blob.channel(p);52
53for (int sh = 0; sh < upscale_factor; sh++)54{55for (int sw = 0; sw < upscale_factor; sw++)56{57int q;58if (mode == 0)59q = p * upscale_factor * upscale_factor + sh * upscale_factor + sw;60else // if (mode == 1)61q = (sh * upscale_factor + sw) * outc + p;62
63const float* sptr = bottom_blob.channel(q);64
65for (int i = 0; i < h; i++)66{67float* outptr = m.row(i * upscale_factor + sh) + sw;68for (int j = 0; j < w; j++)69{70outptr[0] = sptr[0];71
72sptr++;73outptr += upscale_factor;74}75}76}77}78}79
80return 0;81}
82
83} // namespace ncnn84