ncnn

Форк
0
/
vulkan_activation.comp 
142 строки · 4.3 Кб
1
// Tencent is pleased to support the open source community by making ncnn available.
2
//
3
// Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved.
4
//
5
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
6
// in compliance with the License. You may obtain a copy of the License at
7
//
8
// https://opensource.org/licenses/BSD-3-Clause
9
//
10
// Unless required by applicable law or agreed to in writing, software distributed
11
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
12
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
13
// specific language governing permissions and limitations under the License.
14

15
#ifndef NCNN_VULKAN_ACTIVATION_COMP
16
#define NCNN_VULKAN_ACTIVATION_COMP
17

18
afp activation_afp(afp v, int activation_type, float activation_param_0, float activation_param_1)
19
{
20
    if (activation_type == 1)
21
    {
22
        v = max(v, afp(0.f));
23
    }
24
    if (activation_type == 2)
25
    {
26
        const afp slope = afp(activation_param_0);
27
        v = v < afp(0.f) ? v * slope : v;
28
    }
29
    if (activation_type == 3)
30
    {
31
        const afp const_min = afp(activation_param_0);
32
        const afp const_max = afp(activation_param_1);
33
        v = clamp(v, const_min, const_max);
34
    }
35
    if (activation_type == 4)
36
    {
37
        v = afp(1.f) / (afp(1.f) + exp(-v));
38
    }
39
    if (activation_type == 5)
40
    {
41
#if NCNN_moltenvk
42
        v = v * afp(tanh(float(log(exp(v) + afp(1.f)))));
43
#else
44
        v = v * tanh(log(exp(v) + afp(1.f)));
45
#endif
46
    }
47
    if (activation_type == 6)
48
    {
49
        const afp alpha = afp(activation_param_0);
50
        const afp beta = afp(activation_param_1);
51
        v = v * clamp(v * afp(alpha) + afp(beta), afp(0.f), afp(1.f));
52
    }
53

54
    return v;
55
}
56

57
afpvec4 activation_afpvec4(afpvec4 v, int activation_type, float activation_param_0, float activation_param_1)
58
{
59
    if (activation_type == 1)
60
    {
61
        v = max(v, afp(0.f));
62
    }
63
    if (activation_type == 2)
64
    {
65
        const afp slope = afp(activation_param_0);
66
        v = mix(v, v * afp(slope), lessThan(v, afpvec4(0.f)));
67
    }
68
    if (activation_type == 3)
69
    {
70
        const afp const_min = afp(activation_param_0);
71
        const afp const_max = afp(activation_param_1);
72
        v = clamp(v, const_min, const_max);
73
    }
74
    if (activation_type == 4)
75
    {
76
        v = afp(1.f) / (afp(1.f) + exp(-v));
77
    }
78
    if (activation_type == 5)
79
    {
80
#if NCNN_moltenvk
81
        v = v * afpvec4(tanh(vec4(log(exp(v) + afp(1.f)))));
82
#else
83
        v = v * tanh(log(exp(v) + afp(1.f)));
84
#endif
85
    }
86
    if (activation_type == 6)
87
    {
88
        const afp alpha = afp(activation_param_0);
89
        const afp beta = afp(activation_param_1);
90
        v = v * clamp(v * afp(alpha) + afp(beta), afp(0.f), afp(1.f));
91
    }
92

93
    return v;
94
}
95

96
afpvec8 activation_afpvec8(afpvec8 v, int activation_type, float activation_param_0, float activation_param_1)
97
{
98
    if (activation_type == 1)
99
    {
100
        v[0] = max(v[0], afp(0.f));
101
        v[1] = max(v[1], afp(0.f));
102
    }
103
    if (activation_type == 2)
104
    {
105
        const afp slope = afp(activation_param_0);
106
        v[0] = mix(v[0], v[0] * afp(slope), lessThan(v[0], afpvec4(0.f)));
107
        v[1] = mix(v[1], v[1] * afp(slope), lessThan(v[1], afpvec4(0.f)));
108
    }
109
    if (activation_type == 3)
110
    {
111
        const afp const_min = afp(activation_param_0);
112
        const afp const_max = afp(activation_param_1);
113
        v[0] = clamp(v[0], const_min, const_max);
114
        v[1] = clamp(v[1], const_min, const_max);
115
    }
116
    if (activation_type == 4)
117
    {
118
        v[0] = afp(1.f) / (afp(1.f) + exp(-v[0]));
119
        v[1] = afp(1.f) / (afp(1.f) + exp(-v[1]));
120
    }
121
    if (activation_type == 5)
122
    {
123
#if NCNN_moltenvk
124
        v[0] = v[0] * afpvec4(tanh(vec4(log(exp(v[0]) + afp(1.f)))));
125
        v[1] = v[1] * afpvec4(tanh(vec4(log(exp(v[1]) + afp(1.f)))));
126
#else
127
        v[0] = v[0] * tanh(log(exp(v[0]) + afp(1.f)));
128
        v[1] = v[1] * tanh(log(exp(v[1]) + afp(1.f)));
129
#endif
130
    }
131
    if (activation_type == 6)
132
    {
133
        const afp alpha = afp(activation_param_0);
134
        const afp beta = afp(activation_param_1);
135
        v[0] = v[0] * clamp(v[0] * afp(alpha) + afp(beta), afp(0.f), afp(1.f));
136
        v[1] = v[1] * clamp(v[1] * afp(alpha) + afp(beta), afp(0.f), afp(1.f));
137
    }
138

139
    return v;
140
}
141

142
#endif // NCNN_VULKAN_ACTIVATION_COMP
143

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.