ncnn

Форк
0
/
instancenorm_coeffs.comp 
92 строки · 2.9 Кб
1
// Tencent is pleased to support the open source community by making ncnn available.
2
//
3
// Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved.
4
//
5
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
6
// in compliance with the License. You may obtain a copy of the License at
7
//
8
// https://opensource.org/licenses/BSD-3-Clause
9
//
10
// Unless required by applicable law or agreed to in writing, software distributed
11
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
12
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
13
// specific language governing permissions and limitations under the License.
14

15
#version 450
16

17
#if NCNN_fp16_storage
18
#extension GL_EXT_shader_16bit_storage: require
19
#endif
20
#if NCNN_fp16_arithmetic
21
#extension GL_EXT_shader_explicit_arithmetic_types_float16: require
22
#endif
23

24
layout (constant_id = 0) const float eps = 0;
25
layout (constant_id = 1) const int affine = 0;
26
layout (constant_id = 2) const int w = 0;
27

28
#if NCNN_image_shader
29
layout (binding = 0, imfmtc1) writeonly uniform unfp image3D coeffs_blob;
30
layout (binding = 1) uniform highp sampler3D mean_blob;
31
layout (binding = 2) uniform highp sampler3D var_blob;
32
layout (binding = 3) uniform unfp sampler3D gamma_blob;
33
layout (binding = 4) uniform unfp sampler3D beta_blob;
34
#else
35
layout (binding = 0) writeonly buffer coeffs_blob { sfp coeffs_blob_data[]; };
36
layout (binding = 1) readonly buffer mean_blob { float mean_data[]; };
37
layout (binding = 2) readonly buffer var_blob { float var_data[]; };
38
layout (binding = 3) readonly buffer gamma_blob { sfp gamma_data[]; };
39
layout (binding = 4) readonly buffer beta_blob { sfp beta_data[]; };
40
#endif
41

42
layout (push_constant) uniform parameter
43
{
44
    int w;
45
} p;
46

47
void main()
48
{
49
    int gx = int(gl_GlobalInvocationID.x);
50
    int gy = int(gl_GlobalInvocationID.y);
51
    int gz = int(gl_GlobalInvocationID.z);
52

53
    if (gx >= psc(w) || gy >= 1 || gz >= 1)
54
        return;
55

56
#if NCNN_image_shader
57
    float mean = texelFetch(mean_blob, ivec3(gx, 0, 0), 0).r;
58
    float var = texelFetch(var_blob, ivec3(gx, 0, 0), 0).r;
59
#else
60
    float mean = mean_data[gx];
61
    float var = var_data[gx];
62
#endif
63

64
    float a;
65
    float b;
66
    if (affine == 0)
67
    {
68
        a = 1.f / (sqrt(var + eps));
69
        b = - mean * a;
70
    }
71
    else
72
    {
73
#if NCNN_image_shader
74
        float gamma = float(image3d_ld1(gamma_blob, ivec3(gx, 0, 0)));
75
        float beta = float(image3d_ld1(beta_blob, ivec3(gx, 0, 0)));
76
#else
77
        float gamma = float(buffer_ld1(gamma_data, gx));
78
        float beta = float(buffer_ld1(beta_data, gx));
79
#endif
80

81
        a = gamma / (sqrt(var + eps));
82
        b = - mean * a + beta;
83
    }
84

85
#if NCNN_image_shader
86
    imageStore(coeffs_blob, ivec3(gx*2, 0, 0), vec4(a));
87
    imageStore(coeffs_blob, ivec3(gx*2 +1, 0, 0), vec4(b));
88
#else
89
    buffer_st1(coeffs_blob_data, gx*2, afp(a));
90
    buffer_st1(coeffs_blob_data, gx*2 +1, afp(b));
91
#endif
92
}
93

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.