1
// Tencent is pleased to support the open source community by making ncnn available.
3
// Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved.
5
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
6
// in compliance with the License. You may obtain a copy of the License at
8
// https://opensource.org/licenses/BSD-3-Clause
10
// Unless required by applicable law or agreed to in writing, software distributed
11
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
12
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
13
// specific language governing permissions and limitations under the License.
18
#extension GL_EXT_shader_16bit_storage: require
20
#if NCNN_fp16_arithmetic
21
#extension GL_EXT_shader_explicit_arithmetic_types_float16: require
24
#define LOCAL_MEMORY_UNROLL_INCH 8
26
layout (constant_id = 0) const float alpha = 1.f;
27
layout (constant_id = 1) const float beta = 1.f;
28
layout (constant_id = 2) const int transA = 0;
29
layout (constant_id = 3) const int transB = 0;
30
layout (constant_id = 4) const int constantA = 0;
31
layout (constant_id = 5) const int constantB = 0;
32
layout (constant_id = 6) const int constantC = 0;
33
layout (constant_id = 7) const int M = 0;
34
layout (constant_id = 8) const int N = 0;
35
layout (constant_id = 9) const int K = 0;
36
layout (constant_id = 10) const int constant_broadcast_type_C = 0;
37
layout (constant_id = 11) const int output_N1M = 0;
38
layout (constant_id = 12) const int output_elempack = 0;
39
layout (constant_id = 13) const int output_elemtype = 0;
40
layout (constant_id = 14) const int output_transpose = 0;
45
layout (binding = 0, imfmtc1) writeonly uniform unfp image3D top_blob_3d;
46
layout (binding = 1) uniform unfp sampler3D A_blob_3d;
47
layout (binding = 2) uniform unfp sampler3D B_blob_3d;
48
layout (binding = 3) uniform unfp sampler3D C_blob_3d;
50
layout (binding = 0) writeonly buffer top_blob { sfp top_blob_data[]; };
51
layout (binding = 1) readonly buffer A_blob { sfp A_blob_data[]; };
52
layout (binding = 2) readonly buffer B_blob { sfp B_blob_data[]; };
53
layout (binding = 3) readonly buffer C_blob { sfp C_blob_data[]; };
56
layout (push_constant) uniform parameter
70
#if NCNN_shader_local_memory
71
shared lfp tmp_a[8][LOCAL_MEMORY_UNROLL_INCH][2];
72
shared lfp tmp_b[8][LOCAL_MEMORY_UNROLL_INCH][2];
77
int gx = int(gl_GlobalInvocationID.x) * 2;
78
int gy = int(gl_GlobalInvocationID.y) * 2;
79
int gz = int(gl_GlobalInvocationID.z);
81
#if !NCNN_shader_local_memory
82
if (gx >= psc(N) || gy >= psc(M) || gz >= 1)
91
const int broadcast_type_C = constantC == 1 ? constant_broadcast_type_C : p.broadcast_type_C;
94
if (broadcast_type_C == 0)
96
sum0 = image3d_ld1(C_blob_3d, ivec3(0, 0, 0));
101
if (broadcast_type_C == 1)
103
sum0 = image3d_ld1(C_blob_3d, ivec3(gy, 0, 0));
105
sum2 = image3d_ld1(C_blob_3d, ivec3(gy + 1, 0, 0));
108
if (broadcast_type_C == 2)
110
sum0 = image3d_ld1(C_blob_3d, ivec3(0, gy, 0));
112
sum2 = image3d_ld1(C_blob_3d, ivec3(0, gy + 1, 0));
115
if (broadcast_type_C == 3)
117
sum0 = image3d_ld1(C_blob_3d, ivec3(gx, gy, 0));
118
sum1 = image3d_ld1(C_blob_3d, ivec3(gx + 1, gy, 0));
119
sum2 = image3d_ld1(C_blob_3d, ivec3(gx, gy + 1, 0));
120
sum3 = image3d_ld1(C_blob_3d, ivec3(gx + 1, gy + 1, 0));
122
if (broadcast_type_C == 4)
124
sum0 = image3d_ld1(C_blob_3d, ivec3(gx, 0, 0));
125
sum1 = image3d_ld1(C_blob_3d, ivec3(gx + 1, 0, 0));
130
if (broadcast_type_C == 0)
132
sum0 = buffer_ld1(C_blob_data, 0);
137
if (broadcast_type_C == 1 || broadcast_type_C == 2)
139
sum0 = buffer_ld1(C_blob_data, gy);
141
sum2 = buffer_ld1(C_blob_data, gy + 1);
144
if (broadcast_type_C == 3)
146
const int ci = gy * psc(N) + gx;
147
sum0 = buffer_ld1(C_blob_data, ci);
148
sum1 = buffer_ld1(C_blob_data, ci + 1);
149
sum2 = buffer_ld1(C_blob_data, ci + psc(N));
150
sum3 = buffer_ld1(C_blob_data, ci + psc(N) + 1);
152
if (broadcast_type_C == 4)
154
sum0 = buffer_ld1(C_blob_data, gx);
155
sum1 = buffer_ld1(C_blob_data, gx + 1);
166
#if !NCNN_image_shader && NCNN_shader_local_memory
167
const int NN = psc(K);
169
const int lx = int(gl_LocalInvocationID.x);
170
const int ly = int(gl_LocalInvocationID.y);
173
for (; k + (LOCAL_MEMORY_UNROLL_INCH - 1) < NN; k += LOCAL_MEMORY_UNROLL_INCH)
178
const int ai = (k + lx) * p.A_hstep + gy;
179
tmp_a[ly][lx][0] = sfp2lfp(A_blob_data[ai]);
180
tmp_a[ly][lx][1] = sfp2lfp(A_blob_data[ai + 1]);
184
const int ai = gy * p.A_hstep + (k + lx);
185
tmp_a[ly][lx][0] = sfp2lfp(A_blob_data[ai]);
186
tmp_a[ly][lx][1] = sfp2lfp(A_blob_data[ai + p.A_hstep]);
191
const int bi = gx * p.B_hstep + (k + ly);
192
tmp_b[lx][ly][0] = sfp2lfp(B_blob_data[bi]);
193
tmp_b[lx][ly][1] = sfp2lfp(B_blob_data[bi + p.B_hstep]);
197
const int bi = (k + ly) * p.B_hstep + gx;
198
tmp_b[lx][ly][0] = sfp2lfp(B_blob_data[bi]);
199
tmp_b[lx][ly][1] = sfp2lfp(B_blob_data[bi + 1]);
205
for (int k4 = 0; k4 < LOCAL_MEMORY_UNROLL_INCH; k4++)
207
afp a0 = lfp2afp(tmp_a[ly][k4][0]);
208
afp a1 = lfp2afp(tmp_a[ly][k4][1]);
210
afp b0 = lfp2afp(tmp_b[lx][k4][0]);
211
afp b1 = lfp2afp(tmp_b[lx][k4][1]);
224
const int remain = NN - k;
230
const int ai = (k + lx) * p.A_hstep + gy;
231
tmp_a[ly][lx][0] = sfp2lfp(A_blob_data[ai]);
232
tmp_a[ly][lx][1] = sfp2lfp(A_blob_data[ai + 1]);
236
const int ai = gy * p.A_hstep + (k + lx);
237
tmp_a[ly][lx][0] = sfp2lfp(A_blob_data[ai]);
238
tmp_a[ly][lx][1] = sfp2lfp(A_blob_data[ai + p.A_hstep]);
246
const int bi = gx * p.B_hstep + (k + ly);
247
tmp_b[lx][ly][0] = sfp2lfp(B_blob_data[bi]);
248
tmp_b[lx][ly][1] = sfp2lfp(B_blob_data[bi + p.B_hstep]);
252
const int bi = (k + ly) * p.B_hstep + gx;
253
tmp_b[lx][ly][0] = sfp2lfp(B_blob_data[bi]);
254
tmp_b[lx][ly][1] = sfp2lfp(B_blob_data[bi + 1]);
260
for (int k4 = 0; k4 < remain; k4++)
262
afp a0 = lfp2afp(tmp_a[ly][k4][0]);
263
afp a1 = lfp2afp(tmp_a[ly][k4][1]);
265
afp b0 = lfp2afp(tmp_b[lx][k4][0]);
266
afp b1 = lfp2afp(tmp_b[lx][k4][1]);
275
for (int k = 0; k < psc(K); k++)
286
a0 = image3d_ld1(A_blob_3d, ivec3(gy, 0, k));
287
a1 = image3d_ld1(A_blob_3d, ivec3(gy + 1, 0, k));
291
a0 = image3d_ld1(A_blob_3d, ivec3(gy, k, 0));
292
a1 = image3d_ld1(A_blob_3d, ivec3(gy + 1, k, 0));
299
a0 = image3d_ld1(A_blob_3d, ivec3(k, 0, gy));
300
a1 = image3d_ld1(A_blob_3d, ivec3(k, 0, gy + 1));
304
a0 = image3d_ld1(A_blob_3d, ivec3(k, gy, 0));
305
a1 = image3d_ld1(A_blob_3d, ivec3(k, gy + 1, 0));
313
b0 = image3d_ld1(B_blob_3d, ivec3(k, 0, gx));
314
b1 = image3d_ld1(B_blob_3d, ivec3(k, 0, gx + 1));
318
b0 = image3d_ld1(B_blob_3d, ivec3(k, gx, 0));
319
b1 = image3d_ld1(B_blob_3d, ivec3(k, gx + 1, 0));
326
b0 = image3d_ld1(B_blob_3d, ivec3(gx, 0, k));
327
b1 = image3d_ld1(B_blob_3d, ivec3(gx + 1, 0, k));
331
b0 = image3d_ld1(B_blob_3d, ivec3(gx, k, 0));
332
b1 = image3d_ld1(B_blob_3d, ivec3(gx + 1, k, 0));
338
const int ai = k * p.A_hstep + gy;
339
a0 = buffer_ld1(A_blob_data, ai);
340
a1 = buffer_ld1(A_blob_data, ai + 1);
344
const int ai = gy * p.A_hstep + k;
345
a0 = buffer_ld1(A_blob_data, ai);
346
a1 = buffer_ld1(A_blob_data, ai + p.A_hstep);
351
const int bi = gx * p.B_hstep + k;
352
b0 = buffer_ld1(B_blob_data, bi);
353
b1 = buffer_ld1(B_blob_data, bi + p.B_hstep);
357
const int bi = k * p.B_hstep + gx;
358
b0 = buffer_ld1(B_blob_data, bi);
359
b1 = buffer_ld1(B_blob_data, bi + 1);
370
#if NCNN_shader_local_memory
371
if (gx >= psc(N) || gy >= psc(M) || gz >= 1)
381
if (output_transpose == 1)
385
image3d_st1(top_blob_3d, ivec3(gy, 0, gx), sum0);
386
if (gy + 1 < psc(M)) image3d_st1(top_blob_3d, ivec3(gy + 1, 0, gx), sum2);
389
image3d_st1(top_blob_3d, ivec3(gy, 0, gx + 1), sum1);
390
if (gy + 1 < psc(M)) image3d_st1(top_blob_3d, ivec3(gy + 1, 0, gx + 1), sum3);
395
image3d_st1(top_blob_3d, ivec3(gy, gx, 0), sum0);
396
if (gy + 1 < psc(M)) image3d_st1(top_blob_3d, ivec3(gy + 1, gx, 0), sum2);
399
image3d_st1(top_blob_3d, ivec3(gy, gx + 1, 0), sum1);
400
if (gy + 1 < psc(M)) image3d_st1(top_blob_3d, ivec3(gy + 1, gx + 1, 0), sum3);
408
image3d_st1(top_blob_3d, ivec3(gx, 0, gy), sum0);
409
if (gx + 1 < psc(N)) image3d_st1(top_blob_3d, ivec3(gx + 1, 0, gy), sum1);
412
image3d_st1(top_blob_3d, ivec3(gx, 0, gy + 1), sum2);
413
if (gx + 1 < psc(N)) image3d_st1(top_blob_3d, ivec3(gx + 1, 0, gy + 1), sum3);
418
image3d_st1(top_blob_3d, ivec3(gx, gy, 0), sum0);
419
if (gx + 1 < psc(N)) image3d_st1(top_blob_3d, ivec3(gx + 1, gy, 0), sum1);
422
image3d_st1(top_blob_3d, ivec3(gx, gy + 1, 0), sum2);
423
if (gx + 1 < psc(N)) image3d_st1(top_blob_3d, ivec3(gx + 1, gy + 1, 0), sum3);
428
if (output_transpose == 1)
430
const int gi = gx * p.outhstep + gy;
432
buffer_st1(top_blob_data, gi, sum0);
433
if (gy + 1 < psc(M)) buffer_st1(top_blob_data, gi + 1, sum2);
436
buffer_st1(top_blob_data, gi + p.outhstep, sum1);
437
if (gy + 1 < psc(M)) buffer_st1(top_blob_data, gi + p.outhstep + 1, sum3);
442
const int gi = gy * p.outhstep + gx;
444
buffer_st1(top_blob_data, gi, sum0);
445
if (gx + 1 < psc(N)) buffer_st1(top_blob_data, gi + 1, sum1);
448
buffer_st1(top_blob_data, gi + p.outhstep, sum2);
449
if (gx + 1 < psc(N)) buffer_st1(top_blob_data, gi + p.outhstep + 1, sum3);