1
// Tencent is pleased to support the open source community by making ncnn available.
3
// Copyright (C) 2020 THL A29 Limited, a Tencent company. All rights reserved.
5
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
6
// in compliance with the License. You may obtain a copy of the License at
8
// https://opensource.org/licenses/BSD-3-Clause
10
// Unless required by applicable law or agreed to in writing, software distributed
11
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
12
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
13
// specific language governing permissions and limitations under the License.
18
#extension GL_EXT_shader_16bit_storage: require
19
struct sfpvec8 { f16vec4 abcd; f16vec4 efgh; };
21
#if NCNN_fp16_arithmetic
22
#extension GL_EXT_shader_explicit_arithmetic_types_float16: require
25
layout (constant_id = 0) const int type = 1;
26
layout (constant_id = 1) const float value = 0;
27
layout (constant_id = 2) const int per_channel_pad = 0;
29
#define shape_constant_id_offset 3
30
layout (constant_id = shape_constant_id_offset + 0) const int dims = 0;
31
layout (constant_id = shape_constant_id_offset + 1) const int w = 0;
32
layout (constant_id = shape_constant_id_offset + 2) const int h = 0;
33
layout (constant_id = shape_constant_id_offset + 3) const int c = 0;
34
layout (constant_id = shape_constant_id_offset + 4) const int cstep = 0;
36
layout (constant_id = shape_constant_id_offset + 5) const int outdims = 0;
37
layout (constant_id = shape_constant_id_offset + 6) const int outw = 0;
38
layout (constant_id = shape_constant_id_offset + 7) const int outh = 0;
39
layout (constant_id = shape_constant_id_offset + 8) const int outc = 0;
40
layout (constant_id = shape_constant_id_offset + 9) const int outcstep = 0;
43
layout (binding = 0) uniform unfp sampler3D bottom_blob;
44
layout (binding = 1, imfmtc4) writeonly uniform unfp image3D top_blob;
45
layout (binding = 2) uniform unfp sampler3D per_channel_pad_blob;
47
layout (binding = 0) readonly buffer bottom_blob { sfpvec4 bottom_blob_data[]; };
48
layout (binding = 1) writeonly buffer top_blob { sfpvec8 top_blob_data[]; };
49
layout (binding = 2) readonly buffer per_channel_pad_blob { sfpvec4 per_channel_pad_blob_data[]; };
52
layout (push_constant) uniform parameter
73
int gx = int(gl_GlobalInvocationID.x);
74
int gy = int(gl_GlobalInvocationID.y);
75
int gz = int(gl_GlobalInvocationID.z);
77
if (gx >= psc(outw) || gy >= psc(outh) || gz >= psc(outc))
84
ivec2 x2 = gx * 2 - p.left / 4 + ivec2(0, 1);
89
if (x2.x >= 0 && x2.x < psc(w))
91
v[0] = image3d_ld4(bottom_blob, ivec3(x2.x, 0, 0));
95
v[0] = afpvec4(value);
97
if (x2.y >= 0 && x2.y < psc(w))
99
v[1] = image3d_ld4(bottom_blob, ivec3(x2.y, 0, 0));
103
v[1] = afpvec4(value);
106
image3d_st8(top_blob, ivec3(gx, 0, 0), v);
108
if (x2.x >= 0 && x2.x < psc(w))
110
v[0] = buffer_ld4(bottom_blob_data, x2.x);
114
v[0] = afpvec4(value);
116
if (x2.y >= 0 && x2.y < psc(w))
118
v[1] = buffer_ld4(bottom_blob_data, x2.y);
122
v[1] = afpvec4(value);
125
buffer_st8(top_blob_data, gx, v);
133
v[0] = afpvec4(image3d_ld4(bottom_blob, ivec3(0, 0, 0)).r);
135
else if (x2.x >= psc(w))
137
v[0] = afpvec4(image3d_ld4(bottom_blob, ivec3(psc(w) - 1, 0, 0)).a);
141
v[0] = image3d_ld4(bottom_blob, ivec3(x2.x, 0, 0));
145
v[1] = afpvec4(image3d_ld4(bottom_blob, ivec3(0, 0, 0)).r);
147
else if (x2.y >= psc(w))
149
v[1] = afpvec4(image3d_ld4(bottom_blob, ivec3(psc(w) - 1, 0, 0)).a);
153
v[1] = image3d_ld4(bottom_blob, ivec3(x2.y, 0, 0));
156
image3d_st8(top_blob, ivec3(gx, 0, 0), v);
160
v[0] = afpvec4(buffer_ld4(bottom_blob_data, 0).r);
162
else if (x2.x >= psc(w))
164
v[0] = afpvec4(buffer_ld4(bottom_blob_data, psc(w) - 1).a);
168
v[0] = buffer_ld4(bottom_blob_data, x2.x);
172
v[1] = afpvec4(buffer_ld4(bottom_blob_data, 0).r);
174
else if (x2.y >= psc(w))
176
v[1] = afpvec4(buffer_ld4(bottom_blob_data, psc(w) - 1).a);
180
v[1] = buffer_ld4(bottom_blob_data, x2.y);
183
buffer_st8(top_blob_data, gx, v);
191
afpvec4 v0 = image3d_ld4(bottom_blob, ivec3(-x2.x + 1, 0, 0));
192
afpvec4 v1 = image3d_ld4(bottom_blob, ivec3(-x2.x, 0, 0));
193
v[0] = afpvec4(v1.r, v0.a, v0.b, v0.g);
195
else if (x2.x >= psc(w))
197
afpvec4 v0 = image3d_ld4(bottom_blob, ivec3(psc(w) - x2.x + psc(w) - 2, 0, 0));
198
afpvec4 v1 = image3d_ld4(bottom_blob, ivec3(psc(w) - x2.x + psc(w) - 1, 0, 0));
199
v[0] = afpvec4(v1.b, v1.g, v1.r, v0.a);
203
v[0] = image3d_ld4(bottom_blob, ivec3(x2.x, 0, 0));
207
afpvec4 v0 = image3d_ld4(bottom_blob, ivec3(-x2.y + 1, 0, 0));
208
afpvec4 v1 = image3d_ld4(bottom_blob, ivec3(-x2.y, 0, 0));
209
v[1] = afpvec4(v1.r, v0.a, v0.b, v0.g);
211
else if (x2.y >= psc(w))
213
afpvec4 v0 = image3d_ld4(bottom_blob, ivec3(psc(w) - x2.y + psc(w) - 2, 0, 0));
214
afpvec4 v1 = image3d_ld4(bottom_blob, ivec3(psc(w) - x2.y + psc(w) - 1, 0, 0));
215
v[1] = afpvec4(v1.b, v1.g, v1.r, v0.a);
219
v[1] = image3d_ld4(bottom_blob, ivec3(x2.y, 0, 0));
222
image3d_st8(top_blob, ivec3(gx, 0, 0), v);
226
afpvec4 v0 = buffer_ld4(bottom_blob_data, -x2.x + 1);
227
afpvec4 v1 = buffer_ld4(bottom_blob_data, -x2.x);
228
v[0] = afpvec4(v1.r, v0.a, v0.b, v0.g);
230
else if (x2.x >= psc(w))
232
afpvec4 v0 = buffer_ld4(bottom_blob_data, psc(w) - x2.x + psc(w) - 2);
233
afpvec4 v1 = buffer_ld4(bottom_blob_data, psc(w) - x2.x + psc(w) - 1);
234
v[0] = afpvec4(v1.b, v1.g, v1.r, v0.a);
238
v[0] = buffer_ld4(bottom_blob_data, x2.x);
242
afpvec4 v0 = buffer_ld4(bottom_blob_data, -x2.y + 1);
243
afpvec4 v1 = buffer_ld4(bottom_blob_data, -x2.y);
244
v[1] = afpvec4(v1.r, v0.a, v0.b, v0.g);
246
else if (x2.y >= psc(w))
248
afpvec4 v0 = buffer_ld4(bottom_blob_data, psc(w) - x2.y + psc(w) - 2);
249
afpvec4 v1 = buffer_ld4(bottom_blob_data, psc(w) - x2.y + psc(w) - 1);
250
v[1] = afpvec4(v1.b, v1.g, v1.r, v0.a);
254
v[1] = buffer_ld4(bottom_blob_data, x2.y);
257
buffer_st8(top_blob_data, gx, v);
261
else if (psc(dims) == 2)
264
ivec2 y2 = gy * 2 - p.top / 4 + ivec2(0, 1);
269
if (x >= 0 && x < psc(w) && y2.x >= 0 && y2.x < psc(h))
271
v[0] = image3d_ld4(bottom_blob, ivec3(x, y2.x, 0));
275
v[0] = afpvec4(value);
277
if (x >= 0 && x < psc(w) && y2.y >= 0 && y2.y < psc(h))
279
v[1] = image3d_ld4(bottom_blob, ivec3(x, y2.y, 0));
283
v[1] = afpvec4(value);
286
image3d_st8(top_blob, ivec3(gx, gy, 0), v);
288
ivec2 v_offset = y2 * psc(w) + x;
290
if (x >= 0 && x < psc(w) && y2.x >= 0 && y2.x < psc(h))
292
v[0] = buffer_ld4(bottom_blob_data, v_offset.x);
296
v[0] = afpvec4(value);
298
if (x >= 0 && x < psc(w) && y2.y >= 0 && y2.y < psc(h))
300
v[1] = buffer_ld4(bottom_blob_data, v_offset.y);
304
v[1] = afpvec4(value);
307
const int gi = gy * psc(outw) + gx;
309
buffer_st8(top_blob_data, gi, v);
314
x = clamp(x, 0, psc(w) - 1);
319
v[0] = afpvec4(image3d_ld4(bottom_blob, ivec3(x, 0, 0)).r);
321
else if (y2.x >= psc(h))
323
v[0] = afpvec4(image3d_ld4(bottom_blob, ivec3(x, psc(h) - 1, 0)).a);
327
v[0] = image3d_ld4(bottom_blob, ivec3(x, y2.x, 0));
331
v[1] = afpvec4(image3d_ld4(bottom_blob, ivec3(x, 0, 0)).r);
333
else if (y2.y >= psc(h))
335
v[1] = afpvec4(image3d_ld4(bottom_blob, ivec3(x, psc(h) - 1, 0)).a);
339
v[1] = image3d_ld4(bottom_blob, ivec3(x, y2.y, 0));
342
image3d_st8(top_blob, ivec3(gx, gy, 0), v);
346
v[0] = afpvec4(buffer_ld4(bottom_blob_data, x).r);
348
else if (y2.x >= psc(h))
350
v[0] = afpvec4(buffer_ld4(bottom_blob_data, (psc(h) - 1) * psc(w) + x).a);
354
v[0] = buffer_ld4(bottom_blob_data, y2.x * psc(w) + x);
358
v[1] = afpvec4(buffer_ld4(bottom_blob_data, x).r);
360
else if (y2.y >= psc(h))
362
v[1] = afpvec4(buffer_ld4(bottom_blob_data, (psc(h) - 1) * psc(w) + x).a);
366
v[1] = buffer_ld4(bottom_blob_data, y2.y * psc(w) + x);
369
const int gi = gy * psc(outw) + gx;
371
buffer_st8(top_blob_data, gi, v);
377
// NOTE psc(X) get zeros on nvidia
378
// TODO only enable this workaround for some nvidia driver
379
x = (p.w - 1) - abs(x - (p.w - 1));
380
// x = (psc(w) - 1) - abs(x - (psc(w) - 1));
385
afpvec4 v0 = image3d_ld4(bottom_blob, ivec3(x, -y2.x + 1, 0));
386
afpvec4 v1 = image3d_ld4(bottom_blob, ivec3(x, -y2.x, 0));
387
v[0] = afpvec4(v1.r, v0.a, v0.b, v0.g);
389
else if (y2.x >= psc(h))
391
afpvec4 v0 = image3d_ld4(bottom_blob, ivec3(x, psc(h) - y2.x + psc(h) - 2, 0));
392
afpvec4 v1 = image3d_ld4(bottom_blob, ivec3(x, psc(h) - y2.x + psc(h) - 1, 0));
393
v[0] = afpvec4(v1.b, v1.g, v1.r, v0.a);
397
v[0] = image3d_ld4(bottom_blob, ivec3(x, y2.x, 0));
401
afpvec4 v0 = image3d_ld4(bottom_blob, ivec3(x, -y2.y + 1, 0));
402
afpvec4 v1 = image3d_ld4(bottom_blob, ivec3(x, -y2.y, 0));
403
v[1] = afpvec4(v1.r, v0.a, v0.b, v0.g);
405
else if (y2.y >= psc(h))
407
afpvec4 v0 = image3d_ld4(bottom_blob, ivec3(x, psc(h) - y2.y + psc(h) - 2, 0));
408
afpvec4 v1 = image3d_ld4(bottom_blob, ivec3(x, psc(h) - y2.y + psc(h) - 1, 0));
409
v[1] = afpvec4(v1.b, v1.g, v1.r, v0.a);
413
v[1] = image3d_ld4(bottom_blob, ivec3(x, y2.y, 0));
416
image3d_st8(top_blob, ivec3(gx, gy, 0), v);
420
afpvec4 v0 = buffer_ld4(bottom_blob_data, (-y2.x + 1) * psc(w) + x);
421
afpvec4 v1 = buffer_ld4(bottom_blob_data, (-y2.x) * psc(w) + x);
422
v[0] = afpvec4(v1.r, v0.a, v0.b, v0.g);
424
else if (y2.x >= psc(h))
426
afpvec4 v0 = buffer_ld4(bottom_blob_data, (psc(h) - y2.x + psc(h) - 2) * psc(w) + x);
427
afpvec4 v1 = buffer_ld4(bottom_blob_data, (psc(h) - y2.x + psc(h) - 1) * psc(w) + x);
428
v[0] = afpvec4(v1.b, v1.g, v1.r, v0.a);
432
v[0] = buffer_ld4(bottom_blob_data, y2.x * psc(w) + x);
436
afpvec4 v0 = buffer_ld4(bottom_blob_data, (-y2.y + 1) * psc(w) + x);
437
afpvec4 v1 = buffer_ld4(bottom_blob_data, (-y2.y) * psc(w) + x);
438
v[1] = afpvec4(v1.r, v0.a, v0.b, v0.g);
440
else if (y2.y >= psc(h))
442
afpvec4 v0 = buffer_ld4(bottom_blob_data, (psc(h) - y2.y + psc(h) - 2) * psc(w) + x);
443
afpvec4 v1 = buffer_ld4(bottom_blob_data, (psc(h) - y2.y + psc(h) - 1) * psc(w) + x);
444
v[1] = afpvec4(v1.b, v1.g, v1.r, v0.a);
448
v[1] = buffer_ld4(bottom_blob_data, y2.y * psc(w) + x);
451
const int gi = gy * psc(outw) + gx;
453
buffer_st8(top_blob_data, gi, v);
457
else // if (psc(dims) == 3)
461
ivec2 z2 = gz * 2 - p.front / 4 + ivec2(0, 1);
466
if (x >= 0 && x < psc(w) && y >= 0 && y < psc(h) && z2.x >= 0 && z2.x < psc(c))
468
v[0] = image3d_ld4(bottom_blob, ivec3(x, y, z2.x));
472
v[0] = per_channel_pad == 1 ? image3d_ld4(per_channel_pad_blob, ivec3(gz * 2, 0, 0)) : afpvec4(value);
474
if (x >= 0 && x < psc(w) && y >= 0 && y < psc(h) && z2.y >= 0 && z2.y < psc(c))
476
v[1] = image3d_ld4(bottom_blob, ivec3(x, y, z2.y));
480
v[1] = per_channel_pad == 1 ? image3d_ld4(per_channel_pad_blob, ivec3(gz * 2 + 1, 0, 0)) : afpvec4(value);
483
image3d_st8(top_blob, ivec3(gx, gy, gz), v);
485
ivec2 v_offset = z2 * psc(cstep) + y * psc(w) + x;
487
if (x >= 0 && x < psc(w) && y >= 0 && y < psc(h) && z2.x >= 0 && z2.x < psc(c))
489
v[0] = buffer_ld4(bottom_blob_data, v_offset.x);
493
v[0] = per_channel_pad == 1 ? buffer_ld4(per_channel_pad_blob_data, gz * 2) : afpvec4(value);
495
if (x >= 0 && x < psc(w) && y >= 0 && y < psc(h) && z2.y >= 0 && z2.y < psc(c))
497
v[1] = buffer_ld4(bottom_blob_data, v_offset.y);
501
v[1] = per_channel_pad == 1 ? buffer_ld4(per_channel_pad_blob_data, gz * 2 + 1) : afpvec4(value);
504
const int gi = gz * psc(outcstep) + gy * psc(outw) + gx;
506
buffer_st8(top_blob_data, gi, v);
511
x = clamp(x, 0, psc(w) - 1);
512
y = clamp(y, 0, psc(h) - 1);
517
v[0] = afpvec4(image3d_ld4(bottom_blob, ivec3(x, y, 0)).r);
519
else if (z2.x >= psc(c))
521
v[0] = afpvec4(image3d_ld4(bottom_blob, ivec3(x, y, psc(c) - 1)).a);
525
v[0] = image3d_ld4(bottom_blob, ivec3(x, y, z2.x));
529
v[1] = afpvec4(image3d_ld4(bottom_blob, ivec3(x, y, 0)).r);
531
else if (z2.y >= psc(c))
533
v[1] = afpvec4(image3d_ld4(bottom_blob, ivec3(x, y, psc(c) - 1)).a);
537
v[1] = image3d_ld4(bottom_blob, ivec3(x, y, z2.y));
540
image3d_st8(top_blob, ivec3(gx, gy, gz), v);
544
v[0] = afpvec4(buffer_ld4(bottom_blob_data, y * psc(w) + x).r);
546
else if (z2.x >= psc(c))
548
v[0] = afpvec4(buffer_ld4(bottom_blob_data, (psc(c) - 1) * psc(cstep) + y * psc(w) + x).a);
552
v[0] = buffer_ld4(bottom_blob_data, z2.x * psc(cstep) + y * psc(w) + x);
556
v[1] = afpvec4(buffer_ld4(bottom_blob_data, y * psc(w) + x).r);
558
else if (z2.y >= psc(c))
560
v[1] = afpvec4(buffer_ld4(bottom_blob_data, (psc(c) - 1) * psc(cstep) + y * psc(w) + x).a);
564
v[1] = buffer_ld4(bottom_blob_data, z2.y * psc(cstep) + y * psc(w) + x);
567
const int gi = gz * psc(outcstep) + gy * psc(outw) + gx;
569
buffer_st8(top_blob_data, gi, v);
576
// NOTE psc(X) get zeros on nvidia
577
// TODO only enable this workaround for some nvidia driver
578
x = (p.w - 1) - abs(x - (p.w - 1));
579
y = (p.h - 1) - abs(y - (p.h - 1));
580
// x = (psc(w) - 1) - abs(x - (psc(w) - 1));
581
// y = (psc(h) - 1) - abs(y - (psc(h) - 1));
586
afpvec4 v0 = image3d_ld4(bottom_blob, ivec3(x, y, -z2.x + 1));
587
afpvec4 v1 = image3d_ld4(bottom_blob, ivec3(x, y, -z2.x));
588
v[0] = afpvec4(v1.r, v0.a, v0.b, v0.g);
590
else if (z2.x >= psc(c))
592
afpvec4 v0 = image3d_ld4(bottom_blob, ivec3(x, y, psc(c) - z2.x + psc(c) - 2));
593
afpvec4 v1 = image3d_ld4(bottom_blob, ivec3(x, y, psc(c) - z2.x + psc(c) - 1));
594
v[0] = afpvec4(v1.b, v1.g, v1.r, v0.a);
598
v[0] = image3d_ld4(bottom_blob, ivec3(x, y, z2.x));
602
afpvec4 v0 = image3d_ld4(bottom_blob, ivec3(x, y, -z2.y + 1));
603
afpvec4 v1 = image3d_ld4(bottom_blob, ivec3(x, y, -z2.y));
604
v[1] = afpvec4(v1.r, v0.a, v0.b, v0.g);
606
else if (z2.y >= psc(c))
608
afpvec4 v0 = image3d_ld4(bottom_blob, ivec3(x, y, psc(c) - z2.y + psc(c) - 2));
609
afpvec4 v1 = image3d_ld4(bottom_blob, ivec3(x, y, psc(c) - z2.y + psc(c) - 1));
610
v[1] = afpvec4(v1.b, v1.g, v1.r, v0.a);
614
v[1] = image3d_ld4(bottom_blob, ivec3(x, y, z2.y));
617
image3d_st8(top_blob, ivec3(gx, gy, gz), v);
621
afpvec4 v0 = buffer_ld4(bottom_blob_data, (-z2.x + 1) * psc(cstep) + y * psc(w) + x);
622
afpvec4 v1 = buffer_ld4(bottom_blob_data, (-z2.x) * psc(cstep) + y * psc(w) + x);
623
v[0] = afpvec4(v1.r, v0.a, v0.b, v0.g);
625
else if (z2.x >= psc(c))
627
afpvec4 v0 = buffer_ld4(bottom_blob_data, (psc(c) - z2.x + psc(c) - 2) * psc(cstep) + y * psc(w) + x);
628
afpvec4 v1 = buffer_ld4(bottom_blob_data, (psc(c) - z2.x + psc(c) - 1) * psc(cstep) + y * psc(w) + x);
629
v[0] = afpvec4(v1.b, v1.g, v1.r, v0.a);
633
v[0] = buffer_ld4(bottom_blob_data, z2.x * psc(cstep) + y * psc(w) + x);
637
afpvec4 v0 = buffer_ld4(bottom_blob_data, (-z2.y + 1) * psc(cstep) + y * psc(w) + x);
638
afpvec4 v1 = buffer_ld4(bottom_blob_data, (-z2.y) * psc(cstep) + y * psc(w) + x);
639
v[1] = afpvec4(v1.r, v0.a, v0.b, v0.g);
641
else if (z2.y >= psc(c))
643
afpvec4 v0 = buffer_ld4(bottom_blob_data, (psc(c) - z2.y + psc(c) - 2) * psc(cstep) + y * psc(w) + x);
644
afpvec4 v1 = buffer_ld4(bottom_blob_data, (psc(c) - z2.y + psc(c) - 1) * psc(cstep) + y * psc(w) + x);
645
v[1] = afpvec4(v1.b, v1.g, v1.r, v0.a);
649
v[1] = buffer_ld4(bottom_blob_data, z2.y * psc(cstep) + y * psc(w) + x);
652
const int gi = gz * psc(outcstep) + gy * psc(outw) + gx;
654
buffer_st8(top_blob_data, gi, v);