ncnn
301 строка · 14.6 Кб
1// Tencent is pleased to support the open source community by making ncnn available.
2//
3// Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved.
4//
5// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
6// in compliance with the License. You may obtain a copy of the License at
7//
8// https://opensource.org/licenses/BSD-3-Clause
9//
10// Unless required by applicable law or agreed to in writing, software distributed
11// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
12// CONDITIONS OF ANY KIND, either express or implied. See the License for the
13// specific language governing permissions and limitations under the License.
14
15#version 450
16
17#if NCNN_fp16_storage
18#extension GL_EXT_shader_16bit_storage: require
19#endif
20#if NCNN_fp16_arithmetic
21#extension GL_EXT_shader_explicit_arithmetic_types_float16: require
22#endif
23
24#extension GL_GOOGLE_include_directive: enable
25#include "vulkan_activation.comp"
26
27layout (constant_id = 0) const int bias_term = 0;
28layout (constant_id = 1) const int activation_type = 0;
29layout (constant_id = 2) const float activation_param_0 = 0;
30layout (constant_id = 3) const float activation_param_1 = 0;
31
32#define shape_constant_id_offset 4
33layout (constant_id = shape_constant_id_offset + 0) const int c = 0;
34layout (constant_id = shape_constant_id_offset + 1) const int cstep = 0;
35
36layout (constant_id = shape_constant_id_offset + 2) const int block_x = 0;
37layout (constant_id = shape_constant_id_offset + 3) const int block_y = 0;
38
39layout (constant_id = shape_constant_id_offset + 4) const int outw = 0;
40layout (constant_id = shape_constant_id_offset + 5) const int outh = 0;
41layout (constant_id = shape_constant_id_offset + 6) const int outcstep = 0;
42
43#if NCNN_image_shader
44layout (binding = 0) uniform unfp sampler3D top_tm_blob;
45layout (binding = 1, imfmtc1) writeonly uniform unfp image3D top_blob;
46layout (binding = 2) uniform unfp sampler3D bias_blob;
47#else
48layout (binding = 0) readonly buffer top_tm_blob { sfp top_tm_blob_data[]; };
49layout (binding = 1) writeonly buffer top_blob { sfp top_blob_data[]; };
50layout (binding = 2) readonly buffer bias_blob { sfp bias_data[]; };
51#endif
52
53layout (push_constant) uniform parameter
54{
55int c;
56int cstep;
57
58int block_x;
59int block_y;
60
61int outw;
62int outh;
63int outcstep;
64} p;
65
66void main()
67{
68int gx = int(gl_GlobalInvocationID.x);
69int gy = int(gl_GlobalInvocationID.y);
70int gz = int(gl_GlobalInvocationID.z);
71
72if (gx >= psc(block_x) || gy >= psc(block_y) || gz >= psc(c))
73return;
74
75// load 36
76#if NCNN_image_shader
77int sx = gy * psc(block_x) + gx;
78
79afp v00 = image3d_ld1(top_tm_blob, ivec3(sx, gz, 0));
80afp v01 = image3d_ld1(top_tm_blob, ivec3(sx, gz, 1));
81afp v02 = image3d_ld1(top_tm_blob, ivec3(sx, gz, 2));
82afp v03 = image3d_ld1(top_tm_blob, ivec3(sx, gz, 3));
83afp v04 = image3d_ld1(top_tm_blob, ivec3(sx, gz, 4));
84afp v05 = image3d_ld1(top_tm_blob, ivec3(sx, gz, 5));
85afp v10 = image3d_ld1(top_tm_blob, ivec3(sx, gz, 6));
86afp v11 = image3d_ld1(top_tm_blob, ivec3(sx, gz, 7));
87afp v12 = image3d_ld1(top_tm_blob, ivec3(sx, gz, 8));
88afp v13 = image3d_ld1(top_tm_blob, ivec3(sx, gz, 9));
89afp v14 = image3d_ld1(top_tm_blob, ivec3(sx, gz, 10));
90afp v15 = image3d_ld1(top_tm_blob, ivec3(sx, gz, 11));
91afp v20 = image3d_ld1(top_tm_blob, ivec3(sx, gz, 12));
92afp v21 = image3d_ld1(top_tm_blob, ivec3(sx, gz, 13));
93afp v22 = image3d_ld1(top_tm_blob, ivec3(sx, gz, 14));
94afp v23 = image3d_ld1(top_tm_blob, ivec3(sx, gz, 15));
95afp v24 = image3d_ld1(top_tm_blob, ivec3(sx, gz, 16));
96afp v25 = image3d_ld1(top_tm_blob, ivec3(sx, gz, 17));
97afp v30 = image3d_ld1(top_tm_blob, ivec3(sx, gz, 18));
98afp v31 = image3d_ld1(top_tm_blob, ivec3(sx, gz, 19));
99afp v32 = image3d_ld1(top_tm_blob, ivec3(sx, gz, 20));
100afp v33 = image3d_ld1(top_tm_blob, ivec3(sx, gz, 21));
101afp v34 = image3d_ld1(top_tm_blob, ivec3(sx, gz, 22));
102afp v35 = image3d_ld1(top_tm_blob, ivec3(sx, gz, 23));
103afp v40 = image3d_ld1(top_tm_blob, ivec3(sx, gz, 24));
104afp v41 = image3d_ld1(top_tm_blob, ivec3(sx, gz, 25));
105afp v42 = image3d_ld1(top_tm_blob, ivec3(sx, gz, 26));
106afp v43 = image3d_ld1(top_tm_blob, ivec3(sx, gz, 27));
107afp v44 = image3d_ld1(top_tm_blob, ivec3(sx, gz, 28));
108afp v45 = image3d_ld1(top_tm_blob, ivec3(sx, gz, 29));
109afp v50 = image3d_ld1(top_tm_blob, ivec3(sx, gz, 30));
110afp v51 = image3d_ld1(top_tm_blob, ivec3(sx, gz, 31));
111afp v52 = image3d_ld1(top_tm_blob, ivec3(sx, gz, 32));
112afp v53 = image3d_ld1(top_tm_blob, ivec3(sx, gz, 33));
113afp v54 = image3d_ld1(top_tm_blob, ivec3(sx, gz, 34));
114afp v55 = image3d_ld1(top_tm_blob, ivec3(sx, gz, 35));
115#else
116int v_tm_offset = gz * psc(block_x) * psc(block_y) + gy * psc(block_x) + gx;
117
118afp v00 = buffer_ld1(top_tm_blob_data, v_tm_offset + 0 * psc(cstep));
119afp v01 = buffer_ld1(top_tm_blob_data, v_tm_offset + 1 * psc(cstep));
120afp v02 = buffer_ld1(top_tm_blob_data, v_tm_offset + 2 * psc(cstep));
121afp v03 = buffer_ld1(top_tm_blob_data, v_tm_offset + 3 * psc(cstep));
122afp v04 = buffer_ld1(top_tm_blob_data, v_tm_offset + 4 * psc(cstep));
123afp v05 = buffer_ld1(top_tm_blob_data, v_tm_offset + 5 * psc(cstep));
124afp v10 = buffer_ld1(top_tm_blob_data, v_tm_offset + 6 * psc(cstep));
125afp v11 = buffer_ld1(top_tm_blob_data, v_tm_offset + 7 * psc(cstep));
126afp v12 = buffer_ld1(top_tm_blob_data, v_tm_offset + 8 * psc(cstep));
127afp v13 = buffer_ld1(top_tm_blob_data, v_tm_offset + 9 * psc(cstep));
128afp v14 = buffer_ld1(top_tm_blob_data, v_tm_offset + 10 * psc(cstep));
129afp v15 = buffer_ld1(top_tm_blob_data, v_tm_offset + 11 * psc(cstep));
130afp v20 = buffer_ld1(top_tm_blob_data, v_tm_offset + 12 * psc(cstep));
131afp v21 = buffer_ld1(top_tm_blob_data, v_tm_offset + 13 * psc(cstep));
132afp v22 = buffer_ld1(top_tm_blob_data, v_tm_offset + 14 * psc(cstep));
133afp v23 = buffer_ld1(top_tm_blob_data, v_tm_offset + 15 * psc(cstep));
134afp v24 = buffer_ld1(top_tm_blob_data, v_tm_offset + 16 * psc(cstep));
135afp v25 = buffer_ld1(top_tm_blob_data, v_tm_offset + 17 * psc(cstep));
136afp v30 = buffer_ld1(top_tm_blob_data, v_tm_offset + 18 * psc(cstep));
137afp v31 = buffer_ld1(top_tm_blob_data, v_tm_offset + 19 * psc(cstep));
138afp v32 = buffer_ld1(top_tm_blob_data, v_tm_offset + 20 * psc(cstep));
139afp v33 = buffer_ld1(top_tm_blob_data, v_tm_offset + 21 * psc(cstep));
140afp v34 = buffer_ld1(top_tm_blob_data, v_tm_offset + 22 * psc(cstep));
141afp v35 = buffer_ld1(top_tm_blob_data, v_tm_offset + 23 * psc(cstep));
142afp v40 = buffer_ld1(top_tm_blob_data, v_tm_offset + 24 * psc(cstep));
143afp v41 = buffer_ld1(top_tm_blob_data, v_tm_offset + 25 * psc(cstep));
144afp v42 = buffer_ld1(top_tm_blob_data, v_tm_offset + 26 * psc(cstep));
145afp v43 = buffer_ld1(top_tm_blob_data, v_tm_offset + 27 * psc(cstep));
146afp v44 = buffer_ld1(top_tm_blob_data, v_tm_offset + 28 * psc(cstep));
147afp v45 = buffer_ld1(top_tm_blob_data, v_tm_offset + 29 * psc(cstep));
148afp v50 = buffer_ld1(top_tm_blob_data, v_tm_offset + 30 * psc(cstep));
149afp v51 = buffer_ld1(top_tm_blob_data, v_tm_offset + 31 * psc(cstep));
150afp v52 = buffer_ld1(top_tm_blob_data, v_tm_offset + 32 * psc(cstep));
151afp v53 = buffer_ld1(top_tm_blob_data, v_tm_offset + 33 * psc(cstep));
152afp v54 = buffer_ld1(top_tm_blob_data, v_tm_offset + 34 * psc(cstep));
153afp v55 = buffer_ld1(top_tm_blob_data, v_tm_offset + 35 * psc(cstep));
154#endif
155
156#define sq2 1.41421356237
157#define sq2_m2 1.41421356237*2
158#define sq2_d2 1.41421356237/2
159#define sq2_d4 1.41421356237/4
160
161// const float otm[4][6] = {
162// {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 0.0f},
163// {0.0f, sq2/2, -sq2/2, sq2, -sq2, 0.0f},
164// {0.0f, 0.5f, 0.5f, 2.0f, 2.0f, 0.0f},
165// {0.0f, sq2/4, -sq2/4, sq2*2, -sq2*2, 1.0f}
166// };
167
168// implicit transpose
169afp m00 = v00 + v01 + v02 + v03 + v04;
170afp m01 = v10 + v11 + v12 + v13 + v14;
171afp m02 = v20 + v21 + v22 + v23 + v24;
172afp m03 = v30 + v31 + v32 + v33 + v34;
173afp m04 = v40 + v41 + v42 + v43 + v44;
174afp m05 = v50 + v51 + v52 + v53 + v54;
175
176afp m10 = (v01 - v02) * afp(sq2_d2) + (v03 - v04) * afp(sq2);
177afp m11 = (v11 - v12) * afp(sq2_d2) + (v13 - v14) * afp(sq2);
178afp m12 = (v21 - v22) * afp(sq2_d2) + (v23 - v24) * afp(sq2);
179afp m13 = (v31 - v32) * afp(sq2_d2) + (v33 - v34) * afp(sq2);
180afp m14 = (v41 - v42) * afp(sq2_d2) + (v43 - v44) * afp(sq2);
181afp m15 = (v51 - v52) * afp(sq2_d2) + (v53 - v54) * afp(sq2);
182
183afp m20 = (v01 + v02) * afp(0.5) + (v03 + v04) * afp(2);
184afp m21 = (v11 + v12) * afp(0.5) + (v13 + v14) * afp(2);
185afp m22 = (v21 + v22) * afp(0.5) + (v23 + v24) * afp(2);
186afp m23 = (v31 + v32) * afp(0.5) + (v33 + v34) * afp(2);
187afp m24 = (v41 + v42) * afp(0.5) + (v43 + v44) * afp(2);
188afp m25 = (v51 + v52) * afp(0.5) + (v53 + v54) * afp(2);
189
190afp m30 = v05 + (v01 - v02) * afp(sq2_d4) + (v03 - v04) * afp(sq2_m2);
191afp m31 = v15 + (v11 - v12) * afp(sq2_d4) + (v13 - v14) * afp(sq2_m2);
192afp m32 = v25 + (v21 - v22) * afp(sq2_d4) + (v23 - v24) * afp(sq2_m2);
193afp m33 = v35 + (v31 - v32) * afp(sq2_d4) + (v33 - v34) * afp(sq2_m2);
194afp m34 = v45 + (v41 - v42) * afp(sq2_d4) + (v43 - v44) * afp(sq2_m2);
195afp m35 = v55 + (v51 - v52) * afp(sq2_d4) + (v53 - v54) * afp(sq2_m2);
196
197v00 = m00 + m01 + m02 + m03 + m04;
198v10 = m10 + m11 + m12 + m13 + m14;
199v20 = m20 + m21 + m22 + m23 + m24;
200v30 = m30 + m31 + m32 + m33 + m34;
201
202v01 = (m01 - m02) * afp(sq2_d2) + (m03 - m04) * afp(sq2);
203v11 = (m11 - m12) * afp(sq2_d2) + (m13 - m14) * afp(sq2);
204v21 = (m21 - m22) * afp(sq2_d2) + (m23 - m24) * afp(sq2);
205v31 = (m31 - m32) * afp(sq2_d2) + (m33 - m34) * afp(sq2);
206
207v02 = (m01 + m02) * afp(0.5) + (m03 + m04) * afp(2);
208v12 = (m11 + m12) * afp(0.5) + (m13 + m14) * afp(2);
209v22 = (m21 + m22) * afp(0.5) + (m23 + m24) * afp(2);
210v32 = (m31 + m32) * afp(0.5) + (m33 + m34) * afp(2);
211
212v03 = m05 + (m01 - m02) * afp(sq2_d4) + (m03 - m04) * afp(sq2_m2);
213v13 = m15 + (m11 - m12) * afp(sq2_d4) + (m13 - m14) * afp(sq2_m2);
214v23 = m25 + (m21 - m22) * afp(sq2_d4) + (m23 - m24) * afp(sq2_m2);
215v33 = m35 + (m31 - m32) * afp(sq2_d4) + (m33 - m34) * afp(sq2_m2);
216
217if (bias_term == 1)
218{
219#if NCNN_image_shader
220const afp bias_value = image3d_ld1(bias_blob, ivec3(gz, 0, 0));
221#else
222const afp bias_value = buffer_ld1(bias_data, gz);
223#endif
224
225v00 = bias_value + v00;
226v01 = bias_value + v01;
227v02 = bias_value + v02;
228v03 = bias_value + v03;
229v10 = bias_value + v10;
230v11 = bias_value + v11;
231v12 = bias_value + v12;
232v13 = bias_value + v13;
233v20 = bias_value + v20;
234v21 = bias_value + v21;
235v22 = bias_value + v22;
236v23 = bias_value + v23;
237v30 = bias_value + v30;
238v31 = bias_value + v31;
239v32 = bias_value + v32;
240v33 = bias_value + v33;
241}
242
243v00 = activation_afp(v00, activation_type, activation_param_0, activation_param_1);
244v01 = activation_afp(v01, activation_type, activation_param_0, activation_param_1);
245v02 = activation_afp(v02, activation_type, activation_param_0, activation_param_1);
246v03 = activation_afp(v03, activation_type, activation_param_0, activation_param_1);
247v10 = activation_afp(v10, activation_type, activation_param_0, activation_param_1);
248v11 = activation_afp(v11, activation_type, activation_param_0, activation_param_1);
249v12 = activation_afp(v12, activation_type, activation_param_0, activation_param_1);
250v13 = activation_afp(v13, activation_type, activation_param_0, activation_param_1);
251v20 = activation_afp(v20, activation_type, activation_param_0, activation_param_1);
252v21 = activation_afp(v21, activation_type, activation_param_0, activation_param_1);
253v22 = activation_afp(v22, activation_type, activation_param_0, activation_param_1);
254v23 = activation_afp(v23, activation_type, activation_param_0, activation_param_1);
255v30 = activation_afp(v30, activation_type, activation_param_0, activation_param_1);
256v31 = activation_afp(v31, activation_type, activation_param_0, activation_param_1);
257v32 = activation_afp(v32, activation_type, activation_param_0, activation_param_1);
258v33 = activation_afp(v33, activation_type, activation_param_0, activation_param_1);
259
260// store 4x4
261int x = gx * 4;
262int y = gy * 4;
263
264#if NCNN_image_shader
265image3d_st1(top_blob, ivec3(x, y, gz), v00);
266image3d_st1(top_blob, ivec3(x + 1, y, gz), v01);
267image3d_st1(top_blob, ivec3(x + 2, y, gz), v02);
268image3d_st1(top_blob, ivec3(x + 3, y, gz), v03);
269image3d_st1(top_blob, ivec3(x, y + 1, gz), v10);
270image3d_st1(top_blob, ivec3(x + 1, y + 1, gz), v11);
271image3d_st1(top_blob, ivec3(x + 2, y + 1, gz), v12);
272image3d_st1(top_blob, ivec3(x + 3, y + 1, gz), v13);
273image3d_st1(top_blob, ivec3(x, y + 2, gz), v20);
274image3d_st1(top_blob, ivec3(x + 1, y + 2, gz), v21);
275image3d_st1(top_blob, ivec3(x + 2, y + 2, gz), v22);
276image3d_st1(top_blob, ivec3(x + 3, y + 2, gz), v23);
277image3d_st1(top_blob, ivec3(x, y + 3, gz), v30);
278image3d_st1(top_blob, ivec3(x + 1, y + 3, gz), v31);
279image3d_st1(top_blob, ivec3(x + 2, y + 3, gz), v32);
280image3d_st1(top_blob, ivec3(x + 3, y + 3, gz), v33);
281#else
282ivec4 v_offset = gz * psc(outcstep) + y * psc(outw) + x + ivec4(0, 1, 2, 3) * psc(outw);
283
284buffer_st1(top_blob_data, v_offset.r + 0, v00);
285if (x + 1 < psc(outw)) buffer_st1(top_blob_data, v_offset.r + 1, v01);
286if (x + 2 < psc(outw)) buffer_st1(top_blob_data, v_offset.r + 2, v02);
287if (x + 3 < psc(outw)) buffer_st1(top_blob_data, v_offset.r + 3, v03);
288if (y + 1 < psc(outh)) buffer_st1(top_blob_data, v_offset.g + 0, v10);
289if (y + 1 < psc(outh) && x + 1 < psc(outw)) buffer_st1(top_blob_data, v_offset.g + 1, v11);
290if (y + 1 < psc(outh) && x + 2 < psc(outw)) buffer_st1(top_blob_data, v_offset.g + 2, v12);
291if (y + 1 < psc(outh) && x + 3 < psc(outw)) buffer_st1(top_blob_data, v_offset.g + 3, v13);
292if (y + 2 < psc(outh)) buffer_st1(top_blob_data, v_offset.b + 0, v20);
293if (y + 2 < psc(outh) && x + 1 < psc(outw)) buffer_st1(top_blob_data, v_offset.b + 1, v21);
294if (y + 2 < psc(outh) && x + 2 < psc(outw)) buffer_st1(top_blob_data, v_offset.b + 2, v22);
295if (y + 2 < psc(outh) && x + 3 < psc(outw)) buffer_st1(top_blob_data, v_offset.b + 3, v23);
296if (y + 3 < psc(outh)) buffer_st1(top_blob_data, v_offset.a + 0, v30);
297if (y + 3 < psc(outh) && x + 1 < psc(outw)) buffer_st1(top_blob_data, v_offset.a + 1, v31);
298if (y + 3 < psc(outh) && x + 2 < psc(outw)) buffer_st1(top_blob_data, v_offset.a + 2, v32);
299if (y + 3 < psc(outh) && x + 3 < psc(outw)) buffer_st1(top_blob_data, v_offset.a + 3, v33);
300#endif
301}
302