ncnn

Форк
0
/
sigmoid_riscv.cpp 
151 строка · 3.9 Кб
1
// Tencent is pleased to support the open source community by making ncnn available.
2
//
3
// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved.
4
//
5
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
6
// in compliance with the License. You may obtain a copy of the License at
7
//
8
// https://opensource.org/licenses/BSD-3-Clause
9
//
10
// Unless required by applicable law or agreed to in writing, software distributed
11
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
12
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
13
// specific language governing permissions and limitations under the License.
14

15
#include "sigmoid_riscv.h"
16

17
#if __riscv_vector
18
#include <riscv_vector.h>
19
#include "rvv_mathfun.h"
20
#include "rvv_mathfun_fp16s.h"
21
#endif // __riscv_vector
22

23
namespace ncnn {
24

25
Sigmoid_riscv::Sigmoid_riscv()
26
{
27
#if __riscv_vector
28
    support_packing = true;
29
#if __riscv_zfh
30
    support_fp16_storage = true;
31
#endif
32
#endif // __riscv_vector
33
}
34

35
int Sigmoid_riscv::forward_inplace(Mat& bottom_top_blob, const Option& opt) const
36
{
37
#if __riscv_vector && __riscv_zfh
38
    int elembits = bottom_top_blob.elembits();
39

40
    if (opt.use_fp16_storage && elembits == 16)
41
    {
42
        if (opt.use_fp16_arithmetic)
43
            return forward_inplace_fp16sa(bottom_top_blob, opt);
44
        else
45
            return forward_inplace_fp16s(bottom_top_blob, opt);
46
    }
47
#endif
48

49
    int w = bottom_top_blob.w;
50
    int h = bottom_top_blob.h;
51
    int d = bottom_top_blob.d;
52
    int channels = bottom_top_blob.c;
53
    int elempack = bottom_top_blob.elempack;
54
    int size = w * h * d * elempack;
55

56
    #pragma omp parallel for num_threads(opt.num_threads)
57
    for (int q = 0; q < channels; q++)
58
    {
59
        float* ptr = bottom_top_blob.channel(q);
60

61
#if __riscv_vector
62
        int n = size;
63
        while (n > 0)
64
        {
65
            size_t vl = vsetvl_e32m8(n);
66

67
            vfloat32m8_t _p = vle32_v_f32m8(ptr, vl);
68
            _p = sigmoid_ps(_p, vl);
69
            vse32_v_f32m8(ptr, _p, vl);
70

71
            ptr += vl;
72
            n -= vl;
73
        }
74
#else  // __riscv_vector
75
        for (int i = 0; i < size; i++)
76
        {
77
            *ptr = 1.f / (1.f + exp(-*ptr));
78

79
            ptr++;
80
        }
81
#endif // __riscv_vector
82
    }
83

84
    return 0;
85
}
86

87
#if __riscv_vector && __riscv_zfh
88
int Sigmoid_riscv::forward_inplace_fp16s(Mat& bottom_top_blob, const Option& opt) const
89
{
90
    int w = bottom_top_blob.w;
91
    int h = bottom_top_blob.h;
92
    int d = bottom_top_blob.d;
93
    int channels = bottom_top_blob.c;
94
    int elempack = bottom_top_blob.elempack;
95
    int size = w * h * d * elempack;
96

97
    #pragma omp parallel for num_threads(opt.num_threads)
98
    for (int q = 0; q < channels; q++)
99
    {
100
        __fp16* ptr = bottom_top_blob.channel(q);
101

102
        int n = size;
103
        while (n > 0)
104
        {
105
            size_t vl = vsetvl_e16m4(n);
106

107
            vfloat32m8_t _p = vfwcvt_f_f_v_f32m8(vle16_v_f16m4(ptr, vl), vl);
108
            _p = sigmoid_ps(_p, vl);
109
            vse16_v_f16m4(ptr, vfncvt_f_f_w_f16m4(_p, vl), vl);
110

111
            ptr += vl;
112
            n -= vl;
113
        }
114
    }
115

116
    return 0;
117
}
118

119
int Sigmoid_riscv::forward_inplace_fp16sa(Mat& bottom_top_blob, const Option& opt) const
120
{
121
    int w = bottom_top_blob.w;
122
    int h = bottom_top_blob.h;
123
    int d = bottom_top_blob.d;
124
    int channels = bottom_top_blob.c;
125
    int elempack = bottom_top_blob.elempack;
126
    int size = w * h * d * elempack;
127

128
    #pragma omp parallel for num_threads(opt.num_threads)
129
    for (int q = 0; q < channels; q++)
130
    {
131
        __fp16* ptr = bottom_top_blob.channel(q);
132

133
        int n = size;
134
        while (n > 0)
135
        {
136
            size_t vl = vsetvl_e16m8(n);
137

138
            vfloat16m8_t _p = vle16_v_f16m8(ptr, vl);
139
            _p = sigmoid_ps(_p, vl);
140
            vse16_v_f16m8(ptr, _p, vl);
141

142
            ptr += vl;
143
            n -= vl;
144
        }
145
    }
146

147
    return 0;
148
}
149
#endif // __riscv_vector && __riscv_zfh
150

151
} // namespace ncnn
152

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.