1
// Tencent is pleased to support the open source community by making ncnn available.
3
// Copyright (C) 2020 THL A29 Limited, a Tencent company. All rights reserved.
5
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
6
// in compliance with the License. You may obtain a copy of the License at
8
// https://opensource.org/licenses/BSD-3-Clause
10
// Unless required by applicable law or agreed to in writing, software distributed
11
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
12
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
13
// specific language governing permissions and limitations under the License.
18
#extension GL_EXT_shader_16bit_storage: require
19
struct sfpvec8 { f16vec4 abcd; f16vec4 efgh; };
21
#if NCNN_fp16_arithmetic
22
#extension GL_EXT_shader_explicit_arithmetic_types_float16: require
25
layout (constant_id = 0) const int across_spatial = 0;
26
layout (constant_id = 1) const int across_channel = 0;
29
layout (binding = 0) uniform highp sampler3D square_blob;
30
layout (binding = 1, rgba32f) writeonly uniform highp image3D sqsum_blob;
32
layout (binding = 0) readonly buffer square_blob { mat2x4 square_blob_data[]; };
33
layout (binding = 1) writeonly buffer sqsum_blob { mat2x4 sqsum_blob_data[]; };
36
layout (push_constant) uniform parameter
51
int gx = int(gl_GlobalInvocationID.x);
52
int gy = int(gl_GlobalInvocationID.y);
53
int gz = int(gl_GlobalInvocationID.z);
55
if (gx >= p.outw || gy >= p.outh || gz >= p.outc)
60
if (across_spatial == 1 && across_channel == 1)
73
mat2x4 v0 = mat2x4(texelFetch(square_blob, ivec3(sx * 2, sy, sz), 0), texelFetch(square_blob, ivec3(sx * 2 + 1, sy, sz), 0));
79
mat2x4 v0 = mat2x4(texelFetch(square_blob, ivec3(sx * 2, sy, sz), 0), texelFetch(square_blob, ivec3(sx * 2 + 1, sy, sz), 0));
80
mat2x4 v1 = mat2x4(texelFetch(square_blob, ivec3((sx + 1) * 2, sy, sz), 0), texelFetch(square_blob, ivec3((sx + 1) * 2 + 1, sy, sz), 0));
89
mat2x4 v0 = mat2x4(texelFetch(square_blob, ivec3(sx * 2, sy, sz), 0), texelFetch(square_blob, ivec3(sx * 2 + 1, sy, sz), 0));
90
mat2x4 v2 = mat2x4(texelFetch(square_blob, ivec3(sx * 2, sy + 1, sz), 0), texelFetch(square_blob, ivec3(sx * 2 + 1, sy + 1, sz), 0));
96
mat2x4 v0 = mat2x4(texelFetch(square_blob, ivec3(sx * 2, sy, sz), 0), texelFetch(square_blob, ivec3(sx * 2 + 1, sy, sz), 0));
97
mat2x4 v1 = mat2x4(texelFetch(square_blob, ivec3((sx + 1) * 2, sy, sz), 0), texelFetch(square_blob, ivec3((sx + 1) * 2 + 1, sy, sz), 0));
98
mat2x4 v2 = mat2x4(texelFetch(square_blob, ivec3(sx * 2, sy + 1, sz), 0), texelFetch(square_blob, ivec3(sx * 2 + 1, sy + 1, sz), 0));
99
mat2x4 v3 = mat2x4(texelFetch(square_blob, ivec3((sx + 1) * 2, sy + 1, sz), 0), texelFetch(square_blob, ivec3((sx + 1) * 2 + 1, sy + 1, sz), 0));
101
sqsum = v0 + v1 + v2 + v3;
111
mat2x4 v0 = mat2x4(texelFetch(square_blob, ivec3(sx * 2, sy, sz), 0), texelFetch(square_blob, ivec3(sx * 2 + 1, sy, sz), 0));
112
mat2x4 v4 = mat2x4(texelFetch(square_blob, ivec3(sx * 2, sy, sz + 1), 0), texelFetch(square_blob, ivec3(sx * 2 + 1, sy, sz + 1), 0));
118
mat2x4 v0 = mat2x4(texelFetch(square_blob, ivec3(sx * 2, sy, sz), 0), texelFetch(square_blob, ivec3(sx * 2 + 1, sy, sz), 0));
119
mat2x4 v1 = mat2x4(texelFetch(square_blob, ivec3((sx + 1) * 2, sy, sz), 0), texelFetch(square_blob, ivec3((sx + 1) * 2 + 1, sy, sz), 0));
120
mat2x4 v4 = mat2x4(texelFetch(square_blob, ivec3(sx * 2, sy, sz + 1), 0), texelFetch(square_blob, ivec3(sx * 2 + 1, sy, sz + 1), 0));
121
mat2x4 v5 = mat2x4(texelFetch(square_blob, ivec3((sx + 1) * 2, sy, sz + 1), 0), texelFetch(square_blob, ivec3((sx + 1) * 2 + 1, sy, sz + 1), 0));
123
sqsum = v0 + v1 + v4 + v5;
130
mat2x4 v0 = mat2x4(texelFetch(square_blob, ivec3(sx * 2, sy, sz), 0), texelFetch(square_blob, ivec3(sx * 2 + 1, sy, sz), 0));
131
mat2x4 v2 = mat2x4(texelFetch(square_blob, ivec3(sx * 2, sy + 1, sz), 0), texelFetch(square_blob, ivec3(sx * 2 + 1, sy + 1, sz), 0));
132
mat2x4 v4 = mat2x4(texelFetch(square_blob, ivec3(sx * 2, sy, sz + 1), 0), texelFetch(square_blob, ivec3(sx * 2 + 1, sy, sz + 1), 0));
133
mat2x4 v6 = mat2x4(texelFetch(square_blob, ivec3(sx * 2, sy + 1, sz + 1), 0), texelFetch(square_blob, ivec3(sx * 2 + 1, sy + 1, sz + 1), 0));
135
sqsum = v0 + v2 + v4 + v6;
139
mat2x4 v0 = mat2x4(texelFetch(square_blob, ivec3(sx * 2, sy, sz), 0), texelFetch(square_blob, ivec3(sx * 2 + 1, sy, sz), 0));
140
mat2x4 v1 = mat2x4(texelFetch(square_blob, ivec3((sx + 1) * 2, sy, sz), 0), texelFetch(square_blob, ivec3((sx + 1) * 2 + 1, sy, sz), 0));
141
mat2x4 v2 = mat2x4(texelFetch(square_blob, ivec3(sx * 2, sy + 1, sz), 0), texelFetch(square_blob, ivec3(sx * 2 + 1, sy + 1, sz), 0));
142
mat2x4 v3 = mat2x4(texelFetch(square_blob, ivec3((sx + 1) * 2, sy + 1, sz), 0), texelFetch(square_blob, ivec3((sx + 1) * 2 + 1, sy + 1, sz), 0));
143
mat2x4 v4 = mat2x4(texelFetch(square_blob, ivec3(sx * 2, sy, sz + 1), 0), texelFetch(square_blob, ivec3(sx * 2 + 1, sy, sz + 1), 0));
144
mat2x4 v5 = mat2x4(texelFetch(square_blob, ivec3((sx + 1) * 2, sy, sz + 1), 0), texelFetch(square_blob, ivec3((sx + 1) * 2 + 1, sy, sz + 1), 0));
145
mat2x4 v6 = mat2x4(texelFetch(square_blob, ivec3(sx * 2, sy + 1, sz + 1), 0), texelFetch(square_blob, ivec3(sx * 2 + 1, sy + 1, sz + 1), 0));
146
mat2x4 v7 = mat2x4(texelFetch(square_blob, ivec3((sx + 1) * 2, sy + 1, sz + 1), 0), texelFetch(square_blob, ivec3((sx + 1) * 2 + 1, sy + 1, sz + 1), 0));
148
sqsum = v0 + v1 + v2 + v3 + v4 + v5 + v6 + v7;
156
ivec2 v_offset0 = sz * p.cstep + sx + ivec2(0, 1);
157
ivec2 v_offset1 = v_offset0 + p.cstep;
163
mat2x4 v0 = square_blob_data[v_offset0.r];
169
mat2x4 v0 = square_blob_data[v_offset0.r];
170
mat2x4 v1 = square_blob_data[v_offset0.g];
179
mat2x4 v0 = square_blob_data[v_offset0.r];
180
mat2x4 v2 = square_blob_data[v_offset1.r];
186
mat2x4 v0 = square_blob_data[v_offset0.r];
187
mat2x4 v1 = square_blob_data[v_offset0.g];
188
mat2x4 v2 = square_blob_data[v_offset1.r];
189
mat2x4 v3 = square_blob_data[v_offset1.g];
191
sqsum = v0 + v1 + v2 + v3;
197
if (across_spatial == 1 && across_channel == 0)
207
mat2x4 v0 = mat2x4(texelFetch(square_blob, ivec3(sx * 2, sy, gz), 0), texelFetch(square_blob, ivec3(sx * 2 + 1, sy, gz), 0));
213
mat2x4 v0 = mat2x4(texelFetch(square_blob, ivec3(sx * 2, sy, gz), 0), texelFetch(square_blob, ivec3(sx * 2 + 1, sy, gz), 0));
214
mat2x4 v1 = mat2x4(texelFetch(square_blob, ivec3((sx + 1) * 2, sy, gz), 0), texelFetch(square_blob, ivec3((sx + 1) * 2 + 1, sy, gz), 0));
223
mat2x4 v0 = mat2x4(texelFetch(square_blob, ivec3(sx * 2, sy, gz), 0), texelFetch(square_blob, ivec3(sx * 2 + 1, sy, gz), 0));
224
mat2x4 v2 = mat2x4(texelFetch(square_blob, ivec3(sx * 2, sy + 1, gz), 0), texelFetch(square_blob, ivec3(sx * 2 + 1, sy + 1, gz), 0));
230
mat2x4 v0 = mat2x4(texelFetch(square_blob, ivec3(sx * 2, sy, gz), 0), texelFetch(square_blob, ivec3(sx * 2 + 1, sy, gz), 0));
231
mat2x4 v1 = mat2x4(texelFetch(square_blob, ivec3((sx + 1) * 2, sy, gz), 0), texelFetch(square_blob, ivec3((sx + 1) * 2 + 1, sy, gz), 0));
232
mat2x4 v2 = mat2x4(texelFetch(square_blob, ivec3(sx * 2, sy + 1, gz), 0), texelFetch(square_blob, ivec3(sx * 2 + 1, sy + 1, gz), 0));
233
mat2x4 v3 = mat2x4(texelFetch(square_blob, ivec3((sx + 1) * 2, sy + 1, gz), 0), texelFetch(square_blob, ivec3((sx + 1) * 2 + 1, sy + 1, gz), 0));
235
sqsum = v0 + v1 + v2 + v3;
242
ivec4 v_offset = sz * p.cstep + sx + ivec4(0, 1, 2, 3);
246
mat2x4 v0 = square_blob_data[v_offset.r];
250
else if (sx == p.w - 2)
252
mat2x4 v0 = square_blob_data[v_offset.r];
253
mat2x4 v1 = square_blob_data[v_offset.g];
257
else if (sx == p.w - 3)
259
mat2x4 v0 = square_blob_data[v_offset.r];
260
mat2x4 v1 = square_blob_data[v_offset.g];
261
mat2x4 v2 = square_blob_data[v_offset.b];
263
sqsum = v0 + v1 + v2;
267
mat2x4 v0 = square_blob_data[v_offset.r];
268
mat2x4 v1 = square_blob_data[v_offset.g];
269
mat2x4 v2 = square_blob_data[v_offset.b];
270
mat2x4 v3 = square_blob_data[v_offset.a];
272
sqsum = v0 + v1 + v2 + v3;
277
if (across_spatial == 0 && across_channel == 1)
284
mat2x4 v0 = mat2x4(texelFetch(square_blob, ivec3(gx * 2, gy, sz), 0), texelFetch(square_blob, ivec3(gx * 2 + 1, gy, sz), 0));
288
else if (sz == p.c - 2)
290
mat2x4 v0 = mat2x4(texelFetch(square_blob, ivec3(gx * 2, gy, sz), 0), texelFetch(square_blob, ivec3(gx * 2 + 1, gy, sz), 0));
291
mat2x4 v1 = mat2x4(texelFetch(square_blob, ivec3(gx * 2, gy, sz + 1), 0), texelFetch(square_blob, ivec3(gx * 2 + 1, gy, sz + 1), 0));
295
else if (sz == p.c - 3)
297
mat2x4 v0 = mat2x4(texelFetch(square_blob, ivec3(gx * 2, gy, sz), 0), texelFetch(square_blob, ivec3(gx * 2 + 1, gy, sz), 0));
298
mat2x4 v1 = mat2x4(texelFetch(square_blob, ivec3(gx * 2, gy, sz + 1), 0), texelFetch(square_blob, ivec3(gx * 2 + 1, gy, sz + 1), 0));
299
mat2x4 v2 = mat2x4(texelFetch(square_blob, ivec3(gx * 2, gy, sz + 2), 0), texelFetch(square_blob, ivec3(gx * 2 + 1, gy, sz + 2), 0));
301
sqsum = v0 + v1 + v2;
305
mat2x4 v0 = mat2x4(texelFetch(square_blob, ivec3(gx * 2, gy, sz), 0), texelFetch(square_blob, ivec3(gx * 2 + 1, gy, sz), 0));
306
mat2x4 v1 = mat2x4(texelFetch(square_blob, ivec3(gx * 2, gy, sz + 1), 0), texelFetch(square_blob, ivec3(gx * 2 + 1, gy, sz + 1), 0));
307
mat2x4 v2 = mat2x4(texelFetch(square_blob, ivec3(gx * 2, gy, sz + 2), 0), texelFetch(square_blob, ivec3(gx * 2 + 1, gy, sz + 2), 0));
308
mat2x4 v3 = mat2x4(texelFetch(square_blob, ivec3(gx * 2, gy, sz + 3), 0), texelFetch(square_blob, ivec3(gx * 2 + 1, gy, sz + 3), 0));
310
sqsum = v0 + v1 + v2 + v3;
316
ivec4 v_offset = (sz + ivec4(0, 1, 2, 3)) * p.cstep + sx;
320
mat2x4 v0 = square_blob_data[v_offset.r];
324
else if (sz == p.c - 2)
326
mat2x4 v0 = square_blob_data[v_offset.r];
327
mat2x4 v1 = square_blob_data[v_offset.g];
331
else if (sz == p.c - 3)
333
mat2x4 v0 = square_blob_data[v_offset.r];
334
mat2x4 v1 = square_blob_data[v_offset.g];
335
mat2x4 v2 = square_blob_data[v_offset.b];
337
sqsum = v0 + v1 + v2;
341
mat2x4 v0 = square_blob_data[v_offset.r];
342
mat2x4 v1 = square_blob_data[v_offset.g];
343
mat2x4 v2 = square_blob_data[v_offset.b];
344
mat2x4 v3 = square_blob_data[v_offset.a];
346
sqsum = v0 + v1 + v2 + v3;
352
imageStore(sqsum_blob, ivec3(gx * 2, gy, gz), sqsum[0]);
353
imageStore(sqsum_blob, ivec3(gx * 2 + 1, gy, gz), sqsum[1]);
355
int gi = gz * p.outcstep + gx;
357
sqsum_blob_data[gi] = sqsum;