ncnn

Форк
0
/
scale.cpp 
174 строки · 4.3 Кб
1
// Tencent is pleased to support the open source community by making ncnn available.
2
//
3
// Copyright (C) 2017 THL A29 Limited, a Tencent company. All rights reserved.
4
//
5
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
6
// in compliance with the License. You may obtain a copy of the License at
7
//
8
// https://opensource.org/licenses/BSD-3-Clause
9
//
10
// Unless required by applicable law or agreed to in writing, software distributed
11
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
12
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
13
// specific language governing permissions and limitations under the License.
14

15
#include "scale.h"
16

17
namespace ncnn {
18

19
Scale::Scale()
20
{
21
    one_blob_only = true;
22
    support_inplace = true;
23
}
24

25
int Scale::load_param(const ParamDict& pd)
26
{
27
    scale_data_size = pd.get(0, 0);
28
    bias_term = pd.get(1, 0);
29

30
    if (scale_data_size == -233)
31
        one_blob_only = false;
32

33
    return 0;
34
}
35

36
int Scale::load_model(const ModelBin& mb)
37
{
38
    if (scale_data_size == -233)
39
        return 0;
40

41
    scale_data = mb.load(scale_data_size, 1);
42
    if (scale_data.empty())
43
        return -100;
44

45
    if (bias_term)
46
    {
47
        bias_data = mb.load(scale_data_size, 1);
48
        if (bias_data.empty())
49
            return -100;
50
    }
51

52
    return 0;
53
}
54

55
int Scale::forward_inplace(std::vector<Mat>& bottom_top_blobs, const Option& opt) const
56
{
57
    Mat& bottom_top_blob = bottom_top_blobs[0];
58
    const Mat& scale_blob = bottom_top_blobs[1];
59

60
    int dims = bottom_top_blob.dims;
61

62
    if (dims == 1)
63
    {
64
        int w = bottom_top_blob.w;
65

66
        float* ptr = bottom_top_blob;
67

68
        if (bias_term)
69
        {
70
            #pragma omp parallel for num_threads(opt.num_threads)
71
            for (int i = 0; i < w; i++)
72
            {
73
                ptr[i] = ptr[i] * scale_blob[i] + bias_data[i];
74
            }
75
        }
76
        else
77
        {
78
            #pragma omp parallel for num_threads(opt.num_threads)
79
            for (int i = 0; i < w; i++)
80
            {
81
                ptr[i] *= scale_blob[i];
82
            }
83
        }
84
    }
85

86
    if (dims == 2)
87
    {
88
        int w = bottom_top_blob.w;
89
        int h = bottom_top_blob.h;
90

91
        if (bias_term)
92
        {
93
            #pragma omp parallel for num_threads(opt.num_threads)
94
            for (int i = 0; i < h; i++)
95
            {
96
                float* ptr = bottom_top_blob.row(i);
97
                float s = scale_blob[i];
98
                float bias = bias_data[i];
99

100
                for (int j = 0; j < w; j++)
101
                {
102
                    ptr[j] = ptr[j] * s + bias;
103
                }
104
            }
105
        }
106
        else
107
        {
108
            #pragma omp parallel for num_threads(opt.num_threads)
109
            for (int i = 0; i < h; i++)
110
            {
111
                float* ptr = bottom_top_blob.row(i);
112
                float s = scale_blob[i];
113

114
                for (int j = 0; j < w; j++)
115
                {
116
                    ptr[j] *= s;
117
                }
118
            }
119
        }
120
    }
121

122
    if (dims == 3)
123
    {
124
        int w = bottom_top_blob.w;
125
        int h = bottom_top_blob.h;
126
        int channels = bottom_top_blob.c;
127
        int size = w * h;
128

129
        if (bias_term)
130
        {
131
            #pragma omp parallel for num_threads(opt.num_threads)
132
            for (int q = 0; q < channels; q++)
133
            {
134
                float* ptr = bottom_top_blob.channel(q);
135

136
                float s = scale_blob[q];
137
                float bias = bias_data[q];
138

139
                for (int i = 0; i < size; i++)
140
                {
141
                    ptr[i] = ptr[i] * s + bias;
142
                }
143
            }
144
        }
145
        else
146
        {
147
            #pragma omp parallel for num_threads(opt.num_threads)
148
            for (int q = 0; q < channels; q++)
149
            {
150
                float* ptr = bottom_top_blob.channel(q);
151

152
                float s = scale_blob[q];
153

154
                for (int i = 0; i < size; i++)
155
                {
156
                    ptr[i] *= s;
157
                }
158
            }
159
        }
160
    }
161

162
    return 0;
163
}
164

165
int Scale::forward_inplace(Mat& bottom_top_blob, const Option& opt) const
166
{
167
    std::vector<Mat> bottom_top_blobs(2);
168
    bottom_top_blobs[0] = bottom_top_blob;
169
    bottom_top_blobs[1] = scale_data;
170

171
    return forward_inplace(bottom_top_blobs, opt);
172
}
173

174
} // namespace ncnn
175

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.