ncnn

Форк
0
/
gemm.comp 
453 строки · 13.3 Кб
1
// Tencent is pleased to support the open source community by making ncnn available.
2
//
3
// Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved.
4
//
5
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
6
// in compliance with the License. You may obtain a copy of the License at
7
//
8
// https://opensource.org/licenses/BSD-3-Clause
9
//
10
// Unless required by applicable law or agreed to in writing, software distributed
11
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
12
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
13
// specific language governing permissions and limitations under the License.
14

15
#version 450
16

17
#if NCNN_fp16_storage
18
#extension GL_EXT_shader_16bit_storage: require
19
#endif
20
#if NCNN_fp16_arithmetic
21
#extension GL_EXT_shader_explicit_arithmetic_types_float16: require
22
#endif
23

24
#define LOCAL_MEMORY_UNROLL_INCH 8
25

26
layout (constant_id = 0) const float alpha = 1.f;
27
layout (constant_id = 1) const float beta = 1.f;
28
layout (constant_id = 2) const int transA = 0;
29
layout (constant_id = 3) const int transB = 0;
30
layout (constant_id = 4) const int constantA = 0;
31
layout (constant_id = 5) const int constantB = 0;
32
layout (constant_id = 6) const int constantC = 0;
33
layout (constant_id = 7) const int M = 0;
34
layout (constant_id = 8) const int N = 0;
35
layout (constant_id = 9) const int K = 0;
36
layout (constant_id = 10) const int constant_broadcast_type_C = 0;
37
layout (constant_id = 11) const int output_N1M = 0;
38
layout (constant_id = 12) const int output_elempack = 0;
39
layout (constant_id = 13) const int output_elemtype = 0;
40
layout (constant_id = 14) const int output_transpose = 0;
41

42
// TODO psc more
43

44
#if NCNN_image_shader
45
layout (binding = 0, imfmtc1) writeonly uniform unfp image3D top_blob_3d;
46
layout (binding = 1) uniform unfp sampler3D A_blob_3d;
47
layout (binding = 2) uniform unfp sampler3D B_blob_3d;
48
layout (binding = 3) uniform unfp sampler3D C_blob_3d;
49
#else
50
layout (binding = 0) writeonly buffer top_blob { sfp top_blob_data[]; };
51
layout (binding = 1) readonly buffer A_blob { sfp A_blob_data[]; };
52
layout (binding = 2) readonly buffer B_blob { sfp B_blob_data[]; };
53
layout (binding = 3) readonly buffer C_blob { sfp C_blob_data[]; };
54
#endif
55

56
layout (push_constant) uniform parameter
57
{
58
    int M;
59
    int N;
60
    int K;
61
    int broadcast_type_C;
62
    int A_dims;
63
    int A_hstep;
64
    int B_dims;
65
    int B_hstep;
66
    int outdims;
67
    int outhstep;
68
} p;
69

70
#if NCNN_shader_local_memory
71
shared lfp tmp_a[8][LOCAL_MEMORY_UNROLL_INCH][2];
72
shared lfp tmp_b[8][LOCAL_MEMORY_UNROLL_INCH][2];
73
#endif
74

75
void main()
76
{
77
    int gx = int(gl_GlobalInvocationID.x) * 2;
78
    int gy = int(gl_GlobalInvocationID.y) * 2;
79
    int gz = int(gl_GlobalInvocationID.z);
80

81
#if !NCNN_shader_local_memory
82
    if (gx >= psc(N) || gy >= psc(M) || gz >= 1)
83
        return;
84
#endif
85

86
    afp sum0 = afp(0.f);
87
    afp sum1 = afp(0.f);
88
    afp sum2 = afp(0.f);
89
    afp sum3 = afp(0.f);
90

91
    const int broadcast_type_C = constantC == 1 ? constant_broadcast_type_C : p.broadcast_type_C;
92

93
#if NCNN_image_shader
94
    if (broadcast_type_C == 0)
95
    {
96
        sum0 = image3d_ld1(C_blob_3d, ivec3(0, 0, 0));
97
        sum1 = sum0;
98
        sum2 = sum0;
99
        sum3 = sum0;
100
    }
101
    if (broadcast_type_C == 1)
102
    {
103
        sum0 = image3d_ld1(C_blob_3d, ivec3(gy, 0, 0));
104
        sum1 = sum0;
105
        sum2 = image3d_ld1(C_blob_3d, ivec3(gy + 1, 0, 0));
106
        sum3 = sum2;
107
    }
108
    if (broadcast_type_C == 2)
109
    {
110
        sum0 = image3d_ld1(C_blob_3d, ivec3(0, gy, 0));
111
        sum1 = sum0;
112
        sum2 = image3d_ld1(C_blob_3d, ivec3(0, gy + 1, 0));
113
        sum3 = sum2;
114
    }
115
    if (broadcast_type_C == 3)
116
    {
117
        sum0 = image3d_ld1(C_blob_3d, ivec3(gx, gy, 0));
118
        sum1 = image3d_ld1(C_blob_3d, ivec3(gx + 1, gy, 0));
119
        sum2 = image3d_ld1(C_blob_3d, ivec3(gx, gy + 1, 0));
120
        sum3 = image3d_ld1(C_blob_3d, ivec3(gx + 1, gy + 1, 0));
121
    }
122
    if (broadcast_type_C == 4)
123
    {
124
        sum0 = image3d_ld1(C_blob_3d, ivec3(gx, 0, 0));
125
        sum1 = image3d_ld1(C_blob_3d, ivec3(gx + 1, 0, 0));
126
        sum2 = sum0;
127
        sum3 = sum1;
128
    }
129
#else
130
    if (broadcast_type_C == 0)
131
    {
132
        sum0 = buffer_ld1(C_blob_data, 0);
133
        sum1 = sum0;
134
        sum2 = sum0;
135
        sum3 = sum0;
136
    }
137
    if (broadcast_type_C == 1 || broadcast_type_C == 2)
138
    {
139
        sum0 = buffer_ld1(C_blob_data, gy);
140
        sum1 = sum0;
141
        sum2 = buffer_ld1(C_blob_data, gy + 1);
142
        sum3 = sum2;
143
    }
144
    if (broadcast_type_C == 3)
145
    {
146
        const int ci = gy * psc(N) + gx;
147
        sum0 = buffer_ld1(C_blob_data, ci);
148
        sum1 = buffer_ld1(C_blob_data, ci + 1);
149
        sum2 = buffer_ld1(C_blob_data, ci + psc(N));
150
        sum3 = buffer_ld1(C_blob_data, ci + psc(N) + 1);
151
    }
152
    if (broadcast_type_C == 4)
153
    {
154
        sum0 = buffer_ld1(C_blob_data, gx);
155
        sum1 = buffer_ld1(C_blob_data, gx + 1);
156
        sum2 = sum0;
157
        sum3 = sum1;
158
    }
159
#endif
160

161
    sum0 *= afp(beta);
162
    sum1 *= afp(beta);
163
    sum2 *= afp(beta);
164
    sum3 *= afp(beta);
165

166
#if !NCNN_image_shader && NCNN_shader_local_memory
167
    const int NN = psc(K);
168

169
    const int lx = int(gl_LocalInvocationID.x);
170
    const int ly = int(gl_LocalInvocationID.y);
171

172
    int k = 0;
173
    for (; k + (LOCAL_MEMORY_UNROLL_INCH - 1) < NN; k += LOCAL_MEMORY_UNROLL_INCH)
174
    {
175
        {
176
            if (transA == 1)
177
            {
178
                const int ai = (k + lx) * p.A_hstep + gy;
179
                tmp_a[ly][lx][0] = sfp2lfp(A_blob_data[ai]);
180
                tmp_a[ly][lx][1] = sfp2lfp(A_blob_data[ai + 1]);
181
            }
182
            else
183
            {
184
                const int ai = gy * p.A_hstep + (k + lx);
185
                tmp_a[ly][lx][0] = sfp2lfp(A_blob_data[ai]);
186
                tmp_a[ly][lx][1] = sfp2lfp(A_blob_data[ai + p.A_hstep]);
187
            }
188

189
            if (transB == 1)
190
            {
191
                const int bi = gx * p.B_hstep + (k + ly);
192
                tmp_b[lx][ly][0] = sfp2lfp(B_blob_data[bi]);
193
                tmp_b[lx][ly][1] = sfp2lfp(B_blob_data[bi + p.B_hstep]);
194
            }
195
            else
196
            {
197
                const int bi = (k + ly) * p.B_hstep + gx;
198
                tmp_b[lx][ly][0] = sfp2lfp(B_blob_data[bi]);
199
                tmp_b[lx][ly][1] = sfp2lfp(B_blob_data[bi + 1]);
200
            }
201
        }
202

203
        barrier();
204

205
        for (int k4 = 0; k4 < LOCAL_MEMORY_UNROLL_INCH; k4++)
206
        {
207
            afp a0 = lfp2afp(tmp_a[ly][k4][0]);
208
            afp a1 = lfp2afp(tmp_a[ly][k4][1]);
209

210
            afp b0 = lfp2afp(tmp_b[lx][k4][0]);
211
            afp b1 = lfp2afp(tmp_b[lx][k4][1]);
212

213
            sum0 += a0 * b0;
214
            sum1 += a0 * b1;
215
            sum2 += a1 * b0;
216
            sum3 += a1 * b1;
217
        }
218

219
        barrier();
220
    }
221

222
    if (k < NN)
223
    {
224
        const int remain = NN - k;
225

226
        if (lx < remain)
227
        {
228
            if (transA == 1)
229
            {
230
                const int ai = (k + lx) * p.A_hstep + gy;
231
                tmp_a[ly][lx][0] = sfp2lfp(A_blob_data[ai]);
232
                tmp_a[ly][lx][1] = sfp2lfp(A_blob_data[ai + 1]);
233
            }
234
            else
235
            {
236
                const int ai = gy * p.A_hstep + (k + lx);
237
                tmp_a[ly][lx][0] = sfp2lfp(A_blob_data[ai]);
238
                tmp_a[ly][lx][1] = sfp2lfp(A_blob_data[ai + p.A_hstep]);
239
            }
240
        }
241

242
        if (ly < remain)
243
        {
244
            if (transB == 1)
245
            {
246
                const int bi = gx * p.B_hstep + (k + ly);
247
                tmp_b[lx][ly][0] = sfp2lfp(B_blob_data[bi]);
248
                tmp_b[lx][ly][1] = sfp2lfp(B_blob_data[bi + p.B_hstep]);
249
            }
250
            else
251
            {
252
                const int bi = (k + ly) * p.B_hstep + gx;
253
                tmp_b[lx][ly][0] = sfp2lfp(B_blob_data[bi]);
254
                tmp_b[lx][ly][1] = sfp2lfp(B_blob_data[bi + 1]);
255
            }
256
        }
257

258
        barrier();
259

260
        for (int k4 = 0; k4 < remain; k4++)
261
        {
262
            afp a0 = lfp2afp(tmp_a[ly][k4][0]);
263
            afp a1 = lfp2afp(tmp_a[ly][k4][1]);
264

265
            afp b0 = lfp2afp(tmp_b[lx][k4][0]);
266
            afp b1 = lfp2afp(tmp_b[lx][k4][1]);
267

268
            sum0 += a0 * b0;
269
            sum1 += a0 * b1;
270
            sum2 += a1 * b0;
271
            sum3 += a1 * b1;
272
        }
273
    }
274
#else
275
    for (int k = 0; k < psc(K); k++)
276
    {
277
        afp a0;
278
        afp a1;
279
        afp b0;
280
        afp b1;
281
#if NCNN_image_shader
282
        if (transA == 1)
283
        {
284
            if (p.A_dims == 3)
285
            {
286
                a0 = image3d_ld1(A_blob_3d, ivec3(gy, 0, k));
287
                a1 = image3d_ld1(A_blob_3d, ivec3(gy + 1, 0, k));
288
            }
289
            else
290
            {
291
                a0 = image3d_ld1(A_blob_3d, ivec3(gy, k, 0));
292
                a1 = image3d_ld1(A_blob_3d, ivec3(gy + 1, k, 0));
293
            }
294
        }
295
        else
296
        {
297
            if (p.A_dims == 3)
298
            {
299
                a0 = image3d_ld1(A_blob_3d, ivec3(k, 0, gy));
300
                a1 = image3d_ld1(A_blob_3d, ivec3(k, 0, gy + 1));
301
            }
302
            else
303
            {
304
                a0 = image3d_ld1(A_blob_3d, ivec3(k, gy, 0));
305
                a1 = image3d_ld1(A_blob_3d, ivec3(k, gy + 1, 0));
306
            }
307
        }
308

309
        if (transB == 1)
310
        {
311
            if (p.B_dims == 3)
312
            {
313
                b0 = image3d_ld1(B_blob_3d, ivec3(k, 0, gx));
314
                b1 = image3d_ld1(B_blob_3d, ivec3(k, 0, gx + 1));
315
            }
316
            else
317
            {
318
                b0 = image3d_ld1(B_blob_3d, ivec3(k, gx, 0));
319
                b1 = image3d_ld1(B_blob_3d, ivec3(k, gx + 1, 0));
320
            }
321
        }
322
        else
323
        {
324
            if (p.B_dims == 3)
325
            {
326
                b0 = image3d_ld1(B_blob_3d, ivec3(gx, 0, k));
327
                b1 = image3d_ld1(B_blob_3d, ivec3(gx + 1, 0, k));
328
            }
329
            else
330
            {
331
                b0 = image3d_ld1(B_blob_3d, ivec3(gx, k, 0));
332
                b1 = image3d_ld1(B_blob_3d, ivec3(gx + 1, k, 0));
333
            }
334
        }
335
#else
336
        if (transA == 1)
337
        {
338
            const int ai = k * p.A_hstep + gy;
339
            a0 = buffer_ld1(A_blob_data, ai);
340
            a1 = buffer_ld1(A_blob_data, ai + 1);
341
        }
342
        else
343
        {
344
            const int ai = gy * p.A_hstep + k;
345
            a0 = buffer_ld1(A_blob_data, ai);
346
            a1 = buffer_ld1(A_blob_data, ai + p.A_hstep);
347
        }
348

349
        if (transB == 1)
350
        {
351
            const int bi = gx * p.B_hstep + k;
352
            b0 = buffer_ld1(B_blob_data, bi);
353
            b1 = buffer_ld1(B_blob_data, bi + p.B_hstep);
354
        }
355
        else
356
        {
357
            const int bi = k * p.B_hstep + gx;
358
            b0 = buffer_ld1(B_blob_data, bi);
359
            b1 = buffer_ld1(B_blob_data, bi + 1);
360
        }
361
#endif
362

363
        sum0 += a0 * b0;
364
        sum1 += a0 * b1;
365
        sum2 += a1 * b0;
366
        sum3 += a1 * b1;
367
    }
368
#endif
369

370
#if NCNN_shader_local_memory
371
    if (gx >= psc(N) || gy >= psc(M) || gz >= 1)
372
        return;
373
#endif
374

375
    sum0 *= afp(alpha);
376
    sum1 *= afp(alpha);
377
    sum2 *= afp(alpha);
378
    sum3 *= afp(alpha);
379

380
#if NCNN_image_shader
381
    if (output_transpose == 1)
382
    {
383
        if (output_N1M == 1)
384
        {
385
            image3d_st1(top_blob_3d, ivec3(gy, 0, gx), sum0);
386
            if (gy + 1 < psc(M)) image3d_st1(top_blob_3d, ivec3(gy + 1, 0, gx), sum2);
387
            if (gx + 1 < psc(N))
388
            {
389
                image3d_st1(top_blob_3d, ivec3(gy, 0, gx + 1), sum1);
390
                if (gy + 1 < psc(M)) image3d_st1(top_blob_3d, ivec3(gy + 1, 0, gx + 1), sum3);
391
            }
392
        }
393
        else
394
        {
395
            image3d_st1(top_blob_3d, ivec3(gy, gx, 0), sum0);
396
            if (gy + 1 < psc(M)) image3d_st1(top_blob_3d, ivec3(gy + 1, gx, 0), sum2);
397
            if (gx + 1 < psc(N))
398
            {
399
                image3d_st1(top_blob_3d, ivec3(gy, gx + 1, 0), sum1);
400
                if (gy + 1 < psc(M)) image3d_st1(top_blob_3d, ivec3(gy + 1, gx + 1, 0), sum3);
401
            }
402
        }
403
    }
404
    else
405
    {
406
        if (output_N1M == 1)
407
        {
408
            image3d_st1(top_blob_3d, ivec3(gx, 0, gy), sum0);
409
            if (gx + 1 < psc(N)) image3d_st1(top_blob_3d, ivec3(gx + 1, 0, gy), sum1);
410
            if (gy + 1 < psc(M))
411
            {
412
                image3d_st1(top_blob_3d, ivec3(gx, 0, gy + 1), sum2);
413
                if (gx + 1 < psc(N)) image3d_st1(top_blob_3d, ivec3(gx + 1, 0, gy + 1), sum3);
414
            }
415
        }
416
        else
417
        {
418
            image3d_st1(top_blob_3d, ivec3(gx, gy, 0), sum0);
419
            if (gx + 1 < psc(N)) image3d_st1(top_blob_3d, ivec3(gx + 1, gy, 0), sum1);
420
            if (gy + 1 < psc(M))
421
            {
422
                image3d_st1(top_blob_3d, ivec3(gx, gy + 1, 0), sum2);
423
                if (gx + 1 < psc(N)) image3d_st1(top_blob_3d, ivec3(gx + 1, gy + 1, 0), sum3);
424
            }
425
        }
426
    }
427
#else
428
    if (output_transpose == 1)
429
    {
430
        const int gi = gx * p.outhstep + gy;
431

432
        buffer_st1(top_blob_data, gi, sum0);
433
        if (gy + 1 < psc(M)) buffer_st1(top_blob_data, gi + 1, sum2);
434
        if (gx + 1 < psc(N))
435
        {
436
            buffer_st1(top_blob_data, gi + p.outhstep, sum1);
437
            if (gy + 1 < psc(M)) buffer_st1(top_blob_data, gi + p.outhstep + 1, sum3);
438
        }
439
    }
440
    else
441
    {
442
        const int gi = gy * p.outhstep + gx;
443

444
        buffer_st1(top_blob_data, gi, sum0);
445
        if (gx + 1 < psc(N)) buffer_st1(top_blob_data, gi + 1, sum1);
446
        if (gy + 1 < psc(M))
447
        {
448
            buffer_st1(top_blob_data, gi + p.outhstep, sum2);
449
            if (gx + 1 < psc(N)) buffer_st1(top_blob_data, gi + p.outhstep + 1, sum3);
450
        }
451
    }
452
#endif
453
}
454

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.