ncnn

Форк
0
/
convolution_3x3s1d1_winograd43_transform_output.comp 
301 строка · 14.6 Кб
1
// Tencent is pleased to support the open source community by making ncnn available.
2
//
3
// Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved.
4
//
5
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
6
// in compliance with the License. You may obtain a copy of the License at
7
//
8
// https://opensource.org/licenses/BSD-3-Clause
9
//
10
// Unless required by applicable law or agreed to in writing, software distributed
11
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
12
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
13
// specific language governing permissions and limitations under the License.
14

15
#version 450
16

17
#if NCNN_fp16_storage
18
#extension GL_EXT_shader_16bit_storage: require
19
#endif
20
#if NCNN_fp16_arithmetic
21
#extension GL_EXT_shader_explicit_arithmetic_types_float16: require
22
#endif
23

24
#extension GL_GOOGLE_include_directive: enable
25
#include "vulkan_activation.comp"
26

27
layout (constant_id = 0) const int bias_term = 0;
28
layout (constant_id = 1) const int activation_type = 0;
29
layout (constant_id = 2) const float activation_param_0 = 0;
30
layout (constant_id = 3) const float activation_param_1 = 0;
31

32
#define shape_constant_id_offset 4
33
layout (constant_id = shape_constant_id_offset + 0) const int c = 0;
34
layout (constant_id = shape_constant_id_offset + 1) const int cstep = 0;
35

36
layout (constant_id = shape_constant_id_offset + 2) const int block_x = 0;
37
layout (constant_id = shape_constant_id_offset + 3) const int block_y = 0;
38

39
layout (constant_id = shape_constant_id_offset + 4) const int outw = 0;
40
layout (constant_id = shape_constant_id_offset + 5) const int outh = 0;
41
layout (constant_id = shape_constant_id_offset + 6) const int outcstep = 0;
42

43
#if NCNN_image_shader
44
layout (binding = 0) uniform unfp sampler3D top_tm_blob;
45
layout (binding = 1, imfmtc1) writeonly uniform unfp image3D top_blob;
46
layout (binding = 2) uniform unfp sampler3D bias_blob;
47
#else
48
layout (binding = 0) readonly buffer top_tm_blob { sfp top_tm_blob_data[]; };
49
layout (binding = 1) writeonly buffer top_blob { sfp top_blob_data[]; };
50
layout (binding = 2) readonly buffer bias_blob { sfp bias_data[]; };
51
#endif
52

53
layout (push_constant) uniform parameter
54
{
55
    int c;
56
    int cstep;
57

58
    int block_x;
59
    int block_y;
60

61
    int outw;
62
    int outh;
63
    int outcstep;
64
} p;
65

66
void main()
67
{
68
    int gx = int(gl_GlobalInvocationID.x);
69
    int gy = int(gl_GlobalInvocationID.y);
70
    int gz = int(gl_GlobalInvocationID.z);
71

72
    if (gx >= psc(block_x) || gy >= psc(block_y) || gz >= psc(c))
73
        return;
74

75
    // load 36
76
#if NCNN_image_shader
77
    int sx = gy * psc(block_x) + gx;
78

79
    afp v00 = image3d_ld1(top_tm_blob, ivec3(sx, gz, 0));
80
    afp v01 = image3d_ld1(top_tm_blob, ivec3(sx, gz, 1));
81
    afp v02 = image3d_ld1(top_tm_blob, ivec3(sx, gz, 2));
82
    afp v03 = image3d_ld1(top_tm_blob, ivec3(sx, gz, 3));
83
    afp v04 = image3d_ld1(top_tm_blob, ivec3(sx, gz, 4));
84
    afp v05 = image3d_ld1(top_tm_blob, ivec3(sx, gz, 5));
85
    afp v10 = image3d_ld1(top_tm_blob, ivec3(sx, gz, 6));
86
    afp v11 = image3d_ld1(top_tm_blob, ivec3(sx, gz, 7));
87
    afp v12 = image3d_ld1(top_tm_blob, ivec3(sx, gz, 8));
88
    afp v13 = image3d_ld1(top_tm_blob, ivec3(sx, gz, 9));
89
    afp v14 = image3d_ld1(top_tm_blob, ivec3(sx, gz, 10));
90
    afp v15 = image3d_ld1(top_tm_blob, ivec3(sx, gz, 11));
91
    afp v20 = image3d_ld1(top_tm_blob, ivec3(sx, gz, 12));
92
    afp v21 = image3d_ld1(top_tm_blob, ivec3(sx, gz, 13));
93
    afp v22 = image3d_ld1(top_tm_blob, ivec3(sx, gz, 14));
94
    afp v23 = image3d_ld1(top_tm_blob, ivec3(sx, gz, 15));
95
    afp v24 = image3d_ld1(top_tm_blob, ivec3(sx, gz, 16));
96
    afp v25 = image3d_ld1(top_tm_blob, ivec3(sx, gz, 17));
97
    afp v30 = image3d_ld1(top_tm_blob, ivec3(sx, gz, 18));
98
    afp v31 = image3d_ld1(top_tm_blob, ivec3(sx, gz, 19));
99
    afp v32 = image3d_ld1(top_tm_blob, ivec3(sx, gz, 20));
100
    afp v33 = image3d_ld1(top_tm_blob, ivec3(sx, gz, 21));
101
    afp v34 = image3d_ld1(top_tm_blob, ivec3(sx, gz, 22));
102
    afp v35 = image3d_ld1(top_tm_blob, ivec3(sx, gz, 23));
103
    afp v40 = image3d_ld1(top_tm_blob, ivec3(sx, gz, 24));
104
    afp v41 = image3d_ld1(top_tm_blob, ivec3(sx, gz, 25));
105
    afp v42 = image3d_ld1(top_tm_blob, ivec3(sx, gz, 26));
106
    afp v43 = image3d_ld1(top_tm_blob, ivec3(sx, gz, 27));
107
    afp v44 = image3d_ld1(top_tm_blob, ivec3(sx, gz, 28));
108
    afp v45 = image3d_ld1(top_tm_blob, ivec3(sx, gz, 29));
109
    afp v50 = image3d_ld1(top_tm_blob, ivec3(sx, gz, 30));
110
    afp v51 = image3d_ld1(top_tm_blob, ivec3(sx, gz, 31));
111
    afp v52 = image3d_ld1(top_tm_blob, ivec3(sx, gz, 32));
112
    afp v53 = image3d_ld1(top_tm_blob, ivec3(sx, gz, 33));
113
    afp v54 = image3d_ld1(top_tm_blob, ivec3(sx, gz, 34));
114
    afp v55 = image3d_ld1(top_tm_blob, ivec3(sx, gz, 35));
115
#else
116
    int v_tm_offset = gz * psc(block_x) * psc(block_y) + gy * psc(block_x) + gx;
117

118
    afp v00 = buffer_ld1(top_tm_blob_data, v_tm_offset + 0 * psc(cstep));
119
    afp v01 = buffer_ld1(top_tm_blob_data, v_tm_offset + 1 * psc(cstep));
120
    afp v02 = buffer_ld1(top_tm_blob_data, v_tm_offset + 2 * psc(cstep));
121
    afp v03 = buffer_ld1(top_tm_blob_data, v_tm_offset + 3 * psc(cstep));
122
    afp v04 = buffer_ld1(top_tm_blob_data, v_tm_offset + 4 * psc(cstep));
123
    afp v05 = buffer_ld1(top_tm_blob_data, v_tm_offset + 5 * psc(cstep));
124
    afp v10 = buffer_ld1(top_tm_blob_data, v_tm_offset + 6 * psc(cstep));
125
    afp v11 = buffer_ld1(top_tm_blob_data, v_tm_offset + 7 * psc(cstep));
126
    afp v12 = buffer_ld1(top_tm_blob_data, v_tm_offset + 8 * psc(cstep));
127
    afp v13 = buffer_ld1(top_tm_blob_data, v_tm_offset + 9 * psc(cstep));
128
    afp v14 = buffer_ld1(top_tm_blob_data, v_tm_offset + 10 * psc(cstep));
129
    afp v15 = buffer_ld1(top_tm_blob_data, v_tm_offset + 11 * psc(cstep));
130
    afp v20 = buffer_ld1(top_tm_blob_data, v_tm_offset + 12 * psc(cstep));
131
    afp v21 = buffer_ld1(top_tm_blob_data, v_tm_offset + 13 * psc(cstep));
132
    afp v22 = buffer_ld1(top_tm_blob_data, v_tm_offset + 14 * psc(cstep));
133
    afp v23 = buffer_ld1(top_tm_blob_data, v_tm_offset + 15 * psc(cstep));
134
    afp v24 = buffer_ld1(top_tm_blob_data, v_tm_offset + 16 * psc(cstep));
135
    afp v25 = buffer_ld1(top_tm_blob_data, v_tm_offset + 17 * psc(cstep));
136
    afp v30 = buffer_ld1(top_tm_blob_data, v_tm_offset + 18 * psc(cstep));
137
    afp v31 = buffer_ld1(top_tm_blob_data, v_tm_offset + 19 * psc(cstep));
138
    afp v32 = buffer_ld1(top_tm_blob_data, v_tm_offset + 20 * psc(cstep));
139
    afp v33 = buffer_ld1(top_tm_blob_data, v_tm_offset + 21 * psc(cstep));
140
    afp v34 = buffer_ld1(top_tm_blob_data, v_tm_offset + 22 * psc(cstep));
141
    afp v35 = buffer_ld1(top_tm_blob_data, v_tm_offset + 23 * psc(cstep));
142
    afp v40 = buffer_ld1(top_tm_blob_data, v_tm_offset + 24 * psc(cstep));
143
    afp v41 = buffer_ld1(top_tm_blob_data, v_tm_offset + 25 * psc(cstep));
144
    afp v42 = buffer_ld1(top_tm_blob_data, v_tm_offset + 26 * psc(cstep));
145
    afp v43 = buffer_ld1(top_tm_blob_data, v_tm_offset + 27 * psc(cstep));
146
    afp v44 = buffer_ld1(top_tm_blob_data, v_tm_offset + 28 * psc(cstep));
147
    afp v45 = buffer_ld1(top_tm_blob_data, v_tm_offset + 29 * psc(cstep));
148
    afp v50 = buffer_ld1(top_tm_blob_data, v_tm_offset + 30 * psc(cstep));
149
    afp v51 = buffer_ld1(top_tm_blob_data, v_tm_offset + 31 * psc(cstep));
150
    afp v52 = buffer_ld1(top_tm_blob_data, v_tm_offset + 32 * psc(cstep));
151
    afp v53 = buffer_ld1(top_tm_blob_data, v_tm_offset + 33 * psc(cstep));
152
    afp v54 = buffer_ld1(top_tm_blob_data, v_tm_offset + 34 * psc(cstep));
153
    afp v55 = buffer_ld1(top_tm_blob_data, v_tm_offset + 35 * psc(cstep));
154
#endif
155

156
#define sq2 1.41421356237
157
#define sq2_m2 1.41421356237*2
158
#define sq2_d2 1.41421356237/2
159
#define sq2_d4 1.41421356237/4
160

161
    // const float otm[4][6] = {
162
    //     {1.0f, 1.0f,   1.0f,  1.0f,  1.0f,   0.0f},
163
    //     {0.0f, sq2/2, -sq2/2, sq2,   -sq2,   0.0f},
164
    //     {0.0f, 0.5f,   0.5f,  2.0f,  2.0f,   0.0f},
165
    //     {0.0f, sq2/4, -sq2/4, sq2*2, -sq2*2, 1.0f}
166
    // };
167

168
    // implicit transpose
169
    afp m00 = v00 + v01 + v02 + v03 + v04;
170
    afp m01 = v10 + v11 + v12 + v13 + v14;
171
    afp m02 = v20 + v21 + v22 + v23 + v24;
172
    afp m03 = v30 + v31 + v32 + v33 + v34;
173
    afp m04 = v40 + v41 + v42 + v43 + v44;
174
    afp m05 = v50 + v51 + v52 + v53 + v54;
175

176
    afp m10 = (v01 - v02) * afp(sq2_d2) + (v03 - v04) * afp(sq2);
177
    afp m11 = (v11 - v12) * afp(sq2_d2) + (v13 - v14) * afp(sq2);
178
    afp m12 = (v21 - v22) * afp(sq2_d2) + (v23 - v24) * afp(sq2);
179
    afp m13 = (v31 - v32) * afp(sq2_d2) + (v33 - v34) * afp(sq2);
180
    afp m14 = (v41 - v42) * afp(sq2_d2) + (v43 - v44) * afp(sq2);
181
    afp m15 = (v51 - v52) * afp(sq2_d2) + (v53 - v54) * afp(sq2);
182

183
    afp m20 = (v01 + v02) * afp(0.5) + (v03 + v04) * afp(2);
184
    afp m21 = (v11 + v12) * afp(0.5) + (v13 + v14) * afp(2);
185
    afp m22 = (v21 + v22) * afp(0.5) + (v23 + v24) * afp(2);
186
    afp m23 = (v31 + v32) * afp(0.5) + (v33 + v34) * afp(2);
187
    afp m24 = (v41 + v42) * afp(0.5) + (v43 + v44) * afp(2);
188
    afp m25 = (v51 + v52) * afp(0.5) + (v53 + v54) * afp(2);
189

190
    afp m30 = v05 + (v01 - v02) * afp(sq2_d4) + (v03 - v04) * afp(sq2_m2);
191
    afp m31 = v15 + (v11 - v12) * afp(sq2_d4) + (v13 - v14) * afp(sq2_m2);
192
    afp m32 = v25 + (v21 - v22) * afp(sq2_d4) + (v23 - v24) * afp(sq2_m2);
193
    afp m33 = v35 + (v31 - v32) * afp(sq2_d4) + (v33 - v34) * afp(sq2_m2);
194
    afp m34 = v45 + (v41 - v42) * afp(sq2_d4) + (v43 - v44) * afp(sq2_m2);
195
    afp m35 = v55 + (v51 - v52) * afp(sq2_d4) + (v53 - v54) * afp(sq2_m2);
196

197
    v00 = m00 + m01 + m02 + m03 + m04;
198
    v10 = m10 + m11 + m12 + m13 + m14;
199
    v20 = m20 + m21 + m22 + m23 + m24;
200
    v30 = m30 + m31 + m32 + m33 + m34;
201

202
    v01 = (m01 - m02) * afp(sq2_d2) + (m03 - m04) * afp(sq2);
203
    v11 = (m11 - m12) * afp(sq2_d2) + (m13 - m14) * afp(sq2);
204
    v21 = (m21 - m22) * afp(sq2_d2) + (m23 - m24) * afp(sq2);
205
    v31 = (m31 - m32) * afp(sq2_d2) + (m33 - m34) * afp(sq2);
206

207
    v02 = (m01 + m02) * afp(0.5) + (m03 + m04) * afp(2);
208
    v12 = (m11 + m12) * afp(0.5) + (m13 + m14) * afp(2);
209
    v22 = (m21 + m22) * afp(0.5) + (m23 + m24) * afp(2);
210
    v32 = (m31 + m32) * afp(0.5) + (m33 + m34) * afp(2);
211

212
    v03 = m05 + (m01 - m02) * afp(sq2_d4) + (m03 - m04) * afp(sq2_m2);
213
    v13 = m15 + (m11 - m12) * afp(sq2_d4) + (m13 - m14) * afp(sq2_m2);
214
    v23 = m25 + (m21 - m22) * afp(sq2_d4) + (m23 - m24) * afp(sq2_m2);
215
    v33 = m35 + (m31 - m32) * afp(sq2_d4) + (m33 - m34) * afp(sq2_m2);
216

217
    if (bias_term == 1)
218
    {
219
#if NCNN_image_shader
220
        const afp bias_value = image3d_ld1(bias_blob, ivec3(gz, 0, 0));
221
#else
222
        const afp bias_value = buffer_ld1(bias_data, gz);
223
#endif
224

225
        v00 = bias_value + v00;
226
        v01 = bias_value + v01;
227
        v02 = bias_value + v02;
228
        v03 = bias_value + v03;
229
        v10 = bias_value + v10;
230
        v11 = bias_value + v11;
231
        v12 = bias_value + v12;
232
        v13 = bias_value + v13;
233
        v20 = bias_value + v20;
234
        v21 = bias_value + v21;
235
        v22 = bias_value + v22;
236
        v23 = bias_value + v23;
237
        v30 = bias_value + v30;
238
        v31 = bias_value + v31;
239
        v32 = bias_value + v32;
240
        v33 = bias_value + v33;
241
    }
242

243
    v00 = activation_afp(v00, activation_type, activation_param_0, activation_param_1);
244
    v01 = activation_afp(v01, activation_type, activation_param_0, activation_param_1);
245
    v02 = activation_afp(v02, activation_type, activation_param_0, activation_param_1);
246
    v03 = activation_afp(v03, activation_type, activation_param_0, activation_param_1);
247
    v10 = activation_afp(v10, activation_type, activation_param_0, activation_param_1);
248
    v11 = activation_afp(v11, activation_type, activation_param_0, activation_param_1);
249
    v12 = activation_afp(v12, activation_type, activation_param_0, activation_param_1);
250
    v13 = activation_afp(v13, activation_type, activation_param_0, activation_param_1);
251
    v20 = activation_afp(v20, activation_type, activation_param_0, activation_param_1);
252
    v21 = activation_afp(v21, activation_type, activation_param_0, activation_param_1);
253
    v22 = activation_afp(v22, activation_type, activation_param_0, activation_param_1);
254
    v23 = activation_afp(v23, activation_type, activation_param_0, activation_param_1);
255
    v30 = activation_afp(v30, activation_type, activation_param_0, activation_param_1);
256
    v31 = activation_afp(v31, activation_type, activation_param_0, activation_param_1);
257
    v32 = activation_afp(v32, activation_type, activation_param_0, activation_param_1);
258
    v33 = activation_afp(v33, activation_type, activation_param_0, activation_param_1);
259

260
    // store 4x4
261
    int x = gx * 4;
262
    int y = gy * 4;
263

264
#if NCNN_image_shader
265
    image3d_st1(top_blob, ivec3(x, y, gz), v00);
266
    image3d_st1(top_blob, ivec3(x + 1, y, gz), v01);
267
    image3d_st1(top_blob, ivec3(x + 2, y, gz), v02);
268
    image3d_st1(top_blob, ivec3(x + 3, y, gz), v03);
269
    image3d_st1(top_blob, ivec3(x, y + 1, gz), v10);
270
    image3d_st1(top_blob, ivec3(x + 1, y + 1, gz), v11);
271
    image3d_st1(top_blob, ivec3(x + 2, y + 1, gz), v12);
272
    image3d_st1(top_blob, ivec3(x + 3, y + 1, gz), v13);
273
    image3d_st1(top_blob, ivec3(x, y + 2, gz), v20);
274
    image3d_st1(top_blob, ivec3(x + 1, y + 2, gz), v21);
275
    image3d_st1(top_blob, ivec3(x + 2, y + 2, gz), v22);
276
    image3d_st1(top_blob, ivec3(x + 3, y + 2, gz), v23);
277
    image3d_st1(top_blob, ivec3(x, y + 3, gz), v30);
278
    image3d_st1(top_blob, ivec3(x + 1, y + 3, gz), v31);
279
    image3d_st1(top_blob, ivec3(x + 2, y + 3, gz), v32);
280
    image3d_st1(top_blob, ivec3(x + 3, y + 3, gz), v33);
281
#else
282
    ivec4 v_offset = gz * psc(outcstep) + y * psc(outw) + x + ivec4(0, 1, 2, 3) * psc(outw);
283

284
    buffer_st1(top_blob_data, v_offset.r + 0, v00);
285
    if (x + 1 < psc(outw)) buffer_st1(top_blob_data, v_offset.r + 1, v01);
286
    if (x + 2 < psc(outw)) buffer_st1(top_blob_data, v_offset.r + 2, v02);
287
    if (x + 3 < psc(outw)) buffer_st1(top_blob_data, v_offset.r + 3, v03);
288
    if (y + 1 < psc(outh)) buffer_st1(top_blob_data, v_offset.g + 0, v10);
289
    if (y + 1 < psc(outh) && x + 1 < psc(outw)) buffer_st1(top_blob_data, v_offset.g + 1, v11);
290
    if (y + 1 < psc(outh) && x + 2 < psc(outw)) buffer_st1(top_blob_data, v_offset.g + 2, v12);
291
    if (y + 1 < psc(outh) && x + 3 < psc(outw)) buffer_st1(top_blob_data, v_offset.g + 3, v13);
292
    if (y + 2 < psc(outh)) buffer_st1(top_blob_data, v_offset.b + 0, v20);
293
    if (y + 2 < psc(outh) && x + 1 < psc(outw)) buffer_st1(top_blob_data, v_offset.b + 1, v21);
294
    if (y + 2 < psc(outh) && x + 2 < psc(outw)) buffer_st1(top_blob_data, v_offset.b + 2, v22);
295
    if (y + 2 < psc(outh) && x + 3 < psc(outw)) buffer_st1(top_blob_data, v_offset.b + 3, v23);
296
    if (y + 3 < psc(outh)) buffer_st1(top_blob_data, v_offset.a + 0, v30);
297
    if (y + 3 < psc(outh) && x + 1 < psc(outw)) buffer_st1(top_blob_data, v_offset.a + 1, v31);
298
    if (y + 3 < psc(outh) && x + 2 < psc(outw)) buffer_st1(top_blob_data, v_offset.a + 2, v32);
299
    if (y + 3 < psc(outh) && x + 3 < psc(outw)) buffer_st1(top_blob_data, v_offset.a + 3, v33);
300
#endif
301
}
302

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.