ncnn

Форк
0
/
tanh_riscv.cpp 
150 строк · 3.9 Кб
1
// Tencent is pleased to support the open source community by making ncnn available.
2
//
3
// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved.
4
//
5
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
6
// in compliance with the License. You may obtain a copy of the License at
7
//
8
// https://opensource.org/licenses/BSD-3-Clause
9
//
10
// Unless required by applicable law or agreed to in writing, software distributed
11
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
12
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
13
// specific language governing permissions and limitations under the License.
14

15
#include "tanh_riscv.h"
16

17
#if __riscv_vector
18
#include <riscv_vector.h>
19
#include "rvv_mathfun.h"
20
#include "rvv_mathfun_fp16s.h"
21
#endif // __riscv_vector
22

23
namespace ncnn {
24

25
TanH_riscv::TanH_riscv()
26
{
27
#if __riscv_vector
28
    support_packing = true;
29
#if __riscv_zfh
30
    support_fp16_storage = true;
31
#endif
32
#endif // __riscv_vector
33
}
34

35
int TanH_riscv::forward_inplace(Mat& bottom_top_blob, const Option& opt) const
36
{
37
#if __riscv_vector && __riscv_zfh
38
    int elembits = bottom_top_blob.elembits();
39

40
    if (opt.use_fp16_storage && elembits == 16)
41
    {
42
        if (opt.use_fp16_arithmetic)
43
            return forward_inplace_fp16sa(bottom_top_blob, opt);
44
        else
45
            return forward_inplace_fp16s(bottom_top_blob, opt);
46
    }
47
#endif
48

49
    int w = bottom_top_blob.w;
50
    int h = bottom_top_blob.h;
51
    int d = bottom_top_blob.d;
52
    int channels = bottom_top_blob.c;
53
    int elempack = bottom_top_blob.elempack;
54
    int size = w * h * d * elempack;
55

56
    #pragma omp parallel for num_threads(opt.num_threads)
57
    for (int q = 0; q < channels; q++)
58
    {
59
        float* ptr = bottom_top_blob.channel(q);
60

61
#if __riscv_vector
62
        int n = size;
63
        while (n > 0)
64
        {
65
            size_t vl = vsetvl_e32m8(n);
66

67
            vfloat32m8_t _p = vle32_v_f32m8(ptr, vl);
68
            _p = tanh_ps(_p, vl);
69
            vse32_v_f32m8(ptr, _p, vl);
70

71
            ptr += vl;
72
            n -= vl;
73
        }
74
#else  // __riscv_vector
75
        for (int i = 0; i < size; i++)
76
        {
77
            *ptr = tanh(*ptr);
78
            ptr++;
79
        }
80
#endif // __riscv_vector
81
    }
82

83
    return 0;
84
}
85

86
#if __riscv_vector && __riscv_zfh
87
int TanH_riscv::forward_inplace_fp16s(Mat& bottom_top_blob, const Option& opt) const
88
{
89
    int w = bottom_top_blob.w;
90
    int h = bottom_top_blob.h;
91
    int d = bottom_top_blob.d;
92
    int channels = bottom_top_blob.c;
93
    int elempack = bottom_top_blob.elempack;
94
    int size = w * h * d * elempack;
95

96
    #pragma omp parallel for num_threads(opt.num_threads)
97
    for (int q = 0; q < channels; q++)
98
    {
99
        __fp16* ptr = bottom_top_blob.channel(q);
100

101
        int n = size;
102
        while (n > 0)
103
        {
104
            size_t vl = vsetvl_e16m4(n);
105

106
            vfloat32m8_t _p = vfwcvt_f_f_v_f32m8(vle16_v_f16m4(ptr, vl), vl);
107
            _p = tanh_ps(_p, vl);
108
            vse16_v_f16m4(ptr, vfncvt_f_f_w_f16m4(_p, vl), vl);
109

110
            ptr += vl;
111
            n -= vl;
112
        }
113
    }
114

115
    return 0;
116
}
117

118
int TanH_riscv::forward_inplace_fp16sa(Mat& bottom_top_blob, const Option& opt) const
119
{
120
    int w = bottom_top_blob.w;
121
    int h = bottom_top_blob.h;
122
    int d = bottom_top_blob.d;
123
    int channels = bottom_top_blob.c;
124
    int elempack = bottom_top_blob.elempack;
125
    int size = w * h * d * elempack;
126

127
    #pragma omp parallel for num_threads(opt.num_threads)
128
    for (int q = 0; q < channels; q++)
129
    {
130
        __fp16* ptr = bottom_top_blob.channel(q);
131

132
        int n = size;
133
        while (n > 0)
134
        {
135
            size_t vl = vsetvl_e16m8(n);
136

137
            vfloat16m8_t _p = vle16_v_f16m8(ptr, vl);
138
            _p = tanh_ps(_p, vl);
139
            vse16_v_f16m8(ptr, _p, vl);
140

141
            ptr += vl;
142
            n -= vl;
143
        }
144
    }
145

146
    return 0;
147
}
148
#endif // __riscv_vector && __riscv_zfh
149

150
} // namespace ncnn
151

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.