ncnn

Форк
0
/
cumulativesum.cpp 
171 строка · 4.3 Кб
1
// Copyright (c) 2023 Xiaomi Corp.        (author: Fangjun Kuang)
2
//
3
// Licensed under the BSD 3-Clause License (the "License"); you may not use this
4
// file except in compliance with the License. You may obtain a copy of the
5
// License at
6
//
7
// https://opensource.org/licenses/BSD-3-Clause
8
//
9
// Unless required by applicable law or agreed to in writing, software
10
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12
// License for the specific language governing permissions and limitations under
13
// the License.
14

15
#include "cumulativesum.h"
16

17
namespace ncnn {
18

19
CumulativeSum::CumulativeSum()
20
{
21
    one_blob_only = true;
22
    support_inplace = true;
23
}
24

25
int CumulativeSum::load_param(const ParamDict& pd)
26
{
27
    axis = pd.get(0, 0);
28

29
    return 0;
30
}
31

32
int CumulativeSum::forward_inplace(Mat& bottom_top_blob, const Option& opt) const
33
{
34
    int dims = bottom_top_blob.dims;
35
    int positive_axis = axis < 0 ? dims + axis : axis;
36

37
    if (dims == 1)
38
    {   // ignore axis
39
        int w = bottom_top_blob.w;
40

41
        float* ptr = bottom_top_blob;
42

43
        for (int i = 1; i < w; ++i)
44
        {
45
            ptr[i] = ptr[i] + ptr[i - 1];
46
        }
47

48
        return 0;
49
    } // if (dims == 1)
50

51
    if (dims == 2 && positive_axis == 0)
52
    {
53
        // sum over rows
54
        int w = bottom_top_blob.w;
55
        int h = bottom_top_blob.h;
56

57
        for (int i = 1; i < h; ++i)
58
        {
59
            const float* prev_row = bottom_top_blob.row(i - 1);
60
            float* this_row = bottom_top_blob.row(i);
61

62
            for (int k = 0; k < w; ++k)
63
            {
64
                this_row[k] = this_row[k] + prev_row[k];
65
            }
66
        }
67

68
        return 0;
69
    } // if (dims == 2 && positive_axis == 0)
70

71
    if (dims == 2 && positive_axis == 1)
72
    {
73
        // sum over columns
74
        int w = bottom_top_blob.w;
75
        int h = bottom_top_blob.h;
76

77
        #pragma omp parallel for num_threads(opt.num_threads)
78
        for (int i = 0; i < h; ++i)
79
        {
80
            float* ptr = bottom_top_blob.row(i);
81

82
            for (int k = 1; k < w; ++k)
83
            {
84
                ptr[k] = ptr[k] + ptr[k - 1];
85
            }
86
        }
87

88
        return 0;
89
    } // if (dims == 2 && positive_axis == 1)
90

91
    if (dims == 3 && positive_axis == 0)
92
    {
93
        // sum over channels
94
        int w = bottom_top_blob.w;
95
        int h = bottom_top_blob.h;
96
        int c = bottom_top_blob.c;
97

98
        int size = w * h;
99

100
        for (int i = 1; i < c; ++i)
101
        {
102
            const float* prev = bottom_top_blob.channel(i - 1);
103
            float* cur = bottom_top_blob.channel(i);
104

105
            for (int k = 0; k < size; ++k)
106
            {
107
                cur[k] = cur[k] + prev[k];
108
            }
109
        }
110

111
        return 0;
112
    } // if (dims == 3 && positive_axis == 0)
113

114
    if (dims == 3 && positive_axis == 1)
115
    {
116
        // sum over rows within each channel
117

118
        int w = bottom_top_blob.w;
119
        int h = bottom_top_blob.h;
120
        int c = bottom_top_blob.c;
121

122
        #pragma omp parallel for num_threads(opt.num_threads)
123
        for (int q = 0; q < c; ++q)
124
        {
125
            Mat this_channel = bottom_top_blob.channel(q);
126

127
            for (int i = 1; i < h; ++i)
128
            {
129
                const float* prev_row = this_channel.row(i - 1);
130
                float* this_row = this_channel.row(i);
131

132
                for (int k = 0; k < w; ++k)
133
                {
134
                    this_row[k] = this_row[k] + prev_row[k];
135
                }
136
            }
137
        }
138

139
        return 0;
140
    } // if (dims == 3 && positive_axis == 1)
141

142
    if (dims == 3 && positive_axis == 2)
143
    {
144
        // sum over columns within each channel
145

146
        int w = bottom_top_blob.w;
147
        int h = bottom_top_blob.h;
148
        int c = bottom_top_blob.c;
149

150
        #pragma omp parallel for num_threads(opt.num_threads)
151
        for (int q = 0; q < c; ++q)
152
        {
153
            Mat this_channel = bottom_top_blob.channel(q);
154

155
            for (int i = 0; i < h; ++i)
156
            {
157
                float* ptr = this_channel.row(i);
158
                for (int k = 1; k < w; ++k)
159
                {
160
                    ptr[k] = ptr[k] + ptr[k - 1];
161
                }
162
            }
163
        }
164

165
        return 0;
166
    } // if (dims == 3 && positive_axis == 2)
167

168
    return -100;
169
}
170

171
} // namespace ncnn
172

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.