1
// Tencent is pleased to support the open source community by making ncnn available.
3
// Copyright (C) 2017 THL A29 Limited, a Tencent company. All rights reserved.
5
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
6
// in compliance with the License. You may obtain a copy of the License at
8
// https://opensource.org/licenses/BSD-3-Clause
10
// Unless required by applicable law or agreed to in writing, software distributed
11
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
12
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
13
// specific language governing permissions and limitations under the License.
24
support_inplace = false;
27
int Embed::load_param(const ParamDict& pd)
29
num_output = pd.get(0, 0);
30
input_dim = pd.get(1, 0);
31
bias_term = pd.get(2, 0);
32
weight_data_size = pd.get(3, 0);
33
int8_scale_term = pd.get(18, 0);
38
int Embed::load_model(const ModelBin& mb)
40
weight_data = mb.load(weight_data_size, 0);
41
if (weight_data.empty())
46
bias_data = mb.load(num_output, 1);
47
if (bias_data.empty())
54
weight_data_int8_scale = mb.load(1, 1)[0];
61
static void embed(const Mat& bottom_blob, const Mat& weight_data, const Mat& bias_data, Mat& top_blob, int input_dim, const Option& opt)
63
const int num_output = top_blob.w;
64
const int words = top_blob.h;
66
const float* bias_ptr = bias_data;
68
#pragma omp parallel for num_threads(opt.num_threads)
69
for (int q = 0; q < words; q++)
71
float* outptr = top_blob.row(q);
73
int word_index = ((const int*)bottom_blob)[q];
77
if (word_index >= input_dim)
78
word_index = input_dim - 1;
80
const float* em = (const float*)weight_data + num_output * word_index;
84
for (int p = 0; p < num_output; p++)
86
outptr[p] = em[p] + bias_ptr[p];
91
memcpy(outptr, em, num_output * sizeof(float));
97
static void embed_int8(const Mat& bottom_blob, const Mat& weight_data, float weight_data_int8_scale, const Mat& bias_data, Mat& top_blob, int input_dim, const Option& opt)
99
const int num_output = top_blob.w;
100
const int words = top_blob.h;
102
const float* bias_ptr = bias_data;
104
#pragma omp parallel for num_threads(opt.num_threads)
105
for (int q = 0; q < words; q++)
107
float* outptr = top_blob.row(q);
109
int word_index = ((const int*)bottom_blob)[q];
113
if (word_index >= input_dim)
114
word_index = input_dim - 1;
116
const float descale_em = 1.f / weight_data_int8_scale;
118
const signed char* em = (const signed char*)weight_data + num_output * word_index;
122
for (int p = 0; p < num_output; p++)
124
outptr[p] = em[p] * descale_em + bias_ptr[p];
129
for (int p = 0; p < num_output; p++)
131
outptr[p] = em[p] * descale_em;
138
int Embed::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const
140
int words = static_cast<int>(bottom_blob.total());
142
top_blob.create(num_output, words, 4u, opt.blob_allocator);
143
if (top_blob.empty())
149
embed_int8(bottom_blob, weight_data, weight_data_int8_scale, bias_data, top_blob, input_dim, opt);
154
embed(bottom_blob, weight_data, bias_data, top_blob, input_dim, opt);