1
// Tencent is pleased to support the open source community by making ncnn available.
3
// Copyright (C) 2020 THL A29 Limited, a Tencent company. All rights reserved.
5
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
6
// in compliance with the License. You may obtain a copy of the License at
8
// https://opensource.org/licenses/BSD-3-Clause
10
// Unless required by applicable law or agreed to in writing, software distributed
11
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
12
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
13
// specific language governing permissions and limitations under the License.
18
#extension GL_EXT_shader_16bit_storage: require
20
#if NCNN_fp16_arithmetic
21
#extension GL_EXT_shader_explicit_arithmetic_types_float16: require
24
layout (constant_id = 0) const int type = 1;
25
layout (constant_id = 1) const float value = 0;
26
layout (constant_id = 2) const int per_channel_pad = 0;
28
#define shape_constant_id_offset 3
29
layout (constant_id = shape_constant_id_offset + 0) const int dims = 0;
30
layout (constant_id = shape_constant_id_offset + 1) const int w = 0;
31
layout (constant_id = shape_constant_id_offset + 2) const int h = 0;
32
layout (constant_id = shape_constant_id_offset + 3) const int c = 0;
33
layout (constant_id = shape_constant_id_offset + 4) const int cstep = 0;
35
layout (constant_id = shape_constant_id_offset + 5) const int outdims = 0;
36
layout (constant_id = shape_constant_id_offset + 6) const int outw = 0;
37
layout (constant_id = shape_constant_id_offset + 7) const int outh = 0;
38
layout (constant_id = shape_constant_id_offset + 8) const int outc = 0;
39
layout (constant_id = shape_constant_id_offset + 9) const int outcstep = 0;
42
layout (binding = 0) uniform unfp sampler3D bottom_blob;
43
layout (binding = 1, imfmtc4) writeonly uniform unfp image3D top_blob;
44
layout (binding = 2) uniform unfp sampler3D per_channel_pad_blob;
46
layout (binding = 0) readonly buffer bottom_blob { sfpvec4 bottom_blob_data[]; };
47
layout (binding = 1) writeonly buffer top_blob { sfpvec4 top_blob_data[]; };
48
layout (binding = 2) readonly buffer per_channel_pad_blob { sfpvec4 per_channel_pad_blob_data[]; };
51
layout (push_constant) uniform parameter
72
int gx = int(gl_GlobalInvocationID.x);
73
int gy = int(gl_GlobalInvocationID.y);
74
int gz = int(gl_GlobalInvocationID.z);
76
if (gx >= psc(outw) || gy >= psc(outh) || gz >= psc(outc))
83
int x = gx - p.left / 4;
87
// nvidia driver crash when using load and store pair :(
88
// copy is the workaround --- nihui
90
if (x >= 0 && x < psc(w) * 2)
92
image3d_cp4(top_blob, ivec3(gx, 0, 0), bottom_blob, ivec3(x, 0, 0));
97
image3d_st4(top_blob, ivec3(gx, 0, 0), v);
100
if (x >= 0 && x < psc(w) * 2)
102
buffer_cp4(top_blob_data, gx, bottom_blob_data, x);
106
// nvidia driver is unhappy if we do not touch the v variable here :<
108
buffer_st4(top_blob_data, gx, v);
109
// buffer_st4(top_blob_data, gx, afpvec4(value));
118
v = afpvec4(image3d_ld4(bottom_blob, ivec3(0, 0, 0)).r);
120
else if (x >= psc(w) * 2)
122
v = afpvec4(image3d_ld4(bottom_blob, ivec3(psc(w) * 2 - 1, 0, 0)).a);
126
v = image3d_ld4(bottom_blob, ivec3((x / 2) * 2 + x % 2, 0, 0));
129
image3d_st4(top_blob, ivec3(gx, 0, 0), v);
133
v = afpvec4(buffer_ld4(bottom_blob_data, 0).r);
135
else if (x >= psc(w) * 2)
137
v = afpvec4(buffer_ld4(bottom_blob_data, psc(w) * 2 - 1).a);
141
v = buffer_ld4(bottom_blob_data, (x / 2) * 2 + x % 2);
144
buffer_st4(top_blob_data, gx, v);
152
ivec2 x01 = -x + ivec2(1, 0);
153
afpvec4 v0 = image3d_ld4(bottom_blob, ivec3((x01.x / 2) * 2 + x01.x % 2, 0, 0));
154
afpvec4 v1 = image3d_ld4(bottom_blob, ivec3((x01.y / 2) * 2 + x01.y % 2, 0, 0));
155
v = afpvec4(v1.r, v0.a, v0.b, v0.g);
157
else if (x >= psc(w) * 2)
159
ivec2 x01 = psc(w) * 2 - x + psc(w) * 2 - 1 - ivec2(1, 0);
160
afpvec4 v0 = image3d_ld4(bottom_blob, ivec3((x01.x / 2) * 2 + x01.x % 2, 0, 0));
161
afpvec4 v1 = image3d_ld4(bottom_blob, ivec3((x01.y / 2) * 2 + x01.y % 2, 0, 0));
162
v = afpvec4(v1.b, v1.g, v1.r, v0.a);
166
v = image3d_ld4(bottom_blob, ivec3((x / 2) * 2 + x % 2, 0, 0));
169
image3d_st4(top_blob, ivec3(gx, 0, 0), v);
173
ivec2 x01 = -x + ivec2(1, 0);
174
afpvec4 v0 = buffer_ld4(bottom_blob_data, (x01.x / 2) * 2 + x01.x % 2);
175
afpvec4 v1 = buffer_ld4(bottom_blob_data, (x01.y / 2) * 2 + x01.y % 2);
176
v = afpvec4(v1.r, v0.a, v0.b, v0.g);
178
else if (x >= psc(w) * 2)
180
ivec2 x01 = psc(w) * 2 - x + psc(w) * 2 - 1 - ivec2(1, 0);
181
afpvec4 v0 = buffer_ld4(bottom_blob_data, (x01.x / 2) * 2 + x01.x % 2);
182
afpvec4 v1 = buffer_ld4(bottom_blob_data, (x01.y / 2) * 2 + x01.y % 2);
183
v = afpvec4(v1.b, v1.g, v1.r, v0.a);
187
v = buffer_ld4(bottom_blob_data, (x / 2) * 2 + x % 2);
190
buffer_st4(top_blob_data, gx, v);
194
else if (psc(dims) == 2)
197
int y = gy - p.top / 4;
202
if (x >= 0 && x < psc(w) && y >= 0 && y < psc(h) * 2)
204
v = image3d_ld4(bottom_blob, ivec3(x * 2 + y % 2, y / 2, 0));
211
image3d_st4(top_blob, ivec3(gx, gy, 0), v);
213
if (x >= 0 && x < psc(w) && y >= 0 && y < psc(h) * 2)
215
v = buffer_ld4(bottom_blob_data, (y / 2) * psc(w) * 2 + x * 2 + y % 2);
222
const int gi = gy * psc(outw) + gx;
224
buffer_st4(top_blob_data, gi, v);
229
x = clamp(x, 0, psc(w) * 2 - 1);
234
v = afpvec4(image3d_ld4(bottom_blob, ivec3(x * 2, 0, 0)).r);
236
else if (y >= psc(h) * 2)
238
v = afpvec4(image3d_ld4(bottom_blob, ivec3(x * 2 + 1, psc(h) * 2 - 1, 0)).a);
242
v = image3d_ld4(bottom_blob, ivec3(x * 2 + y % 2, y / 2, 0));
245
image3d_st4(top_blob, ivec3(gx, gy, 0), v);
249
v = afpvec4(buffer_ld4(bottom_blob_data, x * 2).r);
251
else if (y >= psc(h) * 2)
253
v = afpvec4(buffer_ld4(bottom_blob_data, (psc(h) * 2 - 1) * psc(w) * 2 + x * 2 + 1).a);
257
v = buffer_ld4(bottom_blob_data, (y / 2) * psc(w) * 2 + x * 2 + y % 2);
260
const int gi = gy * psc(outw) + gx;
262
buffer_st4(top_blob_data, gi, v);
268
// NOTE psc(X) get zeros on nvidia
269
// TODO only enable this workaround for some nvidia driver
270
x = (p.w * 2 - 1) - abs(x - (p.w * 2 - 1));
271
// x = (psc(w) * 2 - 1) - abs(x - (psc(w) * 2 - 1));
276
ivec2 y01 = -y + ivec2(1, 0);
277
afpvec4 v0 = image3d_ld4(bottom_blob, ivec3(x * 2 + y01.x % 2, y01.x / 2, 0));
278
afpvec4 v1 = image3d_ld4(bottom_blob, ivec3(x * 2 + y01.y % 2, y01.y / 2, 0));
279
v = afpvec4(v1.r, v0.a, v0.b, v0.g);
281
else if (y >= psc(h) * 2)
283
ivec2 y01 = psc(h) * 2 - y + psc(h) * 2 - 1 - ivec2(1, 0);
284
afpvec4 v0 = image3d_ld4(bottom_blob, ivec3(x * 2 + y01.x % 2, y01.x / 2, 0));
285
afpvec4 v1 = image3d_ld4(bottom_blob, ivec3(x * 2 + y01.y % 2, y01.y / 2, 0));
286
v = afpvec4(v1.b, v1.g, v1.r, v0.a);
290
v = image3d_ld4(bottom_blob, ivec3(x * 2 + y % 2, y / 2, 0));
293
image3d_st4(top_blob, ivec3(gx, gy, 0), v);
297
ivec2 y01 = -y + ivec2(1, 0);
298
afpvec4 v0 = buffer_ld4(bottom_blob_data, (y01.x / 2) * psc(w) * 2 + x * 2 + y01.x % 2);
299
afpvec4 v1 = buffer_ld4(bottom_blob_data, (y01.y / 2) * psc(w) * 2 + x * 2 + y01.y % 2);
300
v = afpvec4(v1.r, v0.a, v0.b, v0.g);
302
else if (y >= psc(h) * 2)
304
ivec2 y01 = psc(h) * 2 - y + psc(h) * 2 - 1 - ivec2(1, 0);
305
afpvec4 v0 = buffer_ld4(bottom_blob_data, (y01.x / 2) * psc(w) * 2 + x * 2 + y01.x % 2);
306
afpvec4 v1 = buffer_ld4(bottom_blob_data, (y01.y / 2) * psc(w) * 2 + x * 2 + y01.y % 2);
307
v = afpvec4(v1.b, v1.g, v1.r, v0.a);
311
v = buffer_ld4(bottom_blob_data, (y / 2) * psc(w) + x * 2 + y % 2);
314
const int gi = gy * psc(outw) + gx;
316
buffer_st4(top_blob_data, gi, v);
320
else // if (psc(dims) == 3)
324
int z = gz - p.front / 4;
329
if (x >= 0 && x < psc(w) && y >= 0 && y < psc(h) && z >= 0 && z < psc(c) * 2)
331
v = image3d_ld4(bottom_blob, ivec3(x * 2 + z % 2, y, z / 2));
335
v = per_channel_pad == 1 ? image3d_ld4(per_channel_pad_blob, ivec3(gz, 0, 0)) : afpvec4(value);
338
image3d_st4(top_blob, ivec3(gx, gy, gz), v);
340
if (x >= 0 && x < psc(w) && y >= 0 && y < psc(h) && z >= 0 && z < psc(c) * 2)
342
v = buffer_ld4(bottom_blob_data, ((z / 2) * psc(cstep) + y * psc(w)) * 2 + x * 2 + z % 2);
346
v = per_channel_pad == 1 ? buffer_ld4(per_channel_pad_blob_data, gz) : afpvec4(value);
349
const int gi = gz * psc(outcstep) + gy * psc(outw) + gx;
351
buffer_st4(top_blob_data, gi, v);
356
x = clamp(x, 0, psc(w) * 2 - 1);
357
y = clamp(y, 0, psc(h) * 2 - 1);
362
v = afpvec4(image3d_ld4(bottom_blob, ivec3(x * 2, y, 0)).r);
364
else if (z >= psc(c) * 2)
366
v = afpvec4(image3d_ld4(bottom_blob, ivec3(x * 2 + 1, y, psc(c) * 2 - 1)).a);
370
v = image3d_ld4(bottom_blob, ivec3(x * 2 + z % 2, y, z / 2));
373
image3d_st4(top_blob, ivec3(gx, gy, gz), v);
377
v = afpvec4(buffer_ld4(bottom_blob_data, y * psc(w) * 2 + x * 2).r);
379
else if (z >= psc(c) * 2)
381
v = afpvec4(buffer_ld4(bottom_blob_data, ((psc(c) * 2 - 1) * psc(cstep) + y * psc(w)) * 2 + x * 2 + 1).a);
385
v = buffer_ld4(bottom_blob_data, ((z / 2) * psc(cstep) + y * psc(w)) * 2 + x * 2 + y % 2);
388
const int gi = gz * psc(outcstep) + gy * psc(outw) + gx;
390
buffer_st4(top_blob_data, gi, v);
397
// NOTE psc(X) get zeros on nvidia
398
// TODO only enable this workaround for some nvidia driver
399
x = (p.w * 2 - 1) - abs(x - (p.w * 2 - 1));
400
y = (p.h * 2 - 1) - abs(y - (p.h * 2 - 1));
401
// x = (psc(w) * 2 - 1) - abs(x - (psc(w) * 2 - 1));
402
// y = (psc(h) * 2 - 1) - abs(y - (psc(h) * 2 - 1));
407
ivec2 z01 = -z + ivec2(1, 0);
408
afpvec4 v0 = image3d_ld4(bottom_blob, ivec3(x * 2 + z01.x % 2, y, z01.x / 2));
409
afpvec4 v1 = image3d_ld4(bottom_blob, ivec3(x * 2 + z01.y % 2, y, z01.y / 2));
410
v = afpvec4(v1.r, v0.a, v0.b, v0.g);
412
else if (z >= psc(c) * 2)
414
ivec2 z01 = psc(c) * 2 - z + psc(c) * 2 - 1 - ivec2(1, 0);
415
afpvec4 v0 = image3d_ld4(bottom_blob, ivec3(x * 2 + z01.x % 2, y, z01.x / 2));
416
afpvec4 v1 = image3d_ld4(bottom_blob, ivec3(x * 2 + z01.y % 2, y, z01.y / 2));
417
v = afpvec4(v1.b, v1.g, v1.r, v0.a);
421
v = image3d_ld4(bottom_blob, ivec3(x * 2 + z % 2, y, z / 2));
424
image3d_st4(top_blob, ivec3(gx, gy, gz), v);
428
ivec2 z01 = -y + ivec2(1, 0);
429
afpvec4 v0 = buffer_ld4(bottom_blob_data, ((z01.x / 2) * psc(cstep) + y * psc(w)) * 2 + x * 2 + z01.x % 2);
430
afpvec4 v1 = buffer_ld4(bottom_blob_data, ((z01.y / 2) * psc(cstep) + y * psc(w)) * 2 + x * 2 + z01.y % 2);
431
v = afpvec4(v1.r, v0.a, v0.b, v0.g);
433
else if (z >= psc(c) * 2)
435
ivec2 z01 = psc(c) * 2 - z + psc(c) * 2 - 1 - ivec2(1, 0);
436
afpvec4 v0 = buffer_ld4(bottom_blob_data, ((z01.x / 2) * psc(cstep) + y * psc(w)) * 2 + x * 2 + z01.x % 2);
437
afpvec4 v1 = buffer_ld4(bottom_blob_data, ((z01.y / 2) * psc(cstep) + y * psc(w)) * 2 + x * 2 + z01.y % 2);
438
v = afpvec4(v1.b, v1.g, v1.r, v0.a);
442
v = buffer_ld4(bottom_blob_data, ((z / 2) * psc(cstep) + y * psc(w)) * 2 + x * 2 + y % 2);
445
const int gi = gz * psc(outcstep) + gy * psc(outw) + gx;
447
buffer_st4(top_blob_data, gi, v);