You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

tensorflow2ncnn.cpp 21 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664
  1. // Tencent is pleased to support the open source community by making ncnn available.
  2. //
  3. // Copyright (C) 2017 THL A29 Limited, a Tencent company. All rights reserved.
  4. //
  5. // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
  6. // in compliance with the License. You may obtain a copy of the License at
  7. //
  8. // https://opensource.org/licenses/BSD-3-Clause
  9. //
  10. // Unless required by applicable law or agreed to in writing, software distributed
  11. // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
  12. // CONDITIONS OF ANY KIND, either express or implied. See the License for the
  13. // specific language governing permissions and limitations under the License.
  14. #include <stdio.h>
  15. #include <limits.h>
  16. #include <iostream>
  17. #include <fstream>
  18. #include <set>
  19. #include <limits>
  20. #include <algorithm>
  21. #include <google/protobuf/io/coded_stream.h>
  22. #include <google/protobuf/io/zero_copy_stream_impl.h>
  23. #include <google/protobuf/text_format.h>
  24. #include <google/protobuf/message.h>
  25. #include "graph.pb.h"
  26. static bool read_proto_from_binary(const char* filepath, google::protobuf::Message* message)
  27. {
  28. std::ifstream fs(filepath, std::ifstream::in | std::ifstream::binary);
  29. if (!fs.is_open())
  30. {
  31. fprintf(stderr, "open failed %s\n", filepath);
  32. return false;
  33. }
  34. google::protobuf::io::IstreamInputStream input(&fs);
  35. google::protobuf::io::CodedInputStream codedstr(&input);
  36. codedstr.SetTotalBytesLimit(INT_MAX, INT_MAX / 2);
  37. bool success = message->ParseFromCodedStream(&codedstr);
  38. fs.close();
  39. return success;
  40. }
  41. static bool find_tensor_proto(const std::map<std::string, tensorflow::TensorProto>& weights,
  42. const tensorflow::NodeDef& node, tensorflow::TensorProto& tensor)
  43. {
  44. for (int j=0; j<node.input_size(); j++)
  45. {
  46. const std::string& input_name = node.input(j);
  47. const std::map<std::string, tensorflow::TensorProto>::const_iterator it = weights.find(input_name);
  48. if (it != weights.end())
  49. {
  50. tensor = it->second;
  51. return true;
  52. }
  53. }
  54. return false;
  55. }
  56. static bool find_attr_value(const tensorflow::NodeDef& node, const char* key, tensorflow::AttrValue& value)
  57. {
  58. const google::protobuf::Map<std::string, tensorflow::AttrValue>& attr = node.attr();
  59. const google::protobuf::Map<std::string, tensorflow::AttrValue>::const_iterator it = attr.find(key);
  60. if (it != attr.end())
  61. {
  62. value = it->second;
  63. return true;
  64. }
  65. return false;
  66. }
  67. int main(int argc, char** argv)
  68. {
  69. const char* tensorflowpb = argv[1];
  70. const char* ncnn_prototxt = argc >= 4 ? argv[2] : "ncnn.proto";
  71. const char* ncnn_modelbin = argc >= 4 ? argv[3] : "ncnn.bin";
  72. tensorflow::GraphDef graph;
  73. // load
  74. bool s1 = read_proto_from_binary(tensorflowpb, &graph);
  75. if (!s1)
  76. {
  77. fprintf(stderr, "read_proto_from_binary failed\n");
  78. return -1;
  79. }
  80. FILE* pp = fopen(ncnn_prototxt, "wb");
  81. FILE* bp = fopen(ncnn_modelbin, "wb");
  82. int node_count = graph.node_size();
  83. // fprintf(stderr, "node_count = %d\n\n", node_count);
  84. // node reference
  85. std::map<std::string, int> node_reference;
  86. // mapping for Const and Const-Identity
  87. std::map<std::string, tensorflow::TensorProto> weights;
  88. // global definition line
  89. // [layer count] [blob count]
  90. std::set<std::string> blob_names;
  91. for (int i=0; i<node_count; i++)
  92. {
  93. const tensorflow::NodeDef& node = graph.node(i);
  94. const std::string& output_name = node.name();
  95. if (node.op() == "Const")
  96. {
  97. tensorflow::AttrValue value;
  98. if (find_attr_value(node, "value", value))
  99. {
  100. const tensorflow::TensorProto& tensor = value.tensor();
  101. weights[output_name] = tensor;
  102. }
  103. continue;
  104. }
  105. else if (node.op() == "Identity")
  106. {
  107. const std::string& input_name = node.input(0);
  108. weights[output_name] = weights[input_name];
  109. continue;
  110. }
  111. else if (node.op() == "NoOp")
  112. {
  113. weights[output_name] = tensorflow::TensorProto();
  114. continue;
  115. }
  116. // input
  117. for (int j=0; j<node.input_size(); j++)
  118. {
  119. const std::string& input_name = node.input(j);
  120. // fprintf(stderr, "%s\n", input_name.c_str());
  121. if (weights.find(input_name) != weights.end())
  122. {
  123. continue;
  124. }
  125. blob_names.insert(input_name);
  126. if (node_reference.find(input_name) == node_reference.end())
  127. {
  128. node_reference[input_name] = 1;
  129. }
  130. else
  131. {
  132. node_reference[input_name] = node_reference[input_name] + 1;
  133. }
  134. }
  135. // output
  136. // fprintf(stderr, "%s\n", output_name.c_str());
  137. blob_names.insert(output_name);
  138. }
  139. // remove node_reference entry with reference equals to one
  140. int splitncnn_blob_count = 0;
  141. std::map<std::string, int>::iterator it = node_reference.begin();
  142. while (it != node_reference.end())
  143. {
  144. if (it->second == 1)
  145. {
  146. node_reference.erase(it++);
  147. }
  148. else
  149. {
  150. splitncnn_blob_count += it->second;
  151. // fprintf(stderr, "%s %d\n", it->first.c_str(), it->second);
  152. ++it;
  153. }
  154. }
  155. fprintf(pp, "%lu %lu\n", node_count + node_reference.size() - weights.size(), blob_names.size() + splitncnn_blob_count);
  156. int internal_split = 0;
  157. for (int i=0; i<node_count; i++)
  158. {
  159. const tensorflow::NodeDef& node = graph.node(i);
  160. // layer definition line, repeated
  161. // [type] [name] [bottom blob count] [top blob count] [bottom blobs] [top blobs] [layer specific params]
  162. // fprintf(pp, "%-16s %-16s %d %d", layer.type().c_str(), layer.name().c_str(), node.input_size(), layer.top_size());
  163. if (node.op() == "Add" || node.op() == "BiasAdd")
  164. {
  165. // check weights
  166. tensorflow::TensorProto tensor;
  167. if (find_tensor_proto(weights, node, tensor))
  168. {
  169. fprintf(pp, "%-16s", "Bias");
  170. }
  171. else
  172. {
  173. fprintf(pp, "%-16s", "Eltwise");
  174. }
  175. }
  176. else if (node.op() == "AvgPool")
  177. {
  178. fprintf(pp, "%-16s", "Pooling");
  179. }
  180. else if (node.op() == "Const")
  181. {
  182. continue;
  183. }
  184. else if (node.op() == "Conv2D")
  185. {
  186. fprintf(pp, "%-16s", "Convolution");
  187. }
  188. else if (node.op() == "Identity")
  189. {
  190. continue;
  191. }
  192. else if (node.op() == "MatMul")
  193. {
  194. fprintf(pp, "%-16s", "InnerProduct");
  195. }
  196. else if (node.op() == "Max")
  197. {
  198. fprintf(pp, "%-16s", "Eltwise");
  199. }
  200. else if (node.op() == "MaxPool")
  201. {
  202. fprintf(pp, "%-16s", "Pooling");
  203. }
  204. else if (node.op() == "Mul")
  205. {
  206. fprintf(pp, "%-16s", "Eltwise");
  207. }
  208. else if (node.op() == "NoOp")
  209. {
  210. continue;
  211. }
  212. else if (node.op() == "Placeholder")
  213. {
  214. fprintf(pp, "%-16s", "Input");
  215. }
  216. else if (node.op() == "Relu")
  217. {
  218. fprintf(pp, "%-16s", "ReLU");
  219. }
  220. else if (node.op() == "Reshape")
  221. {
  222. fprintf(pp, "%-16s", "Reshape");
  223. }
  224. else if (node.op() == "Softmax")
  225. {
  226. fprintf(pp, "%-16s", "Softmax");
  227. }
  228. else
  229. {
  230. fprintf(pp, "%-16s", node.op().c_str());
  231. }
  232. int input_size = node.input_size();
  233. for (int j=0; j<node.input_size(); j++)
  234. {
  235. const std::string& input_name = node.input(j);
  236. if (weights.find(input_name) != weights.end())
  237. {
  238. input_size--;
  239. }
  240. }
  241. fprintf(pp, " %-16s %d 1", node.name().c_str(), input_size);
  242. for (int j=0; j<node.input_size(); j++)
  243. {
  244. std::string input_name = node.input(j);
  245. if (weights.find(input_name) != weights.end())
  246. {
  247. continue;
  248. }
  249. if (node_reference.find(input_name) != node_reference.end())
  250. {
  251. int refidx = node_reference[input_name] - 1;
  252. node_reference[input_name] = refidx;
  253. char splitsuffix[256];
  254. sprintf(splitsuffix, "_splitncnn_%d", refidx);
  255. input_name = input_name + splitsuffix;
  256. }
  257. fprintf(pp, " %s", input_name.c_str());
  258. }
  259. fprintf(pp, " %s", node.name().c_str());
  260. if (node.op() == "Add" || node.op() == "BiasAdd")
  261. {
  262. // check weights
  263. tensorflow::TensorProto tensor;
  264. if (find_tensor_proto(weights, node, tensor))
  265. {
  266. int weight_data_size = 0;
  267. if (!tensor.tensor_content().empty())
  268. {
  269. if (tensor.dtype() == 1)// float
  270. {
  271. const float* data = reinterpret_cast<const float*>(tensor.tensor_content().c_str());
  272. weight_data_size = tensor.tensor_content().size() / sizeof(float);
  273. fwrite(data, sizeof(float), weight_data_size, bp);
  274. }
  275. else if (tensor.dtype() == 3)// int32
  276. {
  277. const int* data = reinterpret_cast<const int*>(tensor.tensor_content().c_str());
  278. weight_data_size = tensor.tensor_content().size() / sizeof(int);
  279. float tmp;
  280. for (int i=0; i<weight_data_size; i++)
  281. {
  282. tmp = data[i];
  283. fwrite(&tmp, sizeof(float), 1, bp);
  284. }
  285. }
  286. }
  287. fprintf(pp, " %d", weight_data_size);
  288. }
  289. else
  290. {
  291. int op_type = 1;
  292. int num_coeff = 0;
  293. fprintf(pp, " %d %d", op_type, num_coeff);
  294. }
  295. }
  296. else if (node.op() == "AvgPool")
  297. {
  298. int pooling_type = 1;
  299. int kernel_size_h = 1;
  300. int kernel_size_w = 1;
  301. int stride_h = 1;
  302. int stride_w = 1;
  303. int pad = 0;
  304. int global_pooling = 0;
  305. tensorflow::AttrValue value_ksize;
  306. if (find_attr_value(node, "ksize", value_ksize))
  307. {
  308. // batch, height, width, channels
  309. kernel_size_h = value_ksize.list().i(1);
  310. kernel_size_w = value_ksize.list().i(2);
  311. }
  312. tensorflow::AttrValue value_strides;
  313. if (find_attr_value(node, "strides", value_strides))
  314. {
  315. // batch, height, width, channels
  316. stride_h = value_strides.list().i(1);
  317. stride_w = value_strides.list().i(2);
  318. }
  319. tensorflow::AttrValue value_padding;
  320. if (find_attr_value(node, "padding", value_padding))
  321. {
  322. if (value_padding.s() == "VALID")
  323. {
  324. pad = 0;
  325. }
  326. else if (value_padding.s() == "SAME")
  327. {
  328. pad = -233;
  329. }
  330. }
  331. fprintf(pp, " %d %d %d %d %d", pooling_type, kernel_size_w, stride_w, pad, global_pooling);
  332. }
  333. else if (node.op() == "Const")
  334. {
  335. }
  336. else if (node.op() == "Conv2D")
  337. {
  338. // weights
  339. tensorflow::TensorProto tensor;
  340. find_tensor_proto(weights, node, tensor);
  341. const tensorflow::TensorShapeProto& shape = tensor.tensor_shape();
  342. int kernel_size_h = shape.dim(0).size();
  343. int kernel_size_w = shape.dim(1).size();
  344. int num_input = shape.dim(2).size();
  345. int num_output = shape.dim(3).size();
  346. int stride_h = 1;
  347. int stride_w = 1;
  348. int dilation = 1;
  349. int pad = 0;
  350. tensorflow::AttrValue value_strides;
  351. if (find_attr_value(node, "strides", value_strides))
  352. {
  353. // batch, height, width, channels
  354. stride_h = value_strides.list().i(1);
  355. stride_w = value_strides.list().i(2);
  356. }
  357. tensorflow::AttrValue value_padding;
  358. if (find_attr_value(node, "padding", value_padding))
  359. {
  360. if (value_padding.s() == "VALID")
  361. {
  362. pad = 0;
  363. }
  364. else if (value_padding.s() == "SAME")
  365. {
  366. pad = -233;
  367. }
  368. }
  369. int bias_term = 0;
  370. int weight_data_size = 0;
  371. // reorder h-w-i-o to o-i-h-w
  372. if (!tensor.tensor_content().empty())
  373. {
  374. int quantize_tag = 0;
  375. fwrite(&quantize_tag, sizeof(int), 1, bp);
  376. if (tensor.dtype() == 1)// float
  377. {
  378. const float* data = reinterpret_cast<const float*>(tensor.tensor_content().c_str());
  379. weight_data_size = tensor.tensor_content().size() / sizeof(float);
  380. float tmp;
  381. for (int p=0; p<num_output; p++)
  382. {
  383. for (int q=0; q<num_input; q++)
  384. {
  385. for (int i=0; i<kernel_size_h; i++)
  386. {
  387. for (int j=0; j<kernel_size_w; j++)
  388. {
  389. tmp = data[i*kernel_size_w*num_input*num_output + j*num_input*num_output + q*num_output + p];
  390. fwrite(&tmp, sizeof(float), 1, bp);
  391. }
  392. }
  393. }
  394. }
  395. }
  396. else if (tensor.dtype() == 3)// int32
  397. {
  398. const int* data = reinterpret_cast<const int*>(tensor.tensor_content().c_str());
  399. weight_data_size = tensor.tensor_content().size() / sizeof(int);
  400. float tmp;
  401. for (int p=0; p<num_output; p++)
  402. {
  403. for (int q=0; q<num_input; q++)
  404. {
  405. for (int i=0; i<kernel_size_h; i++)
  406. {
  407. for (int j=0; j<kernel_size_w; j++)
  408. {
  409. tmp = data[i*kernel_size_w*num_input*num_output + j*num_input*num_output + q*num_output + p];
  410. fwrite(&tmp, sizeof(float), 1, bp);
  411. }
  412. }
  413. }
  414. }
  415. }
  416. }
  417. fprintf(pp, " %d %d %d %d %d %d %d", num_output, kernel_size_w, dilation, stride_w, pad, bias_term, weight_data_size);
  418. }
  419. else if (node.op() == "Identity")
  420. {
  421. }
  422. else if (node.op() == "MatMul")
  423. {
  424. // weights
  425. tensorflow::TensorProto tensor;
  426. find_tensor_proto(weights, node, tensor);
  427. const tensorflow::TensorShapeProto& shape = tensor.tensor_shape();
  428. int num_input = shape.dim(0).size();
  429. int num_output = shape.dim(1).size();
  430. int bias_term = 0;
  431. int weight_data_size = 0;
  432. // reorder i-o to o-i
  433. if (!tensor.tensor_content().empty())
  434. {
  435. int quantize_tag = 0;
  436. fwrite(&quantize_tag, sizeof(int), 1, bp);
  437. if (tensor.dtype() == 1)// float
  438. {
  439. const float* data = reinterpret_cast<const float*>(tensor.tensor_content().c_str());
  440. weight_data_size = tensor.tensor_content().size() / sizeof(float);
  441. float tmp;
  442. for (int p=0; p<num_output; p++)
  443. {
  444. for (int q=0; q<num_input; q++)
  445. {
  446. tmp = data[q*num_output + p];
  447. fwrite(&tmp, sizeof(float), 1, bp);
  448. }
  449. }
  450. }
  451. else if (tensor.dtype() == 3)// int32
  452. {
  453. const int* data = reinterpret_cast<const int*>(tensor.tensor_content().c_str());
  454. weight_data_size = tensor.tensor_content().size() / sizeof(int);
  455. float tmp;
  456. for (int p=0; p<num_output; p++)
  457. {
  458. for (int q=0; q<num_input; q++)
  459. {
  460. tmp = data[q*num_output + p];
  461. fwrite(&tmp, sizeof(float), 1, bp);
  462. }
  463. }
  464. }
  465. }
  466. fprintf(pp, " %d %d %d", num_output, bias_term, weight_data_size);
  467. }
  468. else if (node.op() == "Max")
  469. {
  470. int op_type = 2;
  471. int num_coeff = 0;
  472. fprintf(pp, " %d %d", op_type, num_coeff);
  473. }
  474. else if (node.op() == "MaxPool")
  475. {
  476. int pooling_type = 0;
  477. int kernel_size_h = 1;
  478. int kernel_size_w = 1;
  479. int stride_h = 1;
  480. int stride_w = 1;
  481. int pad = 0;
  482. int global_pooling = 0;
  483. tensorflow::AttrValue value_ksize;
  484. if (find_attr_value(node, "ksize", value_ksize))
  485. {
  486. // batch, height, width, channels
  487. kernel_size_h = value_ksize.list().i(1);
  488. kernel_size_w = value_ksize.list().i(2);
  489. }
  490. tensorflow::AttrValue value_strides;
  491. if (find_attr_value(node, "strides", value_strides))
  492. {
  493. // batch, height, width, channels
  494. stride_h = value_strides.list().i(1);
  495. stride_w = value_strides.list().i(2);
  496. }
  497. tensorflow::AttrValue value_padding;
  498. if (find_attr_value(node, "padding", value_padding))
  499. {
  500. if (value_padding.s() == "VALID")
  501. {
  502. pad = 0;
  503. }
  504. else if (value_padding.s() == "SAME")
  505. {
  506. pad = -233;
  507. }
  508. }
  509. fprintf(pp, " %d %d %d %d %d", pooling_type, kernel_size_w, stride_w, pad, global_pooling);
  510. }
  511. else if (node.op() == "Mul")
  512. {
  513. int op_type = 0;
  514. int num_coeff = 0;
  515. fprintf(pp, " %d %d", op_type, num_coeff);
  516. }
  517. else if (node.op() == "NoOp")
  518. {
  519. }
  520. else if (node.op() == "Placeholder")
  521. {
  522. // TODO pass through
  523. fprintf(pp, " 0 0 0");
  524. }
  525. else if (node.op() == "Relu")
  526. {
  527. float slope = 0.f;
  528. fprintf(pp, " %f", slope);
  529. }
  530. else if (node.op() == "Reshape")
  531. {
  532. // TODO pass through
  533. fprintf(pp, " 0 0 0");
  534. }
  535. else if (node.op() == "Softmax")
  536. {
  537. }
  538. else
  539. {
  540. const google::protobuf::Map<std::string, tensorflow::AttrValue>& attr = node.attr();
  541. google::protobuf::Map<std::string, tensorflow::AttrValue>::const_iterator it = attr.begin();
  542. for (; it != attr.end(); it++)
  543. {
  544. std::cerr << it->first << std::endl;
  545. std::cerr << it->second.type() << std::endl;
  546. }
  547. }
  548. fprintf(pp, "\n");
  549. std::string output_name = node.name();
  550. if (node_reference.find(output_name) != node_reference.end())
  551. {
  552. int refcount = node_reference[output_name];
  553. if (refcount > 1)
  554. {
  555. char splitname[256];
  556. sprintf(splitname, "splitncnn_%d", internal_split);
  557. fprintf(pp, "%-16s %-16s %d %d", "Split", splitname, 1, refcount);
  558. fprintf(pp, " %s", output_name.c_str());
  559. for (int j=0; j<refcount; j++)
  560. {
  561. fprintf(pp, " %s_splitncnn_%d", output_name.c_str(), j);
  562. }
  563. fprintf(pp, "\n");
  564. internal_split++;
  565. }
  566. }
  567. }
  568. fclose(pp);
  569. fclose(bp);
  570. return 0;
  571. }