ncnn

Форк
0
/
embed.cpp 
160 строк · 4.1 Кб
1
// Tencent is pleased to support the open source community by making ncnn available.
2
//
3
// Copyright (C) 2017 THL A29 Limited, a Tencent company. All rights reserved.
4
//
5
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
6
// in compliance with the License. You may obtain a copy of the License at
7
//
8
// https://opensource.org/licenses/BSD-3-Clause
9
//
10
// Unless required by applicable law or agreed to in writing, software distributed
11
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
12
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
13
// specific language governing permissions and limitations under the License.
14

15
#include "embed.h"
16

17
#include <string.h>
18

19
namespace ncnn {
20

21
Embed::Embed()
22
{
23
    one_blob_only = true;
24
    support_inplace = false;
25
}
26

27
int Embed::load_param(const ParamDict& pd)
28
{
29
    num_output = pd.get(0, 0);
30
    input_dim = pd.get(1, 0);
31
    bias_term = pd.get(2, 0);
32
    weight_data_size = pd.get(3, 0);
33
    int8_scale_term = pd.get(18, 0);
34

35
    return 0;
36
}
37

38
int Embed::load_model(const ModelBin& mb)
39
{
40
    weight_data = mb.load(weight_data_size, 0);
41
    if (weight_data.empty())
42
        return -100;
43

44
    if (bias_term)
45
    {
46
        bias_data = mb.load(num_output, 1);
47
        if (bias_data.empty())
48
            return -100;
49
    }
50

51
#if NCNN_INT8
52
    if (int8_scale_term)
53
    {
54
        weight_data_int8_scale = mb.load(1, 1)[0];
55
    }
56
#endif // NCNN_INT8
57

58
    return 0;
59
}
60

61
static void embed(const Mat& bottom_blob, const Mat& weight_data, const Mat& bias_data, Mat& top_blob, int input_dim, const Option& opt)
62
{
63
    const int num_output = top_blob.w;
64
    const int words = top_blob.h;
65

66
    const float* bias_ptr = bias_data;
67

68
    #pragma omp parallel for num_threads(opt.num_threads)
69
    for (int q = 0; q < words; q++)
70
    {
71
        float* outptr = top_blob.row(q);
72

73
        int word_index = ((const int*)bottom_blob)[q];
74

75
        if (word_index < 0)
76
            word_index = 0;
77
        if (word_index >= input_dim)
78
            word_index = input_dim - 1;
79

80
        const float* em = (const float*)weight_data + num_output * word_index;
81

82
        if (bias_ptr)
83
        {
84
            for (int p = 0; p < num_output; p++)
85
            {
86
                outptr[p] = em[p] + bias_ptr[p];
87
            }
88
        }
89
        else
90
        {
91
            memcpy(outptr, em, num_output * sizeof(float));
92
        }
93
    }
94
}
95

96
#if NCNN_INT8
97
static void embed_int8(const Mat& bottom_blob, const Mat& weight_data, float weight_data_int8_scale, const Mat& bias_data, Mat& top_blob, int input_dim, const Option& opt)
98
{
99
    const int num_output = top_blob.w;
100
    const int words = top_blob.h;
101

102
    const float* bias_ptr = bias_data;
103

104
    #pragma omp parallel for num_threads(opt.num_threads)
105
    for (int q = 0; q < words; q++)
106
    {
107
        float* outptr = top_blob.row(q);
108

109
        int word_index = ((const int*)bottom_blob)[q];
110

111
        if (word_index < 0)
112
            word_index = 0;
113
        if (word_index >= input_dim)
114
            word_index = input_dim - 1;
115

116
        const float descale_em = 1.f / weight_data_int8_scale;
117

118
        const signed char* em = (const signed char*)weight_data + num_output * word_index;
119

120
        if (bias_ptr)
121
        {
122
            for (int p = 0; p < num_output; p++)
123
            {
124
                outptr[p] = em[p] * descale_em + bias_ptr[p];
125
            }
126
        }
127
        else
128
        {
129
            for (int p = 0; p < num_output; p++)
130
            {
131
                outptr[p] = em[p] * descale_em;
132
            }
133
        }
134
    }
135
}
136
#endif // NCNN_INT8
137

138
int Embed::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const
139
{
140
    int words = static_cast<int>(bottom_blob.total());
141

142
    top_blob.create(num_output, words, 4u, opt.blob_allocator);
143
    if (top_blob.empty())
144
        return -100;
145

146
#if NCNN_INT8
147
    if (int8_scale_term)
148
    {
149
        embed_int8(bottom_blob, weight_data, weight_data_int8_scale, bias_data, top_blob, input_dim, opt);
150
    }
151
    else
152
#endif // NCNN_INT8
153
    {
154
        embed(bottom_blob, weight_data, bias_data, top_blob, input_dim, opt);
155
    }
156

157
    return 0;
158
}
159

160
} // namespace ncnn
161

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.