You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

paramdict.cpp 15 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642
  1. // Tencent is pleased to support the open source community by making ncnn available.
  2. //
  3. // Copyright (C) 2017 THL A29 Limited, a Tencent company. All rights reserved.
  4. //
  5. // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
  6. // in compliance with the License. You may obtain a copy of the License at
  7. //
  8. // https://opensource.org/licenses/BSD-3-Clause
  9. //
  10. // Unless required by applicable law or agreed to in writing, software distributed
  11. // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
  12. // CONDITIONS OF ANY KIND, either express or implied. See the License for the
  13. // specific language governing permissions and limitations under the License.
  14. #include "paramdict.h"
  15. #include "datareader.h"
  16. #include "mat.h"
  17. #include "platform.h"
  18. #include <ctype.h>
  19. #if NCNN_STDIO
  20. #include <stdio.h>
  21. #endif
  22. namespace ncnn {
  23. class ParamDictPrivate
  24. {
  25. public:
  26. struct
  27. {
  28. // 0 = null
  29. // 1 = int/float
  30. // 2 = int
  31. // 3 = float
  32. // 4 = array of int/float
  33. // 5 = array of int
  34. // 6 = array of float
  35. // 7 = string
  36. int type;
  37. union
  38. {
  39. int i;
  40. float f;
  41. };
  42. Mat v;
  43. std::string s;
  44. } params[NCNN_MAX_PARAM_COUNT];
  45. };
  46. ParamDict::ParamDict()
  47. : d(new ParamDictPrivate)
  48. {
  49. clear();
  50. }
  51. ParamDict::~ParamDict()
  52. {
  53. delete d;
  54. }
  55. ParamDict::ParamDict(const ParamDict& rhs)
  56. : d(new ParamDictPrivate)
  57. {
  58. for (int i = 0; i < NCNN_MAX_PARAM_COUNT; i++)
  59. {
  60. int type = rhs.d->params[i].type;
  61. d->params[i].type = type;
  62. if (type == 1 || type == 2 || type == 3)
  63. {
  64. d->params[i].i = rhs.d->params[i].i;
  65. }
  66. else if (type == 7)
  67. {
  68. d->params[i].s = rhs.d->params[i].s;
  69. }
  70. else // if (type == 4 || type == 5 || type == 6)
  71. {
  72. d->params[i].v = rhs.d->params[i].v;
  73. }
  74. }
  75. }
  76. ParamDict& ParamDict::operator=(const ParamDict& rhs)
  77. {
  78. if (this == &rhs)
  79. return *this;
  80. for (int i = 0; i < NCNN_MAX_PARAM_COUNT; i++)
  81. {
  82. int type = rhs.d->params[i].type;
  83. d->params[i].type = type;
  84. if (type == 1 || type == 2 || type == 3)
  85. {
  86. d->params[i].i = rhs.d->params[i].i;
  87. }
  88. else if (type == 7)
  89. {
  90. d->params[i].s = rhs.d->params[i].s;
  91. }
  92. else // if (type == 4 || type == 5 || type == 6)
  93. {
  94. d->params[i].v = rhs.d->params[i].v;
  95. }
  96. }
  97. return *this;
  98. }
  99. int ParamDict::type(int id) const
  100. {
  101. return d->params[id].type;
  102. }
  103. // TODO strict type check
  104. int ParamDict::get(int id, int def) const
  105. {
  106. return d->params[id].type ? d->params[id].i : def;
  107. }
  108. float ParamDict::get(int id, float def) const
  109. {
  110. return d->params[id].type ? d->params[id].f : def;
  111. }
  112. Mat ParamDict::get(int id, const Mat& def) const
  113. {
  114. return d->params[id].type ? d->params[id].v : def;
  115. }
  116. std::string ParamDict::get(int id, const std::string& def) const
  117. {
  118. return d->params[id].type ? d->params[id].s : def;
  119. }
  120. void ParamDict::set(int id, int i)
  121. {
  122. d->params[id].type = 2;
  123. d->params[id].i = i;
  124. }
  125. void ParamDict::set(int id, float f)
  126. {
  127. d->params[id].type = 3;
  128. d->params[id].f = f;
  129. }
  130. void ParamDict::set(int id, const Mat& v)
  131. {
  132. d->params[id].type = 4;
  133. d->params[id].v = v;
  134. }
  135. void ParamDict::set(int id, const std::string& s)
  136. {
  137. d->params[id].type = 7;
  138. d->params[id].s = s;
  139. }
  140. void ParamDict::clear()
  141. {
  142. for (int i = 0; i < NCNN_MAX_PARAM_COUNT; i++)
  143. {
  144. d->params[i].type = 0;
  145. d->params[i].i = 0;
  146. d->params[i].v = Mat();
  147. d->params[i].s.clear();
  148. }
  149. }
  150. #if NCNN_STRING
  151. static bool vstr_is_float(const char vstr[16])
  152. {
  153. // look ahead for determine isfloat
  154. for (int j = 0; j < 16; j++)
  155. {
  156. if (vstr[j] == '\0')
  157. break;
  158. if (vstr[j] == '.' || tolower(vstr[j]) == 'e')
  159. return true;
  160. }
  161. return false;
  162. }
  163. static bool vstr_is_string(const char vstr[16])
  164. {
  165. return isalpha(vstr[0]) || vstr[0] == '\"';
  166. }
  167. static float vstr_to_float(const char vstr[16])
  168. {
  169. double v = 0.0;
  170. const char* p = vstr;
  171. // sign
  172. bool sign = *p != '-';
  173. if (*p == '+' || *p == '-')
  174. {
  175. p++;
  176. }
  177. // digits before decimal point or exponent
  178. unsigned int v1 = 0;
  179. while (isdigit(*p))
  180. {
  181. v1 = v1 * 10 + (*p - '0');
  182. p++;
  183. }
  184. v = (double)v1;
  185. // digits after decimal point
  186. if (*p == '.')
  187. {
  188. p++;
  189. unsigned int pow10 = 1;
  190. unsigned int v2 = 0;
  191. while (isdigit(*p))
  192. {
  193. v2 = v2 * 10 + (*p - '0');
  194. pow10 *= 10;
  195. p++;
  196. }
  197. v += v2 / (double)pow10;
  198. }
  199. // exponent
  200. if (*p == 'e' || *p == 'E')
  201. {
  202. p++;
  203. // sign of exponent
  204. bool fact = *p != '-';
  205. if (*p == '+' || *p == '-')
  206. {
  207. p++;
  208. }
  209. // digits of exponent
  210. unsigned int expon = 0;
  211. while (isdigit(*p))
  212. {
  213. expon = expon * 10 + (*p - '0');
  214. p++;
  215. }
  216. double scale = 1.0;
  217. while (expon >= 8)
  218. {
  219. scale *= 1e8;
  220. expon -= 8;
  221. }
  222. while (expon > 0)
  223. {
  224. scale *= 10.0;
  225. expon -= 1;
  226. }
  227. v = fact ? v * scale : v / scale;
  228. }
  229. // fprintf(stderr, "v = %f\n", v);
  230. return sign ? (float)v : (float)-v;
  231. }
  232. int ParamDict::load_param(const DataReader& dr)
  233. {
  234. clear();
  235. // 0=100 1=1.250000 -23303=5,0.1,0.2,0.4,0.8,1.0
  236. // 3=0.1,0.2,0.4,0.8,1.0
  237. // parse each key=value pair
  238. int id = 0;
  239. while (dr.scan("%d=", &id) == 1)
  240. {
  241. bool is_array = id <= -23300;
  242. if (is_array)
  243. {
  244. id = -id - 23300;
  245. }
  246. if (id >= NCNN_MAX_PARAM_COUNT)
  247. {
  248. NCNN_LOGE("id < NCNN_MAX_PARAM_COUNT failed (id=%d, NCNN_MAX_PARAM_COUNT=%d)", id, NCNN_MAX_PARAM_COUNT);
  249. return -1;
  250. }
  251. if (is_array)
  252. {
  253. // old style array
  254. int len = 0;
  255. int nscan = dr.scan("%d", &len);
  256. if (nscan != 1)
  257. {
  258. NCNN_LOGE("ParamDict read array length failed");
  259. return -1;
  260. }
  261. d->params[id].v.create(len);
  262. for (int j = 0; j < len; j++)
  263. {
  264. char vstr[16];
  265. nscan = dr.scan(",%15[^,\n ]", vstr);
  266. if (nscan != 1)
  267. {
  268. NCNN_LOGE("ParamDict read array element failed");
  269. return -1;
  270. }
  271. bool is_float = vstr_is_float(vstr);
  272. if (is_float)
  273. {
  274. float* ptr = d->params[id].v;
  275. ptr[j] = vstr_to_float(vstr);
  276. }
  277. else
  278. {
  279. int* ptr = d->params[id].v;
  280. nscan = sscanf(vstr, "%d", &ptr[j]);
  281. if (nscan != 1)
  282. {
  283. NCNN_LOGE("ParamDict parse array element failed");
  284. return -1;
  285. }
  286. }
  287. d->params[id].type = is_float ? 6 : 5;
  288. }
  289. continue;
  290. }
  291. char vstr[16];
  292. char comma[4];
  293. int nscan = dr.scan("%15[^,\n ]", vstr);
  294. if (nscan != 1)
  295. {
  296. NCNN_LOGE("ParamDict read value failed");
  297. return -1;
  298. }
  299. bool is_string = vstr_is_string(vstr);
  300. if (is_string)
  301. {
  302. // scan the remaining string
  303. char vstr2[256];
  304. vstr2[241] = '\0'; // max 255 = 15 + 240
  305. if (vstr[0] == '\"')
  306. {
  307. nscan = dr.scan("%255[^\"]\"", vstr2);
  308. }
  309. else
  310. {
  311. nscan = dr.scan("%255[^\n ]", vstr2);
  312. }
  313. if (nscan == 1)
  314. {
  315. if (vstr2[241] != '\0')
  316. {
  317. NCNN_LOGE("string too long (id=%d)", id);
  318. return -1;
  319. }
  320. if (vstr[0] == '\"')
  321. d->params[id].s = std::string(&vstr[1]) + vstr2;
  322. else
  323. d->params[id].s = std::string(vstr) + vstr2;
  324. }
  325. else
  326. {
  327. if (vstr[0] == '\"')
  328. d->params[id].s = std::string(&vstr[1]);
  329. else
  330. d->params[id].s = std::string(vstr);
  331. }
  332. if (d->params[id].s[d->params[id].s.size() - 1] == '\"')
  333. d->params[id].s.resize(d->params[id].s.size() - 1);
  334. d->params[id].type = 7;
  335. continue;
  336. }
  337. bool is_float = vstr_is_float(vstr);
  338. nscan = dr.scan("%1[,]", comma);
  339. is_array = nscan == 1;
  340. if (is_array)
  341. {
  342. std::vector<float> af;
  343. std::vector<int> ai;
  344. if (is_float)
  345. {
  346. af.push_back(vstr_to_float(vstr));
  347. }
  348. else
  349. {
  350. int v = 0;
  351. nscan = sscanf(vstr, "%d", &v);
  352. if (nscan != 1)
  353. {
  354. NCNN_LOGE("ParamDict parse value failed");
  355. return -1;
  356. }
  357. ai.push_back(v);
  358. }
  359. while (1)
  360. {
  361. nscan = dr.scan("%15[^,\n ]", vstr);
  362. if (nscan != 1)
  363. {
  364. break;
  365. }
  366. if (is_float)
  367. {
  368. af.push_back(vstr_to_float(vstr));
  369. }
  370. else
  371. {
  372. int v = 0;
  373. nscan = sscanf(vstr, "%d", &v);
  374. if (nscan != 1)
  375. {
  376. NCNN_LOGE("ParamDict parse value failed");
  377. return -1;
  378. }
  379. ai.push_back(v);
  380. }
  381. nscan = dr.scan("%1[,]", comma);
  382. if (nscan != 1)
  383. {
  384. break;
  385. }
  386. }
  387. if (is_float)
  388. {
  389. d->params[id].v.create((int)af.size());
  390. memcpy(d->params[id].v.data, af.data(), af.size() * 4);
  391. }
  392. else
  393. {
  394. d->params[id].v.create((int)ai.size());
  395. memcpy(d->params[id].v.data, ai.data(), ai.size() * 4);
  396. }
  397. d->params[id].type = is_float ? 6 : 5;
  398. }
  399. else
  400. {
  401. if (is_float)
  402. {
  403. d->params[id].f = vstr_to_float(vstr);
  404. }
  405. else
  406. {
  407. nscan = sscanf(vstr, "%d", &d->params[id].i);
  408. if (nscan != 1)
  409. {
  410. NCNN_LOGE("ParamDict parse value failed");
  411. return -1;
  412. }
  413. }
  414. d->params[id].type = is_float ? 3 : 2;
  415. }
  416. }
  417. return 0;
  418. }
  419. #endif // NCNN_STRING
  420. int ParamDict::load_param_bin(const DataReader& dr)
  421. {
  422. clear();
  423. // binary 0
  424. // binary 100
  425. // binary 1
  426. // binary 1.250000
  427. // binary 3 | array_bit
  428. // binary 5
  429. // binary 0.1
  430. // binary 0.2
  431. // binary 0.4
  432. // binary 0.8
  433. // binary 1.0
  434. // binary -233(EOP)
  435. int id = 0;
  436. size_t nread;
  437. nread = dr.read(&id, sizeof(int));
  438. if (nread != sizeof(int))
  439. {
  440. NCNN_LOGE("ParamDict read id failed %zd", nread);
  441. return -1;
  442. }
  443. #if __BIG_ENDIAN__
  444. swap_endianness_32(&id);
  445. #endif
  446. while (id != -233)
  447. {
  448. bool is_array = id <= -23300;
  449. bool is_string = id <= -23400;
  450. if (is_string)
  451. {
  452. id = -id - 23400;
  453. }
  454. else if (is_array)
  455. {
  456. id = -id - 23300;
  457. }
  458. if (id >= NCNN_MAX_PARAM_COUNT)
  459. {
  460. NCNN_LOGE("id < NCNN_MAX_PARAM_COUNT failed (id=%d, NCNN_MAX_PARAM_COUNT=%d)", id, NCNN_MAX_PARAM_COUNT);
  461. return -1;
  462. }
  463. if (is_string)
  464. {
  465. int len = 0;
  466. nread = dr.read(&len, sizeof(int));
  467. if (nread != sizeof(int))
  468. {
  469. NCNN_LOGE("ParamDict read array length failed %zd", nread);
  470. return -1;
  471. }
  472. #if __BIG_ENDIAN__
  473. swap_endianness_32(&len);
  474. #endif
  475. if (len > 255)
  476. {
  477. NCNN_LOGE("string too long (id=%d)", id);
  478. return -1;
  479. }
  480. size_t len_padded = (len + 3) / 4 * 4;
  481. std::vector<char> tmpstr(len_padded + 1);
  482. char* ptr = (char*)tmpstr.data();
  483. nread = dr.read(ptr, len_padded);
  484. if (nread != len_padded)
  485. {
  486. NCNN_LOGE("ParamDict read string failed %zd", nread);
  487. return -1;
  488. }
  489. tmpstr[len_padded] = '\0';
  490. d->params[id].s = tmpstr.data();
  491. d->params[id].type = 7;
  492. }
  493. else if (is_array)
  494. {
  495. int len = 0;
  496. nread = dr.read(&len, sizeof(int));
  497. if (nread != sizeof(int))
  498. {
  499. NCNN_LOGE("ParamDict read array length failed %zd", nread);
  500. return -1;
  501. }
  502. #if __BIG_ENDIAN__
  503. swap_endianness_32(&len);
  504. #endif
  505. d->params[id].v.create(len);
  506. float* ptr = d->params[id].v;
  507. nread = dr.read(ptr, sizeof(float) * len);
  508. if (nread != sizeof(float) * len)
  509. {
  510. NCNN_LOGE("ParamDict read array element failed %zd", nread);
  511. return -1;
  512. }
  513. #if __BIG_ENDIAN__
  514. for (int i = 0; i < len; i++)
  515. {
  516. swap_endianness_32(ptr + i);
  517. }
  518. #endif
  519. d->params[id].type = 4;
  520. }
  521. else
  522. {
  523. nread = dr.read(&d->params[id].f, sizeof(float));
  524. if (nread != sizeof(float))
  525. {
  526. NCNN_LOGE("ParamDict read value failed %zd", nread);
  527. return -1;
  528. }
  529. #if __BIG_ENDIAN__
  530. swap_endianness_32(&d->params[id].f);
  531. #endif
  532. d->params[id].type = 1;
  533. }
  534. nread = dr.read(&id, sizeof(int));
  535. if (nread != sizeof(int))
  536. {
  537. NCNN_LOGE("ParamDict read EOP failed %zd", nread);
  538. return -1;
  539. }
  540. #if __BIG_ENDIAN__
  541. swap_endianness_32(&id);
  542. #endif
  543. }
  544. return 0;
  545. }
  546. } // namespace ncnn