ncnn

Форк
0
/
diag.cpp 
91 строка · 2.4 Кб
1
// Tencent is pleased to support the open source community by making ncnn available.
2
//
3
// Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved.
4
//
5
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
6
// in compliance with the License. You may obtain a copy of the License at
7
//
8
// https://opensource.org/licenses/BSD-3-Clause
9
//
10
// Unless required by applicable law or agreed to in writing, software distributed
11
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
12
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
13
// specific language governing permissions and limitations under the License.
14

15
#include "diag.h"
16

17
namespace ncnn {
18

19
Diag::Diag()
20
{
21
    one_blob_only = true;
22
    support_inplace = false;
23
}
24

25
int Diag::load_param(const ParamDict& pd)
26
{
27
    diagonal = pd.get(0, 0);
28

29
    return 0;
30
}
31

32
int Diag::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const
33
{
34
    int dims = bottom_blob.dims;
35
    size_t elemsize = bottom_blob.elemsize;
36

37
    if (dims == 1)
38
    {
39
        int w = bottom_blob.w;
40
        int top_w = w + ((diagonal >= 0) ? diagonal : -diagonal);
41

42
        top_blob.create(top_w, top_w, elemsize, opt.blob_allocator);
43
        if (top_blob.empty())
44
            return -100;
45

46
        top_blob.fill(0.0f);
47

48
        int bias_r = -std::min(diagonal, 0);
49
        int bias_c = std::max(diagonal, 0);
50

51
        for (int i = 0; i < w; i++)
52
        {
53
            top_blob.row(i + bias_r)[i + bias_c] = bottom_blob[i];
54
        }
55
    }
56
    if (dims == 2)
57
    {
58
        int w = bottom_blob.w;
59
        int h = bottom_blob.h;
60

61
        int len = 0;
62
        int minimum = std::min(w - h, 0);
63
        int maximum = std::max(w - h, 0);
64
        if (diagonal <= maximum && diagonal >= minimum)
65
            len = std::min(w, h);
66
        else if (diagonal > -h && diagonal < minimum)
67
            len = diagonal + h;
68
        else if (diagonal > maximum && diagonal < w)
69
            len = -diagonal + w;
70

71
        top_blob.create(len, elemsize, opt.blob_allocator);
72
        if (top_blob.empty())
73
        {
74
            if (len == 0)
75
                return 0;
76
            return -100;
77
        }
78

79
        int bias_r = -std::min(diagonal, 0);
80
        int bias_c = std::max(diagonal, 0);
81

82
        for (int i = 0; i < len; i++)
83
        {
84
            top_blob[i] = bottom_blob.row(i + bias_r)[i + bias_c];
85
        }
86
    }
87

88
    return 0;
89
}
90

91
} // namespace ncnn
92

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.