You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

tensorflow2ncnn.cpp 22 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709
  1. // Tencent is pleased to support the open source community by making ncnn available.
  2. //
  3. // Copyright (C) 2017 THL A29 Limited, a Tencent company. All rights reserved.
  4. //
  5. // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
  6. // in compliance with the License. You may obtain a copy of the License at
  7. //
  8. // https://opensource.org/licenses/BSD-3-Clause
  9. //
  10. // Unless required by applicable law or agreed to in writing, software distributed
  11. // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
  12. // CONDITIONS OF ANY KIND, either express or implied. See the License for the
  13. // specific language governing permissions and limitations under the License.
  14. #include <stdio.h>
  15. #include <limits.h>
  16. #include <iostream>
  17. #include <fstream>
  18. #include <set>
  19. #include <limits>
  20. #include <algorithm>
  21. #include <google/protobuf/io/coded_stream.h>
  22. #include <google/protobuf/io/zero_copy_stream_impl.h>
  23. #include <google/protobuf/text_format.h>
  24. #include <google/protobuf/message.h>
  25. #include "graph.pb.h"
  26. static bool read_proto_from_binary(const char* filepath, google::protobuf::Message* message)
  27. {
  28. std::ifstream fs(filepath, std::ifstream::in | std::ifstream::binary);
  29. if (!fs.is_open())
  30. {
  31. fprintf(stderr, "open failed %s\n", filepath);
  32. return false;
  33. }
  34. google::protobuf::io::IstreamInputStream input(&fs);
  35. google::protobuf::io::CodedInputStream codedstr(&input);
  36. codedstr.SetTotalBytesLimit(INT_MAX, INT_MAX / 2);
  37. bool success = message->ParseFromCodedStream(&codedstr);
  38. fs.close();
  39. return success;
  40. }
  41. static bool find_tensor_proto(const std::map<std::string, tensorflow::TensorProto>& weights,
  42. const tensorflow::NodeDef& node, tensorflow::TensorProto& tensor)
  43. {
  44. for (int j=0; j<node.input_size(); j++)
  45. {
  46. const std::string& input_name = node.input(j);
  47. const std::map<std::string, tensorflow::TensorProto>::const_iterator it = weights.find(input_name);
  48. if (it != weights.end())
  49. {
  50. tensor = it->second;
  51. return true;
  52. }
  53. }
  54. return false;
  55. }
  56. static bool find_attr_value(const tensorflow::NodeDef& node, const char* key, tensorflow::AttrValue& value)
  57. {
  58. const google::protobuf::Map<std::string, tensorflow::AttrValue>& attr = node.attr();
  59. const google::protobuf::Map<std::string, tensorflow::AttrValue>::const_iterator it = attr.find(key);
  60. if (it != attr.end())
  61. {
  62. value = it->second;
  63. return true;
  64. }
  65. return false;
  66. }
  67. int main(int argc, char** argv)
  68. {
  69. const char* tensorflowpb = argv[1];
  70. const char* ncnn_prototxt = argc >= 4 ? argv[2] : "ncnn.proto";
  71. const char* ncnn_modelbin = argc >= 4 ? argv[3] : "ncnn.bin";
  72. tensorflow::GraphDef graph;
  73. // load
  74. bool s1 = read_proto_from_binary(tensorflowpb, &graph);
  75. if (!s1)
  76. {
  77. fprintf(stderr, "read_proto_from_binary failed\n");
  78. return -1;
  79. }
  80. FILE* pp = fopen(ncnn_prototxt, "wb");
  81. FILE* bp = fopen(ncnn_modelbin, "wb");
  82. int node_count = graph.node_size();
  83. // fprintf(stderr, "node_count = %d\n\n", node_count);
  84. // node reference
  85. std::map<std::string, int> node_reference;
  86. // mapping for Const and Const-Identity
  87. std::map<std::string, tensorflow::TensorProto> weights;
  88. // Dropout like Identity
  89. std::set<std::string> dropouts;
  90. // global definition line
  91. // [layer count] [blob count]
  92. std::set<std::string> blob_names;
  93. for (int i=0; i<node_count; i++)
  94. {
  95. const tensorflow::NodeDef& node = graph.node(i);
  96. const std::string& output_name = node.name();
  97. if (node.op() == "Const")
  98. {
  99. tensorflow::AttrValue value;
  100. if (find_attr_value(node, "value", value))
  101. {
  102. const tensorflow::TensorProto& tensor = value.tensor();
  103. weights[output_name] = tensor;
  104. }
  105. continue;
  106. }
  107. else if (node.op() == "Identity")
  108. {
  109. const std::string& input_name = node.input(0);
  110. if (weights.find(input_name) != weights.end())
  111. {
  112. weights[output_name] = weights[input_name];
  113. continue;
  114. }
  115. else
  116. {
  117. dropouts.insert(output_name);
  118. }
  119. }
  120. else if (node.op() == "NoOp")
  121. {
  122. weights[output_name] = tensorflow::TensorProto();
  123. continue;
  124. }
  125. // input
  126. for (int j=0; j<node.input_size(); j++)
  127. {
  128. const std::string& input_name = node.input(j);
  129. // fprintf(stderr, "%s\n", input_name.c_str());
  130. if (weights.find(input_name) != weights.end())
  131. {
  132. continue;
  133. }
  134. blob_names.insert(input_name);
  135. if (node_reference.find(input_name) == node_reference.end())
  136. {
  137. node_reference[input_name] = 1;
  138. }
  139. else
  140. {
  141. node_reference[input_name] = node_reference[input_name] + 1;
  142. }
  143. }
  144. // output
  145. // fprintf(stderr, "%s\n", output_name.c_str());
  146. blob_names.insert(output_name);
  147. }
  148. // remove node_reference entry with reference equals to one
  149. int splitncnn_blob_count = 0;
  150. std::map<std::string, int>::iterator it = node_reference.begin();
  151. while (it != node_reference.end())
  152. {
  153. if (it->second == 1)
  154. {
  155. node_reference.erase(it++);
  156. }
  157. else
  158. {
  159. splitncnn_blob_count += it->second;
  160. // fprintf(stderr, "%s %d\n", it->first.c_str(), it->second);
  161. ++it;
  162. }
  163. }
  164. fprintf(pp, "%lu %lu\n", node_count + node_reference.size() - weights.size(), blob_names.size() + splitncnn_blob_count);
  165. int internal_split = 0;
  166. for (int i=0; i<node_count; i++)
  167. {
  168. const tensorflow::NodeDef& node = graph.node(i);
  169. // layer definition line, repeated
  170. // [type] [name] [bottom blob count] [top blob count] [bottom blobs] [top blobs] [layer specific params]
  171. // fprintf(pp, "%-16s %-16s %d %d", layer.type().c_str(), layer.name().c_str(), node.input_size(), layer.top_size());
  172. if (node.op() == "Add" || node.op() == "BiasAdd")
  173. {
  174. // check weights
  175. tensorflow::TensorProto tensor;
  176. if (find_tensor_proto(weights, node, tensor))
  177. {
  178. fprintf(pp, "%-16s", "Bias");
  179. }
  180. else
  181. {
  182. fprintf(pp, "%-16s", "Eltwise");
  183. }
  184. }
  185. else if (node.op() == "AvgPool")
  186. {
  187. fprintf(pp, "%-16s", "Pooling");
  188. }
  189. else if (node.op() == "Const")
  190. {
  191. continue;
  192. }
  193. else if (node.op() == "Conv2D")
  194. {
  195. fprintf(pp, "%-16s", "Convolution");
  196. }
  197. else if (node.op() == "Identity")
  198. {
  199. if (dropouts.find(node.name()) != dropouts.end())
  200. {
  201. fprintf(pp, "%-16s", "Dropout");
  202. }
  203. else
  204. {
  205. continue;
  206. }
  207. }
  208. else if (node.op() == "MatMul")
  209. {
  210. fprintf(pp, "%-16s", "InnerProduct");
  211. }
  212. else if (node.op() == "Max")
  213. {
  214. fprintf(pp, "%-16s", "Eltwise");
  215. }
  216. else if (node.op() == "MaxPool")
  217. {
  218. fprintf(pp, "%-16s", "Pooling");
  219. }
  220. else if (node.op() == "Mul")
  221. {
  222. fprintf(pp, "%-16s", "Eltwise");
  223. }
  224. else if (node.op() == "NoOp")
  225. {
  226. continue;
  227. }
  228. else if (node.op() == "Placeholder")
  229. {
  230. fprintf(pp, "%-16s", "Input");
  231. }
  232. else if (node.op() == "Relu")
  233. {
  234. fprintf(pp, "%-16s", "ReLU");
  235. }
  236. else if (node.op() == "Reshape")
  237. {
  238. fprintf(pp, "%-16s", "Reshape");
  239. }
  240. else if (node.op() == "Softmax")
  241. {
  242. fprintf(pp, "%-16s", "Softmax");
  243. }
  244. else
  245. {
  246. fprintf(pp, "%-16s", node.op().c_str());
  247. }
  248. int input_size = node.input_size();
  249. for (int j=0; j<node.input_size(); j++)
  250. {
  251. const std::string& input_name = node.input(j);
  252. if (weights.find(input_name) != weights.end())
  253. {
  254. input_size--;
  255. }
  256. }
  257. fprintf(pp, " %-16s %d 1", node.name().c_str(), input_size);
  258. for (int j=0; j<node.input_size(); j++)
  259. {
  260. std::string input_name = node.input(j);
  261. if (weights.find(input_name) != weights.end())
  262. {
  263. continue;
  264. }
  265. if (node_reference.find(input_name) != node_reference.end())
  266. {
  267. int refidx = node_reference[input_name] - 1;
  268. node_reference[input_name] = refidx;
  269. char splitsuffix[256];
  270. sprintf(splitsuffix, "_splitncnn_%d", refidx);
  271. input_name = input_name + splitsuffix;
  272. }
  273. fprintf(pp, " %s", input_name.c_str());
  274. }
  275. fprintf(pp, " %s", node.name().c_str());
  276. if (node.op() == "Add" || node.op() == "BiasAdd")
  277. {
  278. // check weights
  279. tensorflow::TensorProto tensor;
  280. if (find_tensor_proto(weights, node, tensor))
  281. {
  282. int weight_data_size = 0;
  283. if (!tensor.tensor_content().empty())
  284. {
  285. if (tensor.dtype() == 1)// float
  286. {
  287. const float* data = reinterpret_cast<const float*>(tensor.tensor_content().c_str());
  288. weight_data_size = tensor.tensor_content().size() / sizeof(float);
  289. fwrite(data, sizeof(float), weight_data_size, bp);
  290. }
  291. else if (tensor.dtype() == 3)// int32
  292. {
  293. const int* data = reinterpret_cast<const int*>(tensor.tensor_content().c_str());
  294. weight_data_size = tensor.tensor_content().size() / sizeof(int);
  295. float tmp;
  296. for (int i=0; i<weight_data_size; i++)
  297. {
  298. tmp = data[i];
  299. fwrite(&tmp, sizeof(float), 1, bp);
  300. }
  301. }
  302. }
  303. fprintf(pp, " %d", weight_data_size);
  304. }
  305. else
  306. {
  307. int op_type = 1;
  308. int num_coeff = 0;
  309. fprintf(pp, " %d %d", op_type, num_coeff);
  310. }
  311. }
  312. else if (node.op() == "AvgPool")
  313. {
  314. int pooling_type = 1;
  315. int kernel_size_h = 1;
  316. int kernel_size_w = 1;
  317. int stride_h = 1;
  318. int stride_w = 1;
  319. int pad = 0;
  320. int global_pooling = 0;
  321. tensorflow::AttrValue value_ksize;
  322. if (find_attr_value(node, "ksize", value_ksize))
  323. {
  324. // batch, height, width, channels
  325. kernel_size_h = value_ksize.list().i(1);
  326. kernel_size_w = value_ksize.list().i(2);
  327. }
  328. tensorflow::AttrValue value_strides;
  329. if (find_attr_value(node, "strides", value_strides))
  330. {
  331. // batch, height, width, channels
  332. stride_h = value_strides.list().i(1);
  333. stride_w = value_strides.list().i(2);
  334. }
  335. tensorflow::AttrValue value_padding;
  336. if (find_attr_value(node, "padding", value_padding))
  337. {
  338. if (value_padding.s() == "VALID")
  339. {
  340. pad = 0;
  341. }
  342. else if (value_padding.s() == "SAME")
  343. {
  344. pad = -233;
  345. }
  346. }
  347. fprintf(pp, " %d %d %d %d %d", pooling_type, kernel_size_w, stride_w, pad, global_pooling);
  348. }
  349. else if (node.op() == "Const")
  350. {
  351. }
  352. else if (node.op() == "Conv2D")
  353. {
  354. // weights
  355. tensorflow::TensorProto tensor;
  356. find_tensor_proto(weights, node, tensor);
  357. const tensorflow::TensorShapeProto& shape = tensor.tensor_shape();
  358. int kernel_size_h = shape.dim(0).size();
  359. int kernel_size_w = shape.dim(1).size();
  360. int num_input = shape.dim(2).size();
  361. int num_output = shape.dim(3).size();
  362. int stride_h = 1;
  363. int stride_w = 1;
  364. int dilation = 1;
  365. int pad = 0;
  366. tensorflow::AttrValue value_strides;
  367. if (find_attr_value(node, "strides", value_strides))
  368. {
  369. // batch, height, width, channels
  370. stride_h = value_strides.list().i(1);
  371. stride_w = value_strides.list().i(2);
  372. }
  373. tensorflow::AttrValue value_padding;
  374. if (find_attr_value(node, "padding", value_padding))
  375. {
  376. if (value_padding.s() == "VALID")
  377. {
  378. pad = 0;
  379. }
  380. else if (value_padding.s() == "SAME")
  381. {
  382. pad = -233;
  383. }
  384. }
  385. int bias_term = 0;
  386. int weight_data_size = 0;
  387. // reorder h-w-i-o to o-i-h-w
  388. if (!tensor.tensor_content().empty())
  389. {
  390. int quantize_tag = 0;
  391. fwrite(&quantize_tag, sizeof(int), 1, bp);
  392. if (tensor.dtype() == 1)// float
  393. {
  394. const float* data = reinterpret_cast<const float*>(tensor.tensor_content().c_str());
  395. weight_data_size = tensor.tensor_content().size() / sizeof(float);
  396. float tmp;
  397. for (int p=0; p<num_output; p++)
  398. {
  399. for (int q=0; q<num_input; q++)
  400. {
  401. for (int i=0; i<kernel_size_h; i++)
  402. {
  403. for (int j=0; j<kernel_size_w; j++)
  404. {
  405. tmp = data[i*kernel_size_w*num_input*num_output + j*num_input*num_output + q*num_output + p];
  406. fwrite(&tmp, sizeof(float), 1, bp);
  407. }
  408. }
  409. }
  410. }
  411. }
  412. else if (tensor.dtype() == 3)// int32
  413. {
  414. const int* data = reinterpret_cast<const int*>(tensor.tensor_content().c_str());
  415. weight_data_size = tensor.tensor_content().size() / sizeof(int);
  416. float tmp;
  417. for (int p=0; p<num_output; p++)
  418. {
  419. for (int q=0; q<num_input; q++)
  420. {
  421. for (int i=0; i<kernel_size_h; i++)
  422. {
  423. for (int j=0; j<kernel_size_w; j++)
  424. {
  425. tmp = data[i*kernel_size_w*num_input*num_output + j*num_input*num_output + q*num_output + p];
  426. fwrite(&tmp, sizeof(float), 1, bp);
  427. }
  428. }
  429. }
  430. }
  431. }
  432. }
  433. fprintf(pp, " %d %d %d %d %d %d %d", num_output, kernel_size_w, dilation, stride_w, pad, bias_term, weight_data_size);
  434. }
  435. else if (node.op() == "Identity")
  436. {
  437. }
  438. else if (node.op() == "MatMul")
  439. {
  440. // weights
  441. tensorflow::TensorProto tensor;
  442. find_tensor_proto(weights, node, tensor);
  443. const tensorflow::TensorShapeProto& shape = tensor.tensor_shape();
  444. int num_input = shape.dim(0).size();
  445. int num_output = shape.dim(1).size();
  446. int bias_term = 0;
  447. int weight_data_size = 0;
  448. // reorder i-o to o-i
  449. if (!tensor.tensor_content().empty())
  450. {
  451. int quantize_tag = 0;
  452. fwrite(&quantize_tag, sizeof(int), 1, bp);
  453. if (tensor.dtype() == 1)// float
  454. {
  455. const float* data = reinterpret_cast<const float*>(tensor.tensor_content().c_str());
  456. weight_data_size = tensor.tensor_content().size() / sizeof(float);
  457. float tmp;
  458. for (int p=0; p<num_output; p++)
  459. {
  460. for (int q=0; q<num_input; q++)
  461. {
  462. tmp = data[q*num_output + p];
  463. fwrite(&tmp, sizeof(float), 1, bp);
  464. }
  465. }
  466. }
  467. else if (tensor.dtype() == 3)// int32
  468. {
  469. const int* data = reinterpret_cast<const int*>(tensor.tensor_content().c_str());
  470. weight_data_size = tensor.tensor_content().size() / sizeof(int);
  471. float tmp;
  472. for (int p=0; p<num_output; p++)
  473. {
  474. for (int q=0; q<num_input; q++)
  475. {
  476. tmp = data[q*num_output + p];
  477. fwrite(&tmp, sizeof(float), 1, bp);
  478. }
  479. }
  480. }
  481. }
  482. fprintf(pp, " %d %d %d", num_output, bias_term, weight_data_size);
  483. }
  484. else if (node.op() == "Max")
  485. {
  486. int op_type = 2;
  487. int num_coeff = 0;
  488. fprintf(pp, " %d %d", op_type, num_coeff);
  489. }
  490. else if (node.op() == "MaxPool")
  491. {
  492. int pooling_type = 0;
  493. int kernel_size_h = 1;
  494. int kernel_size_w = 1;
  495. int stride_h = 1;
  496. int stride_w = 1;
  497. int pad = 0;
  498. int global_pooling = 0;
  499. tensorflow::AttrValue value_ksize;
  500. if (find_attr_value(node, "ksize", value_ksize))
  501. {
  502. // batch, height, width, channels
  503. kernel_size_h = value_ksize.list().i(1);
  504. kernel_size_w = value_ksize.list().i(2);
  505. }
  506. tensorflow::AttrValue value_strides;
  507. if (find_attr_value(node, "strides", value_strides))
  508. {
  509. // batch, height, width, channels
  510. stride_h = value_strides.list().i(1);
  511. stride_w = value_strides.list().i(2);
  512. }
  513. tensorflow::AttrValue value_padding;
  514. if (find_attr_value(node, "padding", value_padding))
  515. {
  516. if (value_padding.s() == "VALID")
  517. {
  518. pad = 0;
  519. }
  520. else if (value_padding.s() == "SAME")
  521. {
  522. pad = -233;
  523. }
  524. }
  525. fprintf(pp, " %d %d %d %d %d", pooling_type, kernel_size_w, stride_w, pad, global_pooling);
  526. }
  527. else if (node.op() == "Mul")
  528. {
  529. int op_type = 0;
  530. int num_coeff = 0;
  531. fprintf(pp, " %d %d", op_type, num_coeff);
  532. }
  533. else if (node.op() == "NoOp")
  534. {
  535. }
  536. else if (node.op() == "Placeholder")
  537. {
  538. // TODO pass through
  539. fprintf(pp, " 0 0 0");
  540. }
  541. else if (node.op() == "Relu")
  542. {
  543. float slope = 0.f;
  544. fprintf(pp, " %f", slope);
  545. }
  546. else if (node.op() == "Reshape")
  547. {
  548. tensorflow::TensorProto tensor;
  549. if (find_tensor_proto(weights, node, tensor))
  550. {
  551. if (!tensor.tensor_content().empty() && tensor.dtype() == 3)// int32
  552. {
  553. const int* data = reinterpret_cast<const int*>(tensor.tensor_content().c_str());
  554. int size = tensor.tensor_content().size() / sizeof(int);
  555. // n h w c
  556. // n h w
  557. // n w
  558. if (size == 4)
  559. {
  560. fprintf(pp, " %d %d %d 0", data[2], data[1], data[3]);
  561. }
  562. if (size == 3)
  563. {
  564. fprintf(pp, " %d %d -233 1", data[2], data[1]);
  565. }
  566. if (size == 2)
  567. {
  568. fprintf(pp, " %d -233 -233 1", data[1]);
  569. }
  570. }
  571. }
  572. else
  573. {
  574. // pass through
  575. fprintf(pp, " 0 0 0");
  576. }
  577. }
  578. else if (node.op() == "Softmax")
  579. {
  580. }
  581. else
  582. {
  583. const google::protobuf::Map<std::string, tensorflow::AttrValue>& attr = node.attr();
  584. google::protobuf::Map<std::string, tensorflow::AttrValue>::const_iterator it = attr.begin();
  585. for (; it != attr.end(); it++)
  586. {
  587. std::cerr << it->first << std::endl;
  588. std::cerr << it->second.type() << std::endl;
  589. }
  590. }
  591. fprintf(pp, "\n");
  592. std::string output_name = node.name();
  593. if (node_reference.find(output_name) != node_reference.end())
  594. {
  595. int refcount = node_reference[output_name];
  596. if (refcount > 1)
  597. {
  598. char splitname[256];
  599. sprintf(splitname, "splitncnn_%d", internal_split);
  600. fprintf(pp, "%-16s %-16s %d %d", "Split", splitname, 1, refcount);
  601. fprintf(pp, " %s", output_name.c_str());
  602. for (int j=0; j<refcount; j++)
  603. {
  604. fprintf(pp, " %s_splitncnn_%d", output_name.c_str(), j);
  605. }
  606. fprintf(pp, "\n");
  607. internal_split++;
  608. }
  609. }
  610. }
  611. fclose(pp);
  612. fclose(bp);
  613. return 0;
  614. }