ncnn

Форк
0
/
gru.cpp 
555 строк · 17.3 Кб
1
// Tencent is pleased to support the open source community by making ncnn available.
2
//
3
// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved.
4
//
5
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
6
// in compliance with the License. You may obtain a copy of the License at
7
//
8
// https://opensource.org/licenses/BSD-3-Clause
9
//
10
// Unless required by applicable law or agreed to in writing, software distributed
11
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
12
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
13
// specific language governing permissions and limitations under the License.
14

15
#include "gru.h"
16

17
namespace ncnn {
18

19
GRU::GRU()
20
{
21
    one_blob_only = false;
22
    support_inplace = false;
23
}
24

25
int GRU::load_param(const ParamDict& pd)
26
{
27
    num_output = pd.get(0, 0);
28
    weight_data_size = pd.get(1, 0);
29
    direction = pd.get(2, 0);
30
    int8_scale_term = pd.get(8, 0);
31

32
    if (int8_scale_term)
33
    {
34
#if !NCNN_INT8
35
        NCNN_LOGE("please build ncnn with NCNN_INT8 enabled for int8 inference");
36
        return -1;
37
#endif
38
    }
39

40
    return 0;
41
}
42

43
int GRU::load_model(const ModelBin& mb)
44
{
45
    int num_directions = direction == 2 ? 2 : 1;
46

47
    int size = weight_data_size / num_directions / num_output / 3;
48

49
    // raw weight data
50
    weight_xc_data = mb.load(size, num_output * 3, num_directions, 0);
51
    if (weight_xc_data.empty())
52
        return -100;
53

54
    bias_c_data = mb.load(num_output, 4, num_directions, 0);
55
    if (bias_c_data.empty())
56
        return -100;
57

58
    weight_hc_data = mb.load(num_output, num_output * 3, num_directions, 0);
59
    if (weight_hc_data.empty())
60
        return -100;
61

62
#if NCNN_INT8
63
    if (int8_scale_term)
64
    {
65
        weight_xc_data_int8_scales = mb.load(num_output * 3, num_directions, 1);
66
        weight_hc_data_int8_scales = mb.load(num_output * 3, num_directions, 1);
67
    }
68
#endif // NCNN_INT8
69

70
    return 0;
71
}
72

73
static int gru(const Mat& bottom_blob, Mat& top_blob, int reverse, const Mat& weight_xc, const Mat& bias_c, const Mat& weight_hc, Mat& hidden_state, const Option& opt)
74
{
75
    int size = bottom_blob.w;
76
    int T = bottom_blob.h;
77

78
    int num_output = top_blob.w;
79

80
    // 2 x num_output
81
    Mat gates(2, num_output, 4u, opt.workspace_allocator);
82
    if (gates.empty())
83
        return -100;
84

85
    // unroll
86
    for (int t = 0; t < T; t++)
87
    {
88
        int ti = reverse ? T - 1 - t : t;
89

90
        const float* x = bottom_blob.row(ti);
91
        #pragma omp parallel for num_threads(opt.num_threads)
92
        for (int q = 0; q < num_output; q++)
93
        {
94
            float* gates_data = gates.row(q);
95

96
            // gate reset update
97
            const float* bias_c_R = bias_c.row(0);
98
            const float* bias_c_U = bias_c.row(1);
99

100
            const float* weight_xc_R = weight_xc.row(num_output * 0 + q);
101
            const float* weight_xc_U = weight_xc.row(num_output * 1 + q);
102
            const float* weight_hc_R = weight_hc.row(num_output * 0 + q);
103
            const float* weight_hc_U = weight_hc.row(num_output * 1 + q);
104

105
            float R = bias_c_R[q];
106
            float U = bias_c_U[q];
107

108
            for (int i = 0; i < size; i++)
109
            {
110
                float xi = x[i];
111

112
                R += weight_xc_R[i] * xi;
113
                U += weight_xc_U[i] * xi;
114
            }
115

116
            for (int i = 0; i < num_output; i++)
117
            {
118
                float h_cont = hidden_state[i];
119

120
                R += weight_hc_R[i] * h_cont;
121
                U += weight_hc_U[i] * h_cont;
122
            }
123

124
            // sigmoid(R)
125
            // sigmoid(U)
126
            R = 1.f / (1.f + expf(-R));
127
            U = 1.f / (1.f + expf(-U));
128

129
            // gate new
130
            const float* bias_c_WN = bias_c.row(2);
131
            const float* bias_c_BN = bias_c.row(3);
132

133
            const float* weight_xc_N = weight_xc.row(num_output * 2 + q);
134
            const float* weight_hc_N = weight_hc.row(num_output * 2 + q);
135

136
            float N = bias_c_BN[q];
137

138
            for (int i = 0; i < num_output; i++)
139
            {
140
                float h_cont = hidden_state[i];
141

142
                N += weight_hc_N[i] * h_cont;
143
            }
144

145
            N = bias_c_WN[q] + R * N;
146

147
            for (int i = 0; i < size; i++)
148
            {
149
                float xi = x[i];
150

151
                N += weight_xc_N[i] * xi;
152
            }
153

154
            // tanh(N)
155
            N = tanhf(N);
156

157
            gates_data[0] = U;
158
            gates_data[1] = N;
159
        }
160

161
        // h_t := (1 - update) .* new + update .* h_{t-1}
162
        float* output_data = top_blob.row(ti);
163
        #pragma omp parallel for num_threads(opt.num_threads)
164
        for (int q = 0; q < num_output; q++)
165
        {
166
            const float* gates_data = gates.row(q);
167

168
            float U = gates_data[0];
169
            float N = gates_data[1];
170

171
            float H = (1 - U) * N + U * hidden_state[q];
172

173
            hidden_state[q] = H;
174
            output_data[q] = H;
175
        }
176
    }
177

178
    return 0;
179
}
180

181
#if NCNN_INT8
182
static int gru_int8(const Mat& bottom_blob, Mat& top_blob, int reverse, const Mat& weight_xc_int8, const float* weight_xc_int8_scales, const Mat& bias_c, const Mat& weight_hc_int8, const float* weight_hc_int8_scales, Mat& hidden_state, const Option& opt)
183
{
184
    int size = bottom_blob.w;
185
    int T = bottom_blob.h;
186

187
    int num_output = top_blob.w;
188

189
    // 2 x num_output
190
    Mat gates(2, num_output, 4u, opt.workspace_allocator);
191
    if (gates.empty())
192
        return -100;
193

194
    // dynamic quantize bottom_blob
195
    Mat bottom_blob_int8(size, T, (size_t)1u, 1, opt.workspace_allocator);
196
    Mat bottom_blob_int8_scales(T, (size_t)4u, 1, opt.workspace_allocator);
197
    {
198
        for (int t = 0; t < T; t++)
199
        {
200
            const float* x = bottom_blob.row(t);
201

202
            float absmax = 0.f;
203
            for (int i = 0; i < size; i++)
204
            {
205
                absmax = std::max(absmax, (float)fabs(x[i]));
206
            }
207

208
            bottom_blob_int8_scales[t] = 127.f / absmax;
209
        }
210

211
        Option opt_quant = opt;
212
        opt_quant.blob_allocator = opt.workspace_allocator;
213
        opt_quant.use_packing_layout = false;
214
        quantize_to_int8(bottom_blob, bottom_blob_int8, bottom_blob_int8_scales, opt_quant);
215
    }
216

217
    Mat hidden_state_int8(num_output, (size_t)1u, 1, opt.workspace_allocator);
218
    Mat hidden_state_int8_scales(1, (size_t)4u, 1, opt.workspace_allocator);
219

220
    // unroll
221
    for (int t = 0; t < T; t++)
222
    {
223
        int ti = reverse ? T - 1 - t : t;
224

225
        // dynamic quantize hidden_state
226
        {
227
            float absmax = 0.f;
228
            for (int i = 0; i < num_output; i++)
229
            {
230
                absmax = std::max(absmax, (float)fabs(hidden_state[i]));
231
            }
232

233
            if (absmax == 0.f)
234
            {
235
                hidden_state_int8_scales[0] = 1.f;
236
                hidden_state_int8.fill<signed char>(0);
237
            }
238
            else
239
            {
240
                hidden_state_int8_scales[0] = 127.f / absmax;
241

242
                Option opt_quant = opt;
243
                opt_quant.blob_allocator = opt.workspace_allocator;
244
                opt_quant.use_packing_layout = false;
245
                quantize_to_int8(hidden_state, hidden_state_int8, hidden_state_int8_scales, opt_quant);
246
            }
247
        }
248

249
        const signed char* x = bottom_blob_int8.row<const signed char>(ti);
250
        const signed char* hs = hidden_state_int8;
251
        const float descale_x = 1.f / bottom_blob_int8_scales[ti];
252
        const float descale_h = 1.f / hidden_state_int8_scales[0];
253
        #pragma omp parallel for num_threads(opt.num_threads)
254
        for (int q = 0; q < num_output; q++)
255
        {
256
            float* gates_data = gates.row(q);
257

258
            // gate reset update
259
            const float* bias_c_R = bias_c.row(0);
260
            const float* bias_c_U = bias_c.row(1);
261

262
            const signed char* weight_xc_int8_R = weight_xc_int8.row<const signed char>(num_output * 0 + q);
263
            const signed char* weight_xc_int8_U = weight_xc_int8.row<const signed char>(num_output * 1 + q);
264
            const signed char* weight_hc_int8_R = weight_hc_int8.row<const signed char>(num_output * 0 + q);
265
            const signed char* weight_hc_int8_U = weight_hc_int8.row<const signed char>(num_output * 1 + q);
266

267
            const float descale_xc_R = 1.f / weight_xc_int8_scales[num_output * 0 + q];
268
            const float descale_xc_U = 1.f / weight_xc_int8_scales[num_output * 1 + q];
269
            const float descale_hc_R = 1.f / weight_hc_int8_scales[num_output * 0 + q];
270
            const float descale_hc_U = 1.f / weight_hc_int8_scales[num_output * 1 + q];
271

272
            int Rx = 0;
273
            int Ux = 0;
274
            for (int i = 0; i < size; i++)
275
            {
276
                signed char xi = x[i];
277

278
                Rx += weight_xc_int8_R[i] * xi;
279
                Ux += weight_xc_int8_U[i] * xi;
280
            }
281

282
            int Rh = 0;
283
            int Uh = 0;
284
            for (int i = 0; i < num_output; i++)
285
            {
286
                signed char h_cont = hs[i];
287

288
                Rh += weight_hc_int8_R[i] * h_cont;
289
                Uh += weight_hc_int8_U[i] * h_cont;
290
            }
291

292
            float R = bias_c_R[q] + Rx * (descale_x * descale_xc_R) + Rh * (descale_h * descale_hc_R);
293
            float U = bias_c_U[q] + Ux * (descale_x * descale_xc_U) + Uh * (descale_h * descale_hc_U);
294

295
            // sigmoid(R)
296
            // sigmoid(U)
297
            R = 1.f / (1.f + expf(-R));
298
            U = 1.f / (1.f + expf(-U));
299

300
            // gate new
301
            const float* bias_c_WN = bias_c.row(2);
302
            const float* bias_c_BN = bias_c.row(3);
303

304
            const signed char* weight_xc_int8_N = weight_xc_int8.row<const signed char>(num_output * 2 + q);
305
            const signed char* weight_hc_int8_N = weight_hc_int8.row<const signed char>(num_output * 2 + q);
306

307
            const float descale_xc_N = 1.f / weight_xc_int8_scales[num_output * 2 + q];
308
            const float descale_hc_N = 1.f / weight_hc_int8_scales[num_output * 2 + q];
309

310
            int Nh = 0;
311
            for (int i = 0; i < num_output; i++)
312
            {
313
                signed char h_cont = hs[i];
314

315
                Nh += weight_hc_int8_N[i] * h_cont;
316
            }
317

318
            int Nx = 0;
319
            for (int i = 0; i < size; i++)
320
            {
321
                signed char xi = x[i];
322

323
                Nx += weight_xc_int8_N[i] * xi;
324
            }
325

326
            float N = bias_c_BN[q] + Nh * (descale_h * descale_hc_N);
327
            N = bias_c_WN[q] + R * N + Nx * (descale_x * descale_xc_N);
328

329
            // tanh(N)
330
            N = tanhf(N);
331

332
            gates_data[0] = U;
333
            gates_data[1] = N;
334
        }
335

336
        // h_t := (1 - update) .* new + update .* h_{t-1}
337
        float* output_data = top_blob.row(ti);
338
        #pragma omp parallel for num_threads(opt.num_threads)
339
        for (int q = 0; q < num_output; q++)
340
        {
341
            const float* gates_data = gates.row(q);
342

343
            float U = gates_data[0];
344
            float N = gates_data[1];
345

346
            float H = (1 - U) * N + U * hidden_state[q];
347

348
            hidden_state[q] = H;
349
            output_data[q] = H;
350
        }
351
    }
352

353
    return 0;
354
}
355
#endif // NCNN_INT8
356

357
int GRU::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const
358
{
359
    int T = bottom_blob.h;
360

361
    int num_directions = direction == 2 ? 2 : 1;
362

363
    // initial hidden state
364
    Mat hidden(num_output, 4u, opt.workspace_allocator);
365
    if (hidden.empty())
366
        return -100;
367
    hidden.fill(0.f);
368

369
    top_blob.create(num_output * num_directions, T, 4u, opt.blob_allocator);
370
    if (top_blob.empty())
371
        return -100;
372

373
    // Uni directional
374
    if (direction == 0 || direction == 1)
375
    {
376
#if NCNN_INT8
377
        if (int8_scale_term)
378
        {
379
            int ret = gru_int8(bottom_blob, top_blob, direction, weight_xc_data.channel(0), weight_xc_data_int8_scales.row(0), bias_c_data.channel(0), weight_hc_data.channel(0), weight_hc_data_int8_scales.row(0), hidden, opt);
380
            if (ret != 0)
381
                return ret;
382
        }
383
        else
384
#endif
385
        {
386
            int ret = gru(bottom_blob, top_blob, direction, weight_xc_data.channel(0), bias_c_data.channel(0), weight_hc_data.channel(0), hidden, opt);
387
            if (ret != 0)
388
                return ret;
389
        }
390
    }
391

392
    if (direction == 2)
393
    {
394
        Mat top_blob_forward(num_output, T, 4u, opt.workspace_allocator);
395
        if (top_blob_forward.empty())
396
            return -100;
397

398
        Mat top_blob_reverse(num_output, T, 4u, opt.workspace_allocator);
399
        if (top_blob_reverse.empty())
400
            return -100;
401

402
#if NCNN_INT8
403
        if (int8_scale_term)
404
        {
405
            int ret = gru_int8(bottom_blob, top_blob_forward, 0, weight_xc_data.channel(0), weight_xc_data_int8_scales.row(0), bias_c_data.channel(0), weight_hc_data.channel(0), weight_hc_data_int8_scales.row(0), hidden, opt);
406
            if (ret != 0)
407
                return ret;
408
        }
409
        else
410
#endif
411
        {
412
            int ret = gru(bottom_blob, top_blob_forward, 0, weight_xc_data.channel(0), bias_c_data.channel(0), weight_hc_data.channel(0), hidden, opt);
413
            if (ret != 0)
414
                return ret;
415
        }
416

417
        hidden.fill(0.0f);
418

419
#if NCNN_INT8
420
        if (int8_scale_term)
421
        {
422
            int ret = gru_int8(bottom_blob, top_blob_reverse, 1, weight_xc_data.channel(1), weight_xc_data_int8_scales.row(1), bias_c_data.channel(1), weight_hc_data.channel(1), weight_hc_data_int8_scales.row(1), hidden, opt);
423
            if (ret != 0)
424
                return ret;
425
        }
426
        else
427
#endif
428
        {
429
            int ret = gru(bottom_blob, top_blob_reverse, 1, weight_xc_data.channel(1), bias_c_data.channel(1), weight_hc_data.channel(1), hidden, opt);
430
            if (ret != 0)
431
                return ret;
432
        }
433

434
        // concat w
435
        for (int i = 0; i < T; i++)
436
        {
437
            const float* pf = top_blob_forward.row(i);
438
            const float* pr = top_blob_reverse.row(i);
439
            float* ptr = top_blob.row(i);
440

441
            memcpy(ptr, pf, num_output * sizeof(float));
442
            memcpy(ptr + num_output, pr, num_output * sizeof(float));
443
        }
444
    }
445

446
    return 0;
447
}
448

449
int GRU::forward(const std::vector<Mat>& bottom_blobs, std::vector<Mat>& top_blobs, const Option& opt) const
450
{
451
    const Mat& bottom_blob = bottom_blobs[0];
452
    int T = bottom_blob.h;
453
    int num_directions = direction == 2 ? 2 : 1;
454

455
    Mat hidden;
456
    Allocator* hidden_allocator = top_blobs.size() == 2 ? opt.blob_allocator : opt.workspace_allocator;
457
    if (bottom_blobs.size() == 2)
458
    {
459
        hidden = bottom_blobs[1].clone(hidden_allocator);
460
    }
461
    else
462
    {
463
        hidden.create(num_output, num_directions, 4u, hidden_allocator);
464
        if (hidden.empty())
465
            return -100;
466
        hidden.fill(0.f);
467
    }
468

469
    Mat& top_blob = top_blobs[0];
470
    top_blob.create(num_output * num_directions, T, 4u, opt.blob_allocator);
471
    if (top_blob.empty())
472
        return -100;
473

474
    // Uni directional
475
    if (direction == 0 || direction == 1)
476
    {
477
#if NCNN_INT8
478
        if (int8_scale_term)
479
        {
480
            int ret = gru_int8(bottom_blob, top_blob, direction, weight_xc_data.channel(0), weight_xc_data_int8_scales.row(0), bias_c_data.channel(0), weight_hc_data.channel(0), weight_hc_data_int8_scales.row(0), hidden, opt);
481
            if (ret != 0)
482
                return ret;
483
        }
484
        else
485
#endif
486
        {
487
            int ret = gru(bottom_blob, top_blob, direction, weight_xc_data.channel(0), bias_c_data.channel(0), weight_hc_data.channel(0), hidden, opt);
488
            if (ret != 0)
489
                return ret;
490
        }
491
    }
492

493
    if (direction == 2)
494
    {
495
        Mat top_blob_forward(num_output, T, 4u, opt.workspace_allocator);
496
        if (top_blob_forward.empty())
497
            return -100;
498

499
        Mat top_blob_reverse(num_output, T, 4u, opt.workspace_allocator);
500
        if (top_blob_reverse.empty())
501
            return -100;
502

503
        Mat hidden0 = hidden.row_range(0, 1);
504
#if NCNN_INT8
505
        if (int8_scale_term)
506
        {
507
            int ret = gru_int8(bottom_blob, top_blob_forward, 0, weight_xc_data.channel(0), weight_xc_data_int8_scales.row(0), bias_c_data.channel(0), weight_hc_data.channel(0), weight_hc_data_int8_scales.row(0), hidden0, opt);
508
            if (ret != 0)
509
                return ret;
510
        }
511
        else
512
#endif
513
        {
514
            int ret = gru(bottom_blob, top_blob_forward, 0, weight_xc_data.channel(0), bias_c_data.channel(0), weight_hc_data.channel(0), hidden0, opt);
515
            if (ret != 0)
516
                return ret;
517
        }
518

519
        Mat hidden1 = hidden.row_range(1, 1);
520
#if NCNN_INT8
521
        if (int8_scale_term)
522
        {
523
            int ret = gru_int8(bottom_blob, top_blob_reverse, 1, weight_xc_data.channel(1), weight_xc_data_int8_scales.row(1), bias_c_data.channel(1), weight_hc_data.channel(1), weight_hc_data_int8_scales.row(1), hidden1, opt);
524
            if (ret != 0)
525
                return ret;
526
        }
527
        else
528
#endif
529
        {
530
            int ret = gru(bottom_blob, top_blob_reverse, 1, weight_xc_data.channel(1), bias_c_data.channel(1), weight_hc_data.channel(1), hidden1, opt);
531
            if (ret != 0)
532
                return ret;
533
        }
534

535
        // concat w
536
        for (int i = 0; i < T; i++)
537
        {
538
            const float* pf = top_blob_forward.row(i);
539
            const float* pr = top_blob_reverse.row(i);
540
            float* ptr = top_blob.row(i);
541

542
            memcpy(ptr, pf, num_output * sizeof(float));
543
            memcpy(ptr + num_output, pr, num_output * sizeof(float));
544
        }
545
    }
546

547
    if (top_blobs.size() == 2)
548
    {
549
        top_blobs[1] = hidden;
550
    }
551

552
    return 0;
553
}
554

555
} // namespace ncnn
556

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.