ncnn

Форк
0
/
instancenorm_reduce_sum4_fp32.comp 
139 строк · 3.6 Кб
1
// Tencent is pleased to support the open source community by making ncnn available.
2
//
3
// Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved.
4
//
5
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
6
// in compliance with the License. You may obtain a copy of the License at
7
//
8
// https://opensource.org/licenses/BSD-3-Clause
9
//
10
// Unless required by applicable law or agreed to in writing, software distributed
11
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
12
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
13
// specific language governing permissions and limitations under the License.
14

15
#version 450
16

17
#if NCNN_fp16_storage
18
#extension GL_EXT_shader_16bit_storage: require
19
#endif
20
#if NCNN_fp16_arithmetic
21
#extension GL_EXT_shader_explicit_arithmetic_types_float16: require
22
#endif
23

24
#if NCNN_image_shader
25
layout (binding = 0) uniform highp sampler3D square_blob;
26
layout (binding = 1, r32f) writeonly uniform highp image3D sqsum_blob;
27
#else
28
layout (binding = 0) readonly buffer square_blob { float square_blob_data[]; };
29
layout (binding = 1) writeonly buffer sqsum_blob { float sqsum_blob_data[]; };
30
#endif
31

32
layout (push_constant) uniform parameter
33
{
34
    int w;
35
    int h;
36
    int c;
37
    int cstep;
38

39
    int outw;
40
    int outh;
41
    int outc;
42
    int outcstep;
43
} p;
44

45
void main()
46
{
47
    int gx = int(gl_GlobalInvocationID.x);
48
    int gy = int(gl_GlobalInvocationID.y);
49
    int gz = int(gl_GlobalInvocationID.z);
50

51
    if (gx >= p.outw || gy >= p.outh || gz >= p.outc)
52
        return;
53

54
    float sum;
55

56
#if NCNN_image_shader
57
    int sx = gx * 2;
58
    int sy = gy * 2;
59

60
    if (sy == p.h - 1)
61
    {
62
        if (sx == p.w - 1)
63
        {
64
            float v0 = texelFetch(square_blob, ivec3(sx, sy, gz), 0).r;
65

66
            sum = v0;
67
        }
68
        else
69
        {
70
            float v0 = texelFetch(square_blob, ivec3(sx, sy, gz), 0).r;
71
            float v1 = texelFetch(square_blob, ivec3(sx + 1, sy, gz), 0).r;
72

73
            sum = v0 + v1;
74
        }
75
    }
76
    else
77
    {
78
        if (sx == p.w - 1)
79
        {
80
            float v0 = texelFetch(square_blob, ivec3(sx, sy, gz), 0).r;
81
            float v2 = texelFetch(square_blob, ivec3(sx, sy + 1, gz), 0).r;
82

83
            sum = v0 + v2;
84
        }
85
        else
86
        {
87
            float v0 = texelFetch(square_blob, ivec3(sx, sy, gz), 0).r;
88
            float v1 = texelFetch(square_blob, ivec3(sx + 1, sy, gz), 0).r;
89
            float v2 = texelFetch(square_blob, ivec3(sx, sy + 1, gz), 0).r;
90
            float v3 = texelFetch(square_blob, ivec3(sx + 1, sy + 1, gz), 0).r;
91

92
            sum = v0 + v1 + v2 + v3;
93
        }
94
    }
95
#else
96
    int sx = gx * 4;
97

98
    int v_offset = gz * p.cstep + sx;
99

100
    if (sx == p.w - 1)
101
    {
102
        float v0 = square_blob_data[v_offset];
103

104
        sum = v0;
105
    }
106
    else if (sx == p.w - 2)
107
    {
108
        float v0 = square_blob_data[v_offset];
109
        float v1 = square_blob_data[v_offset + 1];
110

111
        sum = v0 + v1;
112
    }
113
    else if (sx == p.w - 2)
114
    {
115
        float v0 = square_blob_data[v_offset];
116
        float v1 = square_blob_data[v_offset + 1];
117
        float v2 = square_blob_data[v_offset + 2];
118

119
        sum = v0 + v1 + v2;
120
    }
121
    else
122
    {
123
        float v0 = square_blob_data[v_offset];
124
        float v1 = square_blob_data[v_offset + 1];
125
        float v2 = square_blob_data[v_offset + 2];
126
        float v3 = square_blob_data[v_offset + 3];
127

128
        sum = v0 + v1 + v2 + v3;
129
    }
130
#endif
131

132
#if NCNN_image_shader
133
    imageStore(sqsum_blob, ivec3(gx, gy, gz), vec4(sum));
134
#else
135
    int gi = gz * p.outcstep + gx;
136

137
    sqsum_blob_data[gi] = sum;
138
#endif
139
}
140

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.