You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

mxnet2ncnn.cpp 29 kB

8 years ago
8 years ago
8 years ago
8 years ago
8 years ago
8 years ago
8 years ago
8 years ago
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075
  1. // Tencent is pleased to support the open source community by making ncnn available.
  2. //
  3. // Copyright (C) 2017 THL A29 Limited, a Tencent company. All rights reserved.
  4. //
  5. // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
  6. // in compliance with the License. You may obtain a copy of the License at
  7. //
  8. // https://opensource.org/licenses/BSD-3-Clause
  9. //
  10. // Unless required by applicable law or agreed to in writing, software distributed
  11. // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
  12. // CONDITIONS OF ANY KIND, either express or implied. See the License for the
  13. // specific language governing permissions and limitations under the License.
  14. #include <stdio.h>
  15. #include <stdint.h>
  16. #include <string.h>
  17. #include <map>
  18. #include <set>
  19. #include <string>
  20. #include <vector>
  21. class MXNetParam;
  22. class MXNetNode
  23. {
  24. public:
  25. bool has_attr(const char* key) const;
  26. class AttrProxy
  27. {
  28. MXNetNode const* _n;
  29. const char* const _key;
  30. public:
  31. AttrProxy( MXNetNode const* n, const char* key ) : _n(n), _key(key) {}
  32. operator int() const { return _n->attr_i(_key); }
  33. operator float() const { return _n->attr_f(_key); }
  34. operator std::string() const { return _n->attr_s(_key); }
  35. operator std::vector<int>() const { return _n->attr_ai(_key); }
  36. };
  37. AttrProxy attr(const char* key) const { return AttrProxy(this, key); }
  38. int attr_i(const char* key) const;
  39. float attr_f(const char* key) const;
  40. std::string attr_s(const char* key) const;
  41. std::vector<int> attr_ai(const char* key) const;
  42. public:
  43. bool is_weight() const;
  44. bool has_weight(int i) const;
  45. std::vector<float> weight(int i, int init_len = 0) const;
  46. std::vector<MXNetNode>* nodes;// reference
  47. std::vector<MXNetParam>* params;// reference
  48. public:
  49. std::string op;
  50. std::string name;
  51. std::map<std::string, std::string> attrs;
  52. std::vector<int> inputs;
  53. std::vector<int> weights;
  54. };
  55. class MXNetParam
  56. {
  57. public:
  58. std::string name;
  59. std::vector<float> data;
  60. std::string init;
  61. };
  62. bool MXNetNode::has_attr(const char* key) const
  63. {
  64. const std::map<std::string, std::string>::const_iterator it = attrs.find(key);
  65. return it != attrs.end();
  66. }
  67. int MXNetNode::attr_i(const char* key) const
  68. {
  69. const std::map<std::string, std::string>::const_iterator it = attrs.find(key);
  70. if (it == attrs.end())
  71. return 0;
  72. if (it->second == "False")
  73. return 0;
  74. if (it->second == "True")
  75. return 1;
  76. int i = 0;
  77. int nscan = sscanf(it->second.c_str(), "%d", &i);
  78. if (nscan != 1)
  79. return 0;
  80. return i;
  81. }
  82. float MXNetNode::attr_f(const char* key) const
  83. {
  84. const std::map<std::string, std::string>::const_iterator it = attrs.find(key);
  85. if (it == attrs.end())
  86. return 0.f;
  87. float f = 0;
  88. int nscan = sscanf(it->second.c_str(), "%f", &f);
  89. if (nscan != 1)
  90. return 0.f;
  91. return f;
  92. }
  93. std::string MXNetNode::attr_s(const char* key) const
  94. {
  95. const std::map<std::string, std::string>::const_iterator it = attrs.find(key);
  96. if (it == attrs.end())
  97. return std::string();
  98. return it->second;
  99. }
  100. std::vector<int> MXNetNode::attr_ai(const char* key) const
  101. {
  102. const std::map<std::string, std::string>::const_iterator it = attrs.find(key);
  103. if (it == attrs.end())
  104. return std::vector<int>();
  105. // (1,2,3)
  106. std::vector<int> list;
  107. int i = 0;
  108. int c = 0;
  109. int nconsumed = 0;
  110. int nscan = sscanf(it->second.c_str() + c, "%*[(,]%d%n", &i, &nconsumed);
  111. while (nscan == 1)
  112. {
  113. list.push_back(i);
  114. // fprintf(stderr, "%d\n", i);
  115. i = 0;
  116. c += nconsumed;
  117. nscan = sscanf(it->second.c_str() + c, "%*[(,]%d%n", &i, &nconsumed);
  118. }
  119. return list;
  120. }
  121. bool MXNetNode::is_weight() const
  122. {
  123. for (int i=0; i<(int)(*params).size(); i++)
  124. {
  125. const MXNetParam& p = (*params)[i];
  126. if (p.name == name)
  127. return true;
  128. }
  129. return false;
  130. }
  131. bool MXNetNode::has_weight(int i) const
  132. {
  133. if (i < 0 || i >= (int)weights.size())
  134. return false;
  135. const std::string& name = (*nodes)[ weights[i] ].name;
  136. for (int i=0; i<(int)(*params).size(); i++)
  137. {
  138. const MXNetParam& p = (*params)[i];
  139. if (p.name == name)
  140. return true;
  141. }
  142. return false;
  143. }
  144. std::vector<float> MXNetNode::weight(int i, int init_len) const
  145. {
  146. if (i < 0 || i >= (int)weights.size())
  147. return std::vector<float>();
  148. const std::string& name = (*nodes)[ weights[i] ].name;
  149. for (int i=0; i<(int)(*params).size(); i++)
  150. {
  151. const MXNetParam& p = (*params)[i];
  152. if (p.name != name)
  153. continue;
  154. if (!p.data.empty())
  155. return p.data;
  156. std::vector<float> data;
  157. if (!p.init.empty() && init_len != 0)
  158. {
  159. if (p.init == "[\\$zero\\$, {}]")
  160. {
  161. data.resize(init_len, 0.f);
  162. }
  163. else if (p.init == "[\\$one\\$, {}]")
  164. {
  165. data.resize(init_len, 1.f);
  166. }
  167. }
  168. return data;
  169. }
  170. return std::vector<float>();
  171. }
  172. static void replace_backslash_doublequote_dollar(char* s)
  173. {
  174. char* a = s;
  175. char* b = s+1;
  176. while (*a && *b)
  177. {
  178. if (*a == '\\' && *b == '\"')
  179. {
  180. *b = '$';
  181. }
  182. a++;
  183. b++;
  184. }
  185. }
  186. static std::vector<int> parse_input_list(const char* s)
  187. {
  188. std::vector<int> inputs;
  189. if (memcmp(s, "[]", 2) == 0)
  190. return inputs;
  191. int nscan = 0;
  192. int nconsumed = 0;
  193. int id;
  194. int c = 1;// skip leading [
  195. nscan = sscanf(s + c, "[%d, %*[^]]]%n", &id, &nconsumed);
  196. while (nscan == 1)
  197. {
  198. inputs.push_back(id);
  199. // fprintf(stderr, "%d\n", id);
  200. c += nconsumed;
  201. nscan = sscanf(s + c, "%*[^[][%d, %*[^]]]%n", &id, &nconsumed);
  202. }
  203. return inputs;
  204. }
  205. static bool read_mxnet_json(const char* jsonpath, std::vector<MXNetNode>& nodes)
  206. {
  207. FILE* fp = fopen(jsonpath, "rb");
  208. if (!fp)
  209. {
  210. fprintf(stderr, "fopen %s failed\n", jsonpath);
  211. return false;
  212. }
  213. int internal_unknown = 0;
  214. char line[1024];
  215. //{
  216. fgets(line, 1024, fp);
  217. MXNetNode n;
  218. bool in_nodes_list = false;
  219. bool in_node_block = false;
  220. bool in_attr_block = false;
  221. while (!feof(fp))
  222. {
  223. char* s = fgets(line, 1024, fp);
  224. if (!s)
  225. break;
  226. if (in_attr_block)
  227. {
  228. // },
  229. if (memcmp(line, " }", 7) == 0)
  230. {
  231. in_attr_block = false;
  232. continue;
  233. }
  234. // replace \" with \$
  235. replace_backslash_doublequote_dollar(line);
  236. // "kernel": "(7,7)",
  237. char key[256] = {0};
  238. char value[256] = {0};
  239. int nscan = sscanf(line, " \"%255[^\"]\": \"%255[^\"]\"", key, value);
  240. if (nscan == 2)
  241. {
  242. n.attrs[key] = value;
  243. // fprintf(stderr, "# %s = %s\n", key, value);
  244. continue;
  245. }
  246. }
  247. if (in_node_block)
  248. {
  249. // },
  250. if (memcmp(line, " }", 5) == 0)
  251. {
  252. // new node
  253. if (n.name.empty())
  254. {
  255. // assign default unknown name
  256. char unknownname[256];
  257. sprintf(unknownname, "unknownncnn_%d", internal_unknown);
  258. n.name = unknownname;
  259. internal_unknown++;
  260. }
  261. nodes.push_back(n);
  262. in_node_block = false;
  263. continue;
  264. }
  265. int nscan;
  266. // "op": "Convolution",
  267. char op[256] = {0};
  268. nscan = sscanf(line, " \"op\": \"%255[^\"]\",", op);
  269. if (nscan == 1)
  270. {
  271. n.op = op;
  272. // fprintf(stderr, "op = %s\n", op);
  273. continue;
  274. }
  275. // "name": "conv0",
  276. char name[256] = {0};
  277. nscan = sscanf(line, " \"name\": \"%255[^\"]\",", name);
  278. if (nscan == 1)
  279. {
  280. n.name = name;
  281. // fprintf(stderr, "name = %s\n", name);
  282. continue;
  283. }
  284. // "inputs": []
  285. char inputs[256] = {0};
  286. nscan = sscanf(line, " \"inputs\": %255[^\n]", inputs);
  287. if (nscan == 1)
  288. {
  289. n.inputs = parse_input_list(inputs);
  290. // fprintf(stderr, "inputs = %s\n", inputs);
  291. continue;
  292. }
  293. // "param": {},
  294. if (memcmp(line, " \"param\": {}", 17) == 0)
  295. {
  296. continue;
  297. }
  298. // replace \" with \$
  299. replace_backslash_doublequote_dollar(line);
  300. // "attr": {"__init__": "[\"zero\", {}]"},
  301. char key[256] = {0};
  302. char value[256] = {0};
  303. nscan = sscanf(line, " \"attr\": {\"%255[^\"]\": \"%255[^\"]\"}", key, value);
  304. if (nscan == 2)
  305. {
  306. n.attrs[key] = value;
  307. // fprintf(stderr, "# %s = %s\n", key, value);
  308. continue;
  309. }
  310. // "attrs": {"__init__": "[\"zero\", {}]"},
  311. nscan = sscanf(line, " \"attrs\": {\"%255[^\"]\": \"%255[^\"]\"}", key, value);
  312. if (nscan == 2)
  313. {
  314. n.attrs[key] = value;
  315. // fprintf(stderr, "# %s = %s\n", key, value);
  316. continue;
  317. }
  318. // "param": {"p": "0.5"},
  319. nscan = sscanf(line, " \"param\": {\"%255[^\"]\": \"%255[^\"]\"}", key, value);
  320. if (nscan == 2)
  321. {
  322. n.attrs[key] = value;
  323. // fprintf(stderr, "# %s = %s\n", key, value);
  324. continue;
  325. }
  326. // "attr": {
  327. if (memcmp(line, " \"attr\": {", 15) == 0)
  328. {
  329. in_attr_block = true;
  330. continue;
  331. }
  332. // "attrs": {
  333. if (memcmp(line, " \"attrs\": {", 15) == 0)
  334. {
  335. in_attr_block = true;
  336. continue;
  337. }
  338. // "param": {
  339. if (memcmp(line, " \"param\": {", 16) == 0)
  340. {
  341. in_attr_block = true;
  342. continue;
  343. }
  344. }
  345. if (in_nodes_list)
  346. {
  347. // ],
  348. if (memcmp(line, " ],", 4) == 0)
  349. {
  350. in_nodes_list = false;
  351. // all nodes parsed
  352. break;
  353. }
  354. // {
  355. if (memcmp(line, " {", 5) == 0)
  356. {
  357. n = MXNetNode();
  358. in_node_block = true;
  359. continue;
  360. }
  361. }
  362. // "nodes": [
  363. if (memcmp(line, " \"nodes\": [", 12) == 0)
  364. {
  365. in_nodes_list = true;
  366. continue;
  367. }
  368. }
  369. fclose(fp);
  370. return true;
  371. }
  372. static bool read_mxnet_param(const char* parampath, std::vector<MXNetParam>& params)
  373. {
  374. FILE* fp = fopen(parampath, "rb");
  375. if (!fp)
  376. {
  377. fprintf(stderr, "fopen %s failed\n", parampath);
  378. return false;
  379. }
  380. uint64_t header;
  381. uint64_t reserved;
  382. fread(&header, 1, sizeof(uint64_t), fp);
  383. fread(&reserved, 1, sizeof(uint64_t), fp);
  384. // NDArray vec
  385. // each data
  386. uint64_t data_count;
  387. fread(&data_count, 1, sizeof(uint64_t), fp);
  388. // fprintf(stderr, "data count = %d\n", (int)data_count);
  389. for (int i = 0; i < (int)data_count; i++)
  390. {
  391. uint32_t magic;// 0xF993FAC9
  392. fread(&magic, 1, sizeof(uint32_t), fp);
  393. // shape
  394. uint32_t ndim;
  395. std::vector<int64_t> shape;
  396. if (magic == 0xF993FAC9)
  397. {
  398. int32_t stype;
  399. fread(&stype, 1, sizeof(int32_t), fp);
  400. fread(&ndim, 1, sizeof(uint32_t), fp);
  401. shape.resize(ndim);
  402. fread(&shape[0], 1, ndim * sizeof(int64_t), fp);
  403. }
  404. else if (magic == 0xF993FAC8)
  405. {
  406. fread(&ndim, 1, sizeof(uint32_t), fp);
  407. shape.resize(ndim);
  408. fread(&shape[0], 1, ndim * sizeof(int64_t), fp);
  409. }
  410. else
  411. {
  412. ndim = magic;
  413. shape.resize(ndim);
  414. std::vector<uint32_t> shape32;
  415. shape32.resize(ndim);
  416. fread(&shape32[0], 1, ndim * sizeof(uint32_t), fp);
  417. for (int j=0; j<(int)ndim; j++)
  418. {
  419. shape[j] = shape32[j];
  420. }
  421. }
  422. // context
  423. int32_t dev_type;
  424. int32_t dev_id;
  425. fread(&dev_type, 1, sizeof(int32_t), fp);
  426. fread(&dev_id, 1, sizeof(int32_t), fp);
  427. int32_t type_flag;
  428. fread(&type_flag, 1, sizeof(int32_t), fp);
  429. // data
  430. size_t len = 0;
  431. if (shape.size() == 1) len = shape[0];
  432. if (shape.size() == 2) len = shape[0] * shape[1];
  433. if (shape.size() == 3) len = shape[0] * shape[1] * shape[2];
  434. if (shape.size() == 4) len = shape[0] * shape[1] * shape[2] * shape[3];
  435. MXNetParam p;
  436. p.data.resize(len);
  437. fread(&p.data[0], 1, len * sizeof(float), fp);
  438. params.push_back(p);
  439. // fprintf(stderr, "%u read\n", len);
  440. }
  441. // each name
  442. uint64_t name_count;
  443. fread(&name_count, 1, sizeof(uint64_t), fp);
  444. // fprintf(stderr, "name count = %d\n", (int)name_count);
  445. for (int i = 0; i < (int)name_count; i++)
  446. {
  447. uint64_t len;
  448. fread(&len, 1, sizeof(uint64_t), fp);
  449. MXNetParam& p = params[i];
  450. p.name.resize(len);
  451. fread((char*)p.name.data(), 1, len, fp);
  452. // cut leading arg:
  453. if (memcmp(p.name.c_str(), "arg:", 4) == 0)
  454. {
  455. p.name = std::string(p.name.c_str() + 4);
  456. }
  457. if (memcmp(p.name.c_str(), "aux:", 4) == 0)
  458. {
  459. p.name = std::string(p.name.c_str() + 4);
  460. }
  461. // fprintf(stderr, "%s read\n", p.name.c_str());
  462. }
  463. fclose(fp);
  464. return true;
  465. }
  466. int main(int argc, char** argv)
  467. {
  468. const char* jsonpath = argv[1];
  469. const char* parampath = argv[2];
  470. const char* ncnn_prototxt = argc >= 5 ? argv[3] : "ncnn.proto";
  471. const char* ncnn_modelbin = argc >= 5 ? argv[4] : "ncnn.bin";
  472. std::vector<MXNetNode> nodes;
  473. std::vector<MXNetParam> params;
  474. read_mxnet_json(jsonpath, nodes);
  475. read_mxnet_param(parampath, params);
  476. FILE* pp = fopen(ncnn_prototxt, "wb");
  477. FILE* bp = fopen(ncnn_modelbin, "wb");
  478. // magic
  479. fprintf(pp, "7767517\n");
  480. int node_count = nodes.size();
  481. // node reference
  482. std::map<int, int> node_reference;
  483. // weight node
  484. std::vector<int> weight_nodes;
  485. // global definition line
  486. // [layer count] [blob count]
  487. std::set<std::string> blob_names;
  488. for (int i=0; i<node_count; i++)
  489. {
  490. MXNetNode& n = nodes[i];
  491. // assign global param reference
  492. n.nodes = &nodes;
  493. n.params = &params;
  494. const std::string& output_name = n.name;
  495. if (n.op == "null")
  496. {
  497. if (n.is_weight())
  498. {
  499. weight_nodes.push_back(i);
  500. }
  501. else
  502. {
  503. if (n.has_attr("__init__"))
  504. {
  505. // init weight param
  506. MXNetParam pi;
  507. pi.name = n.name;
  508. pi.init = (std::string)n.attr("__init__");
  509. params.push_back(pi);
  510. weight_nodes.push_back(i);
  511. }
  512. else
  513. {
  514. // null node without data, treat it as network input
  515. }
  516. }
  517. continue;
  518. }
  519. // distinguish weights and inputs
  520. std::vector<int> weights;
  521. std::vector<int> inputs;
  522. for (int j=0; j<(int)n.inputs.size(); j++)
  523. {
  524. int input_index = n.inputs[j];
  525. if (nodes[input_index].is_weight())
  526. {
  527. weights.push_back(input_index);
  528. continue;
  529. }
  530. inputs.push_back(input_index);
  531. }
  532. n.inputs = inputs;
  533. n.weights = weights;
  534. // input
  535. for (int j=0; j<(int)n.inputs.size(); j++)
  536. {
  537. int input_index = n.inputs[j];
  538. const std::string& input_name = nodes[input_index].name;
  539. // fprintf(stderr, "input = %s\n", input_name.c_str());
  540. blob_names.insert(input_name);
  541. if (node_reference.find(input_index) == node_reference.end())
  542. {
  543. node_reference[input_index] = 1;
  544. }
  545. else
  546. {
  547. node_reference[input_index] = node_reference[input_index] + 1;
  548. }
  549. }
  550. // output
  551. // fprintf(stderr, "output = %s\n", output_name.c_str());
  552. blob_names.insert(output_name);
  553. }
  554. // remove node_reference entry with reference equals to one
  555. int splitncnn_blob_count = 0;
  556. std::map<int, int>::iterator it = node_reference.begin();
  557. while (it != node_reference.end())
  558. {
  559. if (it->second == 1)
  560. {
  561. node_reference.erase(it++);
  562. }
  563. else
  564. {
  565. splitncnn_blob_count += it->second;
  566. // fprintf(stderr, "%s %d\n", it->first.c_str(), it->second);
  567. ++it;
  568. }
  569. }
  570. fprintf(pp, "%lu %lu\n", node_count + node_reference.size() - weight_nodes.size(), blob_names.size() + splitncnn_blob_count);
  571. int internal_split = 0;
  572. for (int i=0; i<node_count; i++)
  573. {
  574. const MXNetNode& n = nodes[i];
  575. if (n.op == "null")
  576. {
  577. if (n.is_weight())
  578. {
  579. continue;
  580. }
  581. fprintf(pp, "%-16s", "Input");
  582. }
  583. else if (n.op == "Activation")
  584. {
  585. std::string type = n.attr("act_type");
  586. if (type == "relu")
  587. {
  588. fprintf(pp, "%-16s", "ReLU");
  589. }
  590. else if (type == "sigmoid")
  591. {
  592. fprintf(pp, "%-16s", "Sigmoid");
  593. }
  594. else if (type == "tanh")
  595. {
  596. fprintf(pp, "%-16s", "TanH");
  597. }
  598. }
  599. else if (n.op == "BatchNorm")
  600. {
  601. fprintf(pp, "%-16s", "BatchNorm");
  602. }
  603. else if (n.op == "Concat")
  604. {
  605. fprintf(pp, "%-16s", "Concat");
  606. }
  607. else if (n.op == "Convolution")
  608. {
  609. int num_group = n.attr("num_group");
  610. if (num_group > 0) {
  611. fprintf(pp, "%-16s", "ConvolutionDepthWise");
  612. } else {
  613. fprintf(pp, "%-16s", "Convolution");
  614. }
  615. }
  616. else if (n.op == "Dropout")
  617. {
  618. fprintf(pp, "%-16s", "Dropout");
  619. }
  620. else if (n.op == "elemwise_add")
  621. {
  622. fprintf(pp, "%-16s", "Eltwise");
  623. }
  624. else if (n.op == "Flatten")
  625. {
  626. fprintf(pp, "%-16s", "Flatten");
  627. }
  628. else if (n.op == "FullyConnected")
  629. {
  630. fprintf(pp, "%-16s", "InnerProduct");
  631. }
  632. else if (n.op == "LeakyReLU")
  633. {
  634. std::string type = n.attr("act_type");
  635. if (type == "elu")
  636. {
  637. fprintf(pp, "%-16s", "ELU");
  638. }
  639. else if (type == "leaky")
  640. {
  641. fprintf(pp, "%-16s", "ReLU");
  642. }
  643. else if (type == "prelu")
  644. {
  645. fprintf(pp, "%-16s", "PReLU");
  646. }
  647. }
  648. else if (n.op == "Pooling")
  649. {
  650. fprintf(pp, "%-16s", "Pooling");
  651. }
  652. else if (n.op == "SoftmaxOutput")
  653. {
  654. fprintf(pp, "%-16s", "Softmax");
  655. }
  656. else
  657. {
  658. fprintf(stderr, "%s not supported yet!\n", n.op.c_str());
  659. fprintf(pp, "%-16s", n.op.c_str());
  660. }
  661. int input_size = n.inputs.size();
  662. for (int j=0; j<(int)n.inputs.size(); j++)
  663. {
  664. int input_index = n.inputs[j];
  665. if (nodes[input_index].is_weight())
  666. {
  667. input_size--;
  668. }
  669. }
  670. if (n.op == "SoftmaxOutput")
  671. {
  672. // drop label
  673. input_size--;
  674. }
  675. fprintf(pp, " %-32s %d 1", n.name.c_str(), input_size);
  676. for (int j=0; j<(int)n.inputs.size(); j++)
  677. {
  678. int input_index = n.inputs[j];
  679. if (nodes[input_index].is_weight())
  680. {
  681. continue;
  682. }
  683. if (n.op == "SoftmaxOutput")
  684. {
  685. // drop label
  686. if (nodes[input_index].op == "null")
  687. continue;
  688. }
  689. std::string input_name = nodes[input_index].name;
  690. if (node_reference.find(input_index) != node_reference.end())
  691. {
  692. int refidx = node_reference[input_index] - 1;
  693. node_reference[input_index] = refidx;
  694. char splitsuffix[256];
  695. sprintf(splitsuffix, "_splitncnn_%d", refidx);
  696. input_name = input_name + splitsuffix;
  697. }
  698. fprintf(pp, " %s", input_name.c_str());
  699. }
  700. fprintf(pp, " %s", n.name.c_str());
  701. if (n.op == "null")
  702. {
  703. // dummy input shape
  704. // fprintf(pp, " 0 0 0");
  705. }
  706. else if (n.op == "Activation")
  707. {
  708. std::string type = n.attr("act_type");
  709. if (type == "relu")
  710. {
  711. // fprintf(pp, " 0=%f", 0.f);
  712. }
  713. }
  714. else if (n.op == "BatchNorm")
  715. {
  716. float eps = 1e-3;
  717. if (n.has_attr("eps")) {
  718. eps = n.attr("eps");
  719. }
  720. std::vector<float> slope_data = n.weight(0);
  721. std::vector<float> bias_data = n.weight(1);
  722. int channels = slope_data.size();
  723. std::vector<float> mean_data = n.weight(2, channels);
  724. std::vector<float> var_data = n.weight(3, channels);
  725. for (int j=0; j<(int)var_data.size(); j++)
  726. {
  727. var_data[j] += eps;
  728. }
  729. fprintf(pp, " 0=%d", channels);
  730. fwrite(slope_data.data(), sizeof(float), slope_data.size(), bp);
  731. fwrite(mean_data.data(), sizeof(float), mean_data.size(), bp);
  732. fwrite(var_data.data(), sizeof(float), var_data.size(), bp);
  733. fwrite(bias_data.data(), sizeof(float), bias_data.size(), bp);
  734. }
  735. else if (n.op == "Concat")
  736. {
  737. int dim = n.attr("dim");
  738. fprintf(pp, " 0=%d", dim-1);
  739. }
  740. else if (n.op == "Convolution")
  741. {
  742. int num_filter = n.attr("num_filter");
  743. std::vector<int> kernel = n.attr("kernel");
  744. std::vector<int> dilate = n.attr("dilate");
  745. std::vector<int> stride = n.attr("stride");
  746. std::vector<int> pad = n.attr("pad");
  747. int no_bias = n.attr("no_bias");
  748. int num_group = n.attr("num_group");//TODO depthwise
  749. std::vector<float> weight_data = n.weight(0);
  750. std::vector<float> bias_data = n.weight(1);
  751. fprintf(pp, " 0=%d", num_filter);
  752. if (kernel.size() == 1) {
  753. fprintf(pp, " 1=%d", kernel[0]);
  754. } else if (kernel.size() == 2) {
  755. fprintf(pp, " 1=%d", kernel[1]);
  756. fprintf(pp, " 11=%d", kernel[0]);
  757. }
  758. if (dilate.size() == 1) {
  759. fprintf(pp, " 2=%d", dilate[0]);
  760. } else if (dilate.size() == 2) {
  761. fprintf(pp, " 2=%d", dilate[1]);
  762. fprintf(pp, " 12=%d", dilate[0]);
  763. }
  764. if (stride.size() == 1) {
  765. fprintf(pp, " 3=%d", stride[0]);
  766. } else if (stride.size() == 2) {
  767. fprintf(pp, " 3=%d", stride[1]);
  768. fprintf(pp, " 13=%d", stride[0]);
  769. }
  770. if (pad.size() == 1) {
  771. fprintf(pp, " 4=%d", pad[0]);
  772. } else if (pad.size() == 2) {
  773. fprintf(pp, " 4=%d", pad[1]);
  774. fprintf(pp, " 14=%d", pad[0]);
  775. }
  776. fprintf(pp, " 5=%d", no_bias == 1 ? 0 : 1);
  777. fprintf(pp, " 6=%d", (int)weight_data.size());
  778. if (num_group > 0) {
  779. fprintf(pp, " 7=%d", num_group);
  780. }
  781. int quantize_tag = 0;
  782. fwrite(&quantize_tag, sizeof(int), 1, bp);
  783. fwrite(weight_data.data(), sizeof(float), weight_data.size(), bp);
  784. fwrite(bias_data.data(), sizeof(float), bias_data.size(), bp);
  785. }
  786. else if (n.op == "Dropout")
  787. {
  788. // float p = n.attr("p");
  789. // fprintf(pp, " 0=%d", p);
  790. }
  791. else if (n.op == "elemwise_add")
  792. {
  793. int op_type = 1;
  794. fprintf(pp, " 0=%d", op_type);
  795. }
  796. else if (n.op == "Flatten")
  797. {
  798. }
  799. else if (n.op == "FullyConnected")
  800. {
  801. int num_hidden = n.attr("num_hidden");
  802. int no_bias = n.attr("no_bias");
  803. // int flatten = n.attr("flatten");
  804. // TODO flatten
  805. std::vector<float> weight_data = n.weight(0);
  806. std::vector<float> bias_data = n.weight(1);
  807. fprintf(pp, " 0=%d", num_hidden);
  808. fprintf(pp, " 1=%d", no_bias == 1 ? 0 : 1);
  809. fprintf(pp, " 2=%d", (int)weight_data.size());
  810. int quantize_tag = 0;
  811. fwrite(&quantize_tag, sizeof(int), 1, bp);
  812. fwrite(weight_data.data(), sizeof(float), weight_data.size(), bp);
  813. fwrite(bias_data.data(), sizeof(float), bias_data.size(), bp);
  814. }
  815. else if (n.op == "LeakyReLU")
  816. {
  817. std::string type = n.attr("act_type");
  818. if (type == "elu")
  819. {
  820. }
  821. else if (type == "leaky")
  822. {
  823. }
  824. else if (type == "prelu")
  825. {
  826. std::vector<float> weight_data = n.weight(0);
  827. fprintf(pp, " 0=%d", (int)weight_data.size());
  828. fwrite(weight_data.data(), sizeof(float), weight_data.size(), bp);
  829. }
  830. }
  831. else if (n.op == "Pooling")
  832. {
  833. std::string pool_type = n.attr("pool_type");
  834. std::vector<int> kernel = n.attr("kernel");
  835. std::vector<int> stride = n.attr("stride");
  836. std::vector<int> pad = n.attr("pad");
  837. std::string pooling_convention = n.attr("pooling_convention");
  838. int global_pool = n.attr("global_pool");
  839. int pool = 0;
  840. if (pool_type == "max")
  841. {
  842. pool = 0;
  843. }
  844. else if (pool_type == "avg")
  845. {
  846. pool = 1;
  847. }
  848. if (pooling_convention == "valid")
  849. {
  850. // TODO valid and full mode
  851. }
  852. fprintf(pp, " 0=%d", pool);
  853. if (!kernel.empty())
  854. fprintf(pp, " 1=%d", kernel[0]);
  855. if (!stride.empty())
  856. fprintf(pp, " 2=%d", stride[0]);
  857. if (!pad.empty())
  858. fprintf(pp, " 3=%d", pad[0]);
  859. fprintf(pp, " 4=%d", global_pool);
  860. }
  861. else if (n.op == "SoftmaxOutput")
  862. {
  863. }
  864. else
  865. {
  866. // TODO op specific params
  867. std::map<std::string, std::string>::const_iterator it = n.attrs.begin();
  868. for (; it != n.attrs.end(); it++)
  869. {
  870. fprintf(stderr, "# %s=%s\n", it->first.c_str(), it->second.c_str());
  871. // fprintf(pp, " %s=%s", it->first.c_str(), it->second.c_str());
  872. }
  873. }
  874. fprintf(pp, "\n");
  875. if (node_reference.find(i) != node_reference.end())
  876. {
  877. int refcount = node_reference[i];
  878. if (refcount > 1)
  879. {
  880. std::string output_name = n.name;
  881. char splitname[256];
  882. sprintf(splitname, "splitncnn_%d", internal_split);
  883. fprintf(pp, "%-16s %-32s %d %d", "Split", splitname, 1, refcount);
  884. fprintf(pp, " %s", output_name.c_str());
  885. for (int j=0; j<refcount; j++)
  886. {
  887. fprintf(pp, " %s_splitncnn_%d", output_name.c_str(), j);
  888. }
  889. fprintf(pp, "\n");
  890. internal_split++;
  891. }
  892. }
  893. }
  894. fclose(pp);
  895. fclose(bp);
  896. return 0;
  897. }