ncnn

Форк
0
/
normalize_reduce_sum4_fp32.comp 
357 строк · 10.7 Кб
1
// Tencent is pleased to support the open source community by making ncnn available.
2
//
3
// Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved.
4
//
5
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
6
// in compliance with the License. You may obtain a copy of the License at
7
//
8
// https://opensource.org/licenses/BSD-3-Clause
9
//
10
// Unless required by applicable law or agreed to in writing, software distributed
11
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
12
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
13
// specific language governing permissions and limitations under the License.
14

15
#version 450
16

17
#if NCNN_fp16_storage
18
#extension GL_EXT_shader_16bit_storage: require
19
#endif
20
#if NCNN_fp16_arithmetic
21
#extension GL_EXT_shader_explicit_arithmetic_types_float16: require
22
#endif
23

24
layout (constant_id = 0) const int across_spatial = 0;
25
layout (constant_id = 1) const int across_channel = 0;
26

27
#if NCNN_image_shader
28
layout (binding = 0) uniform highp sampler3D square_blob;
29
layout (binding = 1, r32f) writeonly uniform highp image3D sqsum_blob;
30
#else
31
layout (binding = 0) readonly buffer square_blob { float square_blob_data[]; };
32
layout (binding = 1) writeonly buffer sqsum_blob { float sqsum_blob_data[]; };
33
#endif
34

35
layout (push_constant) uniform parameter
36
{
37
    int w;
38
    int h;
39
    int c;
40
    int cstep;
41

42
    int outw;
43
    int outh;
44
    int outc;
45
    int outcstep;
46
} p;
47

48
void main()
49
{
50
    int gx = int(gl_GlobalInvocationID.x);
51
    int gy = int(gl_GlobalInvocationID.y);
52
    int gz = int(gl_GlobalInvocationID.z);
53

54
    if (gx >= p.outw || gy >= p.outh || gz >= p.outc)
55
        return;
56

57
    float sqsum;
58

59
    if (across_spatial == 1 && across_channel == 1)
60
    {
61
#if NCNN_image_shader
62
        int sz = gz * 2;
63
        int sy = gy * 2;
64
        int sx = gx * 2;
65

66
        if (sz == p.c - 1)
67
        {
68
            if (sy == p.h - 1)
69
            {
70
                if (sx == p.w - 1)
71
                {
72
                    float v0 = texelFetch(square_blob, ivec3(sx, sy, sz), 0).r;
73

74
                    sqsum = v0;
75
                }
76
                else
77
                {
78
                    float v0 = texelFetch(square_blob, ivec3(sx, sy, sz), 0).r;
79
                    float v1 = texelFetch(square_blob, ivec3(sx + 1, sy, sz), 0).r;
80

81
                    sqsum = v0 + v1;
82
                }
83
            }
84
            else
85
            {
86
                if (sx == p.w - 1)
87
                {
88
                    float v0 = texelFetch(square_blob, ivec3(sx, sy, sz), 0).r;
89
                    float v2 = texelFetch(square_blob, ivec3(sx, sy + 1, sz), 0).r;
90

91
                    sqsum = v0 + v2;
92
                }
93
                else
94
                {
95
                    float v0 = texelFetch(square_blob, ivec3(sx, sy, sz), 0).r;
96
                    float v1 = texelFetch(square_blob, ivec3(sx + 1, sy, sz), 0).r;
97
                    float v2 = texelFetch(square_blob, ivec3(sx, sy + 1, sz), 0).r;
98
                    float v3 = texelFetch(square_blob, ivec3(sx + 1, sy + 1, sz), 0).r;
99

100
                    sqsum = v0 + v1 + v2 + v3;
101
                }
102
            }
103
        }
104
        else
105
        {
106
            if (sy == p.h - 1)
107
            {
108
                if (sx == p.w - 1)
109
                {
110
                    float v0 = texelFetch(square_blob, ivec3(sx, sy, sz), 0).r;
111
                    float v4 = texelFetch(square_blob, ivec3(sx, sy, sz + 1), 0).r;
112

113
                    sqsum = v0 + v4;
114
                }
115
                else
116
                {
117
                    float v0 = texelFetch(square_blob, ivec3(sx, sy, sz), 0).r;
118
                    float v1 = texelFetch(square_blob, ivec3(sx + 1, sy, sz), 0).r;
119
                    float v4 = texelFetch(square_blob, ivec3(sx, sy, sz + 1), 0).r;
120
                    float v5 = texelFetch(square_blob, ivec3(sx + 1, sy, sz + 1), 0).r;
121

122
                    sqsum = v0 + v1 + v4 + v5;
123
                }
124
            }
125
            else
126
            {
127
                if (sx == p.w - 1)
128
                {
129
                    float v0 = texelFetch(square_blob, ivec3(sx, sy, sz), 0).r;
130
                    float v2 = texelFetch(square_blob, ivec3(sx, sy + 1, sz), 0).r;
131
                    float v4 = texelFetch(square_blob, ivec3(sx, sy, sz + 1), 0).r;
132
                    float v6 = texelFetch(square_blob, ivec3(sx, sy + 1, sz + 1), 0).r;
133

134
                    sqsum = v0 + v2 + v4 + v6;
135
                }
136
                else
137
                {
138
                    float v0 = texelFetch(square_blob, ivec3(sx, sy, sz), 0).r;
139
                    float v1 = texelFetch(square_blob, ivec3(sx + 1, sy, sz), 0).r;
140
                    float v2 = texelFetch(square_blob, ivec3(sx, sy + 1, sz), 0).r;
141
                    float v3 = texelFetch(square_blob, ivec3(sx + 1, sy + 1, sz), 0).r;
142
                    float v4 = texelFetch(square_blob, ivec3(sx, sy, sz + 1), 0).r;
143
                    float v5 = texelFetch(square_blob, ivec3(sx + 1, sy, sz + 1), 0).r;
144
                    float v6 = texelFetch(square_blob, ivec3(sx, sy + 1, sz + 1), 0).r;
145
                    float v7 = texelFetch(square_blob, ivec3(sx + 1, sy + 1, sz + 1), 0).r;
146

147
                    sqsum = v0 + v1 + v2 + v3 + v4 + v5 + v6 + v7;
148
                }
149
            }
150
        }
151
#else
152
        int sz = gz * 2;
153
        int sx = gx * 2;
154

155
        ivec2 v_offset0 = sz * p.cstep + sx + ivec2(0, 1);
156
        ivec2 v_offset1 = v_offset0 + p.cstep;
157

158
        if (sz == p.c - 1)
159
        {
160
            if (sx == p.w - 1)
161
            {
162
                float v0 = square_blob_data[v_offset0.r];
163

164
                sqsum = v0;
165
            }
166
            else
167
            {
168
                float v0 = square_blob_data[v_offset0.r];
169
                float v1 = square_blob_data[v_offset0.g];
170

171
                sqsum = v0 + v1;
172
            }
173
        }
174
        else
175
        {
176
            if (sx == p.w - 1)
177
            {
178
                float v0 = square_blob_data[v_offset0.r];
179
                float v2 = square_blob_data[v_offset1.r];
180

181
                sqsum = v0 + v2;
182
            }
183
            else
184
            {
185
                float v0 = square_blob_data[v_offset0.r];
186
                float v1 = square_blob_data[v_offset0.g];
187
                float v2 = square_blob_data[v_offset1.r];
188
                float v3 = square_blob_data[v_offset1.g];
189

190
                sqsum = v0 + v1 + v2 + v3;
191
            }
192
        }
193
#endif
194
    }
195

196
    if (across_spatial == 1 && across_channel == 0)
197
    {
198
#if NCNN_image_shader
199
        int sy = gy * 2;
200
        int sx = gx * 2;
201

202
        if (sy == p.h - 1)
203
        {
204
            if (sx == p.w - 1)
205
            {
206
                float v0 = texelFetch(square_blob, ivec3(sx, sy, gz), 0).r;
207

208
                sqsum = v0;
209
            }
210
            else
211
            {
212
                float v0 = texelFetch(square_blob, ivec3(sx, sy, gz), 0).r;
213
                float v1 = texelFetch(square_blob, ivec3(sx + 1, sy, gz), 0).r;
214

215
                sqsum = v0 + v1;
216
            }
217
        }
218
        else
219
        {
220
            if (sx == p.w - 1)
221
            {
222
                float v0 = texelFetch(square_blob, ivec3(sx, sy, gz), 0).r;
223
                float v2 = texelFetch(square_blob, ivec3(sx, sy + 1, gz), 0).r;
224

225
                sqsum = v0 + v2;
226
            }
227
            else
228
            {
229
                float v0 = texelFetch(square_blob, ivec3(sx, sy, gz), 0).r;
230
                float v1 = texelFetch(square_blob, ivec3(sx + 1, sy, gz), 0).r;
231
                float v2 = texelFetch(square_blob, ivec3(sx, sy + 1, gz), 0).r;
232
                float v3 = texelFetch(square_blob, ivec3(sx + 1, sy + 1, gz), 0).r;
233

234
                sqsum = v0 + v1 + v2 + v3;
235
            }
236
        }
237
#else
238
        int sz = gz;
239
        int sx = gx * 4;
240

241
        ivec4 v_offset = sz * p.cstep + sx + ivec4(0, 1, 2, 3);
242

243
        if (sx == p.w - 1)
244
        {
245
            float v0 = square_blob_data[v_offset.r];
246

247
            sqsum = v0;
248
        }
249
        else if (sx == p.w - 2)
250
        {
251
            float v0 = square_blob_data[v_offset.r];
252
            float v1 = square_blob_data[v_offset.g];
253

254
            sqsum = v0 + v1;
255
        }
256
        else if (sx == p.w - 3)
257
        {
258
            float v0 = square_blob_data[v_offset.r];
259
            float v1 = square_blob_data[v_offset.g];
260
            float v2 = square_blob_data[v_offset.b];
261

262
            sqsum = v0 + v1 + v2;
263
        }
264
        else
265
        {
266
            float v0 = square_blob_data[v_offset.r];
267
            float v1 = square_blob_data[v_offset.g];
268
            float v2 = square_blob_data[v_offset.b];
269
            float v3 = square_blob_data[v_offset.a];
270

271
            sqsum = v0 + v1 + v2 + v3;
272
        }
273
#endif
274
    }
275

276
    if (across_spatial == 0 && across_channel == 1)
277
    {
278
#if NCNN_image_shader
279
        int sz = gz * 4;
280

281
        if (sz == p.c - 1)
282
        {
283
            float v0 = texelFetch(square_blob, ivec3(gx, gy, sz), 0).r;
284

285
            sqsum = v0;
286
        }
287
        else if (sz == p.c - 2)
288
        {
289
            float v0 = texelFetch(square_blob, ivec3(gx, gy, sz), 0).r;
290
            float v1 = texelFetch(square_blob, ivec3(gx, gy, sz + 1), 0).r;
291

292
            sqsum = v0 + v1;
293
        }
294
        else if (sz == p.c - 3)
295
        {
296
            float v0 = texelFetch(square_blob, ivec3(gx, gy, sz), 0).r;
297
            float v1 = texelFetch(square_blob, ivec3(gx, gy, sz + 1), 0).r;
298
            float v2 = texelFetch(square_blob, ivec3(gx, gy, sz + 2), 0).r;
299

300
            sqsum = v0 + v1 + v2;
301
        }
302
        else
303
        {
304
            float v0 = texelFetch(square_blob, ivec3(gx, gy, sz), 0).r;
305
            float v1 = texelFetch(square_blob, ivec3(gx, gy, sz + 1), 0).r;
306
            float v2 = texelFetch(square_blob, ivec3(gx, gy, sz + 2), 0).r;
307
            float v3 = texelFetch(square_blob, ivec3(gx, gy, sz + 3), 0).r;
308

309
            sqsum = v0 + v1 + v2 + v3;
310
        }
311
#else
312
        int sz = gz * 4;
313
        int sx = gx;
314

315
        ivec4 v_offset = (sz + ivec4(0, 1, 2, 3)) * p.cstep + sx;
316

317
        if (sz == p.c - 1)
318
        {
319
            float v0 = square_blob_data[v_offset.r];
320

321
            sqsum = v0;
322
        }
323
        else if (sz == p.c - 2)
324
        {
325
            float v0 = square_blob_data[v_offset.r];
326
            float v1 = square_blob_data[v_offset.g];
327

328
            sqsum = v0 + v1;
329
        }
330
        else if (sz == p.c - 3)
331
        {
332
            float v0 = square_blob_data[v_offset.r];
333
            float v1 = square_blob_data[v_offset.g];
334
            float v2 = square_blob_data[v_offset.b];
335

336
            sqsum = v0 + v1 + v2;
337
        }
338
        else
339
        {
340
            float v0 = square_blob_data[v_offset.r];
341
            float v1 = square_blob_data[v_offset.g];
342
            float v2 = square_blob_data[v_offset.b];
343
            float v3 = square_blob_data[v_offset.a];
344

345
            sqsum = v0 + v1 + v2 + v3;
346
        }
347
#endif
348
    }
349

350
#if NCNN_image_shader
351
    imageStore(sqsum_blob, ivec3(gx, gy, gz), vec4(sqsum));
352
#else
353
    int gi = gz * p.outcstep + gx;
354

355
    sqsum_blob_data[gi] = sqsum;
356
#endif
357
}
358

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.