ncnn

Форк
0
/
convolution_3x3s1d1_winograd23_transform_input.comp 
202 строки · 8.1 Кб
1
// Tencent is pleased to support the open source community by making ncnn available.
2
//
3
// Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved.
4
//
5
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
6
// in compliance with the License. You may obtain a copy of the License at
7
//
8
// https://opensource.org/licenses/BSD-3-Clause
9
//
10
// Unless required by applicable law or agreed to in writing, software distributed
11
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
12
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
13
// specific language governing permissions and limitations under the License.
14

15
#version 450
16

17
#if NCNN_fp16_storage
18
#extension GL_EXT_shader_16bit_storage: require
19
#endif
20
#if NCNN_fp16_arithmetic
21
#extension GL_EXT_shader_explicit_arithmetic_types_float16: require
22
#endif
23

24
#define shape_constant_id_offset 0
25
layout (constant_id = shape_constant_id_offset + 0) const int w = 0;
26
layout (constant_id = shape_constant_id_offset + 1) const int h = 0;
27
layout (constant_id = shape_constant_id_offset + 2) const int c = 0;
28
layout (constant_id = shape_constant_id_offset + 3) const int cstep = 0;
29

30
layout (constant_id = shape_constant_id_offset + 4) const int outcstep = 0;
31

32
layout (constant_id = shape_constant_id_offset + 5) const int block_x = 0;
33
layout (constant_id = shape_constant_id_offset + 6) const int block_y = 0;
34

35
#if NCNN_image_shader
36
layout (binding = 0) uniform unfp sampler3D bottom_blob;
37
layout (binding = 1, imfmtc1) writeonly uniform unfp image3D bottom_tm_blob;
38
#else
39
layout (binding = 0) readonly buffer bottom_blob { sfp bottom_blob_data[]; };
40
layout (binding = 1) writeonly buffer bottom_tm_blob { sfp bottom_tm_blob_data[]; };
41
#endif
42

43
layout (push_constant) uniform parameter
44
{
45
    int w;
46
    int h;
47
    int c;
48
    int cstep;
49

50
    int outcstep;
51

52
    int block_x;
53
    int block_y;
54
} p;
55

56
void main()
57
{
58
    int gx = int(gl_GlobalInvocationID.x);
59
    int gy = int(gl_GlobalInvocationID.y);
60
    int gz = int(gl_GlobalInvocationID.z);
61

62
    if (gx >= psc(block_x) || gy >= psc(block_y) || gz >= psc(c))
63
        return;
64

65
    // load 4x4
66
    int sx = gx * 2;
67
    int sy = gy * 2;
68

69
#if NCNN_image_shader
70
    afp v00 = image3d_ld1(bottom_blob, ivec3(sx + 0, sy + 0, gz));
71
    afp v01 = image3d_ld1(bottom_blob, ivec3(sx + 1, sy + 0, gz));
72
    afp v02 = image3d_ld1(bottom_blob, ivec3(sx + 2, sy + 0, gz));
73
    afp v03 = image3d_ld1(bottom_blob, ivec3(sx + 3, sy + 0, gz));
74

75
    afp v10 = image3d_ld1(bottom_blob, ivec3(sx + 0, sy + 1, gz));
76
    afp v11 = image3d_ld1(bottom_blob, ivec3(sx + 1, sy + 1, gz));
77
    afp v12 = image3d_ld1(bottom_blob, ivec3(sx + 2, sy + 1, gz));
78
    afp v13 = image3d_ld1(bottom_blob, ivec3(sx + 3, sy + 1, gz));
79

80
    afp v20 = image3d_ld1(bottom_blob, ivec3(sx + 0, sy + 2, gz));
81
    afp v21 = image3d_ld1(bottom_blob, ivec3(sx + 1, sy + 2, gz));
82
    afp v22 = image3d_ld1(bottom_blob, ivec3(sx + 2, sy + 2, gz));
83
    afp v23 = image3d_ld1(bottom_blob, ivec3(sx + 3, sy + 2, gz));
84

85
    afp v30 = image3d_ld1(bottom_blob, ivec3(sx + 0, sy + 3, gz));
86
    afp v31 = image3d_ld1(bottom_blob, ivec3(sx + 1, sy + 3, gz));
87
    afp v32 = image3d_ld1(bottom_blob, ivec3(sx + 2, sy + 3, gz));
88
    afp v33 = image3d_ld1(bottom_blob, ivec3(sx + 3, sy + 3, gz));
89
#else
90
    int v_offset_0 = gz * psc(cstep) + sy * psc(w) + sx;
91
    ivec4 v_offset = v_offset_0 + ivec4(0, 1, 2, 3) * psc(w);
92

93
    afp v00 = buffer_ld1(bottom_blob_data, v_offset.r + 0);
94
    afp v01 = sx + 1 < psc(w) ? buffer_ld1(bottom_blob_data, v_offset.r + 1) : afp(0.f);
95
    afp v02 = sx + 2 < psc(w) ? buffer_ld1(bottom_blob_data, v_offset.r + 2) : afp(0.f);
96
    afp v03 = sx + 3 < psc(w) ? buffer_ld1(bottom_blob_data, v_offset.r + 3) : afp(0.f);
97

98
    afp v10 = sy + 1 < psc(h) ? buffer_ld1(bottom_blob_data, v_offset.g + 0) : afp(0.f);
99
    afp v11 = sy + 1 < psc(h) && sx + 1 < psc(w) ? buffer_ld1(bottom_blob_data, v_offset.g + 1) : afp(0.f);
100
    afp v12 = sy + 1 < psc(h) && sx + 2 < psc(w) ? buffer_ld1(bottom_blob_data, v_offset.g + 2) : afp(0.f);
101
    afp v13 = sy + 1 < psc(h) && sx + 3 < psc(w) ? buffer_ld1(bottom_blob_data, v_offset.g + 3) : afp(0.f);
102

103
    afp v20 = sy + 2 < psc(h) ? buffer_ld1(bottom_blob_data, v_offset.b + 0) : afp(0.f);
104
    afp v21 = sy + 2 < psc(h) && sx + 1 < psc(w) ? buffer_ld1(bottom_blob_data, v_offset.b + 1) : afp(0.f);
105
    afp v22 = sy + 2 < psc(h) && sx + 2 < psc(w) ? buffer_ld1(bottom_blob_data, v_offset.b + 2) : afp(0.f);
106
    afp v23 = sy + 2 < psc(h) && sx + 3 < psc(w) ? buffer_ld1(bottom_blob_data, v_offset.b + 3) : afp(0.f);
107

108
    afp v30 = sy + 3 < psc(h) ? buffer_ld1(bottom_blob_data, v_offset.a + 0) : afp(0.f);
109
    afp v31 = sy + 3 < psc(h) && sx + 1 < psc(w) ? buffer_ld1(bottom_blob_data, v_offset.a + 1) : afp(0.f);
110
    afp v32 = sy + 3 < psc(h) && sx + 2 < psc(w) ? buffer_ld1(bottom_blob_data, v_offset.a + 2) : afp(0.f);
111
    afp v33 = sy + 3 < psc(h) && sx + 3 < psc(w) ? buffer_ld1(bottom_blob_data, v_offset.a + 3) : afp(0.f);
112
#endif
113

114
    // const float itm[4][4] = {
115
    //     {1.0f,  0.0f, -1.0f,  0.0f},
116
    //     {0.0f,  1.0f,  1.0f,  0.0f},
117
    //     {0.0f, -1.0f,  1.0f,  0.0f},
118
    //     {0.0f, -1.0f,  0.0f,  1.0f}
119
    // };
120

121
    // implicit transpose
122
    afp m00 = v00 - v02;
123
    afp m01 = v10 - v12;
124
    afp m02 = v20 - v22;
125
    afp m03 = v30 - v32;
126

127
    afp m10 = v02 + v01;
128
    afp m11 = v12 + v11;
129
    afp m12 = v22 + v21;
130
    afp m13 = v32 + v31;
131

132
    afp m20 = v02 - v01;
133
    afp m21 = v12 - v11;
134
    afp m22 = v22 - v21;
135
    afp m23 = v32 - v31;
136

137
    afp m30 = v03 - v01;
138
    afp m31 = v13 - v11;
139
    afp m32 = v23 - v21;
140
    afp m33 = v33 - v31;
141

142
    v00 = m00 - m02;
143
    v10 = m10 - m12;
144
    v20 = m20 - m22;
145
    v30 = m30 - m32;
146

147
    v01 = m02 + m01;
148
    v11 = m12 + m11;
149
    v21 = m22 + m21;
150
    v31 = m32 + m31;
151

152
    v02 = m02 - m01;
153
    v12 = m12 - m11;
154
    v22 = m22 - m21;
155
    v32 = m32 - m31;
156

157
    v03 = m03 - m01;
158
    v13 = m13 - m11;
159
    v23 = m23 - m21;
160
    v33 = m33 - m31;
161

162
    // store 16
163
#if NCNN_image_shader
164
    int x = gy * psc(block_x) + gx;
165

166
    image3d_st1(bottom_tm_blob, ivec3(x, gz, 0), v00);
167
    image3d_st1(bottom_tm_blob, ivec3(x, gz, 1), v01);
168
    image3d_st1(bottom_tm_blob, ivec3(x, gz, 2), v02);
169
    image3d_st1(bottom_tm_blob, ivec3(x, gz, 3), v03);
170
    image3d_st1(bottom_tm_blob, ivec3(x, gz, 4), v10);
171
    image3d_st1(bottom_tm_blob, ivec3(x, gz, 5), v11);
172
    image3d_st1(bottom_tm_blob, ivec3(x, gz, 6), v12);
173
    image3d_st1(bottom_tm_blob, ivec3(x, gz, 7), v13);
174
    image3d_st1(bottom_tm_blob, ivec3(x, gz, 8), v20);
175
    image3d_st1(bottom_tm_blob, ivec3(x, gz, 9), v21);
176
    image3d_st1(bottom_tm_blob, ivec3(x, gz, 10), v22);
177
    image3d_st1(bottom_tm_blob, ivec3(x, gz, 11), v23);
178
    image3d_st1(bottom_tm_blob, ivec3(x, gz, 12), v30);
179
    image3d_st1(bottom_tm_blob, ivec3(x, gz, 13), v31);
180
    image3d_st1(bottom_tm_blob, ivec3(x, gz, 14), v32);
181
    image3d_st1(bottom_tm_blob, ivec3(x, gz, 15), v33);
182
#else
183
    int v_tm_offset = gz * psc(block_x) * psc(block_y) + gy * psc(block_x) + gx;
184

185
    buffer_st1(bottom_tm_blob_data, v_tm_offset + 0 * psc(outcstep), v00);
186
    buffer_st1(bottom_tm_blob_data, v_tm_offset + 1 * psc(outcstep), v01);
187
    buffer_st1(bottom_tm_blob_data, v_tm_offset + 2 * psc(outcstep), v02);
188
    buffer_st1(bottom_tm_blob_data, v_tm_offset + 3 * psc(outcstep), v03);
189
    buffer_st1(bottom_tm_blob_data, v_tm_offset + 4 * psc(outcstep), v10);
190
    buffer_st1(bottom_tm_blob_data, v_tm_offset + 5 * psc(outcstep), v11);
191
    buffer_st1(bottom_tm_blob_data, v_tm_offset + 6 * psc(outcstep), v12);
192
    buffer_st1(bottom_tm_blob_data, v_tm_offset + 7 * psc(outcstep), v13);
193
    buffer_st1(bottom_tm_blob_data, v_tm_offset + 8 * psc(outcstep), v20);
194
    buffer_st1(bottom_tm_blob_data, v_tm_offset + 9 * psc(outcstep), v21);
195
    buffer_st1(bottom_tm_blob_data, v_tm_offset + 10 * psc(outcstep), v22);
196
    buffer_st1(bottom_tm_blob_data, v_tm_offset + 11 * psc(outcstep), v23);
197
    buffer_st1(bottom_tm_blob_data, v_tm_offset + 12 * psc(outcstep), v30);
198
    buffer_st1(bottom_tm_blob_data, v_tm_offset + 13 * psc(outcstep), v31);
199
    buffer_st1(bottom_tm_blob_data, v_tm_offset + 14 * psc(outcstep), v32);
200
    buffer_st1(bottom_tm_blob_data, v_tm_offset + 15 * psc(outcstep), v33);
201
#endif
202
}
203

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.