You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

save_onnx.cpp 9.6 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333
  1. // Tencent is pleased to support the open source community by making ncnn available.
  2. //
  3. // Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved.
  4. //
  5. // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
  6. // in compliance with the License. You may obtain a copy of the License at
  7. //
  8. // https://opensource.org/licenses/BSD-3-Clause
  9. //
  10. // Unless required by applicable law or agreed to in writing, software distributed
  11. // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
  12. // CONDITIONS OF ANY KIND, either express or implied. See the License for the
  13. // specific language governing permissions and limitations under the License.
  14. #include "save_onnx.h"
  15. #include "onnx.pb.h"
  16. #include <string.h>
  17. #include <fstream>
  18. #include <iostream>
  19. namespace pnnx {
  20. // from cxxabi bridge
  21. extern const char* get_operand_name(const Operand* x);
  22. extern const char* get_operator_type(const Operator* op);
  23. extern const char* get_operator_name(const Operator* op);
  24. extern std::vector<const char*> get_operator_params_keys(const Operator* op);
  25. extern std::vector<const char*> get_operator_attrs_keys(const Operator* op);
  26. extern const Parameter& get_operator_param(const Operator* op, const char* key);
  27. extern const Attribute& get_operator_attr(const Operator* op, const char* key);
  28. extern const char* get_param_s(const Parameter& p);
  29. extern std::vector<const char*> get_param_as(const Parameter& p);
  30. static unsigned short float32_to_float16(float value)
  31. {
  32. // 1 : 8 : 23
  33. union
  34. {
  35. unsigned int u;
  36. float f;
  37. } tmp;
  38. tmp.f = value;
  39. // 1 : 8 : 23
  40. unsigned short sign = (tmp.u & 0x80000000) >> 31;
  41. unsigned short exponent = (tmp.u & 0x7F800000) >> 23;
  42. unsigned int significand = tmp.u & 0x7FFFFF;
  43. // NCNN_LOGE("%d %d %d", sign, exponent, significand);
  44. // 1 : 5 : 10
  45. unsigned short fp16;
  46. if (exponent == 0)
  47. {
  48. // zero or denormal, always underflow
  49. fp16 = (sign << 15) | (0x00 << 10) | 0x00;
  50. }
  51. else if (exponent == 0xFF)
  52. {
  53. // infinity or NaN
  54. fp16 = (sign << 15) | (0x1F << 10) | (significand ? 0x200 : 0x00);
  55. }
  56. else
  57. {
  58. // normalized
  59. short newexp = exponent + (-127 + 15);
  60. if (newexp >= 31)
  61. {
  62. // overflow, return infinity
  63. fp16 = (sign << 15) | (0x1F << 10) | 0x00;
  64. }
  65. else if (newexp <= 0)
  66. {
  67. // Some normal fp32 cannot be expressed as normal fp16
  68. fp16 = (sign << 15) | (0x00 << 10) | 0x00;
  69. }
  70. else
  71. {
  72. // normal fp16
  73. fp16 = (sign << 15) | (newexp << 10) | (significand >> 13);
  74. }
  75. }
  76. return fp16;
  77. }
  78. int save_onnx(const Graph& g, const char* onnxpath, int fp16)
  79. {
  80. onnx::ModelProto model;
  81. onnx::GraphProto* gp = model.mutable_graph();
  82. for (const Operand* x : g.operands)
  83. {
  84. onnx::ValueInfoProto* vip = gp->add_value_info();
  85. vip->set_name(get_operand_name(x));
  86. onnx::TypeProto* tp = vip->mutable_type();
  87. onnx::TypeProto_Tensor* tpt = tp->mutable_tensor_type();
  88. switch (x->type)
  89. {
  90. case 1: // f32
  91. tpt->set_elem_type(fp16 ? 10 : 1);
  92. break;
  93. case 2: // f64
  94. tpt->set_elem_type(fp16 ? 10 : 11);
  95. break;
  96. case 3: // f16
  97. tpt->set_elem_type(10);
  98. break;
  99. case 4: // i32
  100. tpt->set_elem_type(6);
  101. break;
  102. case 5: // i64
  103. tpt->set_elem_type(7);
  104. break;
  105. case 6: // i16
  106. tpt->set_elem_type(5);
  107. break;
  108. case 7: // i8
  109. tpt->set_elem_type(3);
  110. break;
  111. case 8: // u8
  112. tpt->set_elem_type(2);
  113. break;
  114. case 9: // bool
  115. tpt->set_elem_type(9);
  116. break;
  117. case 10: // cp64
  118. tpt->set_elem_type(14);
  119. break;
  120. case 11: // cp128
  121. tpt->set_elem_type(15);
  122. break;
  123. case 12: // cp32
  124. tpt->set_elem_type(0);
  125. break;
  126. default: // null
  127. tpt->set_elem_type(0);
  128. break;
  129. }
  130. onnx::TensorShapeProto* tsp = tpt->mutable_shape();
  131. for (auto s : x->shape)
  132. {
  133. onnx::TensorShapeProto_Dimension* tspd = tsp->add_dim();
  134. tspd->set_dim_value(s);
  135. }
  136. }
  137. for (const Operator* op : g.ops)
  138. {
  139. onnx::NodeProto* np = gp->add_node();
  140. np->set_op_type(get_operator_type(op));
  141. np->set_name(get_operator_name(op));
  142. for (const Operand* oprand : op->inputs)
  143. {
  144. np->add_input(get_operand_name(oprand));
  145. }
  146. for (const Operand* oprand : op->outputs)
  147. {
  148. np->add_output(get_operand_name(oprand));
  149. }
  150. std::vector<const char*> params_keys = get_operator_params_keys(op);
  151. for (const char* param_name : params_keys)
  152. {
  153. const Parameter& param = get_operator_param(op, param_name);
  154. onnx::AttributeProto* ap = np->add_attribute();
  155. ap->set_name(param_name);
  156. if (param.type == 0)
  157. {
  158. ap->set_s("None");
  159. }
  160. if (param.type == 1)
  161. {
  162. if (param.b)
  163. ap->set_i(1);
  164. else
  165. ap->set_i(0);
  166. }
  167. if (param.type == 2)
  168. {
  169. ap->set_i(param.i);
  170. }
  171. if (param.type == 3)
  172. {
  173. ap->set_f(param.f);
  174. }
  175. if (param.type == 4)
  176. {
  177. ap->set_s(get_param_s(param));
  178. }
  179. if (param.type == 5)
  180. {
  181. for (auto i : param.ai)
  182. {
  183. ap->add_ints(i);
  184. }
  185. }
  186. if (param.type == 6)
  187. {
  188. for (auto f : param.af)
  189. {
  190. ap->add_floats(f);
  191. }
  192. }
  193. if (param.type == 7)
  194. {
  195. std::vector<const char*> as = get_param_as(param);
  196. for (auto s : as)
  197. {
  198. ap->add_strings(s);
  199. }
  200. }
  201. }
  202. std::vector<const char*> attrs_keys = get_operator_attrs_keys(op);
  203. for (const char* attr_name : attrs_keys)
  204. {
  205. onnx::TensorProto* tp = gp->add_initializer();
  206. tp->set_name(std::string(get_operator_name(op)) + "." + attr_name);
  207. np->add_input(std::string(get_operator_name(op)) + "." + attr_name);
  208. const Attribute& attr = get_operator_attr(op, attr_name);
  209. for (auto s : attr.shape)
  210. {
  211. tp->add_dims(s);
  212. }
  213. switch (attr.type)
  214. {
  215. case 1: // f32
  216. tp->set_data_type(fp16 ? 10 : 1);
  217. break;
  218. case 2: // f64
  219. tp->set_data_type(fp16 ? 10 : 11);
  220. break;
  221. case 3: // f16
  222. tp->set_data_type(10);
  223. break;
  224. case 4: // i32
  225. tp->set_data_type(6);
  226. break;
  227. case 5: // i64
  228. tp->set_data_type(7);
  229. break;
  230. case 6: // i16
  231. tp->set_data_type(5);
  232. break;
  233. case 7: // i8
  234. tp->set_data_type(3);
  235. break;
  236. case 8: // u8
  237. tp->set_data_type(2);
  238. break;
  239. case 9: // bool
  240. tp->set_data_type(9);
  241. break;
  242. case 10: // cp64
  243. tp->set_data_type(14);
  244. break;
  245. case 11: // cp128
  246. tp->set_data_type(15);
  247. break;
  248. case 12: // cp32
  249. tp->set_data_type(0);
  250. break;
  251. default: // null
  252. tp->set_data_type(0);
  253. break;
  254. }
  255. std::string* d = tp->mutable_raw_data();
  256. if (fp16 && attr.type == 1)
  257. {
  258. // fp32 to fp16
  259. const float* p = (const float*)attr.data.data();
  260. int len = attr.data.size() / 4;
  261. d->resize(len * 2);
  262. unsigned short* p_fp16 = (unsigned short*)d->data();
  263. for (int i = 0; i < len; i++)
  264. {
  265. p_fp16[i] = float32_to_float16(p[i]);
  266. }
  267. }
  268. else if (fp16 && attr.type == 2)
  269. {
  270. // fp64 to fp16
  271. const double* p = (const double*)attr.data.data();
  272. int len = attr.data.size() / 4;
  273. d->resize(len);
  274. unsigned short* p_fp16 = (unsigned short*)d->data();
  275. for (int i = 0; i < len; i++)
  276. {
  277. p_fp16[i] = float32_to_float16((float)p[i]);
  278. }
  279. }
  280. else
  281. {
  282. d->resize(attr.data.size());
  283. memcpy((void*)d->data(), attr.data.data(), attr.data.size());
  284. }
  285. }
  286. }
  287. std::fstream output(onnxpath, std::ios::out | std::ios::trunc | std::ios::binary);
  288. if (!model.SerializeToOstream(&output))
  289. {
  290. fprintf(stderr, "write onnx failed\n");
  291. return -1;
  292. }
  293. return 0;
  294. }
  295. } // namespace pnnx