ncnn

Форк
0
/
groupnorm.cpp 
234 строки · 7.0 Кб
1
// Tencent is pleased to support the open source community by making ncnn available.
2
//
3
// Copyright (C) 2020 THL A29 Limited, a Tencent company. All rights reserved.
4
//
5
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
6
// in compliance with the License. You may obtain a copy of the License at
7
//
8
// https://opensource.org/licenses/BSD-3-Clause
9
//
10
// Unless required by applicable law or agreed to in writing, software distributed
11
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
12
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
13
// specific language governing permissions and limitations under the License.
14

15
#include "groupnorm.h"
16

17
namespace ncnn {
18

19
GroupNorm::GroupNorm()
20
{
21
    one_blob_only = true;
22
    support_inplace = true;
23
}
24

25
int GroupNorm::load_param(const ParamDict& pd)
26
{
27
    group = pd.get(0, 1);
28
    channels = pd.get(1, 0);
29
    eps = pd.get(2, 0.001f);
30
    affine = pd.get(3, 1);
31

32
    return 0;
33
}
34

35
int GroupNorm::load_model(const ModelBin& mb)
36
{
37
    if (affine == 0)
38
        return 0;
39

40
    gamma_data = mb.load(channels, 1);
41
    if (gamma_data.empty())
42
        return -100;
43

44
    beta_data = mb.load(channels, 1);
45
    if (beta_data.empty())
46
        return -100;
47

48
    return 0;
49
}
50

51
int GroupNorm::forward_inplace(Mat& bottom_top_blob, const Option& opt) const
52
{
53
    const int dims = bottom_top_blob.dims;
54
    const int channels_per_group = channels / group;
55

56
    if (dims == 1)
57
    {
58
        #pragma omp parallel for num_threads(opt.num_threads)
59
        for (int g = 0; g < group; g++)
60
        {
61
            Mat bottom_top_blob_g = bottom_top_blob.range(g * channels_per_group, channels_per_group);
62
            const Mat gamma_data_g = gamma_data.range(g * channels_per_group, channels_per_group);
63
            const Mat beta_data_g = beta_data.range(g * channels_per_group, channels_per_group);
64

65
            // mean and var
66
            float sum = 0.f;
67
            for (int q = 0; q < channels_per_group; q++)
68
            {
69
                sum += bottom_top_blob_g[q];
70
            }
71
            float mean = sum / channels_per_group;
72

73
            float sqsum = 0.f;
74
            for (int q = 0; q < channels_per_group; q++)
75
            {
76
                float tmp = bottom_top_blob_g[q] - mean;
77
                sqsum += tmp * tmp;
78
            }
79
            float var = sqsum / channels_per_group;
80

81
            for (int q = 0; q < channels_per_group; q++)
82
            {
83
                float a;
84
                float b;
85
                if (affine)
86
                {
87
                    float gamma = gamma_data_g[q];
88
                    float beta = beta_data_g[q];
89

90
                    a = gamma / sqrtf(var + eps);
91
                    b = -mean * a + beta;
92
                }
93
                else
94
                {
95
                    a = 1.f / (sqrtf(var + eps));
96
                    b = -mean * a;
97
                }
98

99
                bottom_top_blob_g[q] = bottom_top_blob_g[q] * a + b;
100
            }
101
        }
102
    }
103

104
    if (dims == 2)
105
    {
106
        int w = bottom_top_blob.w;
107

108
        #pragma omp parallel for num_threads(opt.num_threads)
109
        for (int g = 0; g < group; g++)
110
        {
111
            Mat bottom_top_blob_g = bottom_top_blob.row_range(g * channels_per_group, channels_per_group);
112
            const Mat gamma_data_g = gamma_data.range(g * channels_per_group, channels_per_group);
113
            const Mat beta_data_g = beta_data.range(g * channels_per_group, channels_per_group);
114

115
            // mean and var
116
            float sum = 0.f;
117
            for (int q = 0; q < channels_per_group; q++)
118
            {
119
                const float* ptr = bottom_top_blob_g.row(q);
120
                for (int i = 0; i < w; i++)
121
                {
122
                    sum += ptr[i];
123
                }
124
            }
125
            float mean = sum / (channels_per_group * w);
126

127
            float sqsum = 0.f;
128
            for (int q = 0; q < channels_per_group; q++)
129
            {
130
                const float* ptr = bottom_top_blob_g.row(q);
131
                for (int i = 0; i < w; i++)
132
                {
133
                    float tmp = ptr[i] - mean;
134
                    sqsum += tmp * tmp;
135
                }
136
            }
137
            float var = sqsum / (channels_per_group * w);
138

139
            for (int q = 0; q < channels_per_group; q++)
140
            {
141
                float a;
142
                float b;
143
                if (affine)
144
                {
145
                    float gamma = gamma_data_g[q];
146
                    float beta = beta_data_g[q];
147

148
                    a = gamma / sqrtf(var + eps);
149
                    b = -mean * a + beta;
150
                }
151
                else
152
                {
153
                    a = 1.f / (sqrtf(var + eps));
154
                    b = -mean * a;
155
                }
156

157
                float* ptr = bottom_top_blob_g.row(q);
158
                for (int i = 0; i < w; i++)
159
                {
160
                    ptr[i] = ptr[i] * a + b;
161
                }
162
            }
163
        }
164
    }
165

166
    if (dims == 3 || dims == 4)
167
    {
168
        int w = bottom_top_blob.w;
169
        int h = bottom_top_blob.h;
170
        int d = bottom_top_blob.d;
171
        int size = w * h * d;
172

173
        #pragma omp parallel for num_threads(opt.num_threads)
174
        for (int g = 0; g < group; g++)
175
        {
176
            Mat bottom_top_blob_g = bottom_top_blob.channel_range(g * channels_per_group, channels_per_group);
177
            const Mat gamma_data_g = gamma_data.range(g * channels_per_group, channels_per_group);
178
            const Mat beta_data_g = beta_data.range(g * channels_per_group, channels_per_group);
179

180
            // mean and var
181
            float sum = 0.f;
182
            for (int q = 0; q < channels_per_group; q++)
183
            {
184
                const float* ptr = bottom_top_blob_g.channel(q);
185
                for (int i = 0; i < size; i++)
186
                {
187
                    sum += ptr[i];
188
                }
189
            }
190
            float mean = sum / (channels_per_group * size);
191

192
            float sqsum = 0.f;
193
            for (int q = 0; q < channels_per_group; q++)
194
            {
195
                const float* ptr = bottom_top_blob_g.channel(q);
196
                for (int i = 0; i < size; i++)
197
                {
198
                    float tmp = ptr[i] - mean;
199
                    sqsum += tmp * tmp;
200
                }
201
            }
202
            float var = sqsum / (channels_per_group * size);
203

204
            for (int q = 0; q < channels_per_group; q++)
205
            {
206
                float a;
207
                float b;
208
                if (affine)
209
                {
210
                    float gamma = gamma_data_g[q];
211
                    float beta = beta_data_g[q];
212

213
                    a = gamma / sqrtf(var + eps);
214
                    b = -mean * a + beta;
215
                }
216
                else
217
                {
218
                    a = 1.f / (sqrtf(var + eps));
219
                    b = -mean * a;
220
                }
221

222
                float* ptr = bottom_top_blob_g.channel(q);
223
                for (int i = 0; i < size; i++)
224
                {
225
                    ptr[i] = ptr[i] * a + b;
226
                }
227
            }
228
        }
229
    }
230

231
    return 0;
232
}
233

234
} // namespace ncnn
235

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.