1
// Tencent is pleased to support the open source community by making ncnn available.
3
// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved.
5
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
6
// in compliance with the License. You may obtain a copy of the License at
8
// https://opensource.org/licenses/BSD-3-Clause
10
// Unless required by applicable law or agreed to in writing, software distributed
11
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
12
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
13
// specific language governing permissions and limitations under the License.
21
one_blob_only = false;
22
support_inplace = false;
25
int GRU::load_param(const ParamDict& pd)
27
num_output = pd.get(0, 0);
28
weight_data_size = pd.get(1, 0);
29
direction = pd.get(2, 0);
30
int8_scale_term = pd.get(8, 0);
35
NCNN_LOGE("please build ncnn with NCNN_INT8 enabled for int8 inference");
43
int GRU::load_model(const ModelBin& mb)
45
int num_directions = direction == 2 ? 2 : 1;
47
int size = weight_data_size / num_directions / num_output / 3;
50
weight_xc_data = mb.load(size, num_output * 3, num_directions, 0);
51
if (weight_xc_data.empty())
54
bias_c_data = mb.load(num_output, 4, num_directions, 0);
55
if (bias_c_data.empty())
58
weight_hc_data = mb.load(num_output, num_output * 3, num_directions, 0);
59
if (weight_hc_data.empty())
65
weight_xc_data_int8_scales = mb.load(num_output * 3, num_directions, 1);
66
weight_hc_data_int8_scales = mb.load(num_output * 3, num_directions, 1);
73
static int gru(const Mat& bottom_blob, Mat& top_blob, int reverse, const Mat& weight_xc, const Mat& bias_c, const Mat& weight_hc, Mat& hidden_state, const Option& opt)
75
int size = bottom_blob.w;
76
int T = bottom_blob.h;
78
int num_output = top_blob.w;
81
Mat gates(2, num_output, 4u, opt.workspace_allocator);
86
for (int t = 0; t < T; t++)
88
int ti = reverse ? T - 1 - t : t;
90
const float* x = bottom_blob.row(ti);
91
#pragma omp parallel for num_threads(opt.num_threads)
92
for (int q = 0; q < num_output; q++)
94
float* gates_data = gates.row(q);
97
const float* bias_c_R = bias_c.row(0);
98
const float* bias_c_U = bias_c.row(1);
100
const float* weight_xc_R = weight_xc.row(num_output * 0 + q);
101
const float* weight_xc_U = weight_xc.row(num_output * 1 + q);
102
const float* weight_hc_R = weight_hc.row(num_output * 0 + q);
103
const float* weight_hc_U = weight_hc.row(num_output * 1 + q);
105
float R = bias_c_R[q];
106
float U = bias_c_U[q];
108
for (int i = 0; i < size; i++)
112
R += weight_xc_R[i] * xi;
113
U += weight_xc_U[i] * xi;
116
for (int i = 0; i < num_output; i++)
118
float h_cont = hidden_state[i];
120
R += weight_hc_R[i] * h_cont;
121
U += weight_hc_U[i] * h_cont;
126
R = 1.f / (1.f + expf(-R));
127
U = 1.f / (1.f + expf(-U));
130
const float* bias_c_WN = bias_c.row(2);
131
const float* bias_c_BN = bias_c.row(3);
133
const float* weight_xc_N = weight_xc.row(num_output * 2 + q);
134
const float* weight_hc_N = weight_hc.row(num_output * 2 + q);
136
float N = bias_c_BN[q];
138
for (int i = 0; i < num_output; i++)
140
float h_cont = hidden_state[i];
142
N += weight_hc_N[i] * h_cont;
145
N = bias_c_WN[q] + R * N;
147
for (int i = 0; i < size; i++)
151
N += weight_xc_N[i] * xi;
161
// h_t := (1 - update) .* new + update .* h_{t-1}
162
float* output_data = top_blob.row(ti);
163
#pragma omp parallel for num_threads(opt.num_threads)
164
for (int q = 0; q < num_output; q++)
166
const float* gates_data = gates.row(q);
168
float U = gates_data[0];
169
float N = gates_data[1];
171
float H = (1 - U) * N + U * hidden_state[q];
182
static int gru_int8(const Mat& bottom_blob, Mat& top_blob, int reverse, const Mat& weight_xc_int8, const float* weight_xc_int8_scales, const Mat& bias_c, const Mat& weight_hc_int8, const float* weight_hc_int8_scales, Mat& hidden_state, const Option& opt)
184
int size = bottom_blob.w;
185
int T = bottom_blob.h;
187
int num_output = top_blob.w;
190
Mat gates(2, num_output, 4u, opt.workspace_allocator);
194
// dynamic quantize bottom_blob
195
Mat bottom_blob_int8(size, T, (size_t)1u, 1, opt.workspace_allocator);
196
Mat bottom_blob_int8_scales(T, (size_t)4u, 1, opt.workspace_allocator);
198
for (int t = 0; t < T; t++)
200
const float* x = bottom_blob.row(t);
203
for (int i = 0; i < size; i++)
205
absmax = std::max(absmax, (float)fabs(x[i]));
208
bottom_blob_int8_scales[t] = 127.f / absmax;
211
Option opt_quant = opt;
212
opt_quant.blob_allocator = opt.workspace_allocator;
213
opt_quant.use_packing_layout = false;
214
quantize_to_int8(bottom_blob, bottom_blob_int8, bottom_blob_int8_scales, opt_quant);
217
Mat hidden_state_int8(num_output, (size_t)1u, 1, opt.workspace_allocator);
218
Mat hidden_state_int8_scales(1, (size_t)4u, 1, opt.workspace_allocator);
221
for (int t = 0; t < T; t++)
223
int ti = reverse ? T - 1 - t : t;
225
// dynamic quantize hidden_state
228
for (int i = 0; i < num_output; i++)
230
absmax = std::max(absmax, (float)fabs(hidden_state[i]));
235
hidden_state_int8_scales[0] = 1.f;
236
hidden_state_int8.fill<signed char>(0);
240
hidden_state_int8_scales[0] = 127.f / absmax;
242
Option opt_quant = opt;
243
opt_quant.blob_allocator = opt.workspace_allocator;
244
opt_quant.use_packing_layout = false;
245
quantize_to_int8(hidden_state, hidden_state_int8, hidden_state_int8_scales, opt_quant);
249
const signed char* x = bottom_blob_int8.row<const signed char>(ti);
250
const signed char* hs = hidden_state_int8;
251
const float descale_x = 1.f / bottom_blob_int8_scales[ti];
252
const float descale_h = 1.f / hidden_state_int8_scales[0];
253
#pragma omp parallel for num_threads(opt.num_threads)
254
for (int q = 0; q < num_output; q++)
256
float* gates_data = gates.row(q);
259
const float* bias_c_R = bias_c.row(0);
260
const float* bias_c_U = bias_c.row(1);
262
const signed char* weight_xc_int8_R = weight_xc_int8.row<const signed char>(num_output * 0 + q);
263
const signed char* weight_xc_int8_U = weight_xc_int8.row<const signed char>(num_output * 1 + q);
264
const signed char* weight_hc_int8_R = weight_hc_int8.row<const signed char>(num_output * 0 + q);
265
const signed char* weight_hc_int8_U = weight_hc_int8.row<const signed char>(num_output * 1 + q);
267
const float descale_xc_R = 1.f / weight_xc_int8_scales[num_output * 0 + q];
268
const float descale_xc_U = 1.f / weight_xc_int8_scales[num_output * 1 + q];
269
const float descale_hc_R = 1.f / weight_hc_int8_scales[num_output * 0 + q];
270
const float descale_hc_U = 1.f / weight_hc_int8_scales[num_output * 1 + q];
274
for (int i = 0; i < size; i++)
276
signed char xi = x[i];
278
Rx += weight_xc_int8_R[i] * xi;
279
Ux += weight_xc_int8_U[i] * xi;
284
for (int i = 0; i < num_output; i++)
286
signed char h_cont = hs[i];
288
Rh += weight_hc_int8_R[i] * h_cont;
289
Uh += weight_hc_int8_U[i] * h_cont;
292
float R = bias_c_R[q] + Rx * (descale_x * descale_xc_R) + Rh * (descale_h * descale_hc_R);
293
float U = bias_c_U[q] + Ux * (descale_x * descale_xc_U) + Uh * (descale_h * descale_hc_U);
297
R = 1.f / (1.f + expf(-R));
298
U = 1.f / (1.f + expf(-U));
301
const float* bias_c_WN = bias_c.row(2);
302
const float* bias_c_BN = bias_c.row(3);
304
const signed char* weight_xc_int8_N = weight_xc_int8.row<const signed char>(num_output * 2 + q);
305
const signed char* weight_hc_int8_N = weight_hc_int8.row<const signed char>(num_output * 2 + q);
307
const float descale_xc_N = 1.f / weight_xc_int8_scales[num_output * 2 + q];
308
const float descale_hc_N = 1.f / weight_hc_int8_scales[num_output * 2 + q];
311
for (int i = 0; i < num_output; i++)
313
signed char h_cont = hs[i];
315
Nh += weight_hc_int8_N[i] * h_cont;
319
for (int i = 0; i < size; i++)
321
signed char xi = x[i];
323
Nx += weight_xc_int8_N[i] * xi;
326
float N = bias_c_BN[q] + Nh * (descale_h * descale_hc_N);
327
N = bias_c_WN[q] + R * N + Nx * (descale_x * descale_xc_N);
336
// h_t := (1 - update) .* new + update .* h_{t-1}
337
float* output_data = top_blob.row(ti);
338
#pragma omp parallel for num_threads(opt.num_threads)
339
for (int q = 0; q < num_output; q++)
341
const float* gates_data = gates.row(q);
343
float U = gates_data[0];
344
float N = gates_data[1];
346
float H = (1 - U) * N + U * hidden_state[q];
357
int GRU::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const
359
int T = bottom_blob.h;
361
int num_directions = direction == 2 ? 2 : 1;
363
// initial hidden state
364
Mat hidden(num_output, 4u, opt.workspace_allocator);
369
top_blob.create(num_output * num_directions, T, 4u, opt.blob_allocator);
370
if (top_blob.empty())
374
if (direction == 0 || direction == 1)
379
int ret = gru_int8(bottom_blob, top_blob, direction, weight_xc_data.channel(0), weight_xc_data_int8_scales.row(0), bias_c_data.channel(0), weight_hc_data.channel(0), weight_hc_data_int8_scales.row(0), hidden, opt);
386
int ret = gru(bottom_blob, top_blob, direction, weight_xc_data.channel(0), bias_c_data.channel(0), weight_hc_data.channel(0), hidden, opt);
394
Mat top_blob_forward(num_output, T, 4u, opt.workspace_allocator);
395
if (top_blob_forward.empty())
398
Mat top_blob_reverse(num_output, T, 4u, opt.workspace_allocator);
399
if (top_blob_reverse.empty())
405
int ret = gru_int8(bottom_blob, top_blob_forward, 0, weight_xc_data.channel(0), weight_xc_data_int8_scales.row(0), bias_c_data.channel(0), weight_hc_data.channel(0), weight_hc_data_int8_scales.row(0), hidden, opt);
412
int ret = gru(bottom_blob, top_blob_forward, 0, weight_xc_data.channel(0), bias_c_data.channel(0), weight_hc_data.channel(0), hidden, opt);
422
int ret = gru_int8(bottom_blob, top_blob_reverse, 1, weight_xc_data.channel(1), weight_xc_data_int8_scales.row(1), bias_c_data.channel(1), weight_hc_data.channel(1), weight_hc_data_int8_scales.row(1), hidden, opt);
429
int ret = gru(bottom_blob, top_blob_reverse, 1, weight_xc_data.channel(1), bias_c_data.channel(1), weight_hc_data.channel(1), hidden, opt);
435
for (int i = 0; i < T; i++)
437
const float* pf = top_blob_forward.row(i);
438
const float* pr = top_blob_reverse.row(i);
439
float* ptr = top_blob.row(i);
441
memcpy(ptr, pf, num_output * sizeof(float));
442
memcpy(ptr + num_output, pr, num_output * sizeof(float));
449
int GRU::forward(const std::vector<Mat>& bottom_blobs, std::vector<Mat>& top_blobs, const Option& opt) const
451
const Mat& bottom_blob = bottom_blobs[0];
452
int T = bottom_blob.h;
453
int num_directions = direction == 2 ? 2 : 1;
456
Allocator* hidden_allocator = top_blobs.size() == 2 ? opt.blob_allocator : opt.workspace_allocator;
457
if (bottom_blobs.size() == 2)
459
hidden = bottom_blobs[1].clone(hidden_allocator);
463
hidden.create(num_output, num_directions, 4u, hidden_allocator);
469
Mat& top_blob = top_blobs[0];
470
top_blob.create(num_output * num_directions, T, 4u, opt.blob_allocator);
471
if (top_blob.empty())
475
if (direction == 0 || direction == 1)
480
int ret = gru_int8(bottom_blob, top_blob, direction, weight_xc_data.channel(0), weight_xc_data_int8_scales.row(0), bias_c_data.channel(0), weight_hc_data.channel(0), weight_hc_data_int8_scales.row(0), hidden, opt);
487
int ret = gru(bottom_blob, top_blob, direction, weight_xc_data.channel(0), bias_c_data.channel(0), weight_hc_data.channel(0), hidden, opt);
495
Mat top_blob_forward(num_output, T, 4u, opt.workspace_allocator);
496
if (top_blob_forward.empty())
499
Mat top_blob_reverse(num_output, T, 4u, opt.workspace_allocator);
500
if (top_blob_reverse.empty())
503
Mat hidden0 = hidden.row_range(0, 1);
507
int ret = gru_int8(bottom_blob, top_blob_forward, 0, weight_xc_data.channel(0), weight_xc_data_int8_scales.row(0), bias_c_data.channel(0), weight_hc_data.channel(0), weight_hc_data_int8_scales.row(0), hidden0, opt);
514
int ret = gru(bottom_blob, top_blob_forward, 0, weight_xc_data.channel(0), bias_c_data.channel(0), weight_hc_data.channel(0), hidden0, opt);
519
Mat hidden1 = hidden.row_range(1, 1);
523
int ret = gru_int8(bottom_blob, top_blob_reverse, 1, weight_xc_data.channel(1), weight_xc_data_int8_scales.row(1), bias_c_data.channel(1), weight_hc_data.channel(1), weight_hc_data_int8_scales.row(1), hidden1, opt);
530
int ret = gru(bottom_blob, top_blob_reverse, 1, weight_xc_data.channel(1), bias_c_data.channel(1), weight_hc_data.channel(1), hidden1, opt);
536
for (int i = 0; i < T; i++)
538
const float* pf = top_blob_forward.row(i);
539
const float* pr = top_blob_reverse.row(i);
540
float* ptr = top_blob.row(i);
542
memcpy(ptr, pf, num_output * sizeof(float));
543
memcpy(ptr + num_output, pr, num_output * sizeof(float));
547
if (top_blobs.size() == 2)
549
top_blobs[1] = hidden;