ncnn
1// Tencent is pleased to support the open source community by making ncnn available.
2//
3// Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved.
4//
5// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
6// in compliance with the License. You may obtain a copy of the License at
7//
8// https://opensource.org/licenses/BSD-3-Clause
9//
10// Unless required by applicable law or agreed to in writing, software distributed
11// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
12// CONDITIONS OF ANY KIND, either express or implied. See the License for the
13// specific language governing permissions and limitations under the License.
14
15#include "diag.h"16
17namespace ncnn {18
19Diag::Diag()20{
21one_blob_only = true;22support_inplace = false;23}
24
25int Diag::load_param(const ParamDict& pd)26{
27diagonal = pd.get(0, 0);28
29return 0;30}
31
32int Diag::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const33{
34int dims = bottom_blob.dims;35size_t elemsize = bottom_blob.elemsize;36
37if (dims == 1)38{39int w = bottom_blob.w;40int top_w = w + ((diagonal >= 0) ? diagonal : -diagonal);41
42top_blob.create(top_w, top_w, elemsize, opt.blob_allocator);43if (top_blob.empty())44return -100;45
46top_blob.fill(0.0f);47
48int bias_r = -std::min(diagonal, 0);49int bias_c = std::max(diagonal, 0);50
51for (int i = 0; i < w; i++)52{53top_blob.row(i + bias_r)[i + bias_c] = bottom_blob[i];54}55}56if (dims == 2)57{58int w = bottom_blob.w;59int h = bottom_blob.h;60
61int len = 0;62int minimum = std::min(w - h, 0);63int maximum = std::max(w - h, 0);64if (diagonal <= maximum && diagonal >= minimum)65len = std::min(w, h);66else if (diagonal > -h && diagonal < minimum)67len = diagonal + h;68else if (diagonal > maximum && diagonal < w)69len = -diagonal + w;70
71top_blob.create(len, elemsize, opt.blob_allocator);72if (top_blob.empty())73{74if (len == 0)75return 0;76return -100;77}78
79int bias_r = -std::min(diagonal, 0);80int bias_c = std::max(diagonal, 0);81
82for (int i = 0; i < len; i++)83{84top_blob[i] = bottom_blob.row(i + bias_r)[i + bias_c];85}86}87
88return 0;89}
90
91} // namespace ncnn92