You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

onnx2ncnn.cpp 52 kB

8 years ago
8 years ago
8 years ago
8 years ago
8 years ago
8 years ago
8 years ago
8 years ago
8 years ago
8 years ago
8 years ago
8 years ago
8 years ago
8 years ago
8 years ago
8 years ago
8 years ago
8 years ago
8 years ago
8 years ago
8 years ago
8 years ago
8 years ago
8 years ago
8 years ago
8 years ago
8 years ago
8 years ago
8 years ago
8 years ago
8 years ago
8 years ago
8 years ago
8 years ago
8 years ago
8 years ago
8 years ago
8 years ago
8 years ago
8 years ago
8 years ago
8 years ago
8 years ago
8 years ago
8 years ago
8 years ago
8 years ago
8 years ago
8 years ago
8 years ago
8 years ago
8 years ago
8 years ago
8 years ago
8 years ago
8 years ago
8 years ago
8 years ago
8 years ago
8 years ago
8 years ago
8 years ago
8 years ago
8 years ago
8 years ago
8 years ago
8 years ago
8 years ago
8 years ago
8 years ago
8 years ago
8 years ago
8 years ago
8 years ago
8 years ago
8 years ago
8 years ago
8 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698
  1. // Tencent is pleased to support the open source community by making ncnn available.
  2. //
  3. // Copyright (C) 2017 THL A29 Limited, a Tencent company. All rights reserved.
  4. //
  5. // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
  6. // in compliance with the License. You may obtain a copy of the License at
  7. //
  8. // https://opensource.org/licenses/BSD-3-Clause
  9. //
  10. // Unless required by applicable law or agreed to in writing, software distributed
  11. // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
  12. // CONDITIONS OF ANY KIND, either express or implied. See the License for the
  13. // specific language governing permissions and limitations under the License.
  14. #include <float.h>
  15. #include <stdio.h>
  16. #include <limits.h>
  17. #include <iostream>
  18. #include <fstream>
  19. #include <set>
  20. #include <limits>
  21. #include <algorithm>
  22. #include <google/protobuf/io/coded_stream.h>
  23. #include <google/protobuf/io/zero_copy_stream_impl.h>
  24. #include <google/protobuf/text_format.h>
  25. #include <google/protobuf/message.h>
  26. #include "onnx.pb.h"
  27. static bool read_proto_from_binary(const char* filepath, google::protobuf::Message* message)
  28. {
  29. std::ifstream fs(filepath, std::ifstream::in | std::ifstream::binary);
  30. if (!fs.is_open())
  31. {
  32. fprintf(stderr, "open failed %s\n", filepath);
  33. return false;
  34. }
  35. google::protobuf::io::IstreamInputStream input(&fs);
  36. google::protobuf::io::CodedInputStream codedstr(&input);
  37. codedstr.SetTotalBytesLimit(INT_MAX, INT_MAX / 2);
  38. bool success = message->ParseFromCodedStream(&codedstr);
  39. fs.close();
  40. return success;
  41. }
  42. static std::vector<int> get_node_attr_ai(const onnx::NodeProto& node, const char* key)
  43. {
  44. std::vector<int> v;
  45. for (int i=0; i<node.attribute_size(); i++)
  46. {
  47. const onnx::AttributeProto& attr = node.attribute(i);
  48. if (attr.name() == key)
  49. {
  50. v.resize(attr.ints_size());
  51. for (int j=0; j<attr.ints_size(); j++)
  52. {
  53. v[j] = attr.ints(j);
  54. }
  55. break;
  56. }
  57. }
  58. return v;
  59. }
  60. static std::vector<float> get_node_attr_af(const onnx::NodeProto& node, const char* key)
  61. {
  62. std::vector<float> v;
  63. for (int i=0; i<node.attribute_size(); i++)
  64. {
  65. const onnx::AttributeProto& attr = node.attribute(i);
  66. if (attr.name() == key)
  67. {
  68. v.resize(attr.floats_size());
  69. for (int j=0; j<attr.floats_size(); j++)
  70. {
  71. v[j] = attr.floats(j);
  72. }
  73. break;
  74. }
  75. }
  76. return v;
  77. }
  78. static int get_node_attr_i(const onnx::NodeProto& node, const char* key, int def = 0)
  79. {
  80. for (int i=0; i<node.attribute_size(); i++)
  81. {
  82. const onnx::AttributeProto& attr = node.attribute(i);
  83. if (attr.name() == key)
  84. {
  85. return attr.i();
  86. }
  87. }
  88. return def;
  89. }
  90. static float get_node_attr_f(const onnx::NodeProto& node, const char* key, float def = 0.f)
  91. {
  92. for (int i=0; i<node.attribute_size(); i++)
  93. {
  94. const onnx::AttributeProto& attr = node.attribute(i);
  95. if (attr.name() == key)
  96. {
  97. return attr.f();
  98. }
  99. }
  100. return def;
  101. }
  102. static std::string get_node_attr_s(const onnx::NodeProto& node, const char* key, const std::string& def = std::string())
  103. {
  104. for (int i=0; i<node.attribute_size(); i++)
  105. {
  106. const onnx::AttributeProto& attr = node.attribute(i);
  107. if (attr.name() == key)
  108. {
  109. return attr.s();
  110. }
  111. }
  112. return def;
  113. }
  114. static onnx::TensorProto get_node_attr_tensor(const onnx::NodeProto& node, const char* key)
  115. {
  116. for (int i=0; i<node.attribute_size(); i++)
  117. {
  118. const onnx::AttributeProto& attr = node.attribute(i);
  119. if (attr.name() == key)
  120. {
  121. return attr.t();
  122. }
  123. }
  124. return onnx::TensorProto();
  125. }
  126. static int get_tensor_proto_data_size(const onnx::TensorProto& tp)
  127. {
  128. if (tp.has_raw_data())
  129. {
  130. const std::string& raw_data = tp.raw_data();
  131. int size = (int)raw_data.size() / 4;
  132. return size;
  133. }
  134. else if (tp.data_type() == 1)
  135. {
  136. return tp.float_data_size();
  137. }
  138. return 0;
  139. }
  140. static void fwrite_tensor_proto_data(const onnx::TensorProto& tp, FILE* bp)
  141. {
  142. int size = get_tensor_proto_data_size(tp);
  143. if (tp.has_raw_data())
  144. {
  145. const std::string& raw_data = tp.raw_data();
  146. fwrite(raw_data.data(), sizeof(float), size, bp);
  147. }
  148. else if (tp.data_type() == 1)
  149. {
  150. fwrite(tp.float_data().data(), sizeof(float), size, bp);
  151. }
  152. }
  153. int main(int argc, char** argv)
  154. {
  155. const char* onnxpb = argv[1];
  156. const char* ncnn_prototxt = argc >= 4 ? argv[2] : "ncnn.param";
  157. const char* ncnn_modelbin = argc >= 4 ? argv[3] : "ncnn.bin";
  158. onnx::ModelProto model;
  159. // load
  160. bool s1 = read_proto_from_binary(onnxpb, &model);
  161. if (!s1)
  162. {
  163. fprintf(stderr, "read_proto_from_binary failed\n");
  164. return -1;
  165. }
  166. FILE* pp = fopen(ncnn_prototxt, "wb");
  167. FILE* bp = fopen(ncnn_modelbin, "wb");
  168. // magic
  169. fprintf(pp, "7767517\n");
  170. const onnx::GraphProto& graph = model.graph();
  171. onnx::GraphProto* mutable_graph = model.mutable_graph();
  172. int node_count = graph.node_size();
  173. // node reference
  174. std::map<std::string, int> node_reference;
  175. // weight node and weight reshape node
  176. std::map<std::string, onnx::TensorProto> weights;
  177. // weight node before BinaryOp
  178. std::map<std::string, onnx::TensorProto> binaryop_weights;
  179. for (int j=0; j<graph.initializer_size(); j++)
  180. {
  181. const onnx::TensorProto& initializer = graph.initializer(j);
  182. // fprintf(stderr, "weight = %s\n", initializer.name().c_str());
  183. weights[initializer.name()] = initializer;
  184. }
  185. // global definition line
  186. // [layer count] [blob count]
  187. std::set<std::string> blob_names;
  188. for (int i=0; i<node_count; i++)
  189. {
  190. const onnx::NodeProto& node = graph.node(i);
  191. const std::string& op = node.op_type();
  192. std::string name = node.name();
  193. if (name.empty())
  194. {
  195. name = node.output(0);
  196. }
  197. if (op == "Constant")
  198. {
  199. onnx::TensorProto tensor = get_node_attr_tensor(node, "value");
  200. weights[node.output(0)] = tensor;
  201. continue;
  202. }
  203. else if (op == "Reshape")
  204. {
  205. if (node.input_size() == 1)
  206. {
  207. const std::string& input_name = node.input(0);
  208. // check weight
  209. if (weights.find(input_name) != weights.end())
  210. {
  211. weights[node.output(0)] = weights[input_name];
  212. continue;
  213. }
  214. }
  215. else if (node.input_size() == 2)
  216. {
  217. // opset 5
  218. const std::string& input_name = node.input(0);
  219. // check weight
  220. if (weights.find(input_name) != weights.end())
  221. {
  222. weights[node.output(0)] = weights[input_name];
  223. // set weight shape directly
  224. const onnx::TensorProto& shape_tp = weights[node.input(1)];
  225. const int64_t* shape_data = shape_tp.int64_data().data();
  226. weights[node.output(0)].clear_dims();
  227. for (int j=0; j<shape_tp.int64_data_size(); j++)
  228. {
  229. weights[node.output(0)].add_dims(shape_data[j]);
  230. }
  231. continue;
  232. }
  233. }
  234. }
  235. else
  236. {
  237. bool isBinaryOp = false;
  238. if (op == "Add" || op == "Mul")
  239. {
  240. isBinaryOp = true;
  241. }
  242. if (isBinaryOp)
  243. {
  244. // check weights
  245. for (int j=0; j<node.input_size(); j++)
  246. {
  247. const std::string& input_name = node.input(j);
  248. std::map<std::string, onnx::TensorProto>::iterator it = weights.find(input_name);
  249. if (it != weights.end())
  250. {
  251. // binary op with weight, insert MemoryData layer and const blob
  252. binaryop_weights[input_name] = it->second;
  253. weights.erase(it);
  254. }
  255. }
  256. }
  257. }
  258. for (int j=0; j<(int)node.input_size(); j++)
  259. {
  260. const std::string& input_name = node.input(j);
  261. // check weight
  262. if (weights.find(input_name) != weights.end())
  263. {
  264. continue;
  265. }
  266. blob_names.insert(input_name);
  267. if (node_reference.find(input_name) == node_reference.end())
  268. {
  269. node_reference[input_name] = 1;
  270. }
  271. else
  272. {
  273. node_reference[input_name] = node_reference[input_name] + 1;
  274. }
  275. }
  276. if (op == "Dropout")
  277. {
  278. const std::string& output_name = node.output(0);
  279. blob_names.insert(output_name);
  280. continue;
  281. }
  282. for (int j=0; j<(int)node.output_size(); j++)
  283. {
  284. const std::string& output_name = node.output(j);
  285. blob_names.insert(output_name);
  286. }
  287. }
  288. // include Input node
  289. int input_node_count = 0;
  290. for (int j=0; j<graph.input_size(); j++)
  291. {
  292. const std::string& input_name = graph.input(j).name();
  293. // check weight
  294. if (weights.find(input_name) != weights.end())
  295. continue;
  296. // check weight before BinaryOp
  297. if (binaryop_weights.find(input_name) != binaryop_weights.end())
  298. continue;
  299. blob_names.insert(input_name);
  300. input_node_count++;
  301. }
  302. // op chain fusion
  303. int reduced_node_count = 0;
  304. for (int i=0; i<node_count; i++)
  305. {
  306. onnx::NodeProto* node = mutable_graph->mutable_node(i);
  307. // MatMul <= Transpose(weight) - MatMul
  308. if (node->op_type() == "Transpose")
  309. {
  310. // check weight
  311. if (weights.find(node->input(0)) == weights.end())
  312. continue;
  313. onnx::TensorProto& B = weights[node->input(0)];
  314. if (B.dims_size() != 2)
  315. continue;
  316. if (node_reference[node->output(0)] != 1)
  317. continue;
  318. // perm = (1, 0)
  319. std::vector<int> perm = get_node_attr_ai(*node, "perm");
  320. if (perm.size() != 2)
  321. continue;
  322. if (perm[0] != 1 || perm[1] != 0)
  323. continue;
  324. if (i+1 >= node_count)
  325. continue;
  326. onnx::NodeProto* node2 = mutable_graph->mutable_node(i+1);
  327. if (node2->op_type() != "MatMul")
  328. continue;
  329. // reduce
  330. node->set_op_type("noop_reducedncnn");
  331. node_reference.erase(node_reference.find(node->output(0)));
  332. blob_names.erase(node->output(0));
  333. node2->set_input(1, node->input(0));
  334. // permute weight
  335. {
  336. const int h = B.dims(0);
  337. const int w = B.dims(1);
  338. std::vector<float> permuted_data;
  339. permuted_data.reserve(h * w);
  340. const float* bptr = B.has_raw_data() ? (const float*)B.raw_data().data() : B.float_data().data();
  341. for (int j=0; j<w; j++)
  342. {
  343. for (int k=0; k<h; k++)
  344. {
  345. float vb = bptr[ k*w + j ];
  346. permuted_data.push_back(vb);
  347. }
  348. }
  349. B.set_dims(0, w);
  350. B.set_dims(1, h);
  351. if (B.has_raw_data())
  352. {
  353. B.set_raw_data(permuted_data.data(), permuted_data.size() * sizeof(float));
  354. }
  355. else
  356. {
  357. for (int j=0; j<(int)permuted_data.size(); j++)
  358. B.set_float_data(j, permuted_data[j]);
  359. }
  360. }
  361. reduced_node_count += 1;
  362. i += 1;
  363. }
  364. }
  365. // remove node_reference entry with reference equals to one
  366. int splitncnn_blob_count = 0;
  367. std::map<std::string, int>::iterator it = node_reference.begin();
  368. while (it != node_reference.end())
  369. {
  370. if (it->second == 1)
  371. {
  372. node_reference.erase(it++);
  373. }
  374. else
  375. {
  376. splitncnn_blob_count += it->second;
  377. // fprintf(stderr, "%s %d\n", it->first.c_str(), it->second);
  378. ++it;
  379. }
  380. }
  381. fprintf(pp, "%lu %lu\n", node_count - reduced_node_count + input_node_count + node_reference.size() + graph.initializer_size() - weights.size(), blob_names.size() + splitncnn_blob_count);
  382. int internal_split = 0;
  383. // place Input at the beginning
  384. for (int j=0; j<graph.input_size(); j++)
  385. {
  386. const std::string& input_name = graph.input(j).name();
  387. // check weight
  388. if (weights.find(input_name) != weights.end())
  389. continue;
  390. // check weight before BinaryOp
  391. if (binaryop_weights.find(input_name) != binaryop_weights.end())
  392. continue;
  393. fprintf(pp, "%-16s %-24s 0 1 %s\n", "Input", input_name.c_str(), input_name.c_str());
  394. // split the input
  395. if (node_reference.find(input_name) == node_reference.end()){
  396. continue;
  397. }
  398. int refcount = node_reference[input_name];
  399. if (refcount <= 1){
  400. continue;
  401. }
  402. char splitname[256];
  403. sprintf(splitname, "splitncnn_input%d", j);
  404. fprintf(pp, "%-16s %-24s %d %d", "Split", splitname, 1, refcount);
  405. fprintf(pp, " %s", input_name.c_str());
  406. for (int k=0; k<refcount; k++){
  407. fprintf(pp, " %s_splitncnn_%d", input_name.c_str(), k);
  408. }
  409. fprintf(pp, "\n");
  410. }
  411. // place MemoryData next
  412. for (int j=0; j<graph.input_size(); j++)
  413. {
  414. const std::string& input_name = graph.input(j).name();
  415. // check weight before BinaryOp
  416. if (binaryop_weights.find(input_name) == binaryop_weights.end())
  417. continue;
  418. fprintf(pp, "%-16s %-24s 0 1 %s", "MemoryData", input_name.c_str(), input_name.c_str());
  419. const onnx::TensorProto& M = binaryop_weights[input_name];
  420. if (M.dims_size() == 1) {
  421. fprintf(pp, " 0=%d", (int)M.dims(0));
  422. } else if (M.dims_size() == 2) {
  423. fprintf(pp, " 0=%d", (int)M.dims(1));
  424. fprintf(pp, " 1=%d", (int)M.dims(0));
  425. } else if (M.dims_size() == 3) {
  426. fprintf(pp, " 0=%d", (int)M.dims(2));
  427. fprintf(pp, " 1=%d", (int)M.dims(1));
  428. fprintf(pp, " 2=%d", (int)M.dims(0));
  429. }
  430. fprintf(pp, "\n");
  431. fwrite_tensor_proto_data(M, bp);
  432. }
  433. for (int i=0; i<node_count; i++)
  434. {
  435. const onnx::NodeProto& node = graph.node(i);
  436. const std::string& op = node.op_type();
  437. // fprintf(stderr, "op = %s\n", op.c_str());
  438. if (op == "noop_reducedncnn")
  439. {
  440. continue;
  441. }
  442. std::string name = node.name();
  443. if (name.empty())
  444. {
  445. name = node.output(0);
  446. }
  447. int input_size = node.input_size();
  448. int output_size = node.output_size();
  449. for (int j=0; j<(int)node.input_size(); j++)
  450. {
  451. const std::string& input_name = node.input(j);
  452. // check weight
  453. if (weights.find(input_name) != weights.end())
  454. {
  455. input_size--;
  456. }
  457. // fprintf(stderr, " input = %s\n", input_name.c_str());
  458. }
  459. for (int j=0; j<(int)node.output_size(); j++)
  460. {
  461. const std::string& output_name = node.output(j);
  462. // fprintf(stderr, " output = %s\n", output_name.c_str());
  463. }
  464. if (op == "Abs")
  465. {
  466. fprintf(pp, "%-16s", "UnaryOp");
  467. }
  468. else if (op == "Acos")
  469. {
  470. fprintf(pp, "%-16s", "UnaryOp");
  471. }
  472. else if (op == "Add")
  473. {
  474. fprintf(pp, "%-16s", "BinaryOp");
  475. }
  476. else if (op == "Asin")
  477. {
  478. fprintf(pp, "%-16s", "UnaryOp");
  479. }
  480. else if (op == "Atan")
  481. {
  482. fprintf(pp, "%-16s", "UnaryOp");
  483. }
  484. else if (op == "AveragePool" || op == "MaxPool")
  485. {
  486. fprintf(pp, "%-16s", "Pooling");
  487. }
  488. else if (op == "BatchNormalization")
  489. {
  490. fprintf(pp, "%-16s", "BatchNorm");
  491. }
  492. else if (op == "Ceil")
  493. {
  494. fprintf(pp, "%-16s", "UnaryOp");
  495. }
  496. else if (op == "Clip")
  497. {
  498. fprintf(pp, "%-16s", "Clip");
  499. }
  500. else if (op == "Concat")
  501. {
  502. fprintf(pp, "%-16s", "Concat");
  503. }
  504. else if (op == "Constant")
  505. {
  506. // check weight before BinaryOp
  507. if (binaryop_weights.find(node.output(0)) != binaryop_weights.end())
  508. {
  509. fprintf(pp, "%-16s", "MemoryData");
  510. }
  511. else
  512. {
  513. continue;
  514. }
  515. }
  516. else if (op == "Conv")
  517. {
  518. int group = get_node_attr_i(node, "group", 1);
  519. if (group > 1) {
  520. fprintf(pp, "%-16s", "ConvolutionDepthWise");
  521. } else {
  522. fprintf(pp, "%-16s", "Convolution");
  523. }
  524. }
  525. else if (op == "ConvTranspose")
  526. {
  527. int group = get_node_attr_i(node, "group", 1);
  528. if (group > 1) {
  529. fprintf(pp, "%-16s", "DeconvolutionDepthWise");
  530. } else {
  531. fprintf(pp, "%-16s", "Deconvolution");
  532. }
  533. }
  534. else if (op == "Cos")
  535. {
  536. fprintf(pp, "%-16s", "UnaryOp");
  537. }
  538. else if (op == "Div")
  539. {
  540. fprintf(pp, "%-16s", "BinaryOp");
  541. }
  542. else if (op == "Dropout")
  543. {
  544. fprintf(pp, "%-16s", "Dropout");
  545. output_size = 1;
  546. }
  547. else if (op == "Elu")
  548. {
  549. fprintf(pp, "%-16s", "ELU");
  550. }
  551. else if (op == "Exp")
  552. {
  553. fprintf(pp, "%-16s", "UnaryOp");
  554. }
  555. else if (op == "Flatten")
  556. {
  557. fprintf(pp, "%-16s", "Flatten");
  558. }
  559. else if (op == "Floor")
  560. {
  561. fprintf(pp, "%-16s", "UnaryOp");
  562. }
  563. else if (op == "Gemm")
  564. {
  565. float alpha = get_node_attr_f(node, "alpha", 1.f);
  566. float beta = get_node_attr_f(node, "beta", 1.f);
  567. int transA = get_node_attr_i(node, "transA", 0);
  568. int transB = get_node_attr_i(node, "transB", 0);
  569. if (alpha == 1.f && beta == 1.f)
  570. {
  571. // InnerProduct-like A * B + C
  572. if (transA == 0 && transB == 1)
  573. {
  574. fprintf(pp, "%-16s", "InnerProduct");
  575. }
  576. }
  577. // TODO
  578. }
  579. else if (op == "GlobalAveragePool")
  580. {
  581. fprintf(pp, "%-16s", "Pooling");
  582. }
  583. else if (op == "GlobalMaxPool")
  584. {
  585. fprintf(pp, "%-16s", "Pooling");
  586. }
  587. else if (op == "ImageScaler")
  588. {
  589. fprintf(pp, "%-16s", "Scale");
  590. }
  591. else if (op == "InstanceNormalization")
  592. {
  593. fprintf(pp, "%-16s", "InstanceNorm");
  594. }
  595. else if (op == "LeakyRelu")
  596. {
  597. fprintf(pp, "%-16s", "ReLU");
  598. }
  599. else if (op == "Log")
  600. {
  601. fprintf(pp, "%-16s", "UnaryOp");
  602. }
  603. else if (op == "LRN")
  604. {
  605. fprintf(pp, "%-16s", "LRN");
  606. }
  607. else if (op == "MatMul")
  608. {
  609. fprintf(pp, "%-16s", "InnerProduct");
  610. }
  611. else if (op == "Max")
  612. {
  613. fprintf(pp, "%-16s", "BinaryOp");
  614. }
  615. else if (op == "Min")
  616. {
  617. fprintf(pp, "%-16s", "BinaryOp");
  618. }
  619. else if (op == "Mul")
  620. {
  621. fprintf(pp, "%-16s", "BinaryOp");
  622. }
  623. else if (op == "Neg")
  624. {
  625. fprintf(pp, "%-16s", "UnaryOp");
  626. }
  627. else if (op == "Pad")
  628. {
  629. fprintf(pp, "%-16s", "Padding");
  630. }
  631. else if (op == "Pow")
  632. {
  633. fprintf(pp, "%-16s", "BinaryOp");
  634. }
  635. else if (op == "PRelu")
  636. {
  637. fprintf(pp, "%-16s", "PReLU");
  638. }
  639. else if (op == "Reciprocal")
  640. {
  641. fprintf(pp, "%-16s", "UnaryOp");
  642. }
  643. else if (op == "Relu")
  644. {
  645. fprintf(pp, "%-16s", "ReLU");
  646. }
  647. else if (op == "Reshape")
  648. {
  649. if (node.input_size() == 1 || node.input_size() == 2)
  650. {
  651. const std::string& input_name = node.input(0);
  652. // skip weight reshape
  653. if (weights.find(input_name) != weights.end())
  654. {
  655. continue;
  656. }
  657. }
  658. fprintf(pp, "%-16s", "Reshape");
  659. }
  660. else if (op == "Sigmoid")
  661. {
  662. fprintf(pp, "%-16s", "Sigmoid");
  663. }
  664. else if (op == "Sin")
  665. {
  666. fprintf(pp, "%-16s", "UnaryOp");
  667. }
  668. else if (op == "Slice")
  669. {
  670. fprintf(pp, "%-16s", "Crop");
  671. }
  672. else if (op == "Softmax")
  673. {
  674. fprintf(pp, "%-16s", "Softmax");
  675. }
  676. else if (op == "Sqrt")
  677. {
  678. fprintf(pp, "%-16s", "UnaryOp");
  679. }
  680. else if (op == "Sub")
  681. {
  682. fprintf(pp, "%-16s", "BinaryOp");
  683. }
  684. else if (op == "Sum")
  685. {
  686. fprintf(pp, "%-16s", "Eltwise");
  687. }
  688. else if (op == "Tan")
  689. {
  690. fprintf(pp, "%-16s", "UnaryOp");
  691. }
  692. else if (op == "Transpose")
  693. {
  694. fprintf(pp, "%-16s", "Permute");
  695. }
  696. else if (op == "Upsample")
  697. {
  698. fprintf(pp, "%-16s", "Interp");
  699. }
  700. else
  701. {
  702. // TODO
  703. fprintf(stderr, "%s not supported yet!\n", op.c_str());
  704. fprintf(pp, "%-16s", op.c_str());
  705. }
  706. fprintf(pp, " %-24s %d %d", name.c_str(), input_size, output_size);
  707. for (int j=0; j<node.input_size(); j++)
  708. {
  709. std::string input_name = node.input(j);
  710. // check weight
  711. if (weights.find(input_name) != weights.end())
  712. {
  713. continue;
  714. }
  715. if (node_reference.find(input_name) != node_reference.end())
  716. {
  717. int refidx = node_reference[input_name] - 1;
  718. node_reference[input_name] = refidx;
  719. char splitsuffix[256];
  720. sprintf(splitsuffix, "_splitncnn_%d", refidx);
  721. input_name = input_name + splitsuffix;
  722. }
  723. fprintf(pp, " %s", input_name.c_str());
  724. }
  725. for (int j=0; j<output_size; j++)
  726. {
  727. const std::string& output_name = node.output(j);
  728. fprintf(pp, " %s", output_name.c_str());
  729. }
  730. if (op == "Abs")
  731. {
  732. int op_type = 0;
  733. fprintf(pp, " 0=%d", op_type);
  734. }
  735. else if (op == "Acos")
  736. {
  737. int op_type = 13;
  738. fprintf(pp, " 0=%d", op_type);
  739. }
  740. else if (op == "Add")
  741. {
  742. int op_type = 0;
  743. fprintf(pp, " 0=%d", op_type);
  744. }
  745. else if (op == "Asin")
  746. {
  747. int op_type = 12;
  748. fprintf(pp, " 0=%d", op_type);
  749. }
  750. else if (op == "Atan")
  751. {
  752. int op_type = 14;
  753. fprintf(pp, " 0=%d", op_type);
  754. }
  755. else if (op == "AveragePool" || op == "MaxPool")
  756. {
  757. std::string auto_pad = get_node_attr_s(node, "auto_pad");//TODO
  758. std::vector<int> kernel_shape = get_node_attr_ai(node, "kernel_shape");
  759. std::vector<int> strides = get_node_attr_ai(node, "strides");
  760. std::vector<int> pads = get_node_attr_ai(node, "pads");
  761. int pool = op == "AveragePool" ? 1 : 0;
  762. int pad_mode = 1;
  763. if (auto_pad == "SAME_LOWER" || auto_pad == "SAME_UPPER")
  764. {
  765. // TODO
  766. pad_mode = 2;
  767. }
  768. fprintf(pp, " 0=%d", pool);
  769. if (kernel_shape.size() == 1) {
  770. fprintf(pp, " 1=%d", kernel_shape[0]);
  771. } else if (kernel_shape.size() == 2) {
  772. fprintf(pp, " 1=%d", kernel_shape[1]);
  773. fprintf(pp, " 11=%d", kernel_shape[0]);
  774. }
  775. if (strides.size() == 1) {
  776. fprintf(pp, " 2=%d", strides[0]);
  777. } else if (strides.size() == 2) {
  778. fprintf(pp, " 2=%d", strides[1]);
  779. fprintf(pp, " 12=%d", strides[0]);
  780. }
  781. if (pads.size() == 1) {
  782. fprintf(pp, " 3=%d", pads[0]);
  783. } else if (pads.size() == 2) {
  784. fprintf(pp, " 3=%d", pads[1]);
  785. fprintf(pp, " 13=%d", pads[0]);
  786. } else if (pads.size() == 4) {
  787. fprintf(pp, " 3=%d", pads[1]);
  788. fprintf(pp, " 13=%d", pads[0]);
  789. fprintf(pp, " 14=%d", pads[3]);
  790. fprintf(pp, " 15=%d", pads[2]);
  791. }
  792. fprintf(pp, " 5=%d", pad_mode);
  793. }
  794. else if (op == "BatchNormalization")
  795. {
  796. float epsilon = get_node_attr_f(node, "epsilon", 1e-5f);
  797. const onnx::TensorProto& scale = weights[node.input(1)];
  798. const onnx::TensorProto& B = weights[node.input(2)];
  799. const onnx::TensorProto& mean = weights[node.input(3)];
  800. const onnx::TensorProto& var = weights[node.input(4)];
  801. int channels = get_tensor_proto_data_size(scale);
  802. fprintf(pp, " 0=%d", channels);
  803. fwrite_tensor_proto_data(scale, bp);
  804. fwrite_tensor_proto_data(mean, bp);
  805. // apply epsilon to var
  806. {
  807. const float* v = var.has_raw_data() ? (const float*)var.raw_data().data() : var.float_data().data();
  808. for (int j=0; j<channels; j++)
  809. {
  810. float ve = v[j] + epsilon;
  811. fwrite(&ve, sizeof(float), 1, bp);
  812. }
  813. }
  814. fwrite_tensor_proto_data(B, bp);
  815. }
  816. else if (op == "Ceil")
  817. {
  818. int op_type = 3;
  819. fprintf(pp, " 0=%d", op_type);
  820. }
  821. else if (op == "Clip")
  822. {
  823. float min = get_node_attr_f(node, "min", -FLT_MAX);
  824. float max = get_node_attr_f(node, "max", FLT_MAX);
  825. fprintf(pp, " 0=%f", min);
  826. fprintf(pp, " 1=%f", max);
  827. }
  828. else if (op == "Concat")
  829. {
  830. int axis = get_node_attr_i(node, "axis", 1);
  831. fprintf(pp, " 0=%d", axis-1);
  832. }
  833. else if (op == "Constant")
  834. {
  835. // check weight before BinaryOp
  836. if (binaryop_weights.find(name) != binaryop_weights.end())
  837. {
  838. const onnx::TensorProto& M = binaryop_weights[name];
  839. if (M.dims_size() == 1) {
  840. fprintf(pp, " 0=%d", (int)M.dims(0));
  841. } else if (M.dims_size() == 2) {
  842. fprintf(pp, " 0=%d", (int)M.dims(1));
  843. } else if (M.dims_size() == 3) {
  844. fprintf(pp, " 0=%d", (int)M.dims(2));
  845. fprintf(pp, " 1=%d", (int)M.dims(1));
  846. } else if (M.dims_size() == 4) {
  847. fprintf(pp, " 0=%d", (int)M.dims(3));
  848. fprintf(pp, " 1=%d", (int)M.dims(2));
  849. fprintf(pp, " 2=%d", (int)M.dims(1));
  850. }
  851. fwrite_tensor_proto_data(M, bp);
  852. }
  853. }
  854. else if (op == "Conv")
  855. {
  856. const onnx::TensorProto& W = weights[node.input(1)];
  857. int num_filter = W.dims(0);
  858. int has_bias = node.input_size() == 3 ? 1 : 0;
  859. std::string auto_pad = get_node_attr_s(node, "auto_pad");//TODO
  860. std::vector<int> kernel_shape = get_node_attr_ai(node, "kernel_shape");
  861. std::vector<int> dilations = get_node_attr_ai(node, "dilations");
  862. std::vector<int> strides = get_node_attr_ai(node, "strides");
  863. std::vector<int> pads = get_node_attr_ai(node, "pads");
  864. int group = get_node_attr_i(node, "group", 1);
  865. fprintf(pp, " 0=%d", num_filter);
  866. if (kernel_shape.size() == 1) {
  867. fprintf(pp, " 1=%d", kernel_shape[0]);
  868. } else if (kernel_shape.size() == 2) {
  869. fprintf(pp, " 1=%d", kernel_shape[1]);
  870. fprintf(pp, " 11=%d", kernel_shape[0]);
  871. }
  872. if (dilations.size() == 1) {
  873. fprintf(pp, " 2=%d", dilations[0]);
  874. } else if (dilations.size() == 2) {
  875. fprintf(pp, " 2=%d", dilations[1]);
  876. fprintf(pp, " 12=%d", dilations[0]);
  877. }
  878. if (strides.size() == 1) {
  879. fprintf(pp, " 3=%d", strides[0]);
  880. } else if (strides.size() == 2) {
  881. fprintf(pp, " 3=%d", strides[1]);
  882. fprintf(pp, " 13=%d", strides[0]);
  883. }
  884. if (auto_pad == "SAME_LOWER" || auto_pad == "SAME_UPPER")
  885. {
  886. // TODO
  887. fprintf(pp, " 4=-233");
  888. }
  889. else
  890. {
  891. if (pads.size() == 1) {
  892. fprintf(pp, " 4=%d", pads[0]);
  893. } else if (pads.size() == 2) {
  894. fprintf(pp, " 4=%d", pads[1]);
  895. fprintf(pp, " 14=%d", pads[0]);
  896. } else if (pads.size() == 4) {
  897. fprintf(pp, " 4=%d", pads[1]);
  898. fprintf(pp, " 14=%d", pads[0]);
  899. // TODO hpad2=pads[2] wpad2=pads[3]
  900. }
  901. }
  902. fprintf(pp, " 5=%d", has_bias);
  903. fprintf(pp, " 6=%d", get_tensor_proto_data_size(W));
  904. if (group > 1) {
  905. fprintf(pp, " 7=%d", group);
  906. }
  907. int quantize_tag = 0;
  908. fwrite(&quantize_tag, sizeof(int), 1, bp);
  909. fwrite_tensor_proto_data(W, bp);
  910. if (has_bias)
  911. {
  912. const onnx::TensorProto& B = weights[node.input(2)];
  913. fwrite_tensor_proto_data(B, bp);
  914. }
  915. }
  916. else if (op == "ConvTranspose")
  917. {
  918. const onnx::TensorProto& W = weights[node.input(1)];
  919. int num_filter = W.dims(1);
  920. int has_bias = node.input_size() == 3 ? 1 : 0;
  921. std::string auto_pad = get_node_attr_s(node, "auto_pad");//TODO
  922. std::vector<int> kernel_shape = get_node_attr_ai(node, "kernel_shape");
  923. std::vector<int> dilations = get_node_attr_ai(node, "dilations");
  924. std::vector<int> strides = get_node_attr_ai(node, "strides");
  925. std::vector<int> output_padding = get_node_attr_ai(node, "output_padding");//TODO implement adj
  926. std::vector<int> output_shape = get_node_attr_ai(node, "output_shape");//TODO
  927. std::vector<int> pads = get_node_attr_ai(node, "pads");
  928. int group = get_node_attr_i(node, "group", 1);
  929. fprintf(pp, " 0=%d", num_filter);
  930. if (kernel_shape.size() == 1) {
  931. fprintf(pp, " 1=%d", kernel_shape[0]);
  932. } else if (kernel_shape.size() == 2) {
  933. fprintf(pp, " 1=%d", kernel_shape[1]);
  934. fprintf(pp, " 11=%d", kernel_shape[0]);
  935. }
  936. if (dilations.size() == 1) {
  937. fprintf(pp, " 2=%d", dilations[0]);
  938. } else if (dilations.size() == 2) {
  939. fprintf(pp, " 2=%d", dilations[1]);
  940. fprintf(pp, " 12=%d", dilations[0]);
  941. }
  942. if (strides.size() == 1) {
  943. fprintf(pp, " 3=%d", strides[0]);
  944. } else if (strides.size() == 2) {
  945. fprintf(pp, " 3=%d", strides[1]);
  946. fprintf(pp, " 13=%d", strides[0]);
  947. }
  948. if (auto_pad == "SAME_LOWER" || auto_pad == "SAME_UPPER")
  949. {
  950. // TODO
  951. fprintf(pp, " 4=-233");
  952. }
  953. else
  954. {
  955. if (pads.size() == 1) {
  956. fprintf(pp, " 4=%d", pads[0]);
  957. } else if (pads.size() == 2) {
  958. fprintf(pp, " 4=%d", pads[1]);
  959. fprintf(pp, " 14=%d", pads[0]);
  960. } else if (pads.size() == 4) {
  961. fprintf(pp, " 4=%d", pads[1]);
  962. fprintf(pp, " 14=%d", pads[0]);
  963. // TODO hpad2=pads[2] wpad2=pads[3]
  964. }
  965. }
  966. fprintf(pp, " 5=%d", has_bias);
  967. fprintf(pp, " 6=%d", get_tensor_proto_data_size(W));
  968. if (group > 1) {
  969. fprintf(pp, " 7=%d", group);
  970. }
  971. int quantize_tag = 0;
  972. fwrite(&quantize_tag, sizeof(int), 1, bp);
  973. int maxk = 0;
  974. if (kernel_shape.size() == 2)
  975. {
  976. maxk = kernel_shape[1] * kernel_shape[0];
  977. }
  978. else
  979. {
  980. maxk = kernel_shape[0] * kernel_shape[0];
  981. }
  982. int weight_data_size = get_tensor_proto_data_size(W);
  983. const float* weight_data = 0;
  984. if (W.has_raw_data())
  985. {
  986. weight_data = (const float*)W.raw_data().data();
  987. }
  988. else if (W.data_type() == 1)
  989. {
  990. weight_data = W.float_data().data();
  991. }
  992. for (int g=0; g<group; g++)
  993. {
  994. // reorder weight from inch-outch to outch-inch
  995. int num_filter_g = num_filter / group;
  996. int num_input = weight_data_size / maxk / num_filter_g / group;
  997. const float* weight_data_ptr = weight_data + g * maxk * num_filter_g * num_input;
  998. for (int k=0; k<num_filter_g; k++)
  999. {
  1000. for (int j=0; j<num_input; j++)
  1001. {
  1002. fwrite(weight_data_ptr + (j*num_filter_g + k) * maxk, sizeof(float), maxk, bp);
  1003. }
  1004. }
  1005. }
  1006. if (has_bias)
  1007. {
  1008. const onnx::TensorProto& B = weights[node.input(2)];
  1009. fwrite_tensor_proto_data(B, bp);
  1010. }
  1011. }
  1012. else if (op == "Cos")
  1013. {
  1014. int op_type = 10;
  1015. fprintf(pp, " 0=%d", op_type);
  1016. }
  1017. else if (op == "Div")
  1018. {
  1019. int op_type = 3;
  1020. fprintf(pp, " 0=%d", op_type);
  1021. }
  1022. else if (op == "Dropout")
  1023. {
  1024. // no-op
  1025. }
  1026. else if (op == "Elu")
  1027. {
  1028. float alpha = get_node_attr_f(node, "alpha", 1.f);
  1029. fprintf(pp, " 0=%f", alpha);
  1030. }
  1031. else if (op == "Exp")
  1032. {
  1033. int op_type = 7;
  1034. fprintf(pp, " 0=%d", op_type);
  1035. }
  1036. else if (op == "Flatten")
  1037. {
  1038. int axis = get_node_attr_i(node, "axis", 1);
  1039. if (axis != 1)
  1040. {
  1041. fprintf(stderr, "Unsupported Flatten axis %d!\n", axis);
  1042. }
  1043. }
  1044. else if (op == "Floor")
  1045. {
  1046. int op_type = 2;
  1047. fprintf(pp, " 0=%d", op_type);
  1048. }
  1049. else if (op == "Gemm")
  1050. {
  1051. float alpha = get_node_attr_f(node, "alpha", 1.f);
  1052. float beta = get_node_attr_f(node, "beta", 1.f);
  1053. int transA = get_node_attr_i(node, "transA", 0);
  1054. int transB = get_node_attr_i(node, "transB", 0);
  1055. if (alpha == 1.f && beta == 1.f)
  1056. {
  1057. // InnerProduct-like A * B + C
  1058. if (transA == 0 && transB == 1)
  1059. {
  1060. const onnx::TensorProto& B = weights[node.input(1)];
  1061. const onnx::TensorProto& C = weights[node.input(2)];
  1062. fprintf(pp, " 0=%d", get_tensor_proto_data_size(C));
  1063. fprintf(pp, " 1=1");
  1064. fprintf(pp, " 2=%d", get_tensor_proto_data_size(B));
  1065. int quantize_tag = 0;
  1066. fwrite(&quantize_tag, sizeof(int), 1, bp);
  1067. fwrite_tensor_proto_data(B, bp);
  1068. fwrite_tensor_proto_data(C, bp);
  1069. }
  1070. }
  1071. }
  1072. else if (op == "GlobalAveragePool")
  1073. {
  1074. int pool = 1;
  1075. int global_pool = 1;
  1076. fprintf(pp, " 0=%d", pool);
  1077. fprintf(pp, " 4=%d", global_pool);
  1078. }
  1079. else if (op == "GlobalMaxPool")
  1080. {
  1081. int pool = 0;
  1082. int global_pool = 1;
  1083. fprintf(pp, " 0=%d", pool);
  1084. fprintf(pp, " 4=%d", global_pool);
  1085. }
  1086. else if (op == "ImageScaler")
  1087. {
  1088. std::vector<float> bias = get_node_attr_af(node, "bias");
  1089. float scale = get_node_attr_f(node, "scale", 1.f);
  1090. int channels = bias.size();
  1091. fprintf(pp, " 0=%d", channels);
  1092. fprintf(pp, " 1=1");
  1093. for (int j=0; j<channels; j++)
  1094. {
  1095. fwrite(&scale, sizeof(float), 1, bp);
  1096. }
  1097. fwrite(&bias[0], sizeof(float), channels, bp);
  1098. }
  1099. else if (op == "InstanceNormalization")
  1100. {
  1101. float eps = get_node_attr_f(node, "epsilon", 1e-5f);
  1102. const onnx::TensorProto& scale = weights[node.input(1)];
  1103. const onnx::TensorProto& B = weights[node.input(2)];
  1104. int channels = get_tensor_proto_data_size(scale);
  1105. fprintf(pp, " 0=%d", channels);
  1106. fprintf(pp, " 1=%f", eps);
  1107. fwrite_tensor_proto_data(scale, bp);
  1108. fwrite_tensor_proto_data(B, bp);
  1109. }
  1110. else if (op == "LeakyRelu")
  1111. {
  1112. float alpha = get_node_attr_f(node, "alpha", 0.01f);
  1113. fprintf(pp, " 0=%f", alpha);
  1114. }
  1115. else if (op == "Log")
  1116. {
  1117. int op_type = 8;
  1118. fprintf(pp, " 0=%d", op_type);
  1119. }
  1120. else if (op == "LRN")
  1121. {
  1122. float alpha = get_node_attr_f(node, "alpha", 1.f);
  1123. float beta = get_node_attr_f(node, "beta", 0.5f);
  1124. float bias = get_node_attr_f(node, "bias", 1.f);
  1125. int size = get_node_attr_i(node, "size", 1);
  1126. int norm_region = 0;
  1127. fprintf(pp, " 0=%d", norm_region);
  1128. fprintf(pp, " 1=%d", size);
  1129. fprintf(pp, " 2=%f", alpha);
  1130. fprintf(pp, " 3=%f", beta);
  1131. fprintf(pp, " 4=%f", bias);
  1132. }
  1133. else if (op == "MatMul")
  1134. {
  1135. const onnx::TensorProto& B = weights[node.input(1)];
  1136. int weight_data_size = get_tensor_proto_data_size(B);
  1137. int num_output = B.dims(B.dims_size()-1);
  1138. int num_input = weight_data_size / num_output;
  1139. fprintf(pp, " 0=%d", num_output);
  1140. fprintf(pp, " 1=0");
  1141. fprintf(pp, " 2=%d", weight_data_size);
  1142. int quantize_tag = 0;
  1143. fwrite(&quantize_tag, sizeof(int), 1, bp);
  1144. // reorder num_input-num_output to num_output-num_input
  1145. {
  1146. const float* bptr = B.has_raw_data() ? (const float*)B.raw_data().data() : B.float_data().data();
  1147. for (int j=0; j<num_output; j++)
  1148. {
  1149. for (int k=0; k<num_input; k++)
  1150. {
  1151. float vb = bptr[ k*num_output + j ];
  1152. fwrite(&vb, sizeof(float), 1, bp);
  1153. }
  1154. }
  1155. }
  1156. // fwrite_tensor_proto_data(B, bp)
  1157. }
  1158. else if (op == "Max")
  1159. {
  1160. int op_type = 4;
  1161. fprintf(pp, " 0=%d", op_type);
  1162. }
  1163. else if (op == "Min")
  1164. {
  1165. int op_type = 5;
  1166. fprintf(pp, " 0=%d", op_type);
  1167. }
  1168. else if (op == "Mul")
  1169. {
  1170. int op_type = 2;
  1171. fprintf(pp, " 0=%d", op_type);
  1172. }
  1173. else if (op == "Neg")
  1174. {
  1175. int op_type = 1;
  1176. fprintf(pp, " 0=%d", op_type);
  1177. }
  1178. else if (op == "Pad")
  1179. {
  1180. std::string mode = get_node_attr_s(node, "mode");
  1181. std::vector<int> pads = get_node_attr_ai(node, "pads");
  1182. float value = get_node_attr_f(node, "value", 0.f);
  1183. int type = 0;
  1184. if (mode == "constant")
  1185. {
  1186. type = 0;
  1187. }
  1188. else if (mode == "edge")
  1189. {
  1190. type = 1;
  1191. }
  1192. else if (mode == "reflect")
  1193. {
  1194. // FIXME
  1195. }
  1196. int top = pads[0];
  1197. int bottom = pads[2];
  1198. int left = pads[1];
  1199. int right = pads[3];
  1200. fprintf(pp, " 0=%d", top);
  1201. fprintf(pp, " 1=%d", bottom);
  1202. fprintf(pp, " 2=%d", left);
  1203. fprintf(pp, " 3=%d", right);
  1204. fprintf(pp, " 4=%d", type);
  1205. fprintf(pp, " 5=%f", value);
  1206. }
  1207. else if (op == "Pow")
  1208. {
  1209. int op_type = 6;
  1210. fprintf(pp, " 0=%d", op_type);
  1211. }
  1212. else if (op == "PRelu")
  1213. {
  1214. const onnx::TensorProto& slope = weights[node.input(1)];
  1215. int num_slope = get_tensor_proto_data_size(slope);
  1216. fprintf(pp, " 0=%d", num_slope);
  1217. fwrite_tensor_proto_data(slope, bp);
  1218. }
  1219. else if (op == "Reciprocal")
  1220. {
  1221. int op_type = 15;
  1222. fprintf(pp, " 0=%d", op_type);
  1223. }
  1224. else if (op == "Reshape")
  1225. {
  1226. std::vector<int> shape;
  1227. if (node.input_size() == 1)
  1228. {
  1229. shape = get_node_attr_ai(node, "shape");
  1230. }
  1231. else
  1232. {
  1233. const onnx::TensorProto& shape_tp = weights[node.input(1)];
  1234. const int64_t* shape_data = shape_tp.int64_data().data();
  1235. for (int j=0; j<shape_tp.int64_data_size(); j++)
  1236. {
  1237. shape.push_back(shape_data[j]);
  1238. }
  1239. }
  1240. if (shape.size() == 1) {
  1241. fprintf(pp, " 0=%d", shape[0]);// should never reach here
  1242. } else if (shape.size() == 2) {
  1243. fprintf(pp, " 0=%d", shape[1]);
  1244. } else if (shape.size() == 3) {
  1245. fprintf(pp, " 0=%d", shape[2]);
  1246. fprintf(pp, " 1=%d", shape[1]);
  1247. } else if (shape.size() == 4) {
  1248. fprintf(pp, " 0=%d", shape[3]);
  1249. fprintf(pp, " 1=%d", shape[2]);
  1250. fprintf(pp, " 2=%d", shape[1]);
  1251. } else if (shape.size() == 5) {
  1252. fprintf(pp, " 0=%d", shape[4] * shape[3]);
  1253. fprintf(pp, " 1=%d", shape[2]);
  1254. fprintf(pp, " 2=%d", shape[1]);
  1255. }
  1256. }
  1257. else if (op == "Sigmoid")
  1258. {
  1259. }
  1260. else if (op == "Sin")
  1261. {
  1262. int op_type = 9;
  1263. fprintf(pp, " 0=%d", op_type);
  1264. }
  1265. else if (op == "Slice")
  1266. {
  1267. std::vector<int> starts = get_node_attr_ai(node, "starts");
  1268. std::vector<int> ends = get_node_attr_ai(node, "ends");
  1269. std::vector<int> steps = get_node_attr_ai(node, "steps");// TODO
  1270. // assert step == 1
  1271. for (int i=0; i<(int)steps.size(); i++)
  1272. {
  1273. if (steps[i] != 1)
  1274. fprintf(stderr, "Unsupported slice step !\n");
  1275. }
  1276. int woffset = 0;
  1277. int hoffset = 0;
  1278. int coffset = 0;
  1279. int outw = -233;
  1280. int outh = -233;
  1281. int outc = -233;
  1282. if (starts.size() == 2)
  1283. {
  1284. woffset = starts[1];
  1285. outw = ends[1] == -1 ? -234 : ends[1] - starts[1];
  1286. }
  1287. else if (starts.size() == 3)
  1288. {
  1289. woffset = starts[2];
  1290. hoffset = starts[1];
  1291. outw = ends[2] == -1 ? -234 : ends[2] - starts[2];
  1292. outh = ends[1] == -1 ? -234 : ends[1] - starts[1];
  1293. }
  1294. else if (starts.size() == 4)
  1295. {
  1296. woffset = starts[3];
  1297. hoffset = starts[2];
  1298. coffset = starts[1];
  1299. outw = ends[3] == -1 ? -234 : ends[3] - starts[3];
  1300. outh = ends[2] == -1 ? -234 : ends[2] - starts[2];
  1301. outc = ends[1] == -1 ? -234 : ends[1] - starts[1];
  1302. }
  1303. fprintf(pp, " 0=%d", woffset);
  1304. fprintf(pp, " 1=%d", hoffset);
  1305. fprintf(pp, " 2=%d", coffset);
  1306. fprintf(pp, " 3=%d", outw);
  1307. fprintf(pp, " 4=%d", outh);
  1308. fprintf(pp, " 5=%d", outc);
  1309. }
  1310. else if (op == "Softmax")
  1311. {
  1312. int axis = get_node_attr_i(node, "axis", 1);
  1313. fprintf(pp, " 0=%d", axis-1);
  1314. fprintf(pp, " 1=1");
  1315. }
  1316. else if (op == "Sqrt")
  1317. {
  1318. int op_type = 5;
  1319. fprintf(pp, " 0=%d", op_type);
  1320. }
  1321. else if (op == "Sub")
  1322. {
  1323. int op_type = 1;
  1324. fprintf(pp, " 0=%d", op_type);
  1325. }
  1326. else if (op == "Sum")
  1327. {
  1328. int op_type = 1;
  1329. fprintf(pp, " 0=%d", op_type);
  1330. }
  1331. else if (op == "Tan")
  1332. {
  1333. int op_type = 11;
  1334. fprintf(pp, " 0=%d", op_type);
  1335. }
  1336. else if (op == "Transpose")
  1337. {
  1338. std::vector<int> perm = get_node_attr_ai(node, "perm");
  1339. if (perm.size() == 4) {
  1340. if (perm[1] == 1 && perm[2] == 2 && perm[3] == 3)
  1341. fprintf(pp, " 0=0");// w h c
  1342. else if (perm[1] == 1 && perm[2] == 3 && perm[3] == 2)
  1343. fprintf(pp, " 0=1");// h w c
  1344. else if (perm[1] == 2 && perm[2] == 1 && perm[3] == 3)
  1345. fprintf(pp, " 0=2");// w c h
  1346. else if (perm[1] == 2 && perm[2] == 3 && perm[3] == 1)
  1347. fprintf(pp, " 0=3");// c w h
  1348. else if (perm[1] == 3 && perm[2] == 1 && perm[3] == 2)
  1349. fprintf(pp, " 0=4");// h c w
  1350. else if (perm[1] == 3 && perm[2] == 2 && perm[3] == 1)
  1351. fprintf(pp, " 0=5");// c h w
  1352. } else if (perm.size() == 5) {
  1353. if (perm[1] == 1 && perm[2] == 2 && perm[3] == 3 && perm[4] == 4)
  1354. fprintf(pp, " 0=0");// wx h c
  1355. else if (perm[1] == 1 && perm[2] == 3 && perm[3] == 4 && perm[4] == 2)
  1356. fprintf(pp, " 0=1");// h wx c
  1357. else if (perm[1] == 2 && perm[2] == 1 && perm[3] == 3 && perm[4] == 4)
  1358. fprintf(pp, " 0=2");// wx c h
  1359. else if (perm[1] == 2 && perm[2] == 3 && perm[3] == 4 && perm[4] == 1)
  1360. fprintf(pp, " 0=3");// c wx h
  1361. else if (perm[1] == 3 && perm[2] == 4 && perm[3] == 1 && perm[4] == 2)
  1362. fprintf(pp, " 0=4");// h c wx
  1363. else if (perm[1] == 3 && perm[2] == 4 && perm[3] == 2 && perm[4] == 1)
  1364. fprintf(pp, " 0=5");// c h wx
  1365. else
  1366. fprintf(stderr, "Unsupported transpose type !\n");
  1367. }
  1368. }
  1369. else if (op == "Upsample")
  1370. {
  1371. std::string mode = get_node_attr_s(node, "mode");
  1372. std::vector<float> scales;
  1373. if (node.input_size() == 1)
  1374. {
  1375. scales = get_node_attr_af(node, "scales");
  1376. }
  1377. else
  1378. {
  1379. const onnx::TensorProto& scales_tp = weights[node.input(1)];
  1380. const float* shape_data = scales_tp.has_raw_data() ? (const float*)scales_tp.raw_data().data() : scales_tp.float_data().data();
  1381. int float_data_size = scales_tp.float_data_size();
  1382. //float data is None, use raw data instead
  1383. if (float_data_size == 0) {
  1384. float_data_size = scales_tp.dims().Get(0);
  1385. }
  1386. for (int j=0; j<float_data_size; j++)
  1387. {
  1388. scales.push_back(shape_data[j]);
  1389. }
  1390. }
  1391. int resize_type = 1;
  1392. if (mode == "nearest")
  1393. {
  1394. resize_type = 1;
  1395. }
  1396. else if (mode == "bilinear" || mode == "linear")
  1397. {
  1398. resize_type = 2;
  1399. }
  1400. else if (mode == "trilinear")
  1401. {
  1402. fprintf(stderr, "Unsupported Upsample mode !\n");
  1403. }
  1404. float h_scale = 1.f;
  1405. float w_scale = 1.f;
  1406. if (scales.size() == 2)
  1407. {
  1408. w_scale = scales[1];
  1409. }
  1410. else if (scales.size() == 3)
  1411. {
  1412. h_scale = scales[1];
  1413. w_scale = scales[2];
  1414. }
  1415. else if (scales.size() == 4)
  1416. {
  1417. h_scale = scales[2];
  1418. w_scale = scales[3];
  1419. if (scales[1] != 1.f)
  1420. fprintf(stderr, "Unsupported Upsample scales !\n");
  1421. }
  1422. else
  1423. {
  1424. fprintf(stderr, "Unsupported Upsample scales !\n");
  1425. }
  1426. fprintf(pp, " 0=%d", resize_type);
  1427. fprintf(pp, " 1=%f", h_scale);
  1428. fprintf(pp, " 2=%f", w_scale);
  1429. }
  1430. else
  1431. {
  1432. // TODO op specific param
  1433. for (int j=0; j<node.attribute_size(); j++)
  1434. {
  1435. const onnx::AttributeProto& attr = node.attribute(j);
  1436. if (attr.type() == 1)
  1437. {
  1438. fprintf(stderr, " # %s=%f\n", attr.name().c_str(), attr.f());
  1439. }
  1440. else if (attr.type() == 2)
  1441. {
  1442. fprintf(stderr, " # %s=%d\n", attr.name().c_str(), attr.i());
  1443. }
  1444. else if (attr.type() == 3)
  1445. {
  1446. fprintf(stderr, " # %s=%s\n", attr.name().c_str(), attr.s().c_str());
  1447. }
  1448. else
  1449. {
  1450. fprintf(stderr, " # %s %d\n", attr.name().c_str(), attr.type());
  1451. }
  1452. }
  1453. }
  1454. fprintf(pp, "\n");
  1455. for (int j=0; j<output_size; j++)
  1456. {
  1457. const std::string& output_name = node.output(j);
  1458. if (node_reference.find(output_name) != node_reference.end())
  1459. {
  1460. int refcount = node_reference[output_name];
  1461. if (refcount > 1)
  1462. {
  1463. char splitname[256];
  1464. sprintf(splitname, "splitncnn_%d", internal_split);
  1465. fprintf(pp, "%-16s %-24s %d %d", "Split", splitname, 1, refcount);
  1466. fprintf(pp, " %s", output_name.c_str());
  1467. for (int k=0; k<refcount; k++)
  1468. {
  1469. fprintf(pp, " %s_splitncnn_%d", output_name.c_str(), k);
  1470. }
  1471. fprintf(pp, "\n");
  1472. internal_split++;
  1473. }
  1474. }
  1475. }
  1476. }
  1477. fclose(pp);
  1478. fclose(bp);
  1479. return 0;
  1480. }