You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

save_ncnn.cpp 15 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456
  1. // Tencent is pleased to support the open source community by making ncnn available.
  2. //
  3. // Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved.
  4. //
  5. // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
  6. // in compliance with the License. You may obtain a copy of the License at
  7. //
  8. // https://opensource.org/licenses/BSD-3-Clause
  9. //
  10. // Unless required by applicable law or agreed to in writing, software distributed
  11. // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
  12. // CONDITIONS OF ANY KIND, either express or implied. See the License for the
  13. // specific language governing permissions and limitations under the License.
  14. #include "save_ncnn.h"
  15. namespace pnnx {
  16. static bool type_is_integer(int type)
  17. {
  18. if (type == 1) return false;
  19. if (type == 2) return false;
  20. if (type == 3) return false;
  21. if (type == 4) return true;
  22. if (type == 5) return true;
  23. if (type == 6) return true;
  24. if (type == 7) return true;
  25. if (type == 8) return true;
  26. if (type == 9) return true;
  27. if (type == 10) return false;
  28. if (type == 11) return false;
  29. if (type == 12) return false;
  30. return false;
  31. }
  32. static const char* type_to_dtype_string(int type)
  33. {
  34. if (type == 1) return "torch.float";
  35. if (type == 2) return "torch.double";
  36. if (type == 3) return "torch.half";
  37. if (type == 4) return "torch.int";
  38. if (type == 5) return "torch.long";
  39. if (type == 6) return "torch.short";
  40. if (type == 7) return "torch.int8";
  41. if (type == 8) return "torch.uint8";
  42. if (type == 9) return "torch.bool";
  43. if (type == 10) return "torch.complex64";
  44. if (type == 11) return "torch.complex128";
  45. if (type == 12) return "torch.complex32";
  46. return "null";
  47. }
  48. static bool string_is_positive_integer(const std::string& t)
  49. {
  50. for (size_t i = 0; i < t.size(); i++)
  51. {
  52. if (t[i] < '0' || t[i] > '9')
  53. return false;
  54. }
  55. return true;
  56. }
  57. static unsigned short float32_to_float16(float value)
  58. {
  59. // 1 : 8 : 23
  60. union
  61. {
  62. unsigned int u;
  63. float f;
  64. } tmp;
  65. tmp.f = value;
  66. // 1 : 8 : 23
  67. unsigned short sign = (tmp.u & 0x80000000) >> 31;
  68. unsigned short exponent = (tmp.u & 0x7F800000) >> 23;
  69. unsigned int significand = tmp.u & 0x7FFFFF;
  70. // NCNN_LOGE("%d %d %d", sign, exponent, significand);
  71. // 1 : 5 : 10
  72. unsigned short fp16;
  73. if (exponent == 0)
  74. {
  75. // zero or denormal, always underflow
  76. fp16 = (sign << 15) | (0x00 << 10) | 0x00;
  77. }
  78. else if (exponent == 0xFF)
  79. {
  80. // infinity or NaN
  81. fp16 = (sign << 15) | (0x1F << 10) | (significand ? 0x200 : 0x00);
  82. }
  83. else
  84. {
  85. // normalized
  86. short newexp = exponent + (-127 + 15);
  87. if (newexp >= 31)
  88. {
  89. // overflow, return infinity
  90. fp16 = (sign << 15) | (0x1F << 10) | 0x00;
  91. }
  92. else if (newexp <= 0)
  93. {
  94. // Some normal fp32 cannot be expressed as normal fp16
  95. fp16 = (sign << 15) | (0x00 << 10) | 0x00;
  96. }
  97. else
  98. {
  99. // normal fp16
  100. fp16 = (sign << 15) | (newexp << 10) | (significand >> 13);
  101. }
  102. }
  103. return fp16;
  104. }
  105. static size_t alignSize(size_t sz, int n)
  106. {
  107. return (sz + n - 1) & -n;
  108. }
  109. int save_ncnn(const Graph& g, const std::string& parampath, const std::string& binpath, const std::string& pypath, int fp16)
  110. {
  111. FILE* paramfp = fopen(parampath.c_str(), "wb");
  112. if (!paramfp)
  113. {
  114. fprintf(stderr, "fopen %s failed\n", parampath.c_str());
  115. return -1;
  116. }
  117. FILE* binfp = fopen(binpath.c_str(), "wb");
  118. if (!binfp)
  119. {
  120. fprintf(stderr, "fopen %s failed\n", binpath.c_str());
  121. fclose(paramfp);
  122. return -1;
  123. }
  124. // magic
  125. fprintf(paramfp, "7767517\n");
  126. // op count and oprand count
  127. fprintf(paramfp, "%d %d\n", (int)g.ops.size(), (int)g.operands.size());
  128. for (const Operator* op : g.ops)
  129. {
  130. fprintf(paramfp, "%-24s %-24s %d %d", op->type.c_str(), op->name.c_str(), (int)op->inputs.size(), (int)op->outputs.size());
  131. for (const Operand* oprand : op->inputs)
  132. {
  133. fprintf(paramfp, " %s", oprand->name.c_str());
  134. }
  135. for (const Operand* oprand : op->outputs)
  136. {
  137. fprintf(paramfp, " %s", oprand->name.c_str());
  138. }
  139. for (const auto& it : op->params)
  140. {
  141. const Parameter& param = it.second;
  142. if (!string_is_positive_integer(it.first))
  143. {
  144. fprintf(stderr, "ignore %s %s param %s=", op->type.c_str(), op->name.c_str(), it.first.c_str());
  145. if (param.type == 0)
  146. {
  147. fprintf(stderr, "None");
  148. }
  149. if (param.type == 1)
  150. {
  151. if (param.b)
  152. fprintf(stderr, "True");
  153. else
  154. fprintf(stderr, "False");
  155. }
  156. if (param.type == 2)
  157. {
  158. fprintf(stderr, "%d", param.i);
  159. }
  160. if (param.type == 3)
  161. {
  162. fprintf(stderr, "%e", param.f);
  163. }
  164. if (param.type == 4)
  165. {
  166. fprintf(stderr, "%s", param.s.c_str());
  167. }
  168. if (param.type == 5)
  169. {
  170. fprintf(stderr, "(");
  171. for (size_t i = 0; i < param.ai.size(); i++)
  172. {
  173. fprintf(stderr, "%d", param.ai[i]);
  174. if (i + 1 != param.ai.size())
  175. fprintf(stderr, ",");
  176. }
  177. fprintf(stderr, ")");
  178. }
  179. if (param.type == 6)
  180. {
  181. fprintf(stderr, "(");
  182. for (size_t i = 0; i < param.af.size(); i++)
  183. {
  184. fprintf(stderr, "%e", param.af[i]);
  185. if (i + 1 != param.af.size())
  186. fprintf(stderr, ",");
  187. }
  188. fprintf(stderr, ")");
  189. }
  190. if (param.type == 7)
  191. {
  192. fprintf(stderr, "(");
  193. for (size_t i = 0; i < param.as.size(); i++)
  194. {
  195. fprintf(stderr, "%s", param.as[i].c_str());
  196. if (i + 1 != param.as.size())
  197. fprintf(stderr, ",");
  198. }
  199. fprintf(stderr, ")");
  200. }
  201. fprintf(stderr, "\n");
  202. continue;
  203. }
  204. const int idkey = std::stoi(it.first);
  205. if (param.type == 2)
  206. {
  207. fprintf(paramfp, " %d=%d", idkey, param.i);
  208. }
  209. if (param.type == 3)
  210. {
  211. fprintf(paramfp, " %d=%e", idkey, param.f);
  212. }
  213. if (param.type == 5)
  214. {
  215. const int array_size = (int)param.ai.size();
  216. fprintf(paramfp, " %d=%d", -23300 - idkey, array_size);
  217. for (size_t i = 0; i < param.ai.size(); i++)
  218. {
  219. fprintf(paramfp, ",%d", param.ai[i]);
  220. }
  221. }
  222. if (param.type == 6)
  223. {
  224. const int array_size = (int)param.af.size();
  225. fprintf(paramfp, " %d=%d", -23300 - idkey, array_size);
  226. for (size_t i = 0; i < param.af.size(); i++)
  227. {
  228. fprintf(paramfp, ",%e", param.af[i]);
  229. }
  230. }
  231. }
  232. bool is_type_flag_fp32 = false;
  233. for (const auto& it : op->attrs)
  234. {
  235. // fprintf(paramfp, " @%s=", it.first.c_str());
  236. const Attribute& attr = it.second;
  237. if (fp16 && is_type_flag_fp32)
  238. {
  239. // fp32 -> fp16
  240. const float* p = (const float*)attr.data.data();
  241. int len = attr.data.size() / 4;
  242. std::vector<char> data_fp16(alignSize(len * 2, 4));
  243. unsigned short* p_fp16 = (unsigned short*)data_fp16.data();
  244. for (int i = 0; i < len; i++)
  245. {
  246. p_fp16[i] = float32_to_float16(p[i]);
  247. }
  248. // pad size to 4bytes
  249. if (len % 2 == 1)
  250. {
  251. // pad with fixed value for model hash consistency
  252. p_fp16[len] = 0x2283;
  253. }
  254. fwrite(data_fp16.data(), data_fp16.size(), 1, binfp);
  255. is_type_flag_fp32 = false;
  256. continue;
  257. }
  258. if (fp16 && attr.type == 0 && attr.data == std::vector<char> {0, 0, 0, 0})
  259. {
  260. // write fp16 flag
  261. unsigned int fp16_flag = 0x01306B47;
  262. fwrite((const char*)&fp16_flag, sizeof(fp16_flag), 1, binfp);
  263. is_type_flag_fp32 = true;
  264. continue;
  265. }
  266. fwrite(attr.data.data(), attr.data.size(), 1, binfp);
  267. }
  268. // if (op->inputnames.size() == op->inputs.size())
  269. // {
  270. // for (size_t i = 0; i < op->inputs.size(); i++)
  271. // {
  272. // const Operand* oprand = op->inputs[i];
  273. // fprintf(paramfp, " $%s=%s", op->inputnames[i].c_str(), oprand->name.c_str());
  274. // }
  275. // }
  276. // for (const Operand* oprand : op->outputs)
  277. // {
  278. // if (oprand->params.find("__batch_index") == oprand->params.end())
  279. // continue;
  280. //
  281. // const int batch_index = oprand->params.at("__batch_index").i;
  282. //
  283. // fprintf(paramfp, " #%s=%d", oprand->name.c_str(), batch_index);
  284. // }
  285. // for (const Operand* oprand : op->outputs)
  286. // {
  287. // if (oprand->shape.empty())
  288. // continue;
  289. //
  290. // fprintf(paramfp, " #%s=", oprand->name.c_str());
  291. //
  292. // fprintf(paramfp, "(");
  293. // for (int64_t i = 0; i < oprand->shape.size() - 1; i++)
  294. // {
  295. // fprintf(paramfp, "%d,", oprand->shape[i]);
  296. // }
  297. // if (oprand->shape.size() > 0)
  298. // fprintf(paramfp, "%d", oprand->shape[oprand->shape.size() - 1]);
  299. // fprintf(paramfp, ")");
  300. //
  301. // fprintf(paramfp, type_to_string(oprand->type));
  302. // }
  303. fprintf(paramfp, "\n");
  304. }
  305. fclose(paramfp);
  306. fclose(binfp);
  307. FILE* pyfp = fopen(pypath.c_str(), "wb");
  308. if (!pyfp)
  309. {
  310. fprintf(stderr, "fopen %s failed\n", pypath.c_str());
  311. return -1;
  312. }
  313. fprintf(pyfp, "import numpy as np\n");
  314. fprintf(pyfp, "import ncnn\n");
  315. fprintf(pyfp, "import torch\n");
  316. fprintf(pyfp, "\n");
  317. // test inference
  318. {
  319. fprintf(pyfp, "def test_inference():\n");
  320. fprintf(pyfp, " torch.manual_seed(0)\n");
  321. for (int input_index = 0;; input_index++)
  322. {
  323. std::string input_name = std::string("in") + std::to_string(input_index);
  324. const Operand* r = g.get_operand(input_name);
  325. if (!r)
  326. break;
  327. if (type_is_integer(r->type))
  328. {
  329. fprintf(pyfp, " %s = torch.randint(10, (", input_name.c_str());
  330. for (size_t i = 0; i < r->shape.size(); i++)
  331. {
  332. fprintf(pyfp, "%d", r->shape[i]);
  333. if (i + 1 != r->shape.size() || r->shape.size() == 1)
  334. fprintf(pyfp, ", ");
  335. }
  336. fprintf(pyfp, "), dtype=%s)\n", type_to_dtype_string(r->type));
  337. }
  338. else
  339. {
  340. fprintf(pyfp, " %s = torch.rand(", input_name.c_str());
  341. for (size_t i = 0; i < r->shape.size(); i++)
  342. {
  343. fprintf(pyfp, "%d, ", r->shape[i]);
  344. }
  345. fprintf(pyfp, "dtype=%s)\n", type_to_dtype_string(r->type));
  346. }
  347. }
  348. fprintf(pyfp, " out = []\n");
  349. fprintf(pyfp, "\n");
  350. fprintf(pyfp, " with ncnn.Net() as net:\n");
  351. fprintf(pyfp, " net.load_param(\"%s\")\n", parampath.c_str());
  352. fprintf(pyfp, " net.load_model(\"%s\")\n", binpath.c_str());
  353. fprintf(pyfp, "\n");
  354. fprintf(pyfp, " with net.create_extractor() as ex:\n");
  355. for (int input_index = 0;; input_index++)
  356. {
  357. std::string input_name = std::string("in") + std::to_string(input_index);
  358. const Operand* r = g.get_operand(input_name);
  359. if (!r)
  360. break;
  361. const int batch_index = r->params.at("__batch_index").i;
  362. if (batch_index != 233)
  363. {
  364. fprintf(pyfp, " ex.input(\"%s\", ncnn.Mat(%s.squeeze(%d).numpy()).clone())\n", input_name.c_str(), input_name.c_str(), batch_index);
  365. }
  366. else
  367. {
  368. fprintf(pyfp, " ex.input(\"%s\", ncnn.Mat(%s.numpy()).clone())\n", input_name.c_str(), input_name.c_str());
  369. }
  370. }
  371. fprintf(pyfp, "\n");
  372. for (int output_index = 0;; output_index++)
  373. {
  374. std::string output_name = std::string("out") + std::to_string(output_index);
  375. const Operand* r = g.get_operand(output_name);
  376. if (!r)
  377. break;
  378. fprintf(pyfp, " _, %s = ex.extract(\"%s\")\n", output_name.c_str(), output_name.c_str());
  379. const int batch_index = r->params.at("__batch_index").i;
  380. if (batch_index != 233)
  381. {
  382. fprintf(pyfp, " out.append(torch.from_numpy(np.array(%s)).unsqueeze(%d))\n", output_name.c_str(), batch_index);
  383. }
  384. else
  385. {
  386. fprintf(pyfp, " out.append(torch.from_numpy(np.array(%s)))\n", output_name.c_str());
  387. }
  388. }
  389. fprintf(pyfp, "\n");
  390. fprintf(pyfp, " if len(out) == 1:\n");
  391. fprintf(pyfp, " return out[0]\n");
  392. fprintf(pyfp, " else:\n");
  393. fprintf(pyfp, " return tuple(out)\n");
  394. }
  395. fclose(pyfp);
  396. return 0;
  397. }
  398. } // namespace pnnx