ncnn

Форк
0
/
paramdict.cpp 
453 строки · 9.7 Кб
1
// Tencent is pleased to support the open source community by making ncnn available.
2
//
3
// Copyright (C) 2017 THL A29 Limited, a Tencent company. All rights reserved.
4
//
5
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
6
// in compliance with the License. You may obtain a copy of the License at
7
//
8
// https://opensource.org/licenses/BSD-3-Clause
9
//
10
// Unless required by applicable law or agreed to in writing, software distributed
11
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
12
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
13
// specific language governing permissions and limitations under the License.
14

15
#include "paramdict.h"
16

17
#include "datareader.h"
18
#include "mat.h"
19
#include "platform.h"
20

21
#include <ctype.h>
22

23
#if NCNN_STDIO
24
#include <stdio.h>
25
#endif
26

27
namespace ncnn {
28

29
class ParamDictPrivate
30
{
31
public:
32
    struct
33
    {
34
        // 0 = null
35
        // 1 = int/float
36
        // 2 = int
37
        // 3 = float
38
        // 4 = array of int/float
39
        // 5 = array of int
40
        // 6 = array of float
41
        int type;
42
        union
43
        {
44
            int i;
45
            float f;
46
        };
47
        Mat v;
48
    } params[NCNN_MAX_PARAM_COUNT];
49
};
50

51
ParamDict::ParamDict()
52
    : d(new ParamDictPrivate)
53
{
54
    clear();
55
}
56

57
ParamDict::~ParamDict()
58
{
59
    delete d;
60
}
61

62
ParamDict::ParamDict(const ParamDict& rhs)
63
    : d(new ParamDictPrivate)
64
{
65
    for (int i = 0; i < NCNN_MAX_PARAM_COUNT; i++)
66
    {
67
        int type = rhs.d->params[i].type;
68
        d->params[i].type = type;
69
        if (type == 1 || type == 2 || type == 3)
70
        {
71
            d->params[i].i = rhs.d->params[i].i;
72
        }
73
        else // if (type == 4 || type == 5 || type == 6)
74
        {
75
            d->params[i].v = rhs.d->params[i].v;
76
        }
77
    }
78
}
79

80
ParamDict& ParamDict::operator=(const ParamDict& rhs)
81
{
82
    if (this == &rhs)
83
        return *this;
84

85
    for (int i = 0; i < NCNN_MAX_PARAM_COUNT; i++)
86
    {
87
        int type = rhs.d->params[i].type;
88
        d->params[i].type = type;
89
        if (type == 1 || type == 2 || type == 3)
90
        {
91
            d->params[i].i = rhs.d->params[i].i;
92
        }
93
        else // if (type == 4 || type == 5 || type == 6)
94
        {
95
            d->params[i].v = rhs.d->params[i].v;
96
        }
97
    }
98

99
    return *this;
100
}
101

102
int ParamDict::type(int id) const
103
{
104
    return d->params[id].type;
105
}
106

107
// TODO strict type check
108
int ParamDict::get(int id, int def) const
109
{
110
    return d->params[id].type ? d->params[id].i : def;
111
}
112

113
float ParamDict::get(int id, float def) const
114
{
115
    return d->params[id].type ? d->params[id].f : def;
116
}
117

118
Mat ParamDict::get(int id, const Mat& def) const
119
{
120
    return d->params[id].type ? d->params[id].v : def;
121
}
122

123
void ParamDict::set(int id, int i)
124
{
125
    d->params[id].type = 2;
126
    d->params[id].i = i;
127
}
128

129
void ParamDict::set(int id, float f)
130
{
131
    d->params[id].type = 3;
132
    d->params[id].f = f;
133
}
134

135
void ParamDict::set(int id, const Mat& v)
136
{
137
    d->params[id].type = 4;
138
    d->params[id].v = v;
139
}
140

141
void ParamDict::clear()
142
{
143
    for (int i = 0; i < NCNN_MAX_PARAM_COUNT; i++)
144
    {
145
        d->params[i].type = 0;
146
        d->params[i].v = Mat();
147
    }
148
}
149

150
#if NCNN_STRING
151
static bool vstr_is_float(const char vstr[16])
152
{
153
    // look ahead for determine isfloat
154
    for (int j = 0; j < 16; j++)
155
    {
156
        if (vstr[j] == '\0')
157
            break;
158

159
        if (vstr[j] == '.' || tolower(vstr[j]) == 'e')
160
            return true;
161
    }
162

163
    return false;
164
}
165

166
static float vstr_to_float(const char vstr[16])
167
{
168
    double v = 0.0;
169

170
    const char* p = vstr;
171

172
    // sign
173
    bool sign = *p != '-';
174
    if (*p == '+' || *p == '-')
175
    {
176
        p++;
177
    }
178

179
    // digits before decimal point or exponent
180
    unsigned int v1 = 0;
181
    while (isdigit(*p))
182
    {
183
        v1 = v1 * 10 + (*p - '0');
184
        p++;
185
    }
186

187
    v = (double)v1;
188

189
    // digits after decimal point
190
    if (*p == '.')
191
    {
192
        p++;
193

194
        unsigned int pow10 = 1;
195
        unsigned int v2 = 0;
196

197
        while (isdigit(*p))
198
        {
199
            v2 = v2 * 10 + (*p - '0');
200
            pow10 *= 10;
201
            p++;
202
        }
203

204
        v += v2 / (double)pow10;
205
    }
206

207
    // exponent
208
    if (*p == 'e' || *p == 'E')
209
    {
210
        p++;
211

212
        // sign of exponent
213
        bool fact = *p != '-';
214
        if (*p == '+' || *p == '-')
215
        {
216
            p++;
217
        }
218

219
        // digits of exponent
220
        unsigned int expon = 0;
221
        while (isdigit(*p))
222
        {
223
            expon = expon * 10 + (*p - '0');
224
            p++;
225
        }
226

227
        double scale = 1.0;
228
        while (expon >= 8)
229
        {
230
            scale *= 1e8;
231
            expon -= 8;
232
        }
233
        while (expon > 0)
234
        {
235
            scale *= 10.0;
236
            expon -= 1;
237
        }
238

239
        v = fact ? v * scale : v / scale;
240
    }
241

242
    //     fprintf(stderr, "v = %f\n", v);
243
    return sign ? (float)v : (float)-v;
244
}
245

246
int ParamDict::load_param(const DataReader& dr)
247
{
248
    clear();
249

250
    //     0=100 1=1.250000 -23303=5,0.1,0.2,0.4,0.8,1.0
251

252
    // parse each key=value pair
253
    int id = 0;
254
    while (dr.scan("%d=", &id) == 1)
255
    {
256
        bool is_array = id <= -23300;
257
        if (is_array)
258
        {
259
            id = -id - 23300;
260
        }
261

262
        if (id >= NCNN_MAX_PARAM_COUNT)
263
        {
264
            NCNN_LOGE("id < NCNN_MAX_PARAM_COUNT failed (id=%d, NCNN_MAX_PARAM_COUNT=%d)", id, NCNN_MAX_PARAM_COUNT);
265
            return -1;
266
        }
267

268
        if (is_array)
269
        {
270
            int len = 0;
271
            int nscan = dr.scan("%d", &len);
272
            if (nscan != 1)
273
            {
274
                NCNN_LOGE("ParamDict read array length failed");
275
                return -1;
276
            }
277

278
            d->params[id].v.create(len);
279

280
            for (int j = 0; j < len; j++)
281
            {
282
                char vstr[16];
283
                nscan = dr.scan(",%15[^,\n ]", vstr);
284
                if (nscan != 1)
285
                {
286
                    NCNN_LOGE("ParamDict read array element failed");
287
                    return -1;
288
                }
289

290
                bool is_float = vstr_is_float(vstr);
291

292
                if (is_float)
293
                {
294
                    float* ptr = d->params[id].v;
295
                    ptr[j] = vstr_to_float(vstr);
296
                }
297
                else
298
                {
299
                    int* ptr = d->params[id].v;
300
                    nscan = sscanf(vstr, "%d", &ptr[j]);
301
                    if (nscan != 1)
302
                    {
303
                        NCNN_LOGE("ParamDict parse array element failed");
304
                        return -1;
305
                    }
306
                }
307

308
                d->params[id].type = is_float ? 6 : 5;
309
            }
310
        }
311
        else
312
        {
313
            char vstr[16];
314
            int nscan = dr.scan("%15s", vstr);
315
            if (nscan != 1)
316
            {
317
                NCNN_LOGE("ParamDict read value failed");
318
                return -1;
319
            }
320

321
            bool is_float = vstr_is_float(vstr);
322

323
            if (is_float)
324
            {
325
                d->params[id].f = vstr_to_float(vstr);
326
            }
327
            else
328
            {
329
                nscan = sscanf(vstr, "%d", &d->params[id].i);
330
                if (nscan != 1)
331
                {
332
                    NCNN_LOGE("ParamDict parse value failed");
333
                    return -1;
334
                }
335
            }
336

337
            d->params[id].type = is_float ? 3 : 2;
338
        }
339
    }
340

341
    return 0;
342
}
343
#endif // NCNN_STRING
344

345
int ParamDict::load_param_bin(const DataReader& dr)
346
{
347
    clear();
348

349
    //     binary 0
350
    //     binary 100
351
    //     binary 1
352
    //     binary 1.250000
353
    //     binary 3 | array_bit
354
    //     binary 5
355
    //     binary 0.1
356
    //     binary 0.2
357
    //     binary 0.4
358
    //     binary 0.8
359
    //     binary 1.0
360
    //     binary -233(EOP)
361

362
    int id = 0;
363
    size_t nread;
364
    nread = dr.read(&id, sizeof(int));
365
    if (nread != sizeof(int))
366
    {
367
        NCNN_LOGE("ParamDict read id failed %zd", nread);
368
        return -1;
369
    }
370

371
#if __BIG_ENDIAN__
372
    swap_endianness_32(&id);
373
#endif
374

375
    while (id != -233)
376
    {
377
        bool is_array = id <= -23300;
378
        if (is_array)
379
        {
380
            id = -id - 23300;
381
        }
382

383
        if (id >= NCNN_MAX_PARAM_COUNT)
384
        {
385
            NCNN_LOGE("id < NCNN_MAX_PARAM_COUNT failed (id=%d, NCNN_MAX_PARAM_COUNT=%d)", id, NCNN_MAX_PARAM_COUNT);
386
            return -1;
387
        }
388

389
        if (is_array)
390
        {
391
            int len = 0;
392
            nread = dr.read(&len, sizeof(int));
393
            if (nread != sizeof(int))
394
            {
395
                NCNN_LOGE("ParamDict read array length failed %zd", nread);
396
                return -1;
397
            }
398

399
#if __BIG_ENDIAN__
400
            swap_endianness_32(&len);
401
#endif
402

403
            d->params[id].v.create(len);
404

405
            float* ptr = d->params[id].v;
406
            nread = dr.read(ptr, sizeof(float) * len);
407
            if (nread != sizeof(float) * len)
408
            {
409
                NCNN_LOGE("ParamDict read array element failed %zd", nread);
410
                return -1;
411
            }
412

413
#if __BIG_ENDIAN__
414
            for (int i = 0; i < len; i++)
415
            {
416
                swap_endianness_32(ptr + i);
417
            }
418
#endif
419

420
            d->params[id].type = 4;
421
        }
422
        else
423
        {
424
            nread = dr.read(&d->params[id].f, sizeof(float));
425
            if (nread != sizeof(float))
426
            {
427
                NCNN_LOGE("ParamDict read value failed %zd", nread);
428
                return -1;
429
            }
430

431
#if __BIG_ENDIAN__
432
            swap_endianness_32(&d->params[id].f);
433
#endif
434

435
            d->params[id].type = 1;
436
        }
437

438
        nread = dr.read(&id, sizeof(int));
439
        if (nread != sizeof(int))
440
        {
441
            NCNN_LOGE("ParamDict read EOP failed %zd", nread);
442
            return -1;
443
        }
444

445
#if __BIG_ENDIAN__
446
        swap_endianness_32(&id);
447
#endif
448
    }
449

450
    return 0;
451
}
452

453
} // namespace ncnn
454

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.