You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

paramdict.cpp 9.5 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430
  1. // Tencent is pleased to support the open source community by making ncnn available.
  2. //
  3. // Copyright (C) 2017 THL A29 Limited, a Tencent company. All rights reserved.
  4. //
  5. // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
  6. // in compliance with the License. You may obtain a copy of the License at
  7. //
  8. // https://opensource.org/licenses/BSD-3-Clause
  9. //
  10. // Unless required by applicable law or agreed to in writing, software distributed
  11. // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
  12. // CONDITIONS OF ANY KIND, either express or implied. See the License for the
  13. // specific language governing permissions and limitations under the License.
  14. #include "paramdict.h"
  15. #include "datareader.h"
  16. #include "mat.h"
  17. #include "platform.h"
  18. #include <ctype.h>
  19. #if NCNN_STDIO
  20. #include <stdio.h>
  21. #endif
  22. namespace ncnn {
  23. class ParamDictPrivate
  24. {
  25. public:
  26. struct
  27. {
  28. // 0 = null
  29. // 1 = int/float
  30. // 2 = int
  31. // 3 = float
  32. // 4 = array of int/float
  33. // 5 = array of int
  34. // 6 = array of float
  35. int type;
  36. union
  37. {
  38. int i;
  39. float f;
  40. };
  41. Mat v;
  42. } params[NCNN_MAX_PARAM_COUNT];
  43. };
  44. ParamDict::ParamDict()
  45. : d(new ParamDictPrivate)
  46. {
  47. clear();
  48. }
  49. ParamDict::~ParamDict()
  50. {
  51. delete d;
  52. }
  53. ParamDict::ParamDict(const ParamDict& rhs)
  54. : d(new ParamDictPrivate)
  55. {
  56. for (int i = 0; i < NCNN_MAX_PARAM_COUNT; i++)
  57. {
  58. int type = rhs.d->params[i].type;
  59. d->params[i].type = type;
  60. if (type == 1 || type == 2 || type == 3)
  61. {
  62. d->params[i].i = rhs.d->params[i].i;
  63. }
  64. else // if (type == 4 || type == 5 || type == 6)
  65. {
  66. d->params[i].v = rhs.d->params[i].v;
  67. }
  68. }
  69. }
  70. ParamDict& ParamDict::operator=(const ParamDict& rhs)
  71. {
  72. if (this == &rhs)
  73. return *this;
  74. for (int i = 0; i < NCNN_MAX_PARAM_COUNT; i++)
  75. {
  76. int type = rhs.d->params[i].type;
  77. d->params[i].type = type;
  78. if (type == 1 || type == 2 || type == 3)
  79. {
  80. d->params[i].i = rhs.d->params[i].i;
  81. }
  82. else // if (type == 4 || type == 5 || type == 6)
  83. {
  84. d->params[i].v = rhs.d->params[i].v;
  85. }
  86. }
  87. return *this;
  88. }
  89. int ParamDict::type(int id) const
  90. {
  91. return d->params[id].type;
  92. }
  93. // TODO strict type check
  94. int ParamDict::get(int id, int def) const
  95. {
  96. return d->params[id].type ? d->params[id].i : def;
  97. }
  98. float ParamDict::get(int id, float def) const
  99. {
  100. return d->params[id].type ? d->params[id].f : def;
  101. }
  102. Mat ParamDict::get(int id, const Mat& def) const
  103. {
  104. return d->params[id].type ? d->params[id].v : def;
  105. }
  106. void ParamDict::set(int id, int i)
  107. {
  108. d->params[id].type = 2;
  109. d->params[id].i = i;
  110. }
  111. void ParamDict::set(int id, float f)
  112. {
  113. d->params[id].type = 3;
  114. d->params[id].f = f;
  115. }
  116. void ParamDict::set(int id, const Mat& v)
  117. {
  118. d->params[id].type = 4;
  119. d->params[id].v = v;
  120. }
  121. void ParamDict::clear()
  122. {
  123. for (int i = 0; i < NCNN_MAX_PARAM_COUNT; i++)
  124. {
  125. d->params[i].type = 0;
  126. d->params[i].v = Mat();
  127. }
  128. }
  129. #if NCNN_STRING
  130. static bool vstr_is_float(const char vstr[16])
  131. {
  132. // look ahead for determine isfloat
  133. for (int j = 0; j < 16; j++)
  134. {
  135. if (vstr[j] == '\0')
  136. break;
  137. if (vstr[j] == '.' || tolower(vstr[j]) == 'e')
  138. return true;
  139. }
  140. return false;
  141. }
  142. static float vstr_to_float(const char vstr[16])
  143. {
  144. double v = 0.0;
  145. const char* p = vstr;
  146. // sign
  147. bool sign = *p != '-';
  148. if (*p == '+' || *p == '-')
  149. {
  150. p++;
  151. }
  152. // digits before decimal point or exponent
  153. unsigned int v1 = 0;
  154. while (isdigit(*p))
  155. {
  156. v1 = v1 * 10 + (*p - '0');
  157. p++;
  158. }
  159. v = (double)v1;
  160. // digits after decimal point
  161. if (*p == '.')
  162. {
  163. p++;
  164. unsigned int pow10 = 1;
  165. unsigned int v2 = 0;
  166. while (isdigit(*p))
  167. {
  168. v2 = v2 * 10 + (*p - '0');
  169. pow10 *= 10;
  170. p++;
  171. }
  172. v += v2 / (double)pow10;
  173. }
  174. // exponent
  175. if (*p == 'e' || *p == 'E')
  176. {
  177. p++;
  178. // sign of exponent
  179. bool fact = *p != '-';
  180. if (*p == '+' || *p == '-')
  181. {
  182. p++;
  183. }
  184. // digits of exponent
  185. unsigned int expon = 0;
  186. while (isdigit(*p))
  187. {
  188. expon = expon * 10 + (*p - '0');
  189. p++;
  190. }
  191. double scale = 1.0;
  192. while (expon >= 8)
  193. {
  194. scale *= 1e8;
  195. expon -= 8;
  196. }
  197. while (expon > 0)
  198. {
  199. scale *= 10.0;
  200. expon -= 1;
  201. }
  202. v = fact ? v * scale : v / scale;
  203. }
  204. // fprintf(stderr, "v = %f\n", v);
  205. return sign ? (float)v : (float)-v;
  206. }
  207. int ParamDict::load_param(const DataReader& dr)
  208. {
  209. clear();
  210. // 0=100 1=1.250000 -23303=5,0.1,0.2,0.4,0.8,1.0
  211. // parse each key=value pair
  212. int id = 0;
  213. while (dr.scan("%d=", &id) == 1)
  214. {
  215. bool is_array = id <= -23300;
  216. if (is_array)
  217. {
  218. id = -id - 23300;
  219. }
  220. if (id >= NCNN_MAX_PARAM_COUNT)
  221. {
  222. NCNN_LOGE("id < NCNN_MAX_PARAM_COUNT failed (id=%d, NCNN_MAX_PARAM_COUNT=%d)", id, NCNN_MAX_PARAM_COUNT);
  223. return -1;
  224. }
  225. if (is_array)
  226. {
  227. int len = 0;
  228. int nscan = dr.scan("%d", &len);
  229. if (nscan != 1)
  230. {
  231. NCNN_LOGE("ParamDict read array length failed");
  232. return -1;
  233. }
  234. d->params[id].v.create(len);
  235. for (int j = 0; j < len; j++)
  236. {
  237. char vstr[16];
  238. nscan = dr.scan(",%15[^,\n ]", vstr);
  239. if (nscan != 1)
  240. {
  241. NCNN_LOGE("ParamDict read array element failed");
  242. return -1;
  243. }
  244. bool is_float = vstr_is_float(vstr);
  245. if (is_float)
  246. {
  247. float* ptr = d->params[id].v;
  248. ptr[j] = vstr_to_float(vstr);
  249. }
  250. else
  251. {
  252. int* ptr = d->params[id].v;
  253. nscan = sscanf(vstr, "%d", &ptr[j]);
  254. if (nscan != 1)
  255. {
  256. NCNN_LOGE("ParamDict parse array element failed");
  257. return -1;
  258. }
  259. }
  260. d->params[id].type = is_float ? 6 : 5;
  261. }
  262. }
  263. else
  264. {
  265. char vstr[16];
  266. int nscan = dr.scan("%15s", vstr);
  267. if (nscan != 1)
  268. {
  269. NCNN_LOGE("ParamDict read value failed");
  270. return -1;
  271. }
  272. bool is_float = vstr_is_float(vstr);
  273. if (is_float)
  274. {
  275. d->params[id].f = vstr_to_float(vstr);
  276. }
  277. else
  278. {
  279. nscan = sscanf(vstr, "%d", &d->params[id].i);
  280. if (nscan != 1)
  281. {
  282. NCNN_LOGE("ParamDict parse value failed");
  283. return -1;
  284. }
  285. }
  286. d->params[id].type = is_float ? 3 : 2;
  287. }
  288. }
  289. return 0;
  290. }
  291. #endif // NCNN_STRING
  292. int ParamDict::load_param_bin(const DataReader& dr)
  293. {
  294. clear();
  295. // binary 0
  296. // binary 100
  297. // binary 1
  298. // binary 1.250000
  299. // binary 3 | array_bit
  300. // binary 5
  301. // binary 0.1
  302. // binary 0.2
  303. // binary 0.4
  304. // binary 0.8
  305. // binary 1.0
  306. // binary -233(EOP)
  307. int id = 0;
  308. size_t nread;
  309. nread = dr.read(&id, sizeof(int));
  310. if (nread != sizeof(int))
  311. {
  312. NCNN_LOGE("ParamDict read id failed %zd", nread);
  313. return -1;
  314. }
  315. while (id != -233)
  316. {
  317. bool is_array = id <= -23300;
  318. if (is_array)
  319. {
  320. id = -id - 23300;
  321. }
  322. if (id >= NCNN_MAX_PARAM_COUNT)
  323. {
  324. NCNN_LOGE("id < NCNN_MAX_PARAM_COUNT failed (id=%d, NCNN_MAX_PARAM_COUNT=%d)", id, NCNN_MAX_PARAM_COUNT);
  325. return -1;
  326. }
  327. if (is_array)
  328. {
  329. int len = 0;
  330. nread = dr.read(&len, sizeof(int));
  331. if (nread != sizeof(int))
  332. {
  333. NCNN_LOGE("ParamDict read array length failed %zd", nread);
  334. return -1;
  335. }
  336. d->params[id].v.create(len);
  337. float* ptr = d->params[id].v;
  338. nread = dr.read(ptr, sizeof(float) * len);
  339. if (nread != sizeof(float) * len)
  340. {
  341. NCNN_LOGE("ParamDict read array element failed %zd", nread);
  342. return -1;
  343. }
  344. d->params[id].type = 4;
  345. }
  346. else
  347. {
  348. nread = dr.read(&d->params[id].f, sizeof(float));
  349. if (nread != sizeof(float))
  350. {
  351. NCNN_LOGE("ParamDict read value failed %zd", nread);
  352. return -1;
  353. }
  354. d->params[id].type = 1;
  355. }
  356. nread = dr.read(&id, sizeof(int));
  357. if (nread != sizeof(int))
  358. {
  359. NCNN_LOGE("ParamDict read EOP failed %zd", nread);
  360. return -1;
  361. }
  362. }
  363. return 0;
  364. }
  365. } // namespace ncnn