ncnn

Форк
0
/
rnn.cpp 
441 строка · 13.3 Кб
1
// Tencent is pleased to support the open source community by making ncnn available.
2
//
3
// Copyright (C) 2017 THL A29 Limited, a Tencent company. All rights reserved.
4
//
5
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
6
// in compliance with the License. You may obtain a copy of the License at
7
//
8
// https://opensource.org/licenses/BSD-3-Clause
9
//
10
// Unless required by applicable law or agreed to in writing, software distributed
11
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
12
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
13
// specific language governing permissions and limitations under the License.
14

15
#include "rnn.h"
16

17
namespace ncnn {
18

19
RNN::RNN()
20
{
21
    one_blob_only = false;
22
    support_inplace = false;
23
}
24

25
int RNN::load_param(const ParamDict& pd)
26
{
27
    num_output = pd.get(0, 0);
28
    weight_data_size = pd.get(1, 0);
29
    direction = pd.get(2, 0);
30
    int8_scale_term = pd.get(8, 0);
31

32
    if (int8_scale_term)
33
    {
34
#if !NCNN_INT8
35
        NCNN_LOGE("please build ncnn with NCNN_INT8 enabled for int8 inference");
36
        return -1;
37
#endif
38
    }
39

40
    return 0;
41
}
42

43
int RNN::load_model(const ModelBin& mb)
44
{
45
    int num_directions = direction == 2 ? 2 : 1;
46

47
    int size = weight_data_size / num_directions / num_output;
48

49
    // raw weight data
50
    weight_xc_data = mb.load(size, num_output, num_directions, 0);
51
    if (weight_xc_data.empty())
52
        return -100;
53

54
    bias_c_data = mb.load(num_output, 1, num_directions, 0);
55
    if (bias_c_data.empty())
56
        return -100;
57

58
    weight_hc_data = mb.load(num_output, num_output, num_directions, 0);
59
    if (weight_hc_data.empty())
60
        return -100;
61

62
#if NCNN_INT8
63
    if (int8_scale_term)
64
    {
65
        weight_xc_data_int8_scales = mb.load(num_output, num_directions, 1);
66
        weight_hc_data_int8_scales = mb.load(num_output, num_directions, 1);
67
    }
68
#endif // NCNN_INT8
69

70
    return 0;
71
}
72

73
static int rnn(const Mat& bottom_blob, Mat& top_blob, int reverse, const Mat& weight_xc, const Mat& bias_c, const Mat& weight_hc, Mat& hidden_state, const Option& opt)
74
{
75
    int size = bottom_blob.w;
76
    int T = bottom_blob.h;
77

78
    int num_output = top_blob.w;
79

80
    // num_output
81
    Mat gates(num_output, 4u, opt.workspace_allocator);
82
    if (gates.empty())
83
        return -100;
84

85
    // unroll
86
    for (int t = 0; t < T; t++)
87
    {
88
        int ti = reverse ? T - 1 - t : t;
89

90
        const float* x = bottom_blob.row(ti);
91
        #pragma omp parallel for num_threads(opt.num_threads)
92
        for (int q = 0; q < num_output; q++)
93
        {
94
            const float* weight_xc_ptr = weight_xc.row(q);
95
            const float* weight_hc_ptr = weight_hc.row(q);
96

97
            float H = bias_c[q];
98

99
            for (int i = 0; i < size; i++)
100
            {
101
                H += weight_xc_ptr[i] * x[i];
102
            }
103

104
            for (int i = 0; i < num_output; i++)
105
            {
106
                H += weight_hc_ptr[i] * hidden_state[i];
107
            }
108

109
            H = tanhf(H);
110

111
            gates[q] = H;
112
        }
113

114
        float* output_data = top_blob.row(ti);
115
        #pragma omp parallel for num_threads(opt.num_threads)
116
        for (int q = 0; q < num_output; q++)
117
        {
118
            float H = gates[q];
119

120
            hidden_state[q] = H;
121
            output_data[q] = H;
122
        }
123
    }
124

125
    return 0;
126
}
127

128
#if NCNN_INT8
129
static int rnn_int8(const Mat& bottom_blob, Mat& top_blob, int reverse, const Mat& weight_xc_int8, const float* weight_xc_int8_scales, const Mat& bias_c, const Mat& weight_hc_int8, const float* weight_hc_int8_scales, Mat& hidden_state, const Option& opt)
130
{
131
    int size = bottom_blob.w;
132
    int T = bottom_blob.h;
133

134
    int num_output = top_blob.w;
135

136
    // num_output
137
    Mat gates(num_output, 4u, opt.workspace_allocator);
138
    if (gates.empty())
139
        return -100;
140

141
    // dynamic quantize bottom_blob
142
    Mat bottom_blob_int8(size, T, (size_t)1u, 1, opt.workspace_allocator);
143
    Mat bottom_blob_int8_scales(T, (size_t)4u, 1, opt.workspace_allocator);
144
    {
145
        for (int t = 0; t < T; t++)
146
        {
147
            const float* x = bottom_blob.row(t);
148

149
            float absmax = 0.f;
150
            for (int i = 0; i < size; i++)
151
            {
152
                absmax = std::max(absmax, (float)fabs(x[i]));
153
            }
154

155
            bottom_blob_int8_scales[t] = 127.f / absmax;
156
        }
157

158
        Option opt_quant = opt;
159
        opt_quant.blob_allocator = opt.workspace_allocator;
160
        opt_quant.use_packing_layout = false;
161
        quantize_to_int8(bottom_blob, bottom_blob_int8, bottom_blob_int8_scales, opt_quant);
162
    }
163

164
    Mat hidden_state_int8(num_output, (size_t)1u, 1, opt.workspace_allocator);
165
    Mat hidden_state_int8_scales(1, (size_t)4u, 1, opt.workspace_allocator);
166

167
    // unroll
168
    for (int t = 0; t < T; t++)
169
    {
170
        int ti = reverse ? T - 1 - t : t;
171

172
        // dynamic quantize hidden_state
173
        {
174
            float absmax = 0.f;
175
            for (int i = 0; i < num_output; i++)
176
            {
177
                absmax = std::max(absmax, (float)fabs(hidden_state[i]));
178
            }
179

180
            if (absmax == 0.f)
181
            {
182
                hidden_state_int8_scales[0] = 1.f;
183
                hidden_state_int8.fill<signed char>(0);
184
            }
185
            else
186
            {
187
                hidden_state_int8_scales[0] = 127.f / absmax;
188

189
                Option opt_quant = opt;
190
                opt_quant.blob_allocator = opt.workspace_allocator;
191
                opt_quant.use_packing_layout = false;
192
                quantize_to_int8(hidden_state, hidden_state_int8, hidden_state_int8_scales, opt_quant);
193
            }
194
        }
195

196
        const signed char* x = bottom_blob_int8.row<const signed char>(ti);
197
        const signed char* hs = hidden_state_int8;
198
        const float descale_x = 1.f / bottom_blob_int8_scales[ti];
199
        const float descale_h = 1.f / hidden_state_int8_scales[0];
200
        #pragma omp parallel for num_threads(opt.num_threads)
201
        for (int q = 0; q < num_output; q++)
202
        {
203
            const signed char* weight_xc_int8_ptr = weight_xc_int8.row<const signed char>(q);
204
            const signed char* weight_hc_int8_ptr = weight_hc_int8.row<const signed char>(q);
205

206
            const float descale_xc = 1.f / weight_xc_int8_scales[q];
207
            const float descale_hc = 1.f / weight_hc_int8_scales[q];
208

209
            int Hx = 0;
210
            for (int i = 0; i < size; i++)
211
            {
212
                Hx += weight_xc_int8_ptr[i] * x[i];
213
            }
214

215
            int Hh = 0;
216
            for (int i = 0; i < num_output; i++)
217
            {
218
                Hh += weight_hc_int8_ptr[i] * hs[i];
219
            }
220

221
            float H = bias_c[q] + Hx * (descale_x * descale_xc) + Hh * (descale_h * descale_hc);
222

223
            H = tanhf(H);
224

225
            gates[q] = H;
226
        }
227

228
        float* output_data = top_blob.row(ti);
229
        #pragma omp parallel for num_threads(opt.num_threads)
230
        for (int q = 0; q < num_output; q++)
231
        {
232
            float H = gates[q];
233

234
            hidden_state[q] = H;
235
            output_data[q] = H;
236
        }
237
    }
238

239
    return 0;
240
}
241
#endif // NCNN_INT8
242

243
int RNN::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const
244
{
245
    int T = bottom_blob.h;
246

247
    int num_directions = direction == 2 ? 2 : 1;
248

249
    // initial hidden state
250
    Mat hidden(num_output, 4u, opt.workspace_allocator);
251
    if (hidden.empty())
252
        return -100;
253
    hidden.fill(0.f);
254

255
    top_blob.create(num_output * num_directions, T, 4u, opt.blob_allocator);
256
    if (top_blob.empty())
257
        return -100;
258

259
    // Uni directional
260
    if (direction == 0 || direction == 1)
261
    {
262
#if NCNN_INT8
263
        if (int8_scale_term)
264
        {
265
            int ret = rnn_int8(bottom_blob, top_blob, direction, weight_xc_data.channel(0), weight_xc_data_int8_scales.row(0), bias_c_data.channel(0), weight_hc_data.channel(0), weight_hc_data_int8_scales.row(0), hidden, opt);
266
            if (ret != 0)
267
                return ret;
268
        }
269
        else
270
#endif
271
        {
272
            int ret = rnn(bottom_blob, top_blob, direction, weight_xc_data.channel(0), bias_c_data.channel(0), weight_hc_data.channel(0), hidden, opt);
273
            if (ret != 0)
274
                return ret;
275
        }
276
    }
277

278
    if (direction == 2)
279
    {
280
        Mat top_blob_forward(num_output, T, 4u, opt.workspace_allocator);
281
        if (top_blob_forward.empty())
282
            return -100;
283

284
        Mat top_blob_reverse(num_output, T, 4u, opt.workspace_allocator);
285
        if (top_blob_reverse.empty())
286
            return -100;
287

288
#if NCNN_INT8
289
        if (int8_scale_term)
290
        {
291
            int ret = rnn_int8(bottom_blob, top_blob_forward, 0, weight_xc_data.channel(0), weight_xc_data_int8_scales.row(0), bias_c_data.channel(0), weight_hc_data.channel(0), weight_hc_data_int8_scales.row(0), hidden, opt);
292
            if (ret != 0)
293
                return ret;
294
        }
295
        else
296
#endif
297
        {
298
            int ret = rnn(bottom_blob, top_blob_forward, 0, weight_xc_data.channel(0), bias_c_data.channel(0), weight_hc_data.channel(0), hidden, opt);
299
            if (ret != 0)
300
                return ret;
301
        }
302

303
        hidden.fill(0.0f);
304

305
#if NCNN_INT8
306
        if (int8_scale_term)
307
        {
308
            int ret = rnn_int8(bottom_blob, top_blob_reverse, 1, weight_xc_data.channel(1), weight_xc_data_int8_scales.row(1), bias_c_data.channel(1), weight_hc_data.channel(1), weight_hc_data_int8_scales.row(1), hidden, opt);
309
            if (ret != 0)
310
                return ret;
311
        }
312
        else
313
#endif
314
        {
315
            int ret = rnn(bottom_blob, top_blob_reverse, 1, weight_xc_data.channel(1), bias_c_data.channel(1), weight_hc_data.channel(1), hidden, opt);
316
            if (ret != 0)
317
                return ret;
318
        }
319

320
        // concat w
321
        for (int i = 0; i < T; i++)
322
        {
323
            const float* pf = top_blob_forward.row(i);
324
            const float* pr = top_blob_reverse.row(i);
325
            float* ptr = top_blob.row(i);
326

327
            memcpy(ptr, pf, num_output * sizeof(float));
328
            memcpy(ptr + num_output, pr, num_output * sizeof(float));
329
        }
330
    }
331

332
    return 0;
333
}
334

335
int RNN::forward(const std::vector<Mat>& bottom_blobs, std::vector<Mat>& top_blobs, const Option& opt) const
336
{
337
    const Mat& bottom_blob = bottom_blobs[0];
338
    int T = bottom_blob.h;
339
    int num_directions = direction == 2 ? 2 : 1;
340

341
    Mat hidden;
342
    Allocator* hidden_allocator = top_blobs.size() == 2 ? opt.blob_allocator : opt.workspace_allocator;
343
    if (bottom_blobs.size() == 2)
344
    {
345
        hidden = bottom_blobs[1].clone(hidden_allocator);
346
    }
347
    else
348
    {
349
        hidden.create(num_output, num_directions, 4u, hidden_allocator);
350
        if (hidden.empty())
351
            return -100;
352
        hidden.fill(0.f);
353
    }
354

355
    Mat& top_blob = top_blobs[0];
356
    top_blob.create(num_output * num_directions, T, 4u, opt.blob_allocator);
357
    if (top_blob.empty())
358
        return -100;
359

360
    // Uni directional
361
    if (direction == 0 || direction == 1)
362
    {
363
#if NCNN_INT8
364
        if (int8_scale_term)
365
        {
366
            int ret = rnn_int8(bottom_blob, top_blob, direction, weight_xc_data.channel(0), weight_xc_data_int8_scales.row(0), bias_c_data.channel(0), weight_hc_data.channel(0), weight_hc_data_int8_scales.row(0), hidden, opt);
367
            if (ret != 0)
368
                return ret;
369
        }
370
        else
371
#endif
372
        {
373
            int ret = rnn(bottom_blob, top_blob, direction, weight_xc_data.channel(0), bias_c_data.channel(0), weight_hc_data.channel(0), hidden, opt);
374
            if (ret != 0)
375
                return ret;
376
        }
377
    }
378

379
    if (direction == 2)
380
    {
381
        Mat top_blob_forward(num_output, T, 4u, opt.workspace_allocator);
382
        if (top_blob_forward.empty())
383
            return -100;
384

385
        Mat top_blob_reverse(num_output, T, 4u, opt.workspace_allocator);
386
        if (top_blob_reverse.empty())
387
            return -100;
388

389
        Mat hidden0 = hidden.row_range(0, 1);
390
#if NCNN_INT8
391
        if (int8_scale_term)
392
        {
393
            int ret = rnn_int8(bottom_blob, top_blob_forward, 0, weight_xc_data.channel(0), weight_xc_data_int8_scales.row(0), bias_c_data.channel(0), weight_hc_data.channel(0), weight_hc_data_int8_scales.row(0), hidden0, opt);
394
            if (ret != 0)
395
                return ret;
396
        }
397
        else
398
#endif
399
        {
400
            int ret = rnn(bottom_blob, top_blob_forward, 0, weight_xc_data.channel(0), bias_c_data.channel(0), weight_hc_data.channel(0), hidden0, opt);
401
            if (ret != 0)
402
                return ret;
403
        }
404

405
        Mat hidden1 = hidden.row_range(1, 1);
406
#if NCNN_INT8
407
        if (int8_scale_term)
408
        {
409
            int ret = rnn_int8(bottom_blob, top_blob_reverse, 1, weight_xc_data.channel(1), weight_xc_data_int8_scales.row(1), bias_c_data.channel(1), weight_hc_data.channel(1), weight_hc_data_int8_scales.row(1), hidden1, opt);
410
            if (ret != 0)
411
                return ret;
412
        }
413
        else
414
#endif
415
        {
416
            int ret = rnn(bottom_blob, top_blob_reverse, 1, weight_xc_data.channel(1), bias_c_data.channel(1), weight_hc_data.channel(1), hidden1, opt);
417
            if (ret != 0)
418
                return ret;
419
        }
420

421
        // concat w
422
        for (int i = 0; i < T; i++)
423
        {
424
            const float* pf = top_blob_forward.row(i);
425
            const float* pr = top_blob_reverse.row(i);
426
            float* ptr = top_blob.row(i);
427

428
            memcpy(ptr, pf, num_output * sizeof(float));
429
            memcpy(ptr + num_output, pr, num_output * sizeof(float));
430
        }
431
    }
432

433
    if (top_blobs.size() == 2)
434
    {
435
        top_blobs[1] = hidden;
436
    }
437

438
    return 0;
439
}
440

441
} // namespace ncnn
442

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.