You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

dump_proto.cc 20 kB

6 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include <fstream>
  17. #include <map>
  18. #include <memory>
  19. #include <unordered_map>
  20. #include <unordered_set>
  21. #include <utility>
  22. #include <algorithm>
  23. #include "debug/anf_ir_utils.h"
  24. #include "proto/anf_ir.pb.h"
  25. #include "utils/graph_utils.h"
  26. #include "utils/symbolic.h"
  27. namespace mindspore {
  28. class ProtoExporter {
  29. public:
  30. ProtoExporter() {}
  31. ~ProtoExporter() {}
  32. std::string GetFuncGraphProtoString(const FuncGraphPtr &func_graph);
  33. private:
  34. void InitModelInfo();
  35. void GetOpNodeTypeAndAttrs(const FuncGraphPtr &func_graph, const AnfNodePtr &node, irpb::NodeProto *node_proto);
  36. std::string GetOpNodeInputId(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
  37. const std::map<AnfNodePtr, size_t> &apply_map,
  38. std::map<AnfNodePtr, size_t> *const_map_ptr);
  39. void SetValueToProto(const ValuePtr &attr_value, irpb::ValueProto *value_proto);
  40. void SetScalarToProto(const ScalarPtr &val, irpb::ValueProto *value_proto);
  41. void SetSequenceToProto(const ValueSequeuePtr &val, irpb::ValueProto *value_proto);
  42. void SetDictionaryToProto(const ValueDictionaryPtr &val, irpb::ValueProto *value_proto);
  43. void SetNodeOutputType(const AnfNodePtr &node, irpb::TypeProto *type_proto);
  44. void SetNodeOutputType(const TypePtr &node, const BaseShapePtr &shape, irpb::TypeProto *type_proto);
  45. void ExportFuncGraph(const FuncGraphPtr &func_graph, irpb::GraphProto *graph_proto);
  46. void ExportParameters(const FuncGraphPtr &func_graph, irpb::GraphProto *graph_proto);
  47. void ExportCNodes(const FuncGraphPtr &func_graph, irpb::GraphProto *graph_proto,
  48. std::map<AnfNodePtr, size_t> *const_map_ptr);
  49. void ExportCNode(const FuncGraphPtr &func_graph, const CNodePtr &node, std::map<AnfNodePtr, size_t> *apply_map_ptr,
  50. std::map<AnfNodePtr, size_t> *const_map_ptr, irpb::GraphProto *graph_proto);
  51. void ExportFuncGraphOutput(const FuncGraphPtr &func_graph, const CNodePtr &ret_node,
  52. const std::map<AnfNodePtr, size_t> &apply_map, std::map<AnfNodePtr, size_t> *const_map_ptr,
  53. irpb::GraphProto *graph_proto);
  54. void ExportValueNodes(const std::map<AnfNodePtr, size_t> &const_map, irpb::GraphProto *graph_proto);
  55. static std::string GetConstNodeId(size_t idx) { return std::string("cst") + std::to_string(idx); }
  56. irpb::ModelProto model_;
  57. };
  58. static irpb::DataType GetNumberDataType(const TypePtr &type) {
  59. switch (type->type_id()) {
  60. case kNumberTypeBool:
  61. return irpb::DT_BOOL;
  62. case kNumberTypeInt8:
  63. return irpb::DT_INT8;
  64. case kNumberTypeInt16:
  65. return irpb::DT_INT16;
  66. case kNumberTypeInt32:
  67. return irpb::DT_INT32;
  68. case kNumberTypeInt64:
  69. return irpb::DT_INT64;
  70. case kNumberTypeUInt8:
  71. return irpb::DT_UINT8;
  72. case kNumberTypeUInt16:
  73. return irpb::DT_UINT16;
  74. case kNumberTypeUInt32:
  75. return irpb::DT_UINT32;
  76. case kNumberTypeUInt64:
  77. return irpb::DT_UINT64;
  78. case kNumberTypeFloat16:
  79. return irpb::DT_FLOAT16;
  80. case kNumberTypeFloat32:
  81. return irpb::DT_FLOAT32;
  82. case kNumberTypeFloat64:
  83. return irpb::DT_FLOAT64;
  84. case kNumberTypeInt:
  85. return irpb::DT_BASE_INT;
  86. case kNumberTypeUInt:
  87. return irpb::DT_BASE_UINT;
  88. case kNumberTypeFloat:
  89. return irpb::DT_BASE_FLOAT;
  90. default:
  91. MS_LOG(EXCEPTION) << "Unexpected type " << type->type_name();
  92. }
  93. }
  94. void ProtoExporter::SetNodeOutputType(const TypePtr &type, const BaseShapePtr &shape, irpb::TypeProto *type_proto) {
  95. if (type_proto == nullptr) {
  96. return;
  97. }
  98. if (type == nullptr) {
  99. type_proto->set_data_type(irpb::DT_UNDEFINED);
  100. } else if (type->isa<Number>()) {
  101. type_proto->set_data_type(GetNumberDataType(type));
  102. } else if (type->isa<TensorType>()) {
  103. TypePtr elem_type = dyn_cast<TensorType>(type)->element();
  104. type_proto->mutable_tensor_type()->set_elem_type(GetNumberDataType(elem_type));
  105. type_proto->set_data_type(irpb::DT_TENSOR);
  106. if (shape != nullptr && shape->isa<abstract::Shape>()) {
  107. abstract::ShapePtr shape_info = dyn_cast<abstract::Shape>(shape);
  108. for (const auto &elem : shape_info->shape()) {
  109. type_proto->mutable_tensor_type()->mutable_shape()->add_dim()->set_size(elem);
  110. }
  111. }
  112. } else if (type->isa<IndexedSlicesType>()) {
  113. // Do Nothing
  114. } else if (type->isa<UndeterminedType>()) {
  115. // Do Nothing
  116. } else if (type->isa<SparseTensorType>()) {
  117. // Do Nothing
  118. } else if (type->isa<Tuple>()) {
  119. TuplePtr tuple_type = dyn_cast<Tuple>(type);
  120. type_proto->set_data_type(irpb::DT_TUPLE);
  121. for (const auto &elem_type : tuple_type->elements()) {
  122. SetNodeOutputType(elem_type, nullptr, type_proto->mutable_sequence_type()->add_elem_types());
  123. }
  124. } else if (type->isa<TypeType>()) {
  125. type_proto->set_data_type(irpb::DT_TYPE);
  126. } else if (type->isa<List>()) {
  127. ListPtr list_type = dyn_cast<List>(type);
  128. type_proto->set_data_type(irpb::DT_LIST);
  129. for (const auto &elem_type : list_type->elements()) {
  130. SetNodeOutputType(elem_type, nullptr, type_proto->mutable_sequence_type()->add_elem_types());
  131. }
  132. } else if (type->isa<TypeAnything>()) {
  133. type_proto->set_data_type(irpb::DT_ANYTHING);
  134. } else if (type->isa<RefKeyType>()) {
  135. type_proto->set_data_type(irpb::DT_REFKEY);
  136. } else if (type->isa<RefType>()) {
  137. type_proto->set_data_type(irpb::DT_REF);
  138. } else if (type->isa<Function>()) {
  139. type_proto->set_data_type(irpb::DT_GRAPH);
  140. } else if (type->isa<TypeNone>()) {
  141. type_proto->set_data_type(irpb::DT_NONE);
  142. } else if (type->isa<String>()) {
  143. type_proto->set_data_type(irpb::DT_STRING);
  144. } else if (type->isa<SymbolicKeyType>()) {
  145. // Do Nothing.
  146. } else {
  147. MS_LOG(EXCEPTION) << "Unknown type: " << type->type_name();
  148. }
  149. }
  150. void ProtoExporter::SetNodeOutputType(const AnfNodePtr &node, irpb::TypeProto *type_proto) {
  151. if (node == nullptr || type_proto == nullptr) {
  152. return;
  153. }
  154. SetNodeOutputType(node->Type(), node->Shape(), type_proto);
  155. }
  156. void ProtoExporter::SetValueToProto(const ValuePtr &val, irpb::ValueProto *value_proto) {
  157. if (val == nullptr || value_proto == nullptr) {
  158. return;
  159. }
  160. if (val->isa<StringImm>()) {
  161. const StringImmPtr &value = dyn_cast<StringImm>(val);
  162. value_proto->set_dtype(irpb::DT_STRING);
  163. value_proto->set_str_val(value->value());
  164. } else if (val->isa<Scalar>()) {
  165. SetScalarToProto(dyn_cast<Scalar>(val), value_proto);
  166. } else if (val->isa<Bool>()) {
  167. value_proto->set_dtype(irpb::DT_TYPE);
  168. value_proto->mutable_type_val()->set_data_type(irpb::DT_BOOL);
  169. } else if (val->isa<Int>()) {
  170. value_proto->set_dtype(irpb::DT_TYPE);
  171. value_proto->mutable_type_val()->set_data_type(irpb::DT_BASE_INT);
  172. } else if (val->isa<Float>()) {
  173. value_proto->set_dtype(irpb::DT_TYPE);
  174. value_proto->mutable_type_val()->set_data_type(irpb::DT_BASE_FLOAT);
  175. } else if (val->isa<ValueSequeue>()) {
  176. SetSequenceToProto(dyn_cast<ValueSequeue>(val), value_proto);
  177. } else if (val->isa<None>()) {
  178. value_proto->set_dtype(irpb::DT_NONE);
  179. value_proto->set_str_val("None");
  180. } else if (val->isa<SymbolicKeyInstance>()) {
  181. SymbolicKeyInstancePtr sym_inst = dyn_cast<SymbolicKeyInstance>(val);
  182. ParameterPtr sym_node = dyn_cast<Parameter>(sym_inst->node());
  183. value_proto->set_dtype(irpb::DT_SYM_INST);
  184. value_proto->set_str_val(sym_node == nullptr ? std::string("nullptr") : sym_node->ToString());
  185. } else if (val->isa<ValueDictionary>()) {
  186. SetDictionaryToProto(dyn_cast<ValueDictionary>(val), value_proto);
  187. } else if (val->isa<tensor::Tensor>()) {
  188. tensor::TensorPtr tensor_ptr = dyn_cast<tensor::Tensor>(val);
  189. value_proto->set_dtype(irpb::DT_TENSOR);
  190. irpb::TensorProto *tensor_proto = value_proto->mutable_tensor_val();
  191. tensor_proto->set_data_type(GetNumberDataType(tensor_ptr->Dtype()));
  192. for (auto &elem : tensor_ptr->shape()) {
  193. tensor_proto->add_dims(elem);
  194. }
  195. } else if (val->isa<TensorType>()) {
  196. value_proto->set_dtype(irpb::DT_TYPE);
  197. irpb::TypeProto *type_proto = value_proto->mutable_type_val();
  198. type_proto->set_data_type(irpb::DT_TENSOR);
  199. TypePtr elem_type = dyn_cast<TensorType>(val)->element();
  200. type_proto->mutable_tensor_type()->set_elem_type(GetNumberDataType(elem_type));
  201. } else {
  202. MS_LOG(WARNING) << "Unsupported type " << val->type_name();
  203. }
  204. }
  205. void ProtoExporter::SetScalarToProto(const ScalarPtr &val, irpb::ValueProto *value_proto) {
  206. if (val == nullptr || value_proto == nullptr) {
  207. return;
  208. }
  209. if (val->isa<BoolImm>()) {
  210. const BoolImmPtr &value = dyn_cast<BoolImm>(val);
  211. value_proto->set_dtype(irpb::DT_BOOL);
  212. value_proto->set_bool_val(value->value());
  213. } else if (val->isa<Int8Imm>()) {
  214. const Int8ImmPtr &value = dyn_cast<Int8Imm>(val);
  215. value_proto->set_dtype(irpb::DT_INT8);
  216. value_proto->set_int_val(value->value());
  217. } else if (val->isa<Int16Imm>()) {
  218. const Int16ImmPtr &value = dyn_cast<Int16Imm>(val);
  219. value_proto->set_dtype(irpb::DT_INT16);
  220. value_proto->set_int_val(value->value());
  221. } else if (val->isa<Int32Imm>()) {
  222. const Int32ImmPtr &value = dyn_cast<Int32Imm>(val);
  223. value_proto->set_dtype(irpb::DT_INT32);
  224. value_proto->set_int_val(value->value());
  225. } else if (val->isa<Int64Imm>()) {
  226. const Int64ImmPtr &value = dyn_cast<Int64Imm>(val);
  227. value_proto->set_dtype(irpb::DT_INT64);
  228. value_proto->set_int_val(value->value());
  229. } else if (val->isa<UInt8Imm>()) {
  230. const UInt8ImmPtr &value = dyn_cast<UInt8Imm>(val);
  231. value_proto->set_dtype(irpb::DT_UINT8);
  232. value_proto->set_uint_val(value->value());
  233. } else if (val->isa<UInt16Imm>()) {
  234. const UInt16ImmPtr &value = dyn_cast<UInt16Imm>(val);
  235. value_proto->set_dtype(irpb::DT_UINT16);
  236. value_proto->set_uint_val(value->value());
  237. } else if (val->isa<UInt32Imm>()) {
  238. const UInt32ImmPtr &value = dyn_cast<UInt32Imm>(val);
  239. value_proto->set_dtype(irpb::DT_UINT32);
  240. value_proto->set_uint_val(value->value());
  241. } else if (val->isa<UInt64Imm>()) {
  242. const UInt64ImmPtr &value = dyn_cast<UInt64Imm>(val);
  243. value_proto->set_dtype(irpb::DT_UINT64);
  244. value_proto->set_uint_val(value->value());
  245. } else if (val->isa<FP32Imm>()) {
  246. const FP32ImmPtr &value = dyn_cast<FP32Imm>(val);
  247. value_proto->set_dtype(irpb::DT_FLOAT32);
  248. value_proto->set_float_val(value->value());
  249. } else if (val->isa<FP64Imm>()) {
  250. const FP64ImmPtr &value = dyn_cast<FP64Imm>(val);
  251. value_proto->set_dtype(irpb::DT_FLOAT64);
  252. value_proto->set_double_val(value->value());
  253. } else {
  254. MS_LOG(EXCEPTION) << "Unknown scalar type " << val->ToString();
  255. }
  256. }
  257. void ProtoExporter::SetSequenceToProto(const ValueSequeuePtr &val, irpb::ValueProto *value_proto) {
  258. if (val == nullptr || value_proto == nullptr) {
  259. return;
  260. }
  261. if (val->isa<ValueTuple>()) {
  262. const ValueTuplePtr &value = dyn_cast<ValueTuple>(val);
  263. value_proto->set_dtype(irpb::DT_TUPLE);
  264. for (const auto &item : value->value()) {
  265. SetValueToProto(item, value_proto->add_values());
  266. }
  267. } else if (val->isa<ValueList>()) {
  268. const ValueListPtr &value = dyn_cast<ValueList>(val);
  269. value_proto->set_dtype(irpb::DT_LIST);
  270. for (const auto &item : value->value()) {
  271. SetValueToProto(item, value_proto->add_values());
  272. }
  273. }
  274. }
  275. void ProtoExporter::SetDictionaryToProto(const ValueDictionaryPtr &val, irpb::ValueProto *value_proto) {
  276. if (val == nullptr || value_proto == nullptr) {
  277. return;
  278. }
  279. value_proto->set_dtype(irpb::DT_DICT);
  280. for (const auto &item : val->value()) {
  281. irpb::NamedValueProto *named_val = value_proto->add_dict_val();
  282. named_val->set_key(item.first);
  283. SetValueToProto(item.second, named_val->mutable_value());
  284. }
  285. }
  286. void ProtoExporter::GetOpNodeTypeAndAttrs(const FuncGraphPtr &, const AnfNodePtr &node, irpb::NodeProto *node_proto) {
  287. if (node == nullptr || node_proto == nullptr) {
  288. return;
  289. }
  290. if (node->isa<CNode>() || node->isa<Parameter>() || IsValueNode<FuncGraph>(node)) {
  291. MS_LOG(EXCEPTION) << "Op node can not be CNode, Parameter or ValueNode Graph. But got " << node->ToString();
  292. }
  293. if (!IsValueNode<Primitive>(node)) {
  294. MS_LOG(EXCEPTION) << "Op node is not primitive: " << node->ToString();
  295. }
  296. const PrimitivePtr &prim = GetValueNode<PrimitivePtr>(node);
  297. node_proto->set_op_type(prim->name());
  298. for (const auto &attr : prim->attrs()) {
  299. irpb::AttributeProto *attr_proto = node_proto->add_attribute();
  300. attr_proto->set_name(attr.first);
  301. SetValueToProto(attr.second, attr_proto->mutable_value());
  302. }
  303. node_proto->set_scope(node->scope()->name());
  304. }
  305. std::string ProtoExporter::GetOpNodeInputId(const FuncGraphPtr &, const AnfNodePtr &node,
  306. const std::map<AnfNodePtr, size_t> &apply_map,
  307. std::map<AnfNodePtr, size_t> *const_map_ptr) {
  308. if (node == nullptr || const_map_ptr == nullptr) {
  309. return "";
  310. }
  311. if (node->isa<CNode>()) {
  312. auto iter = apply_map.find(node);
  313. if (iter == apply_map.end()) {
  314. MS_LOG(EXCEPTION) << "Can not find node '" << node->ToString() << "' in apply_map";
  315. }
  316. return std::to_string(iter->second);
  317. }
  318. if (node->isa<Parameter>()) {
  319. return node->ToString();
  320. }
  321. if (node->isa<ValueNode>()) {
  322. auto iter = const_map_ptr->find(node);
  323. if (iter == const_map_ptr->end()) {
  324. // Start index number from 1
  325. auto const_idx = const_map_ptr->size() + 1;
  326. (*const_map_ptr)[node] = const_idx;
  327. }
  328. return GetConstNodeId((*const_map_ptr)[node]);
  329. }
  330. MS_LOG(EXCEPTION) << "Unknown node type. node is '" << node->ToString() << "'";
  331. }
  332. std::string ProtoExporter::GetFuncGraphProtoString(const FuncGraphPtr &func_graph) {
  333. if (func_graph == nullptr) {
  334. return "";
  335. }
  336. InitModelInfo();
  337. irpb::GraphProto *graph_proto = model_.mutable_graph();
  338. ExportFuncGraph(func_graph, graph_proto);
  339. return model_.SerializeAsString();
  340. }
  341. void ProtoExporter::ExportFuncGraph(const FuncGraphPtr &func_graph, irpb::GraphProto *graph_proto) {
  342. if (func_graph == nullptr || graph_proto == nullptr) {
  343. return;
  344. }
  345. // map for store ValueNodes of this graph
  346. std::map<AnfNodePtr, size_t> const_map;
  347. // set graph name
  348. graph_proto->set_name(func_graph->ToString());
  349. ExportParameters(func_graph, graph_proto);
  350. ExportCNodes(func_graph, graph_proto, &const_map);
  351. ExportValueNodes(const_map, graph_proto);
  352. }
  353. void ProtoExporter::ExportParameters(const FuncGraphPtr &func_graph, irpb::GraphProto *graph_proto) {
  354. if (func_graph == nullptr || graph_proto == nullptr) {
  355. return;
  356. }
  357. std::vector<AnfNodePtr> parameters = func_graph->parameters();
  358. for (auto &param : parameters) {
  359. irpb::ParameterProto *param_proto = graph_proto->add_parameters();
  360. param_proto->set_name(param->ToString());
  361. SetNodeOutputType(param, param_proto->mutable_type());
  362. const ParameterPtr param_ptr = dyn_cast<Parameter>(param);
  363. if (param_ptr == nullptr) {
  364. MS_LOG(EXCEPTION) << "Parameter '" << param->ToString() << "' could not cast to parameter.";
  365. }
  366. }
  367. }
  368. void ProtoExporter::ExportCNodes(const FuncGraphPtr &func_graph, irpb::GraphProto *graph_proto,
  369. std::map<AnfNodePtr, size_t> *const_map_ptr) {
  370. if (func_graph == nullptr || graph_proto == nullptr || const_map_ptr == nullptr) {
  371. return;
  372. }
  373. // topo sort nodes
  374. std::vector<AnfNodePtr> nodes = TopoSort(func_graph->get_return(), SuccIncoming, AlwaysInclude);
  375. std::map<AnfNodePtr, size_t> apply_map;
  376. for (const AnfNodePtr &node : nodes) {
  377. MS_EXCEPTION_IF_NULL(node);
  378. if (!node->isa<CNode>()) {
  379. continue;
  380. }
  381. auto cnode = node->cast<CNodePtr>();
  382. if (cnode != func_graph->get_return()) {
  383. ExportCNode(func_graph, cnode, &apply_map, const_map_ptr, graph_proto);
  384. } else {
  385. ExportFuncGraphOutput(func_graph, cnode, apply_map, const_map_ptr, graph_proto);
  386. }
  387. }
  388. }
  389. void ProtoExporter::ExportCNode(const FuncGraphPtr &func_graph, const CNodePtr &node,
  390. std::map<AnfNodePtr, size_t> *apply_map_ptr,
  391. std::map<AnfNodePtr, size_t> *const_map_ptr, irpb::GraphProto *graph_proto) {
  392. if (func_graph == nullptr || node == nullptr || apply_map_ptr == nullptr || const_map_ptr == nullptr ||
  393. graph_proto == nullptr) {
  394. return;
  395. }
  396. auto apply_idx = apply_map_ptr->size() + 1;
  397. (*apply_map_ptr)[node] = apply_idx;
  398. auto &inputs = node->inputs();
  399. if (inputs.size() < 1) {
  400. MS_LOG(EXCEPTION) << "Inputs of apply node is empty";
  401. }
  402. AnfNodePtr op = inputs[0];
  403. irpb::NodeProto *node_proto = graph_proto->add_node();
  404. // CNode/ConstGraph/Const/Parameter
  405. if (op->isa<CNode>() || IsValueNode<FuncGraph>(op) || op->isa<Parameter>()) {
  406. MS_LOG(WARNING) << "Operator must be a primitive";
  407. } else {
  408. GetOpNodeTypeAndAttrs(func_graph, op, node_proto);
  409. node_proto->set_name(std::to_string(apply_idx));
  410. node_proto->set_scope(node->scope()->name());
  411. node_proto->set_full_name(node->fullname_with_scope());
  412. // process OP inputs
  413. for (size_t i = 1; i < inputs.size(); ++i) {
  414. irpb::InputProto *input_proto = node_proto->add_input();
  415. input_proto->set_type(irpb::InputProto_EdgeType_DATA_EDGE);
  416. std::string id = GetOpNodeInputId(func_graph, inputs[i], *apply_map_ptr, const_map_ptr);
  417. input_proto->set_name(id);
  418. }
  419. // set node output type
  420. SetNodeOutputType(node, node_proto->mutable_output_type());
  421. }
  422. }
  423. void ProtoExporter::ExportFuncGraphOutput(const FuncGraphPtr &func_graph, const CNodePtr &ret_node,
  424. const std::map<AnfNodePtr, size_t> &apply_map,
  425. std::map<AnfNodePtr, size_t> *const_map_ptr, irpb::GraphProto *graph_proto) {
  426. if (ret_node == nullptr || !ret_node->isa<CNode>()) {
  427. MS_LOG(EXCEPTION) << "Graph return node is illegal";
  428. }
  429. AnfNodePtr arg = ret_node->input(1);
  430. if (graph_proto == nullptr) {
  431. MS_LOG(EXCEPTION) << "graph_proto is nullptr";
  432. }
  433. irpb::OutputProto *output_proto = graph_proto->add_outputs();
  434. if (output_proto == nullptr) {
  435. MS_LOG(EXCEPTION) << "output_proto is nullptr";
  436. }
  437. std::string id = GetOpNodeInputId(func_graph, arg, apply_map, const_map_ptr);
  438. output_proto->set_name(id);
  439. SetNodeOutputType(arg, output_proto->mutable_type());
  440. }
  441. static bool CompareValue(const std::pair<AnfNodePtr, size_t> &x, const std::pair<AnfNodePtr, size_t> &y) {
  442. return x.second < y.second;
  443. }
  444. void ProtoExporter::ExportValueNodes(const std::map<AnfNodePtr, size_t> &const_map, irpb::GraphProto *graph_proto) {
  445. std::vector<std::pair<AnfNodePtr, size_t>> nodes;
  446. (void)std::transform(const_map.cbegin(), const_map.cend(), std::back_inserter(nodes),
  447. [](const std::pair<AnfNodePtr, size_t> &item) { return item; });
  448. sort(nodes.begin(), nodes.end(), CompareValue);
  449. for (auto &item : nodes) {
  450. if (graph_proto == nullptr) {
  451. MS_LOG(EXCEPTION) << "graph_proto is nullptr";
  452. }
  453. irpb::NamedValueProto *named_value = graph_proto->add_const_vals();
  454. MS_EXCEPTION_IF_NULL(named_value);
  455. named_value->set_key(GetConstNodeId(item.second));
  456. SetValueToProto(GetValueNode(item.first), named_value->mutable_value());
  457. }
  458. }
  459. void ProtoExporter::InitModelInfo() { model_.set_ir_version(irpb::IR_VERSION); }
  460. std::string GetFuncGraphProtoString(const FuncGraphPtr &func_graph) {
  461. ProtoExporter exporter;
  462. return exporter.GetFuncGraphProtoString(func_graph);
  463. }
  464. } // namespace mindspore