You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

tensorflow2ncnn.cpp 42 kB

8 years ago
8 years ago
8 years ago
8 years ago
8 years ago
8 years ago
8 years ago
8 years ago
8 years ago
8 years ago
8 years ago
8 years ago
8 years ago
8 years ago
8 years ago
8 years ago
8 years ago
8 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308
  1. // Tencent is pleased to support the open source community by making ncnn available.
  2. //
  3. // Copyright (C) 2017 THL A29 Limited, a Tencent company. All rights reserved.
  4. //
  5. // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
  6. // in compliance with the License. You may obtain a copy of the License at
  7. //
  8. // https://opensource.org/licenses/BSD-3-Clause
  9. //
  10. // Unless required by applicable law or agreed to in writing, software distributed
  11. // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
  12. // CONDITIONS OF ANY KIND, either express or implied. See the License for the
  13. // specific language governing permissions and limitations under the License.
  14. #include <stdio.h>
  15. #include <limits.h>
  16. #include <iostream>
  17. #include <fstream>
  18. #include <set>
  19. #include <limits>
  20. #include <algorithm>
  21. #include <google/protobuf/io/coded_stream.h>
  22. #include <google/protobuf/io/zero_copy_stream_impl.h>
  23. #include <google/protobuf/text_format.h>
  24. #include <google/protobuf/message.h>
  25. #include "graph.pb.h"
  26. static bool read_proto_from_binary(const char* filepath, google::protobuf::Message* message)
  27. {
  28. std::ifstream fs(filepath, std::ifstream::in | std::ifstream::binary);
  29. if (!fs.is_open())
  30. {
  31. fprintf(stderr, "open failed %s\n", filepath);
  32. return false;
  33. }
  34. google::protobuf::io::IstreamInputStream input(&fs);
  35. google::protobuf::io::CodedInputStream codedstr(&input);
  36. codedstr.SetTotalBytesLimit(INT_MAX, INT_MAX / 2);
  37. bool success = message->ParseFromCodedStream(&codedstr);
  38. fs.close();
  39. return success;
  40. }
  41. static bool find_tensor_proto(const std::map<std::string, tensorflow::TensorProto>& weights,
  42. const tensorflow::NodeDef& node, tensorflow::TensorProto& tensor)
  43. {
  44. for (int j=0; j<node.input_size(); j++)
  45. {
  46. const std::string& input_name = node.input(j);
  47. const std::map<std::string, tensorflow::TensorProto>::const_iterator it = weights.find(input_name);
  48. if (it != weights.end())
  49. {
  50. tensor = it->second;
  51. return true;
  52. }
  53. }
  54. return false;
  55. }
  56. static bool get_tensor_proto(const std::map<std::string, tensorflow::TensorProto>& consts,
  57. const tensorflow::NodeDef& node, tensorflow::TensorProto& tensor)
  58. {
  59. const std::string& output_name = node.name();
  60. const std::map<std::string, tensorflow::TensorProto>::const_iterator it = consts.find(output_name);
  61. if (it != consts.end())
  62. {
  63. tensor = it->second;
  64. return true;
  65. }
  66. return false;
  67. }
  68. static bool find_attr_value(const tensorflow::NodeDef& node, const char* key, tensorflow::AttrValue& value)
  69. {
  70. const google::protobuf::Map<std::string, tensorflow::AttrValue>& attr = node.attr();
  71. const google::protobuf::Map<std::string, tensorflow::AttrValue>::const_iterator it = attr.find(key);
  72. if (it != attr.end())
  73. {
  74. value = it->second;
  75. return true;
  76. }
  77. return false;
  78. }
  79. static int parse_tensor_reduction_dim(const tensorflow::TensorProto& tensor)
  80. {
  81. int dim = 0;
  82. // dim == 0 // w h c -> X X X
  83. // dim == 1 // w h c -> X X c
  84. // dim == 2 // w h c -> X h c
  85. // dim == -1 // w h c -> w X X
  86. // dim == -2 // w h c -> w h X
  87. if (!tensor.tensor_content().empty() && tensor.dtype() == 3)// int32
  88. {
  89. const int* data = reinterpret_cast<const int*>(tensor.tensor_content().c_str());
  90. int size = tensor.tensor_content().size() / sizeof(int);
  91. // n h w c
  92. // n h w
  93. // n w
  94. // TODO investigate two stage / three stage reduction
  95. if (size == 2)
  96. {
  97. if (data[0] == 1 && data[1] == 2)
  98. {
  99. dim = 1;
  100. }
  101. }
  102. }
  103. else
  104. {
  105. int axis = tensor.int_val(0);
  106. if (axis == 1)
  107. dim = 0;
  108. else if (axis == 3)
  109. dim = -2;
  110. }
  111. return dim;
  112. }
  113. int main(int argc, char** argv)
  114. {
  115. const char* tensorflowpb = argv[1];
  116. const char* ncnn_prototxt = argc >= 4 ? argv[2] : "ncnn.proto";
  117. const char* ncnn_modelbin = argc >= 4 ? argv[3] : "ncnn.bin";
  118. tensorflow::GraphDef graph;
  119. // load
  120. bool s1 = read_proto_from_binary(tensorflowpb, &graph);
  121. if (!s1)
  122. {
  123. fprintf(stderr, "read_proto_from_binary failed\n");
  124. return -1;
  125. }
  126. FILE* pp = fopen(ncnn_prototxt, "wb");
  127. FILE* bp = fopen(ncnn_modelbin, "wb");
  128. int node_count = graph.node_size();
  129. // fprintf(stderr, "node_count = %d\n\n", node_count);
  130. // node reference
  131. std::map<std::string, int> node_reference;
  132. // mapping for Const and Const-Identity
  133. std::map<std::string, tensorflow::TensorProto> weights;
  134. // Dropout like Identity
  135. std::set<std::string> dropouts;
  136. // Const before BinaryOp
  137. std::map<std::string, tensorflow::TensorProto> binaryop_consts;
  138. // global definition line
  139. // [layer count] [blob count]
  140. std::set<std::string> blob_names;
  141. for (int i=0; i<node_count; i++)
  142. {
  143. const tensorflow::NodeDef& node = graph.node(i);
  144. const std::string& output_name = node.name();
  145. if (node.op() == "Const")
  146. {
  147. tensorflow::AttrValue value;
  148. if (find_attr_value(node, "value", value))
  149. {
  150. const tensorflow::TensorProto& tensor = value.tensor();
  151. weights[output_name] = tensor;
  152. }
  153. continue;
  154. }
  155. else if (node.op() == "Identity")
  156. {
  157. const std::string& input_name = node.input(0);
  158. if (weights.find(input_name) != weights.end())
  159. {
  160. weights[output_name] = weights[input_name];
  161. continue;
  162. }
  163. else
  164. {
  165. dropouts.insert(output_name);
  166. }
  167. }
  168. else if (node.op() == "NoOp")
  169. {
  170. weights[output_name] = tensorflow::TensorProto();
  171. continue;
  172. }
  173. else
  174. {
  175. bool isBinaryOp = false;
  176. if (node.op() == "Add" || node.op() == "BiasAdd" || node.op() == "Div"
  177. || node.op() == "Mul" || node.op() == "RealDiv" || node.op() == "Sub")
  178. {
  179. isBinaryOp = true;
  180. }
  181. if (node.op() == "Max" || node.op() == "Maximum" || node.op() == "Min" || node.op() == "Minimum")
  182. {
  183. // check weights
  184. tensorflow::TensorProto tensor;
  185. if (!find_tensor_proto(weights, node, tensor))
  186. {
  187. isBinaryOp = true;
  188. }
  189. }
  190. if (isBinaryOp)
  191. {
  192. // check weights
  193. for (int j=0; j<node.input_size(); j++)
  194. {
  195. const std::string& input_name = node.input(j);
  196. std::map<std::string, tensorflow::TensorProto>::iterator it = weights.find(input_name);
  197. if (it != weights.end())
  198. {
  199. // binary op with const, insert MemoryData layer and const blob
  200. binaryop_consts[input_name] = it->second;
  201. weights.erase(it);
  202. }
  203. }
  204. }
  205. }
  206. // input
  207. for (int j=0; j<node.input_size(); j++)
  208. {
  209. const std::string& input_name = node.input(j);
  210. // fprintf(stderr, "input = %s\n", input_name.c_str());
  211. if (weights.find(input_name) != weights.end())
  212. {
  213. continue;
  214. }
  215. blob_names.insert(input_name);
  216. if (node_reference.find(input_name) == node_reference.end())
  217. {
  218. node_reference[input_name] = 1;
  219. }
  220. else
  221. {
  222. node_reference[input_name] = node_reference[input_name] + 1;
  223. }
  224. }
  225. // output
  226. // fprintf(stderr, "output = %s\n", output_name.c_str());
  227. blob_names.insert(output_name);
  228. }
  229. // remove node_reference entry with reference equals to one
  230. int splitncnn_blob_count = 0;
  231. std::map<std::string, int>::iterator it = node_reference.begin();
  232. while (it != node_reference.end())
  233. {
  234. if (it->second == 1)
  235. {
  236. node_reference.erase(it++);
  237. }
  238. else
  239. {
  240. splitncnn_blob_count += it->second;
  241. // fprintf(stderr, "%s %d\n", it->first.c_str(), it->second);
  242. ++it;
  243. }
  244. }
  245. fprintf(pp, "%lu %lu\n", node_count + node_reference.size() - weights.size(), blob_names.size() + splitncnn_blob_count);
  246. int internal_split = 0;
  247. for (int i=0; i<node_count; i++)
  248. {
  249. const tensorflow::NodeDef& node = graph.node(i);
  250. // layer definition line, repeated
  251. // [type] [name] [bottom blob count] [top blob count] [bottom blobs] [top blobs] [layer specific params]
  252. // fprintf(pp, "%-16s %-16s %d %d", layer.type().c_str(), layer.name().c_str(), node.input_size(), layer.top_size());
  253. if (node.op() == "Add" || node.op() == "BiasAdd")
  254. {
  255. fprintf(pp, "%-16s", "BinaryOp");
  256. }
  257. else if (node.op() == "AvgPool")
  258. {
  259. fprintf(pp, "%-16s", "Pooling");
  260. }
  261. else if (node.op() == "Concat" || node.op() == "ConcatV2")
  262. {
  263. fprintf(pp, "%-16s", "Concat");
  264. }
  265. else if (node.op() == "Const")
  266. {
  267. // check before binaryop
  268. tensorflow::TensorProto tensor;
  269. if (get_tensor_proto(binaryop_consts, node, tensor))
  270. {
  271. fprintf(pp, "%-16s", "MemoryData");
  272. }
  273. else
  274. {
  275. continue;
  276. }
  277. }
  278. else if (node.op() == "Conv2D")
  279. {
  280. fprintf(pp, "%-16s", "Convolution");
  281. }
  282. else if (node.op() == "DepthwiseConv2dNative")
  283. {
  284. fprintf(pp, "%-16s", "ConvolutionDepthWise");
  285. }
  286. else if (node.op() == "Div" || node.op() == "RealDiv")
  287. {
  288. fprintf(pp, "%-16s", "BinaryOp");
  289. }
  290. else if (node.op() == "Exp")
  291. {
  292. fprintf(pp, "%-16s", "UnaryOp");
  293. }
  294. else if (node.op() == "ExpandDims")
  295. {
  296. fprintf(pp, "%-16s", "ExpandDims");
  297. }
  298. else if (node.op() == "Floor")
  299. {
  300. fprintf(pp, "%-16s", "UnaryOp");
  301. }
  302. else if (node.op() == "Identity")
  303. {
  304. // check before binaryop
  305. tensorflow::TensorProto tensor;
  306. if (get_tensor_proto(binaryop_consts, node, tensor))
  307. {
  308. fprintf(pp, "%-16s", "MemoryData");
  309. }
  310. else if (dropouts.find(node.name()) != dropouts.end())
  311. {
  312. fprintf(pp, "%-16s", "Dropout");
  313. }
  314. else
  315. {
  316. continue;
  317. }
  318. }
  319. else if (node.op() == "LRN")
  320. {
  321. fprintf(pp, "%-16s", "LRN");
  322. }
  323. else if (node.op() == "MatMul")
  324. {
  325. fprintf(pp, "%-16s", "InnerProduct");
  326. }
  327. else if (node.op() == "Max" || node.op() == "Maximum")
  328. {
  329. // check weights
  330. tensorflow::TensorProto tensor;
  331. if (find_tensor_proto(weights, node, tensor))
  332. {
  333. fprintf(pp, "%-16s", "Reduction");
  334. }
  335. else
  336. {
  337. fprintf(pp, "%-16s", "BinaryOp");
  338. }
  339. }
  340. else if (node.op() == "MaxPool")
  341. {
  342. fprintf(pp, "%-16s", "Pooling");
  343. }
  344. else if (node.op() == "Min" || node.op() == "Minimum")
  345. {
  346. // check weights
  347. tensorflow::TensorProto tensor;
  348. if (find_tensor_proto(weights, node, tensor))
  349. {
  350. fprintf(pp, "%-16s", "Reduction");
  351. }
  352. else
  353. {
  354. fprintf(pp, "%-16s", "BinaryOp");
  355. }
  356. }
  357. else if (node.op() == "Mul")
  358. {
  359. fprintf(pp, "%-16s", "BinaryOp");
  360. }
  361. else if (node.op() == "Neg")
  362. {
  363. fprintf(pp, "%-16s", "UnaryOp");
  364. }
  365. else if (node.op() == "NoOp")
  366. {
  367. continue;
  368. }
  369. else if (node.op() == "Pad")
  370. {
  371. fprintf(pp, "%-16s", "Padding");
  372. }
  373. else if (node.op() == "Placeholder")
  374. {
  375. fprintf(pp, "%-16s", "Input");
  376. }
  377. else if (node.op() == "Prod")
  378. {
  379. fprintf(pp, "%-16s", "Reduction");
  380. }
  381. else if (node.op() == "Reciprocal")
  382. {
  383. fprintf(pp, "%-16s", "UnaryOp");
  384. }
  385. else if (node.op() == "Relu")
  386. {
  387. fprintf(pp, "%-16s", "ReLU");
  388. }
  389. else if (node.op() == "Reshape")
  390. {
  391. fprintf(pp, "%-16s", "Reshape");
  392. }
  393. else if (node.op() == "Rsqrt")
  394. {
  395. fprintf(pp, "%-16s", "UnaryOp");
  396. }
  397. else if (node.op() == "Sigmoid")
  398. {
  399. fprintf(pp, "%-16s", "Sigmoid");
  400. }
  401. else if (node.op() == "Softmax")
  402. {
  403. fprintf(pp, "%-16s", "Softmax");
  404. }
  405. else if (node.op() == "Square")
  406. {
  407. fprintf(pp, "%-16s", "UnaryOp");
  408. }
  409. else if (node.op() == "Squeeze")
  410. {
  411. fprintf(pp, "%-16s", "Squeeze");
  412. }
  413. else if (node.op() == "Sub")
  414. {
  415. fprintf(pp, "%-16s", "BinaryOp");
  416. }
  417. else if (node.op() == "Sum")
  418. {
  419. fprintf(pp, "%-16s", "Reduction");
  420. }
  421. else
  422. {
  423. fprintf(pp, "%-16s", node.op().c_str());
  424. fprintf(stderr, "%s not supported yet !\nn", node.op().c_str());
  425. }
  426. int input_size = node.input_size();
  427. for (int j=0; j<node.input_size(); j++)
  428. {
  429. const std::string& input_name = node.input(j);
  430. if (weights.find(input_name) != weights.end())
  431. {
  432. input_size--;
  433. }
  434. }
  435. fprintf(pp, " %-32s %d 1", node.name().c_str(), input_size);
  436. for (int j=0; j<node.input_size(); j++)
  437. {
  438. std::string input_name = node.input(j);
  439. if (weights.find(input_name) != weights.end())
  440. {
  441. continue;
  442. }
  443. if (node_reference.find(input_name) != node_reference.end())
  444. {
  445. int refidx = node_reference[input_name] - 1;
  446. node_reference[input_name] = refidx;
  447. char splitsuffix[256];
  448. sprintf(splitsuffix, "_splitncnn_%d", refidx);
  449. input_name = input_name + splitsuffix;
  450. }
  451. fprintf(pp, " %s", input_name.c_str());
  452. }
  453. fprintf(pp, " %s", node.name().c_str());
  454. if (node.op() == "Add" || node.op() == "BiasAdd")
  455. {
  456. int op_type = 0;
  457. fprintf(pp, " %d", op_type);
  458. }
  459. else if (node.op() == "AvgPool")
  460. {
  461. int pooling_type = 1;
  462. int kernel_size_h = 1;
  463. int kernel_size_w = 1;
  464. int stride_h = 1;
  465. int stride_w = 1;
  466. int pad = 0;
  467. int global_pooling = 0;
  468. tensorflow::AttrValue value_ksize;
  469. if (find_attr_value(node, "ksize", value_ksize))
  470. {
  471. // batch, height, width, channels
  472. kernel_size_h = value_ksize.list().i(1);
  473. kernel_size_w = value_ksize.list().i(2);
  474. }
  475. tensorflow::AttrValue value_strides;
  476. if (find_attr_value(node, "strides", value_strides))
  477. {
  478. // batch, height, width, channels
  479. stride_h = value_strides.list().i(1);
  480. stride_w = value_strides.list().i(2);
  481. }
  482. tensorflow::AttrValue value_padding;
  483. if (find_attr_value(node, "padding", value_padding))
  484. {
  485. if (value_padding.s() == "VALID")
  486. {
  487. pad = 0;
  488. }
  489. else if (value_padding.s() == "SAME")
  490. {
  491. pad = -233;
  492. }
  493. }
  494. fprintf(pp, " %d %d %d %d %d", pooling_type, kernel_size_w, stride_w, pad, global_pooling);
  495. }
  496. else if (node.op() == "Concat" || node.op() == "ConcatV2")
  497. {
  498. tensorflow::TensorProto tensor;
  499. if (find_tensor_proto(weights, node, tensor))
  500. {
  501. // TODO
  502. int axis = tensor.int_val(0);
  503. }
  504. }
  505. else if (node.op() == "Const" || node.op() == "Identity")
  506. {
  507. // check before binaryop
  508. tensorflow::TensorProto tensor;
  509. if (get_tensor_proto(binaryop_consts, node, tensor))
  510. {
  511. const tensorflow::TensorShapeProto& shape = tensor.tensor_shape();
  512. int w = 0;
  513. int h = 0;
  514. int c = 0;
  515. if (shape.dim_size() == 1)
  516. {
  517. w = shape.dim(0).size();
  518. }
  519. else if (shape.dim_size() == 2)
  520. {
  521. h = shape.dim(0).size();
  522. w = shape.dim(1).size();
  523. }
  524. else if (shape.dim_size() == 3)
  525. {
  526. c = shape.dim(2).size();
  527. h = shape.dim(0).size();
  528. w = shape.dim(1).size();
  529. }
  530. int weight_data_size = 0;
  531. if (!tensor.tensor_content().empty())
  532. {
  533. if (tensor.dtype() == 1)// float
  534. {
  535. const float* data = reinterpret_cast<const float*>(tensor.tensor_content().c_str());
  536. weight_data_size = tensor.tensor_content().size() / sizeof(float);
  537. if (c == 0)
  538. fwrite(data, sizeof(float), weight_data_size, bp);
  539. else
  540. {
  541. float tmp;
  542. // h-w-c to c-h-w
  543. for (int p=0; p<c; p++)
  544. {
  545. for (int i=0; i<h; i++)
  546. {
  547. for (int j=0; j<w; j++)
  548. {
  549. tmp = data[i*w*c + j*c + p];
  550. fwrite(&tmp, sizeof(float), 1, bp);
  551. }
  552. }
  553. }
  554. }
  555. }
  556. else if (tensor.dtype() == 3)// int32
  557. {
  558. const int* data = reinterpret_cast<const int*>(tensor.tensor_content().c_str());
  559. weight_data_size = tensor.tensor_content().size() / sizeof(int);
  560. float tmp;
  561. if (c == 0)
  562. {
  563. for (int i=0; i<weight_data_size; i++)
  564. {
  565. tmp = data[i];
  566. fwrite(&tmp, sizeof(float), 1, bp);
  567. }
  568. }
  569. else
  570. {
  571. // h-w-c to c-h-w
  572. for (int p=0; p<c; p++)
  573. {
  574. for (int i=0; i<h; i++)
  575. {
  576. for (int j=0; j<w; j++)
  577. {
  578. tmp = data[i*w*c + j*c + p];
  579. fwrite(&tmp, sizeof(float), 1, bp);
  580. }
  581. }
  582. }
  583. }
  584. }
  585. }
  586. else
  587. {
  588. if (tensor.dtype() == 1)// float
  589. {
  590. float val = tensor.float_val(0);
  591. fwrite(&val, sizeof(float), 1, bp);
  592. }
  593. else if (tensor.dtype() == 3)// int32
  594. {
  595. float val = tensor.int_val(0);
  596. fwrite(&val, sizeof(float), 1, bp);
  597. }
  598. }
  599. fprintf(pp, " %d %d %d", w, h, c);
  600. }
  601. }
  602. else if (node.op() == "Conv2D")
  603. {
  604. // weights
  605. tensorflow::TensorProto tensor;
  606. find_tensor_proto(weights, node, tensor);
  607. const tensorflow::TensorShapeProto& shape = tensor.tensor_shape();
  608. int kernel_size_h = shape.dim(0).size();
  609. int kernel_size_w = shape.dim(1).size();
  610. int num_input = shape.dim(2).size();
  611. int num_output = shape.dim(3).size();
  612. int stride_h = 1;
  613. int stride_w = 1;
  614. int dilation_h = 1;
  615. int dilation_w = 1;
  616. int pad = 0;
  617. tensorflow::AttrValue value_strides;
  618. if (find_attr_value(node, "strides", value_strides))
  619. {
  620. // batch, height, width, channels
  621. stride_h = value_strides.list().i(1);
  622. stride_w = value_strides.list().i(2);
  623. }
  624. tensorflow::AttrValue value_padding;
  625. if (find_attr_value(node, "padding", value_padding))
  626. {
  627. if (value_padding.s() == "VALID")
  628. {
  629. pad = 0;
  630. }
  631. else if (value_padding.s() == "SAME")
  632. {
  633. pad = -233;
  634. }
  635. }
  636. tensorflow::AttrValue value_rate;
  637. if (find_attr_value(node, "rate", value_rate))
  638. {
  639. // height, width
  640. dilation_h = value_rate.list().i(0);
  641. dilation_w = value_rate.list().i(1);
  642. }
  643. int bias_term = 0;
  644. int weight_data_size = 0;
  645. // reorder h-w-i-o to o-i-h-w
  646. if (!tensor.tensor_content().empty())
  647. {
  648. int quantize_tag = 0;
  649. fwrite(&quantize_tag, sizeof(int), 1, bp);
  650. if (tensor.dtype() == 1)// float
  651. {
  652. const float* data = reinterpret_cast<const float*>(tensor.tensor_content().c_str());
  653. weight_data_size = tensor.tensor_content().size() / sizeof(float);
  654. float tmp;
  655. for (int p=0; p<num_output; p++)
  656. {
  657. for (int q=0; q<num_input; q++)
  658. {
  659. for (int i=0; i<kernel_size_h; i++)
  660. {
  661. for (int j=0; j<kernel_size_w; j++)
  662. {
  663. tmp = data[i*kernel_size_w*num_input*num_output + j*num_input*num_output + q*num_output + p];
  664. fwrite(&tmp, sizeof(float), 1, bp);
  665. }
  666. }
  667. }
  668. }
  669. }
  670. else if (tensor.dtype() == 3)// int32
  671. {
  672. const int* data = reinterpret_cast<const int*>(tensor.tensor_content().c_str());
  673. weight_data_size = tensor.tensor_content().size() / sizeof(int);
  674. float tmp;
  675. for (int p=0; p<num_output; p++)
  676. {
  677. for (int q=0; q<num_input; q++)
  678. {
  679. for (int i=0; i<kernel_size_h; i++)
  680. {
  681. for (int j=0; j<kernel_size_w; j++)
  682. {
  683. tmp = data[i*kernel_size_w*num_input*num_output + j*num_input*num_output + q*num_output + p];
  684. fwrite(&tmp, sizeof(float), 1, bp);
  685. }
  686. }
  687. }
  688. }
  689. }
  690. }
  691. fprintf(pp, " %d %d %d %d %d %d %d", num_output, kernel_size_w, dilation_w, stride_w, pad, bias_term, weight_data_size);
  692. }
  693. else if (node.op() == "DepthwiseConv2dNative")
  694. {
  695. // weights
  696. tensorflow::TensorProto tensor;
  697. find_tensor_proto(weights, node, tensor);
  698. const tensorflow::TensorShapeProto& shape = tensor.tensor_shape();
  699. int kernel_size_h = shape.dim(0).size();
  700. int kernel_size_w = shape.dim(1).size();
  701. int num_input = shape.dim(2).size();
  702. int channel_multiplier = shape.dim(3).size();
  703. int num_output = num_input * channel_multiplier;
  704. int group = num_input;
  705. int stride_h = 1;
  706. int stride_w = 1;
  707. int dilation_h = 1;
  708. int dilation_w = 1;
  709. int pad = 0;
  710. tensorflow::AttrValue value_strides;
  711. if (find_attr_value(node, "strides", value_strides))
  712. {
  713. // batch, height, width, channels
  714. stride_h = value_strides.list().i(1);
  715. stride_w = value_strides.list().i(2);
  716. }
  717. tensorflow::AttrValue value_padding;
  718. if (find_attr_value(node, "padding", value_padding))
  719. {
  720. if (value_padding.s() == "VALID")
  721. {
  722. pad = 0;
  723. }
  724. else if (value_padding.s() == "SAME")
  725. {
  726. pad = -233;
  727. }
  728. }
  729. tensorflow::AttrValue value_rate;
  730. if (find_attr_value(node, "rate", value_rate))
  731. {
  732. // height, width
  733. dilation_h = value_rate.list().i(0);
  734. dilation_w = value_rate.list().i(1);
  735. }
  736. int bias_term = 0;
  737. int weight_data_size = 0;
  738. // reorder h-w-i-cm to i-cm-h-w
  739. if (!tensor.tensor_content().empty())
  740. {
  741. int quantize_tag = 0;
  742. fwrite(&quantize_tag, sizeof(int), 1, bp);
  743. if (tensor.dtype() == 1)// float
  744. {
  745. const float* data = reinterpret_cast<const float*>(tensor.tensor_content().c_str());
  746. weight_data_size = tensor.tensor_content().size() / sizeof(float);
  747. float tmp;
  748. for (int p=0; p<num_input; p++)
  749. {
  750. for (int q=0; q<channel_multiplier; q++)
  751. {
  752. for (int i=0; i<kernel_size_h; i++)
  753. {
  754. for (int j=0; j<kernel_size_w; j++)
  755. {
  756. tmp = data[i*kernel_size_w*channel_multiplier*num_input + j*channel_multiplier*num_input + p*channel_multiplier + q];
  757. fwrite(&tmp, sizeof(float), 1, bp);
  758. }
  759. }
  760. }
  761. }
  762. }
  763. else if (tensor.dtype() == 3)// int32
  764. {
  765. const int* data = reinterpret_cast<const int*>(tensor.tensor_content().c_str());
  766. weight_data_size = tensor.tensor_content().size() / sizeof(int);
  767. float tmp;
  768. for (int p=0; p<num_input; p++)
  769. {
  770. for (int q=0; q<channel_multiplier; q++)
  771. {
  772. for (int i=0; i<kernel_size_h; i++)
  773. {
  774. for (int j=0; j<kernel_size_w; j++)
  775. {
  776. tmp = data[i*kernel_size_w*channel_multiplier*num_input + j*channel_multiplier*num_input + p*channel_multiplier + q];
  777. fwrite(&tmp, sizeof(float), 1, bp);
  778. }
  779. }
  780. }
  781. }
  782. }
  783. }
  784. fprintf(pp, " %d %d %d %d %d %d %d %d", num_output, kernel_size_w, dilation_w, stride_w, pad, bias_term, weight_data_size, group);
  785. }
  786. else if (node.op() == "Div" || node.op() == "RealDiv")
  787. {
  788. int op_type = 3;
  789. fprintf(pp, " %d", op_type);
  790. }
  791. else if (node.op() == "Exp")
  792. {
  793. int op_type = 7;
  794. fprintf(pp, " %d", op_type);
  795. }
  796. else if (node.op() == "ExpandDims")
  797. {
  798. int expand_w = 0;
  799. int expand_h = 0;
  800. int expand_c = 0;
  801. tensorflow::AttrValue value_dim;
  802. if (find_attr_value(node, "Tdim", value_dim))
  803. {
  804. int dim = value_dim.i();
  805. if (dim == 0)
  806. expand_w = 1;
  807. if (dim == 1)
  808. expand_h = 1;
  809. if (dim == 2)
  810. expand_c = 1;
  811. }
  812. fprintf(pp, " %d %d %d", expand_w, expand_h, expand_c);
  813. }
  814. else if (node.op() == "Floor")
  815. {
  816. int op_type = 2;
  817. fprintf(pp, " %d", op_type);
  818. }
  819. else if (node.op() == "LRN")
  820. {
  821. int norm_region = 0;
  822. int local_size = 1;
  823. float alpha = 1.f;
  824. float beta = 0.5f;
  825. tensorflow::AttrValue value_depth_radius;
  826. if (find_attr_value(node, "depth_radius", value_depth_radius))
  827. {
  828. local_size = value_depth_radius.i() * 2 + 1;
  829. }
  830. tensorflow::AttrValue value_alpha;
  831. if (find_attr_value(node, "alpha", value_alpha))
  832. {
  833. alpha = value_alpha.f();
  834. }
  835. tensorflow::AttrValue value_beta;
  836. if (find_attr_value(node, "beta", value_beta))
  837. {
  838. beta = value_beta.f();
  839. }
  840. // TODO
  841. float bias = 1.f;
  842. tensorflow::AttrValue value_bias;
  843. if (find_attr_value(node, "bias", value_bias))
  844. {
  845. bias = value_bias.f();
  846. }
  847. fprintf(pp, " %d %d %f %f", norm_region, local_size, alpha, beta);
  848. }
  849. else if (node.op() == "MatMul")
  850. {
  851. // weights
  852. tensorflow::TensorProto tensor;
  853. find_tensor_proto(weights, node, tensor);
  854. const tensorflow::TensorShapeProto& shape = tensor.tensor_shape();
  855. int num_input = shape.dim(0).size();
  856. int num_output = shape.dim(1).size();
  857. int bias_term = 0;
  858. int weight_data_size = 0;
  859. // reorder i-o to o-i
  860. if (!tensor.tensor_content().empty())
  861. {
  862. int quantize_tag = 0;
  863. fwrite(&quantize_tag, sizeof(int), 1, bp);
  864. if (tensor.dtype() == 1)// float
  865. {
  866. const float* data = reinterpret_cast<const float*>(tensor.tensor_content().c_str());
  867. weight_data_size = tensor.tensor_content().size() / sizeof(float);
  868. float tmp;
  869. for (int p=0; p<num_output; p++)
  870. {
  871. for (int q=0; q<num_input; q++)
  872. {
  873. tmp = data[q*num_output + p];
  874. fwrite(&tmp, sizeof(float), 1, bp);
  875. }
  876. }
  877. }
  878. else if (tensor.dtype() == 3)// int32
  879. {
  880. const int* data = reinterpret_cast<const int*>(tensor.tensor_content().c_str());
  881. weight_data_size = tensor.tensor_content().size() / sizeof(int);
  882. float tmp;
  883. for (int p=0; p<num_output; p++)
  884. {
  885. for (int q=0; q<num_input; q++)
  886. {
  887. tmp = data[q*num_output + p];
  888. fwrite(&tmp, sizeof(float), 1, bp);
  889. }
  890. }
  891. }
  892. }
  893. fprintf(pp, " %d %d %d", num_output, bias_term, weight_data_size);
  894. }
  895. else if (node.op() == "Max" || node.op() == "Maximum")
  896. {
  897. // check weights
  898. tensorflow::TensorProto tensor;
  899. if (find_tensor_proto(weights, node, tensor))
  900. {
  901. int operation = 4;
  902. int dim = 0;
  903. float coeff = 1.f;
  904. dim = parse_tensor_reduction_dim(tensor);
  905. fprintf(pp, " %d %d %f", operation, dim, coeff);
  906. }
  907. else
  908. {
  909. int op_type = 4;
  910. fprintf(pp, " %d", op_type);
  911. }
  912. }
  913. else if (node.op() == "MaxPool")
  914. {
  915. int pooling_type = 0;
  916. int kernel_size_h = 1;
  917. int kernel_size_w = 1;
  918. int stride_h = 1;
  919. int stride_w = 1;
  920. int pad = 0;
  921. int global_pooling = 0;
  922. tensorflow::AttrValue value_ksize;
  923. if (find_attr_value(node, "ksize", value_ksize))
  924. {
  925. // batch, height, width, channels
  926. kernel_size_h = value_ksize.list().i(1);
  927. kernel_size_w = value_ksize.list().i(2);
  928. }
  929. tensorflow::AttrValue value_strides;
  930. if (find_attr_value(node, "strides", value_strides))
  931. {
  932. // batch, height, width, channels
  933. stride_h = value_strides.list().i(1);
  934. stride_w = value_strides.list().i(2);
  935. }
  936. tensorflow::AttrValue value_padding;
  937. if (find_attr_value(node, "padding", value_padding))
  938. {
  939. if (value_padding.s() == "VALID")
  940. {
  941. pad = -2333;
  942. }
  943. else if (value_padding.s() == "SAME")
  944. {
  945. pad = -233;
  946. }
  947. }
  948. fprintf(pp, " %d %d %d %d %d", pooling_type, kernel_size_w, stride_w, pad, global_pooling);
  949. }
  950. else if (node.op() == "Min" || node.op() == "Minimum")
  951. {
  952. // check weights
  953. tensorflow::TensorProto tensor;
  954. if (find_tensor_proto(weights, node, tensor))
  955. {
  956. int operation = 5;
  957. int dim = 0;
  958. float coeff = 1.f;
  959. dim = parse_tensor_reduction_dim(tensor);
  960. fprintf(pp, " %d %d %f", operation, dim, coeff);
  961. }
  962. else
  963. {
  964. int op_type = 5;
  965. fprintf(pp, " %d", op_type);
  966. }
  967. }
  968. else if (node.op() == "Mul")
  969. {
  970. int op_type = 2;
  971. fprintf(pp, " %d", op_type);
  972. }
  973. else if (node.op() == "Neg")
  974. {
  975. int op_type = 1;
  976. fprintf(pp, " %d", op_type);
  977. }
  978. else if (node.op() == "NoOp")
  979. {
  980. }
  981. else if (node.op() == "Pad")
  982. {
  983. int top = 0;
  984. int bottom = 0;
  985. int left = 0;
  986. int right = 0;
  987. int type = 0;
  988. float value = 0.f;
  989. // check weights
  990. tensorflow::TensorProto tensor;
  991. if (find_tensor_proto(weights, node, tensor))
  992. {
  993. if (!tensor.tensor_content().empty() && tensor.dtype() == 3)// int32
  994. {
  995. const int *data = reinterpret_cast<const int*>(tensor.tensor_content().c_str());
  996. int size = tensor.tensor_content().size() / sizeof(int);
  997. if (size == 8)
  998. {
  999. // n h w c
  1000. top = data[2];
  1001. bottom = data[3];
  1002. left = data[4];
  1003. right = data[5];
  1004. }
  1005. }
  1006. }
  1007. tensorflow::AttrValue value_Tpaddings;
  1008. if (find_attr_value(node, "Tpaddings", value_Tpaddings))
  1009. {
  1010. type = value_Tpaddings.i();
  1011. }
  1012. tensorflow::AttrValue value_T;
  1013. if (find_attr_value(node, "T", value_T))
  1014. {
  1015. value = value_T.f();
  1016. }
  1017. fprintf(pp, " %d %d %d %d %d %f", top, bottom, left, right, type, value);
  1018. }
  1019. else if (node.op() == "Placeholder")
  1020. {
  1021. // TODO pass through
  1022. fprintf(pp, " 0 0 0");
  1023. }
  1024. else if (node.op() == "Prod")
  1025. {
  1026. int operation = 6;
  1027. int dim = 0;
  1028. float coeff = 1.f;
  1029. // check weights
  1030. tensorflow::TensorProto tensor;
  1031. if (find_tensor_proto(weights, node, tensor))
  1032. {
  1033. dim = parse_tensor_reduction_dim(tensor);
  1034. }
  1035. fprintf(pp, " %d %d %f", operation, dim, coeff);
  1036. }
  1037. else if (node.op() == "Reciprocal")
  1038. {
  1039. int op_type = 15;
  1040. fprintf(pp, " %d", op_type);
  1041. }
  1042. else if (node.op() == "Relu")
  1043. {
  1044. float slope = 0.f;
  1045. fprintf(pp, " %f", slope);
  1046. }
  1047. else if (node.op() == "Reshape")
  1048. {
  1049. tensorflow::TensorProto tensor;
  1050. if (find_tensor_proto(weights, node, tensor))
  1051. {
  1052. if (!tensor.tensor_content().empty() && tensor.dtype() == 3)// int32
  1053. {
  1054. const int* data = reinterpret_cast<const int*>(tensor.tensor_content().c_str());
  1055. int size = tensor.tensor_content().size() / sizeof(int);
  1056. // n h w c
  1057. // n h w
  1058. // n w
  1059. if (size == 4)
  1060. {
  1061. fprintf(pp, " %d %d %d 0", data[2], data[1], data[3]);
  1062. }
  1063. if (size == 3)
  1064. {
  1065. fprintf(pp, " %d %d -233 1", data[2], data[1]);
  1066. }
  1067. if (size == 2)
  1068. {
  1069. fprintf(pp, " %d -233 -233 1", data[1]);
  1070. }
  1071. }
  1072. }
  1073. else
  1074. {
  1075. // pass through
  1076. fprintf(pp, " 0 0 0");
  1077. }
  1078. }
  1079. else if (node.op() == "Rsqrt")
  1080. {
  1081. int op_type = 6;
  1082. fprintf(pp, " %d", op_type);
  1083. }
  1084. else if (node.op() == "Sigmoid")
  1085. {
  1086. }
  1087. else if (node.op() == "Softmax")
  1088. {
  1089. }
  1090. else if (node.op() == "Square")
  1091. {
  1092. int op_type = 4;
  1093. fprintf(pp, " %d", op_type);
  1094. }
  1095. else if (node.op() == "Squeeze")
  1096. {
  1097. int squeeze_w = 0;
  1098. int squeeze_h = 0;
  1099. int squeeze_c = 0;
  1100. tensorflow::AttrValue value_squeeze_dims;
  1101. if (find_attr_value(node, "squeeze_dims", value_squeeze_dims))
  1102. {
  1103. for (int i = 0; i<value_squeeze_dims.list().i_size(); i++)
  1104. {
  1105. int dim = value_squeeze_dims.list().i(i);
  1106. if (dim == 0)
  1107. squeeze_w = 1;
  1108. if (dim == 1)
  1109. squeeze_h = 1;
  1110. if (dim == 2)
  1111. squeeze_c = 1;
  1112. }
  1113. }
  1114. fprintf(pp, " %d %d %d", squeeze_w, squeeze_h, squeeze_c);
  1115. }
  1116. else if (node.op() == "Sub")
  1117. {
  1118. int op_type = 1;
  1119. fprintf(pp, " %d", op_type);
  1120. }
  1121. else if (node.op() == "Sum")
  1122. {
  1123. int operation = 0;
  1124. int dim = 0;
  1125. float coeff = 1.f;
  1126. // check weights
  1127. tensorflow::TensorProto tensor;
  1128. if (find_tensor_proto(weights, node, tensor))
  1129. {
  1130. dim = parse_tensor_reduction_dim(tensor);
  1131. }
  1132. fprintf(pp, " %d %d %f", operation, dim, coeff);
  1133. }
  1134. else
  1135. {
  1136. const google::protobuf::Map<std::string, tensorflow::AttrValue>& attr = node.attr();
  1137. google::protobuf::Map<std::string, tensorflow::AttrValue>::const_iterator it = attr.begin();
  1138. for (; it != attr.end(); it++)
  1139. {
  1140. std::cerr << it->first << " #" << it->second.type() << std::endl;
  1141. }
  1142. }
  1143. fprintf(pp, "\n");
  1144. std::string output_name = node.name();
  1145. if (node_reference.find(output_name) != node_reference.end())
  1146. {
  1147. int refcount = node_reference[output_name];
  1148. if (refcount > 1)
  1149. {
  1150. char splitname[256];
  1151. sprintf(splitname, "splitncnn_%d", internal_split);
  1152. fprintf(pp, "%-16s %-32s %d %d", "Split", splitname, 1, refcount);
  1153. fprintf(pp, " %s", output_name.c_str());
  1154. for (int j=0; j<refcount; j++)
  1155. {
  1156. fprintf(pp, " %s_splitncnn_%d", output_name.c_str(), j);
  1157. }
  1158. fprintf(pp, "\n");
  1159. internal_split++;
  1160. }
  1161. }
  1162. }
  1163. fclose(pp);
  1164. fclose(bp);
  1165. return 0;
  1166. }