ncnn

Форк
0
/
softmax_mips.cpp 
174 строки · 4.1 Кб
1
// Leo is pleased to support the open source community by making ncnn available.
2
//
3
// Copyright (C) 2020 Leo <leo@nullptr.com.cn>. All rights reserved.
4
//
5
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
6
// in compliance with the License. You may obtain a copy of the License at
7
//
8
// https://opensource.org/licenses/BSD-3-Clause
9
//
10
// Unless required by applicable law or agreed to in writing, software distributed
11
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
12
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
13
// specific language governing permissions and limitations under the License.
14

15
#include "softmax_mips.h"
16

17
#include <float.h>
18

19
#if __mips_msa
20
#include <msa.h>
21
#include "msa_mathfun.h"
22
#endif // __mips_msa
23

24
namespace ncnn {
25

26
int Softmax_mips::forward_inplace(Mat& bottom_top_blob, const Option& opt) const
27
{
28
    int dims = bottom_top_blob.dims;
29
    size_t elemsize = bottom_top_blob.elemsize;
30
    int positive_axis = axis < 0 ? dims + axis : axis;
31

32
    if (dims != 3 || positive_axis != 0)
33
        return Softmax::forward_inplace(bottom_top_blob, opt);
34

35
    // value = exp( value - global max value )
36
    // sum all value
37
    // value = value / sum
38

39
    int w = bottom_top_blob.w;
40
    int h = bottom_top_blob.h;
41
    int channels = bottom_top_blob.c;
42
    int size = w * h;
43

44
    Mat max;
45
    max.create(w, h, elemsize, opt.workspace_allocator);
46
    if (max.empty())
47
        return -100;
48
    max.fill(-FLT_MAX);
49
    for (int q = 0; q < channels; q++)
50
    {
51
        float* ptr = bottom_top_blob.channel(q);
52
        float* maxptr = max;
53

54
        for (int i = 0; i < size; i++)
55
        {
56
            maxptr[i] = std::max(maxptr[i], ptr[i]);
57
        }
58
    }
59

60
    #pragma omp parallel for num_threads(opt.num_threads)
61
    for (int q = 0; q < channels; q++)
62
    {
63
        float* ptr = bottom_top_blob.channel(q);
64
        float* maxptr = max;
65

66
#if __mips_msa
67
        int nn = size >> 2;
68
        int remain = size - (nn << 2);
69
#else
70
        int remain = size;
71
#endif // __mips_msa
72

73
#if __mips_msa
74
        for (; nn > 0; nn--)
75
        {
76
            v4f32 _p = (v4f32)__msa_ld_w(ptr, 0);
77
            v4f32 _max = (v4f32)__msa_ld_w(maxptr, 0);
78

79
            _p = exp_ps(__msa_fsub_w(_p, _max));
80

81
            __msa_st_w((v4i32)_p, ptr, 0);
82

83
            ptr += 4;
84
            maxptr += 4;
85
        }
86
#endif // __mips_msa
87

88
        for (; remain > 0; remain--)
89
        {
90
            *ptr = exp(*ptr - *maxptr);
91

92
            ptr++;
93
            maxptr++;
94
        }
95
    }
96

97
    Mat sum;
98
    sum.create(w, h, elemsize, opt.workspace_allocator);
99
    if (sum.empty())
100
        return -100;
101
    sum.fill(0.f);
102
    for (int q = 0; q < channels; q++)
103
    {
104
        float* ptr = bottom_top_blob.channel(q);
105
        float* sumptr = sum;
106

107
#if __mips_msa
108
        int nn = size >> 2;
109
        int remain = size - (nn << 2);
110
#else
111
        int remain = size;
112
#endif // __mips_msa
113

114
#if __mips_msa
115
        for (; nn > 0; nn--)
116
        {
117
            v4f32 _p = (v4f32)__msa_ld_w(ptr, 0);
118
            v4f32 _sum = (v4f32)__msa_ld_w(sumptr, 0);
119
            _sum = __msa_fadd_w(_sum, _p);
120
            __msa_st_w((v4i32)_sum, sumptr, 0);
121

122
            ptr += 4;
123
            sumptr += 4;
124
        }
125
#endif // __mips_msa
126

127
        for (; remain > 0; remain--)
128
        {
129
            *sumptr += *ptr;
130

131
            ptr++;
132
            sumptr++;
133
        }
134
    }
135

136
    #pragma omp parallel for num_threads(opt.num_threads)
137
    for (int q = 0; q < channels; q++)
138
    {
139
        float* ptr = bottom_top_blob.channel(q);
140
        float* sumptr = sum;
141

142
#if __mips_msa
143
        int nn = size >> 2;
144
        int remain = size - (nn << 2);
145
#else
146
        int remain = size;
147
#endif // __mips_msa
148

149
#if __mips_msa
150
        for (; nn > 0; nn--)
151
        {
152
            v4f32 _p = (v4f32)__msa_ld_w(ptr, 0);
153
            v4f32 _sum = (v4f32)__msa_ld_w(sumptr, 0);
154
            _p = __msa_fdiv_w(_p, _sum);
155
            __msa_st_w((v4i32)_p, ptr, 0);
156

157
            ptr += 4;
158
            sumptr += 4;
159
        }
160
#endif // __mips_msa
161

162
        for (; remain > 0; remain--)
163
        {
164
            *ptr /= *sumptr;
165

166
            ptr++;
167
            sumptr++;
168
        }
169
    }
170

171
    return 0;
172
}
173

174
} // namespace ncnn
175

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.