ncnn

Форк
0
/
normalize_reduce_sum4_fp32_pack8.comp 
359 строк · 14.0 Кб
1
// Tencent is pleased to support the open source community by making ncnn available.
2
//
3
// Copyright (C) 2020 THL A29 Limited, a Tencent company. All rights reserved.
4
//
5
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
6
// in compliance with the License. You may obtain a copy of the License at
7
//
8
// https://opensource.org/licenses/BSD-3-Clause
9
//
10
// Unless required by applicable law or agreed to in writing, software distributed
11
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
12
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
13
// specific language governing permissions and limitations under the License.
14

15
#version 450
16

17
#if NCNN_fp16_storage
18
#extension GL_EXT_shader_16bit_storage: require
19
struct sfpvec8 { f16vec4 abcd; f16vec4 efgh; };
20
#endif
21
#if NCNN_fp16_arithmetic
22
#extension GL_EXT_shader_explicit_arithmetic_types_float16: require
23
#endif
24

25
layout (constant_id = 0) const int across_spatial = 0;
26
layout (constant_id = 1) const int across_channel = 0;
27

28
#if NCNN_image_shader
29
layout (binding = 0) uniform highp sampler3D square_blob;
30
layout (binding = 1, rgba32f) writeonly uniform highp image3D sqsum_blob;
31
#else
32
layout (binding = 0) readonly buffer square_blob { mat2x4 square_blob_data[]; };
33
layout (binding = 1) writeonly buffer sqsum_blob { mat2x4 sqsum_blob_data[]; };
34
#endif
35

36
layout (push_constant) uniform parameter
37
{
38
    int w;
39
    int h;
40
    int c;
41
    int cstep;
42

43
    int outw;
44
    int outh;
45
    int outc;
46
    int outcstep;
47
} p;
48

49
void main()
50
{
51
    int gx = int(gl_GlobalInvocationID.x);
52
    int gy = int(gl_GlobalInvocationID.y);
53
    int gz = int(gl_GlobalInvocationID.z);
54

55
    if (gx >= p.outw || gy >= p.outh || gz >= p.outc)
56
        return;
57

58
    mat2x4 sqsum;
59

60
    if (across_spatial == 1 && across_channel == 1)
61
    {
62
#if NCNN_image_shader
63
        int sz = gz * 2;
64
        int sy = gy * 2;
65
        int sx = gx * 2;
66

67
        if (sz == p.c - 1)
68
        {
69
            if (sy == p.h - 1)
70
            {
71
                if (sx == p.w - 1)
72
                {
73
                    mat2x4 v0 = mat2x4(texelFetch(square_blob, ivec3(sx * 2, sy, sz), 0), texelFetch(square_blob, ivec3(sx * 2 + 1, sy, sz), 0));
74

75
                    sqsum = v0;
76
                }
77
                else
78
                {
79
                    mat2x4 v0 = mat2x4(texelFetch(square_blob, ivec3(sx * 2, sy, sz), 0), texelFetch(square_blob, ivec3(sx * 2 + 1, sy, sz), 0));
80
                    mat2x4 v1 = mat2x4(texelFetch(square_blob, ivec3((sx + 1) * 2, sy, sz), 0), texelFetch(square_blob, ivec3((sx + 1) * 2 + 1, sy, sz), 0));
81

82
                    sqsum = v0 + v1;
83
                }
84
            }
85
            else
86
            {
87
                if (sx == p.w - 1)
88
                {
89
                    mat2x4 v0 = mat2x4(texelFetch(square_blob, ivec3(sx * 2, sy, sz), 0), texelFetch(square_blob, ivec3(sx * 2 + 1, sy, sz), 0));
90
                    mat2x4 v2 = mat2x4(texelFetch(square_blob, ivec3(sx * 2, sy + 1, sz), 0), texelFetch(square_blob, ivec3(sx * 2 + 1, sy + 1, sz), 0));
91

92
                    sqsum = v0 + v2;
93
                }
94
                else
95
                {
96
                    mat2x4 v0 = mat2x4(texelFetch(square_blob, ivec3(sx * 2, sy, sz), 0), texelFetch(square_blob, ivec3(sx * 2 + 1, sy, sz), 0));
97
                    mat2x4 v1 = mat2x4(texelFetch(square_blob, ivec3((sx + 1) * 2, sy, sz), 0), texelFetch(square_blob, ivec3((sx + 1) * 2 + 1, sy, sz), 0));
98
                    mat2x4 v2 = mat2x4(texelFetch(square_blob, ivec3(sx * 2, sy + 1, sz), 0), texelFetch(square_blob, ivec3(sx * 2 + 1, sy + 1, sz), 0));
99
                    mat2x4 v3 = mat2x4(texelFetch(square_blob, ivec3((sx + 1) * 2, sy + 1, sz), 0), texelFetch(square_blob, ivec3((sx + 1) * 2 + 1, sy + 1, sz), 0));
100

101
                    sqsum = v0 + v1 + v2 + v3;
102
                }
103
            }
104
        }
105
        else
106
        {
107
            if (sy == p.h - 1)
108
            {
109
                if (sx == p.w - 1)
110
                {
111
                    mat2x4 v0 = mat2x4(texelFetch(square_blob, ivec3(sx * 2, sy, sz), 0), texelFetch(square_blob, ivec3(sx * 2 + 1, sy, sz), 0));
112
                    mat2x4 v4 = mat2x4(texelFetch(square_blob, ivec3(sx * 2, sy, sz + 1), 0), texelFetch(square_blob, ivec3(sx * 2 + 1, sy, sz + 1), 0));
113

114
                    sqsum = v0 + v4;
115
                }
116
                else
117
                {
118
                    mat2x4 v0 = mat2x4(texelFetch(square_blob, ivec3(sx * 2, sy, sz), 0), texelFetch(square_blob, ivec3(sx * 2 + 1, sy, sz), 0));
119
                    mat2x4 v1 = mat2x4(texelFetch(square_blob, ivec3((sx + 1) * 2, sy, sz), 0), texelFetch(square_blob, ivec3((sx + 1) * 2 + 1, sy, sz), 0));
120
                    mat2x4 v4 = mat2x4(texelFetch(square_blob, ivec3(sx * 2, sy, sz + 1), 0), texelFetch(square_blob, ivec3(sx * 2 + 1, sy, sz + 1), 0));
121
                    mat2x4 v5 = mat2x4(texelFetch(square_blob, ivec3((sx + 1) * 2, sy, sz + 1), 0), texelFetch(square_blob, ivec3((sx + 1) * 2 + 1, sy, sz + 1), 0));
122

123
                    sqsum = v0 + v1 + v4 + v5;
124
                }
125
            }
126
            else
127
            {
128
                if (sx == p.w - 1)
129
                {
130
                    mat2x4 v0 = mat2x4(texelFetch(square_blob, ivec3(sx * 2, sy, sz), 0), texelFetch(square_blob, ivec3(sx * 2 + 1, sy, sz), 0));
131
                    mat2x4 v2 = mat2x4(texelFetch(square_blob, ivec3(sx * 2, sy + 1, sz), 0), texelFetch(square_blob, ivec3(sx * 2 + 1, sy + 1, sz), 0));
132
                    mat2x4 v4 = mat2x4(texelFetch(square_blob, ivec3(sx * 2, sy, sz + 1), 0), texelFetch(square_blob, ivec3(sx * 2 + 1, sy, sz + 1), 0));
133
                    mat2x4 v6 = mat2x4(texelFetch(square_blob, ivec3(sx * 2, sy + 1, sz + 1), 0), texelFetch(square_blob, ivec3(sx * 2 + 1, sy + 1, sz + 1), 0));
134

135
                    sqsum = v0 + v2 + v4 + v6;
136
                }
137
                else
138
                {
139
                    mat2x4 v0 = mat2x4(texelFetch(square_blob, ivec3(sx * 2, sy, sz), 0), texelFetch(square_blob, ivec3(sx * 2 + 1, sy, sz), 0));
140
                    mat2x4 v1 = mat2x4(texelFetch(square_blob, ivec3((sx + 1) * 2, sy, sz), 0), texelFetch(square_blob, ivec3((sx + 1) * 2 + 1, sy, sz), 0));
141
                    mat2x4 v2 = mat2x4(texelFetch(square_blob, ivec3(sx * 2, sy + 1, sz), 0), texelFetch(square_blob, ivec3(sx * 2 + 1, sy + 1, sz), 0));
142
                    mat2x4 v3 = mat2x4(texelFetch(square_blob, ivec3((sx + 1) * 2, sy + 1, sz), 0), texelFetch(square_blob, ivec3((sx + 1) * 2 + 1, sy + 1, sz), 0));
143
                    mat2x4 v4 = mat2x4(texelFetch(square_blob, ivec3(sx * 2, sy, sz + 1), 0), texelFetch(square_blob, ivec3(sx * 2 + 1, sy, sz + 1), 0));
144
                    mat2x4 v5 = mat2x4(texelFetch(square_blob, ivec3((sx + 1) * 2, sy, sz + 1), 0), texelFetch(square_blob, ivec3((sx + 1) * 2 + 1, sy, sz + 1), 0));
145
                    mat2x4 v6 = mat2x4(texelFetch(square_blob, ivec3(sx * 2, sy + 1, sz + 1), 0), texelFetch(square_blob, ivec3(sx * 2 + 1, sy + 1, sz + 1), 0));
146
                    mat2x4 v7 = mat2x4(texelFetch(square_blob, ivec3((sx + 1) * 2, sy + 1, sz + 1), 0), texelFetch(square_blob, ivec3((sx + 1) * 2 + 1, sy + 1, sz + 1), 0));
147

148
                    sqsum = v0 + v1 + v2 + v3 + v4 + v5 + v6 + v7;
149
                }
150
            }
151
        }
152
#else
153
        int sz = gz * 2;
154
        int sx = gx * 2;
155

156
        ivec2 v_offset0 = sz * p.cstep + sx + ivec2(0, 1);
157
        ivec2 v_offset1 = v_offset0 + p.cstep;
158

159
        if (sz == p.c - 1)
160
        {
161
            if (sx == p.w - 1)
162
            {
163
                mat2x4 v0 = square_blob_data[v_offset0.r];
164

165
                sqsum = v0;
166
            }
167
            else
168
            {
169
                mat2x4 v0 = square_blob_data[v_offset0.r];
170
                mat2x4 v1 = square_blob_data[v_offset0.g];
171

172
                sqsum = v0 + v1;
173
            }
174
        }
175
        else
176
        {
177
            if (sx == p.w - 1)
178
            {
179
                mat2x4 v0 = square_blob_data[v_offset0.r];
180
                mat2x4 v2 = square_blob_data[v_offset1.r];
181

182
                sqsum = v0 + v2;
183
            }
184
            else
185
            {
186
                mat2x4 v0 = square_blob_data[v_offset0.r];
187
                mat2x4 v1 = square_blob_data[v_offset0.g];
188
                mat2x4 v2 = square_blob_data[v_offset1.r];
189
                mat2x4 v3 = square_blob_data[v_offset1.g];
190

191
                sqsum = v0 + v1 + v2 + v3;
192
            }
193
        }
194
#endif
195
    }
196

197
    if (across_spatial == 1 && across_channel == 0)
198
    {
199
#if NCNN_image_shader
200
        int sy = gy * 2;
201
        int sx = gx * 2;
202

203
        if (sy == p.h - 1)
204
        {
205
            if (sx == p.w - 1)
206
            {
207
                mat2x4 v0 = mat2x4(texelFetch(square_blob, ivec3(sx * 2, sy, gz), 0), texelFetch(square_blob, ivec3(sx * 2 + 1, sy, gz), 0));
208

209
                sqsum = v0;
210
            }
211
            else
212
            {
213
                mat2x4 v0 = mat2x4(texelFetch(square_blob, ivec3(sx * 2, sy, gz), 0), texelFetch(square_blob, ivec3(sx * 2 + 1, sy, gz), 0));
214
                mat2x4 v1 = mat2x4(texelFetch(square_blob, ivec3((sx + 1) * 2, sy, gz), 0), texelFetch(square_blob, ivec3((sx + 1) * 2 + 1, sy, gz), 0));
215

216
                sqsum = v0 + v1;
217
            }
218
        }
219
        else
220
        {
221
            if (sx == p.w - 1)
222
            {
223
                mat2x4 v0 = mat2x4(texelFetch(square_blob, ivec3(sx * 2, sy, gz), 0), texelFetch(square_blob, ivec3(sx * 2 + 1, sy, gz), 0));
224
                mat2x4 v2 = mat2x4(texelFetch(square_blob, ivec3(sx * 2, sy + 1, gz), 0), texelFetch(square_blob, ivec3(sx * 2 + 1, sy + 1, gz), 0));
225

226
                sqsum = v0 + v2;
227
            }
228
            else
229
            {
230
                mat2x4 v0 = mat2x4(texelFetch(square_blob, ivec3(sx * 2, sy, gz), 0), texelFetch(square_blob, ivec3(sx * 2 + 1, sy, gz), 0));
231
                mat2x4 v1 = mat2x4(texelFetch(square_blob, ivec3((sx + 1) * 2, sy, gz), 0), texelFetch(square_blob, ivec3((sx + 1) * 2 + 1, sy, gz), 0));
232
                mat2x4 v2 = mat2x4(texelFetch(square_blob, ivec3(sx * 2, sy + 1, gz), 0), texelFetch(square_blob, ivec3(sx * 2 + 1, sy + 1, gz), 0));
233
                mat2x4 v3 = mat2x4(texelFetch(square_blob, ivec3((sx + 1) * 2, sy + 1, gz), 0), texelFetch(square_blob, ivec3((sx + 1) * 2 + 1, sy + 1, gz), 0));
234

235
                sqsum = v0 + v1 + v2 + v3;
236
            }
237
        }
238
#else
239
        int sz = gz;
240
        int sx = gx * 4;
241

242
        ivec4 v_offset = sz * p.cstep + sx + ivec4(0, 1, 2, 3);
243

244
        if (sx == p.w - 1)
245
        {
246
            mat2x4 v0 = square_blob_data[v_offset.r];
247

248
            sqsum = v0;
249
        }
250
        else if (sx == p.w - 2)
251
        {
252
            mat2x4 v0 = square_blob_data[v_offset.r];
253
            mat2x4 v1 = square_blob_data[v_offset.g];
254

255
            sqsum = v0 + v1;
256
        }
257
        else if (sx == p.w - 3)
258
        {
259
            mat2x4 v0 = square_blob_data[v_offset.r];
260
            mat2x4 v1 = square_blob_data[v_offset.g];
261
            mat2x4 v2 = square_blob_data[v_offset.b];
262

263
            sqsum = v0 + v1 + v2;
264
        }
265
        else
266
        {
267
            mat2x4 v0 = square_blob_data[v_offset.r];
268
            mat2x4 v1 = square_blob_data[v_offset.g];
269
            mat2x4 v2 = square_blob_data[v_offset.b];
270
            mat2x4 v3 = square_blob_data[v_offset.a];
271

272
            sqsum = v0 + v1 + v2 + v3;
273
        }
274
#endif
275
    }
276

277
    if (across_spatial == 0 && across_channel == 1)
278
    {
279
#if NCNN_image_shader
280
        int sz = gz * 4;
281

282
        if (sz == p.c - 1)
283
        {
284
            mat2x4 v0 = mat2x4(texelFetch(square_blob, ivec3(gx * 2, gy, sz), 0), texelFetch(square_blob, ivec3(gx * 2 + 1, gy, sz), 0));
285

286
            sqsum = v0;
287
        }
288
        else if (sz == p.c - 2)
289
        {
290
            mat2x4 v0 = mat2x4(texelFetch(square_blob, ivec3(gx * 2, gy, sz), 0), texelFetch(square_blob, ivec3(gx * 2 + 1, gy, sz), 0));
291
            mat2x4 v1 = mat2x4(texelFetch(square_blob, ivec3(gx * 2, gy, sz + 1), 0), texelFetch(square_blob, ivec3(gx * 2 + 1, gy, sz + 1), 0));
292

293
            sqsum = v0 + v1;
294
        }
295
        else if (sz == p.c - 3)
296
        {
297
            mat2x4 v0 = mat2x4(texelFetch(square_blob, ivec3(gx * 2, gy, sz), 0), texelFetch(square_blob, ivec3(gx * 2 + 1, gy, sz), 0));
298
            mat2x4 v1 = mat2x4(texelFetch(square_blob, ivec3(gx * 2, gy, sz + 1), 0), texelFetch(square_blob, ivec3(gx * 2 + 1, gy, sz + 1), 0));
299
            mat2x4 v2 = mat2x4(texelFetch(square_blob, ivec3(gx * 2, gy, sz + 2), 0), texelFetch(square_blob, ivec3(gx * 2 + 1, gy, sz + 2), 0));
300

301
            sqsum = v0 + v1 + v2;
302
        }
303
        else
304
        {
305
            mat2x4 v0 = mat2x4(texelFetch(square_blob, ivec3(gx * 2, gy, sz), 0), texelFetch(square_blob, ivec3(gx * 2 + 1, gy, sz), 0));
306
            mat2x4 v1 = mat2x4(texelFetch(square_blob, ivec3(gx * 2, gy, sz + 1), 0), texelFetch(square_blob, ivec3(gx * 2 + 1, gy, sz + 1), 0));
307
            mat2x4 v2 = mat2x4(texelFetch(square_blob, ivec3(gx * 2, gy, sz + 2), 0), texelFetch(square_blob, ivec3(gx * 2 + 1, gy, sz + 2), 0));
308
            mat2x4 v3 = mat2x4(texelFetch(square_blob, ivec3(gx * 2, gy, sz + 3), 0), texelFetch(square_blob, ivec3(gx * 2 + 1, gy, sz + 3), 0));
309

310
            sqsum = v0 + v1 + v2 + v3;
311
        }
312
#else
313
        int sz = gz * 4;
314
        int sx = gx;
315

316
        ivec4 v_offset = (sz + ivec4(0, 1, 2, 3)) * p.cstep + sx;
317

318
        if (sz == p.c - 1)
319
        {
320
            mat2x4 v0 = square_blob_data[v_offset.r];
321

322
            sqsum = v0;
323
        }
324
        else if (sz == p.c - 2)
325
        {
326
            mat2x4 v0 = square_blob_data[v_offset.r];
327
            mat2x4 v1 = square_blob_data[v_offset.g];
328

329
            sqsum = v0 + v1;
330
        }
331
        else if (sz == p.c - 3)
332
        {
333
            mat2x4 v0 = square_blob_data[v_offset.r];
334
            mat2x4 v1 = square_blob_data[v_offset.g];
335
            mat2x4 v2 = square_blob_data[v_offset.b];
336

337
            sqsum = v0 + v1 + v2;
338
        }
339
        else
340
        {
341
            mat2x4 v0 = square_blob_data[v_offset.r];
342
            mat2x4 v1 = square_blob_data[v_offset.g];
343
            mat2x4 v2 = square_blob_data[v_offset.b];
344
            mat2x4 v3 = square_blob_data[v_offset.a];
345

346
            sqsum = v0 + v1 + v2 + v3;
347
        }
348
#endif
349
    }
350

351
#if NCNN_image_shader
352
    imageStore(sqsum_blob, ivec3(gx * 2, gy, gz), sqsum[0]);
353
    imageStore(sqsum_blob, ivec3(gx * 2 + 1, gy, gz), sqsum[1]);
354
#else
355
    int gi = gz * p.outcstep + gx;
356

357
    sqsum_blob_data[gi] = sqsum;
358
#endif
359
}
360

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.