1
// Tencent is pleased to support the open source community by making ncnn available.
3
// Copyright (C) 2017 THL A29 Limited, a Tencent company. All rights reserved.
5
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
6
// in compliance with the License. You may obtain a copy of the License at
8
// https://opensource.org/licenses/BSD-3-Clause
10
// Unless required by applicable law or agreed to in writing, software distributed
11
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
12
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
13
// specific language governing permissions and limitations under the License.
17
#include "datareader.h"
38
// 4 = array of int/float
48
} params[NCNN_MAX_PARAM_COUNT];
52
: d(new ParamDictPrivate)
57
ParamDict::~ParamDict()
62
ParamDict::ParamDict(const ParamDict& rhs)
63
: d(new ParamDictPrivate)
65
for (int i = 0; i < NCNN_MAX_PARAM_COUNT; i++)
67
int type = rhs.d->params[i].type;
68
d->params[i].type = type;
69
if (type == 1 || type == 2 || type == 3)
71
d->params[i].i = rhs.d->params[i].i;
73
else // if (type == 4 || type == 5 || type == 6)
75
d->params[i].v = rhs.d->params[i].v;
80
ParamDict& ParamDict::operator=(const ParamDict& rhs)
85
for (int i = 0; i < NCNN_MAX_PARAM_COUNT; i++)
87
int type = rhs.d->params[i].type;
88
d->params[i].type = type;
89
if (type == 1 || type == 2 || type == 3)
91
d->params[i].i = rhs.d->params[i].i;
93
else // if (type == 4 || type == 5 || type == 6)
95
d->params[i].v = rhs.d->params[i].v;
102
int ParamDict::type(int id) const
104
return d->params[id].type;
107
// TODO strict type check
108
int ParamDict::get(int id, int def) const
110
return d->params[id].type ? d->params[id].i : def;
113
float ParamDict::get(int id, float def) const
115
return d->params[id].type ? d->params[id].f : def;
118
Mat ParamDict::get(int id, const Mat& def) const
120
return d->params[id].type ? d->params[id].v : def;
123
void ParamDict::set(int id, int i)
125
d->params[id].type = 2;
129
void ParamDict::set(int id, float f)
131
d->params[id].type = 3;
135
void ParamDict::set(int id, const Mat& v)
137
d->params[id].type = 4;
141
void ParamDict::clear()
143
for (int i = 0; i < NCNN_MAX_PARAM_COUNT; i++)
145
d->params[i].type = 0;
146
d->params[i].v = Mat();
151
static bool vstr_is_float(const char vstr[16])
153
// look ahead for determine isfloat
154
for (int j = 0; j < 16; j++)
159
if (vstr[j] == '.' || tolower(vstr[j]) == 'e')
166
static float vstr_to_float(const char vstr[16])
170
const char* p = vstr;
173
bool sign = *p != '-';
174
if (*p == '+' || *p == '-')
179
// digits before decimal point or exponent
183
v1 = v1 * 10 + (*p - '0');
189
// digits after decimal point
194
unsigned int pow10 = 1;
199
v2 = v2 * 10 + (*p - '0');
204
v += v2 / (double)pow10;
208
if (*p == 'e' || *p == 'E')
213
bool fact = *p != '-';
214
if (*p == '+' || *p == '-')
219
// digits of exponent
220
unsigned int expon = 0;
223
expon = expon * 10 + (*p - '0');
239
v = fact ? v * scale : v / scale;
242
// fprintf(stderr, "v = %f\n", v);
243
return sign ? (float)v : (float)-v;
246
int ParamDict::load_param(const DataReader& dr)
250
// 0=100 1=1.250000 -23303=5,0.1,0.2,0.4,0.8,1.0
252
// parse each key=value pair
254
while (dr.scan("%d=", &id) == 1)
256
bool is_array = id <= -23300;
262
if (id >= NCNN_MAX_PARAM_COUNT)
264
NCNN_LOGE("id < NCNN_MAX_PARAM_COUNT failed (id=%d, NCNN_MAX_PARAM_COUNT=%d)", id, NCNN_MAX_PARAM_COUNT);
271
int nscan = dr.scan("%d", &len);
274
NCNN_LOGE("ParamDict read array length failed");
278
d->params[id].v.create(len);
280
for (int j = 0; j < len; j++)
283
nscan = dr.scan(",%15[^,\n ]", vstr);
286
NCNN_LOGE("ParamDict read array element failed");
290
bool is_float = vstr_is_float(vstr);
294
float* ptr = d->params[id].v;
295
ptr[j] = vstr_to_float(vstr);
299
int* ptr = d->params[id].v;
300
nscan = sscanf(vstr, "%d", &ptr[j]);
303
NCNN_LOGE("ParamDict parse array element failed");
308
d->params[id].type = is_float ? 6 : 5;
314
int nscan = dr.scan("%15s", vstr);
317
NCNN_LOGE("ParamDict read value failed");
321
bool is_float = vstr_is_float(vstr);
325
d->params[id].f = vstr_to_float(vstr);
329
nscan = sscanf(vstr, "%d", &d->params[id].i);
332
NCNN_LOGE("ParamDict parse value failed");
337
d->params[id].type = is_float ? 3 : 2;
345
int ParamDict::load_param_bin(const DataReader& dr)
353
// binary 3 | array_bit
364
nread = dr.read(&id, sizeof(int));
365
if (nread != sizeof(int))
367
NCNN_LOGE("ParamDict read id failed %zd", nread);
372
swap_endianness_32(&id);
377
bool is_array = id <= -23300;
383
if (id >= NCNN_MAX_PARAM_COUNT)
385
NCNN_LOGE("id < NCNN_MAX_PARAM_COUNT failed (id=%d, NCNN_MAX_PARAM_COUNT=%d)", id, NCNN_MAX_PARAM_COUNT);
392
nread = dr.read(&len, sizeof(int));
393
if (nread != sizeof(int))
395
NCNN_LOGE("ParamDict read array length failed %zd", nread);
400
swap_endianness_32(&len);
403
d->params[id].v.create(len);
405
float* ptr = d->params[id].v;
406
nread = dr.read(ptr, sizeof(float) * len);
407
if (nread != sizeof(float) * len)
409
NCNN_LOGE("ParamDict read array element failed %zd", nread);
414
for (int i = 0; i < len; i++)
416
swap_endianness_32(ptr + i);
420
d->params[id].type = 4;
424
nread = dr.read(&d->params[id].f, sizeof(float));
425
if (nread != sizeof(float))
427
NCNN_LOGE("ParamDict read value failed %zd", nread);
432
swap_endianness_32(&d->params[id].f);
435
d->params[id].type = 1;
438
nread = dr.read(&id, sizeof(int));
439
if (nread != sizeof(int))
441
NCNN_LOGE("ParamDict read EOP failed %zd", nread);
446
swap_endianness_32(&id);