You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

proto_exporter.cc 22 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include <fstream>
  17. #include <map>
  18. #include <memory>
  19. #include <unordered_map>
  20. #include <unordered_set>
  21. #include <utility>
  22. #include <algorithm>
  23. #include "debug/debugger/debugger.h"
  24. #include "proto/debug_graph.pb.h"
  25. #include "ir/graph_utils.h"
  26. #include "utils/symbolic.h"
  27. namespace mindspore {
  28. class DebuggerProtoExporter {
  29. public:
  30. DebuggerProtoExporter() {}
  31. ~DebuggerProtoExporter() {}
  32. std::string GetFuncGraphProtoString(const FuncGraphPtr &func_graph);
  33. debugger::ModelProto GetFuncGraphProto(const FuncGraphPtr &func_graph);
  34. private:
  35. void InitModelInfo();
  36. void GetOpNodeTypeAndAttrs(const FuncGraphPtr &func_graph, const AnfNodePtr &node, debugger::NodeProto *node_proto);
  37. std::string GetOpNodeInputId(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
  38. const std::map<AnfNodePtr, size_t> &apply_map,
  39. std::map<AnfNodePtr, size_t> *const_map_ptr);
  40. void SetValueToProto(const ValuePtr &attr_value, debugger::ValueProto *value_proto);
  41. void SetScalarToProto(const ScalarPtr &val, debugger::ValueProto *value_proto);
  42. void SetSequenceToProto(const ValueSequeuePtr &val, debugger::ValueProto *value_proto);
  43. void SetDictionaryToProto(const ValueDictionaryPtr &val, debugger::ValueProto *value_proto);
  44. void SetNodeOutputType(const AnfNodePtr &node, debugger::TypeProto *type_proto);
  45. void SetNodeOutputType(const TypePtr &node, const BaseShapePtr &shape, debugger::TypeProto *type_proto);
  46. void ExportFuncGraph(const FuncGraphPtr &func_graph, debugger::GraphProto *graph_proto);
  47. void ExportParameters(const FuncGraphPtr &func_graph, debugger::GraphProto *graph_proto);
  48. void ExportCNodes(const FuncGraphPtr &func_graph, debugger::GraphProto *graph_proto,
  49. std::map<AnfNodePtr, size_t> *const_map_ptr);
  50. void ExportCNode(const FuncGraphPtr &func_graph, const CNodePtr &node, std::map<AnfNodePtr, size_t> *apply_map_ptr,
  51. std::map<AnfNodePtr, size_t> *const_map_ptr, debugger::GraphProto *graph_proto);
  52. void ExportFuncGraphOutput(const FuncGraphPtr &func_graph, const CNodePtr &ret_node,
  53. const std::map<AnfNodePtr, size_t> &apply_map, std::map<AnfNodePtr, size_t> *const_map_ptr,
  54. debugger::GraphProto *graph_proto);
  55. void ExportValueNodes(const std::map<AnfNodePtr, size_t> &const_map, debugger::GraphProto *graph_proto);
  56. static std::string GetConstNodeId(size_t idx) { return std::string("cst") + std::to_string(idx); }
  57. debugger::ModelProto model_;
  58. };
  59. void DebuggerProtoExporter::SetNodeOutputType(const TypePtr &type, const BaseShapePtr &shape,
  60. debugger::TypeProto *type_proto) {
  61. if (type_proto == nullptr) {
  62. return;
  63. }
  64. if (type == nullptr) {
  65. type_proto->set_data_type(debugger::DT_UNDEFINED);
  66. } else if (type->isa<Number>()) {
  67. type_proto->set_data_type(GetDebuggerNumberDataType(type));
  68. } else if (type->isa<TensorType>()) {
  69. TypePtr elem_type = dyn_cast<TensorType>(type)->element();
  70. type_proto->mutable_tensor_type()->set_elem_type(GetDebuggerNumberDataType(elem_type));
  71. type_proto->set_data_type(debugger::DT_TENSOR);
  72. if (shape != nullptr && shape->isa<abstract::Shape>()) {
  73. abstract::ShapePtr shape_info = dyn_cast<abstract::Shape>(shape);
  74. for (const auto &elem : shape_info->shape()) {
  75. type_proto->mutable_tensor_type()->mutable_shape()->add_dim()->set_size(elem);
  76. }
  77. }
  78. } else if (type->isa<Tuple>()) {
  79. TuplePtr tuple_type = dyn_cast<Tuple>(type);
  80. type_proto->set_data_type(debugger::DT_TUPLE);
  81. for (const auto &elem_type : tuple_type->elements()) {
  82. SetNodeOutputType(elem_type, nullptr, type_proto->mutable_sequence_type()->add_elem_types());
  83. }
  84. } else if (type->isa<TypeType>()) {
  85. type_proto->set_data_type(debugger::DT_TYPE);
  86. } else if (type->isa<List>()) {
  87. ListPtr list_type = dyn_cast<List>(type);
  88. type_proto->set_data_type(debugger::DT_LIST);
  89. for (const auto &elem_type : list_type->elements()) {
  90. SetNodeOutputType(elem_type, nullptr, type_proto->mutable_sequence_type()->add_elem_types());
  91. }
  92. } else if (type->isa<TypeAnything>()) {
  93. type_proto->set_data_type(debugger::DT_ANYTHING);
  94. } else if (type->isa<RefKeyType>()) {
  95. type_proto->set_data_type(debugger::DT_REFKEY);
  96. } else if (type->isa<RefType>()) {
  97. type_proto->set_data_type(debugger::DT_REF);
  98. } else if (type->isa<Function>()) {
  99. type_proto->set_data_type(debugger::DT_GRAPH);
  100. } else if (type->isa<TypeNone>()) {
  101. type_proto->set_data_type(debugger::DT_NONE);
  102. } else if (type->isa<String>()) {
  103. type_proto->set_data_type(debugger::DT_STRING);
  104. } else if (type->isa<SymbolicKeyType>()) {
  105. // Do Nothing.
  106. } else if (type->isa<UMonadType>()) {
  107. type_proto->set_data_type(debugger::DT_UMONAD);
  108. } else if (type->isa<IOMonadType>()) {
  109. type_proto->set_data_type(debugger::DT_IOMONAD);
  110. } else {
  111. MS_LOG(EXCEPTION) << "Unknown type: " << type->type_name();
  112. }
  113. }
  114. void DebuggerProtoExporter::SetNodeOutputType(const AnfNodePtr &node, debugger::TypeProto *type_proto) {
  115. if (node == nullptr || type_proto == nullptr) {
  116. return;
  117. }
  118. SetNodeOutputType(node->Type(), node->Shape(), type_proto);
  119. }
  120. void DebuggerProtoExporter::SetValueToProto(const ValuePtr &val, debugger::ValueProto *value_proto) {
  121. if (val == nullptr || value_proto == nullptr) {
  122. return;
  123. }
  124. if (val->isa<StringImm>()) {
  125. const StringImmPtr &value = dyn_cast<StringImm>(val);
  126. value_proto->set_dtype(debugger::DT_STRING);
  127. value_proto->set_str_val(value->value());
  128. } else if (val->isa<Scalar>()) {
  129. SetScalarToProto(dyn_cast<Scalar>(val), value_proto);
  130. } else if (val->isa<Bool>()) {
  131. value_proto->set_dtype(debugger::DT_TYPE);
  132. value_proto->mutable_type_val()->set_data_type(debugger::DT_BOOL);
  133. } else if (val->isa<Int>()) {
  134. value_proto->set_dtype(debugger::DT_TYPE);
  135. value_proto->mutable_type_val()->set_data_type(debugger::DT_BASE_INT);
  136. } else if (val->isa<Float>()) {
  137. value_proto->set_dtype(debugger::DT_TYPE);
  138. value_proto->mutable_type_val()->set_data_type(debugger::DT_BASE_FLOAT);
  139. } else if (val->isa<ValueSequeue>()) {
  140. SetSequenceToProto(dyn_cast<ValueSequeue>(val), value_proto);
  141. } else if (val->isa<None>()) {
  142. value_proto->set_dtype(debugger::DT_NONE);
  143. value_proto->set_str_val("None");
  144. } else if (val->isa<SymbolicKeyInstance>()) {
  145. SymbolicKeyInstancePtr sym_inst = dyn_cast<SymbolicKeyInstance>(val);
  146. ParameterPtr sym_node = dyn_cast<Parameter>(sym_inst->node());
  147. value_proto->set_dtype(debugger::DT_SYM_INST);
  148. value_proto->set_str_val(sym_node == nullptr ? std::string("nullptr") : sym_node->ToString());
  149. } else if (val->isa<ValueDictionary>()) {
  150. SetDictionaryToProto(dyn_cast<ValueDictionary>(val), value_proto);
  151. } else if (val->isa<tensor::Tensor>()) {
  152. tensor::TensorPtr tensor_ptr = dyn_cast<tensor::Tensor>(val);
  153. value_proto->set_dtype(debugger::DT_TENSOR);
  154. debugger::TensorProto *tensor_proto = value_proto->mutable_tensor_val();
  155. tensor_proto->set_data_type(GetDebuggerNumberDataType(tensor_ptr->Dtype()));
  156. for (auto &elem : tensor_ptr->shape()) {
  157. tensor_proto->add_dims(elem);
  158. }
  159. tensor_proto->set_tensor_content(tensor_ptr->data_c(), tensor_ptr->data().nbytes());
  160. } else if (val->isa<TensorType>()) {
  161. value_proto->set_dtype(debugger::DT_TYPE);
  162. debugger::TypeProto *type_proto = value_proto->mutable_type_val();
  163. type_proto->set_data_type(debugger::DT_TENSOR);
  164. TypePtr elem_type = dyn_cast<TensorType>(val)->element();
  165. type_proto->mutable_tensor_type()->set_elem_type(GetDebuggerNumberDataType(elem_type));
  166. } else {
  167. MS_LOG(WARNING) << "Unsupported type " << val->type_name();
  168. }
  169. }
  170. void DebuggerProtoExporter::SetScalarToProto(const ScalarPtr &val, debugger::ValueProto *value_proto) {
  171. if (val == nullptr || value_proto == nullptr) {
  172. return;
  173. }
  174. if (val->isa<BoolImm>()) {
  175. const BoolImmPtr &value = dyn_cast<BoolImm>(val);
  176. value_proto->set_dtype(debugger::DT_BOOL);
  177. value_proto->set_bool_val(value->value());
  178. } else if (val->isa<Int8Imm>()) {
  179. const Int8ImmPtr &value = dyn_cast<Int8Imm>(val);
  180. value_proto->set_dtype(debugger::DT_INT8);
  181. value_proto->set_int_val(value->value());
  182. } else if (val->isa<Int16Imm>()) {
  183. const Int16ImmPtr &value = dyn_cast<Int16Imm>(val);
  184. value_proto->set_dtype(debugger::DT_INT16);
  185. value_proto->set_int_val(value->value());
  186. } else if (val->isa<Int32Imm>()) {
  187. const Int32ImmPtr &value = dyn_cast<Int32Imm>(val);
  188. value_proto->set_dtype(debugger::DT_INT32);
  189. value_proto->set_int_val(value->value());
  190. } else if (val->isa<Int64Imm>()) {
  191. const Int64ImmPtr &value = dyn_cast<Int64Imm>(val);
  192. value_proto->set_dtype(debugger::DT_INT64);
  193. value_proto->set_int_val(value->value());
  194. } else if (val->isa<UInt8Imm>()) {
  195. const UInt8ImmPtr &value = dyn_cast<UInt8Imm>(val);
  196. value_proto->set_dtype(debugger::DT_UINT8);
  197. value_proto->set_uint_val(value->value());
  198. } else if (val->isa<UInt16Imm>()) {
  199. const UInt16ImmPtr &value = dyn_cast<UInt16Imm>(val);
  200. value_proto->set_dtype(debugger::DT_UINT16);
  201. value_proto->set_uint_val(value->value());
  202. } else if (val->isa<UInt32Imm>()) {
  203. const UInt32ImmPtr &value = dyn_cast<UInt32Imm>(val);
  204. value_proto->set_dtype(debugger::DT_UINT32);
  205. value_proto->set_uint_val(value->value());
  206. } else if (val->isa<UInt64Imm>()) {
  207. const UInt64ImmPtr &value = dyn_cast<UInt64Imm>(val);
  208. value_proto->set_dtype(debugger::DT_UINT64);
  209. value_proto->set_uint_val(value->value());
  210. } else if (val->isa<FP32Imm>()) {
  211. const FP32ImmPtr &value = dyn_cast<FP32Imm>(val);
  212. value_proto->set_dtype(debugger::DT_FLOAT32);
  213. value_proto->set_float_val(value->value());
  214. } else if (val->isa<FP64Imm>()) {
  215. const FP64ImmPtr &value = dyn_cast<FP64Imm>(val);
  216. value_proto->set_dtype(debugger::DT_FLOAT64);
  217. value_proto->set_double_val(value->value());
  218. } else {
  219. MS_LOG(EXCEPTION) << "Unknown scalar type " << val->ToString();
  220. }
  221. }
  222. void DebuggerProtoExporter::SetSequenceToProto(const ValueSequeuePtr &val, debugger::ValueProto *value_proto) {
  223. if (val == nullptr || value_proto == nullptr) {
  224. return;
  225. }
  226. if (val->isa<ValueTuple>()) {
  227. const ValueTuplePtr &value = dyn_cast<ValueTuple>(val);
  228. value_proto->set_dtype(debugger::DT_TUPLE);
  229. for (const auto &item : value->value()) {
  230. SetValueToProto(item, value_proto->add_values());
  231. }
  232. } else if (val->isa<ValueList>()) {
  233. const ValueListPtr &value = dyn_cast<ValueList>(val);
  234. value_proto->set_dtype(debugger::DT_LIST);
  235. for (const auto &item : value->value()) {
  236. SetValueToProto(item, value_proto->add_values());
  237. }
  238. }
  239. }
  240. void DebuggerProtoExporter::SetDictionaryToProto(const ValueDictionaryPtr &val, debugger::ValueProto *value_proto) {
  241. if (val == nullptr || value_proto == nullptr) {
  242. return;
  243. }
  244. value_proto->set_dtype(debugger::DT_DICT);
  245. for (const auto &item : val->value()) {
  246. debugger::NamedValueProto *named_val = value_proto->add_dict_val();
  247. named_val->set_key(item.first);
  248. SetValueToProto(item.second, named_val->mutable_value());
  249. }
  250. }
  251. void DebuggerProtoExporter::GetOpNodeTypeAndAttrs(const FuncGraphPtr &, const AnfNodePtr &node,
  252. debugger::NodeProto *node_proto) {
  253. if (node == nullptr || node_proto == nullptr) {
  254. return;
  255. }
  256. if (node->isa<CNode>() || node->isa<Parameter>() || IsValueNode<FuncGraph>(node)) {
  257. MS_LOG(EXCEPTION) << "Op node can not be CNode, Parameter or ValueNode Graph. But got " << node->ToString();
  258. }
  259. if (!IsValueNode<Primitive>(node)) {
  260. MS_LOG(EXCEPTION) << "Op node is not primitive: " << node->ToString();
  261. }
  262. const PrimitivePtr &prim = GetValueNode<PrimitivePtr>(node);
  263. node_proto->set_op_type(prim->name());
  264. for (const auto &attr : prim->attrs()) {
  265. debugger::AttributeProto *attr_proto = node_proto->add_attribute();
  266. attr_proto->set_name(attr.first);
  267. SetValueToProto(attr.second, attr_proto->mutable_value());
  268. }
  269. node_proto->set_scope(node->scope()->name());
  270. }
  271. std::string DebuggerProtoExporter::GetOpNodeInputId(const FuncGraphPtr &, const AnfNodePtr &node,
  272. const std::map<AnfNodePtr, size_t> &apply_map,
  273. std::map<AnfNodePtr, size_t> *const_map_ptr) {
  274. if (node == nullptr || const_map_ptr == nullptr) {
  275. return "";
  276. }
  277. if (node->isa<CNode>()) {
  278. auto iter = apply_map.find(node);
  279. if (iter == apply_map.end()) {
  280. MS_LOG(EXCEPTION) << "Can not find node '" << node->ToString() << "' in apply_map";
  281. }
  282. return std::to_string(iter->second);
  283. }
  284. if (node->isa<Parameter>()) {
  285. return node->ToString();
  286. }
  287. if (node->isa<ValueNode>()) {
  288. auto iter = const_map_ptr->find(node);
  289. if (iter == const_map_ptr->end()) {
  290. // Start index number from 1
  291. auto const_idx = const_map_ptr->size() + 1;
  292. (*const_map_ptr)[node] = const_idx;
  293. }
  294. return GetConstNodeId((*const_map_ptr)[node]);
  295. }
  296. MS_LOG(EXCEPTION) << "Unknown node type. node is '" << node->ToString() << "'";
  297. }
  298. std::string DebuggerProtoExporter::GetFuncGraphProtoString(const FuncGraphPtr &func_graph) {
  299. if (func_graph == nullptr) {
  300. return "";
  301. }
  302. InitModelInfo();
  303. debugger::GraphProto *graph_proto = model_.mutable_graph();
  304. ExportFuncGraph(func_graph, graph_proto);
  305. return model_.SerializeAsString();
  306. }
  307. debugger::ModelProto DebuggerProtoExporter::GetFuncGraphProto(const FuncGraphPtr &func_graph) {
  308. if (func_graph == nullptr) {
  309. return ModelProto();
  310. }
  311. InitModelInfo();
  312. debugger::GraphProto *graph_proto = model_.mutable_graph();
  313. ExportFuncGraph(func_graph, graph_proto);
  314. return model_;
  315. }
  316. void DebuggerProtoExporter::ExportFuncGraph(const FuncGraphPtr &func_graph, debugger::GraphProto *graph_proto) {
  317. if (func_graph == nullptr || graph_proto == nullptr) {
  318. return;
  319. }
  320. // map for store ValueNodes of this graph
  321. std::map<AnfNodePtr, size_t> const_map;
  322. // set graph name
  323. graph_proto->set_name(func_graph->ToString());
  324. MS_LOG(INFO) << "graph names: " << func_graph->ToString();
  325. ExportParameters(func_graph, graph_proto);
  326. ExportCNodes(func_graph, graph_proto, &const_map);
  327. ExportValueNodes(const_map, graph_proto);
  328. }
  329. void DebuggerProtoExporter::ExportParameters(const FuncGraphPtr &func_graph, debugger::GraphProto *graph_proto) {
  330. if (func_graph == nullptr || graph_proto == nullptr) {
  331. return;
  332. }
  333. // cast FuncGraph to KernelGraph to access inputs()
  334. std::vector<AnfNodePtr> parameters = static_cast<session::KernelGraph *>(func_graph.get())->inputs();
  335. for (auto &param : parameters) {
  336. debugger::ParameterProto *param_proto = graph_proto->add_parameters();
  337. param_proto->set_name(param->ToString());
  338. SetNodeOutputType(param, param_proto->mutable_type());
  339. const ParameterPtr param_ptr = dyn_cast<Parameter>(param);
  340. if (param_ptr == nullptr) {
  341. MS_LOG(EXCEPTION) << "Parameter '" << param->ToString() << "' could not cast to parameter.";
  342. }
  343. }
  344. }
  345. void DebuggerProtoExporter::ExportCNodes(const FuncGraphPtr &func_graph, debugger::GraphProto *graph_proto,
  346. std::map<AnfNodePtr, size_t> *const_map_ptr) {
  347. if (func_graph == nullptr || graph_proto == nullptr || const_map_ptr == nullptr) {
  348. return;
  349. }
  350. // topo sort nodes
  351. std::vector<AnfNodePtr> nodes = TopoSort(func_graph->get_return(), SuccIncoming, AlwaysInclude);
  352. std::map<AnfNodePtr, size_t> apply_map;
  353. for (const AnfNodePtr &node : nodes) {
  354. MS_EXCEPTION_IF_NULL(node);
  355. if (!node->isa<CNode>()) {
  356. continue;
  357. }
  358. auto cnode = node->cast<CNodePtr>();
  359. if (cnode != func_graph->get_return()) {
  360. ExportCNode(func_graph, cnode, &apply_map, const_map_ptr, graph_proto);
  361. } else {
  362. ExportFuncGraphOutput(func_graph, cnode, apply_map, const_map_ptr, graph_proto);
  363. }
  364. }
  365. }
  366. void DebuggerProtoExporter::ExportCNode(const FuncGraphPtr &func_graph, const CNodePtr &node,
  367. std::map<AnfNodePtr, size_t> *apply_map_ptr,
  368. std::map<AnfNodePtr, size_t> *const_map_ptr,
  369. debugger::GraphProto *graph_proto) {
  370. if (func_graph == nullptr || node == nullptr || apply_map_ptr == nullptr || const_map_ptr == nullptr ||
  371. graph_proto == nullptr) {
  372. return;
  373. }
  374. auto apply_idx = apply_map_ptr->size() + 1;
  375. (*apply_map_ptr)[node] = apply_idx;
  376. auto &inputs = node->inputs();
  377. if (inputs.size() < 1) {
  378. MS_LOG(EXCEPTION) << "Inputs of apply node is empty";
  379. }
  380. AnfNodePtr op = inputs[0];
  381. debugger::NodeProto *node_proto = graph_proto->add_node();
  382. // CNode/ConstGraph/Const/Parameter
  383. if (op->isa<CNode>() || IsValueNode<FuncGraph>(op) || op->isa<Parameter>()) {
  384. MS_LOG(WARNING) << "Operator must be a primitive";
  385. } else {
  386. GetOpNodeTypeAndAttrs(func_graph, op, node_proto);
  387. node_proto->set_name(std::to_string(apply_idx));
  388. node_proto->set_scope(node->scope()->name());
  389. // add full_name for debugger
  390. node_proto->set_full_name(node->fullname_with_scope());
  391. MS_LOG(INFO) << "full_name: " << node->fullname_with_scope();
  392. // process OP inputs
  393. for (size_t i = 1; i < inputs.size(); ++i) {
  394. debugger::InputProto *input_proto = node_proto->add_input();
  395. input_proto->set_type(debugger::InputProto_EdgeType_DATA_EDGE);
  396. std::string id = GetOpNodeInputId(func_graph, inputs[i], *apply_map_ptr, const_map_ptr);
  397. input_proto->set_name(id);
  398. }
  399. // set node output type
  400. SetNodeOutputType(node, node_proto->mutable_output_type());
  401. }
  402. }
  403. void DebuggerProtoExporter::ExportFuncGraphOutput(const FuncGraphPtr &func_graph, const CNodePtr &ret_node,
  404. const std::map<AnfNodePtr, size_t> &apply_map,
  405. std::map<AnfNodePtr, size_t> *const_map_ptr,
  406. debugger::GraphProto *graph_proto) {
  407. if (ret_node == nullptr || !ret_node->isa<CNode>()) {
  408. MS_LOG(EXCEPTION) << "Graph return node is illegal";
  409. }
  410. AnfNodePtr arg = ret_node->input(1);
  411. if (graph_proto == nullptr) {
  412. MS_LOG(EXCEPTION) << "graph_proto is nullptr";
  413. }
  414. debugger::OutputProto *output_proto = graph_proto->add_outputs();
  415. if (output_proto == nullptr) {
  416. MS_LOG(EXCEPTION) << "output_proto is nullptr";
  417. }
  418. std::string id = GetOpNodeInputId(func_graph, arg, apply_map, const_map_ptr);
  419. output_proto->set_name(id);
  420. SetNodeOutputType(arg, output_proto->mutable_type());
  421. }
  422. static bool CompareValue(const std::pair<AnfNodePtr, size_t> &x, const std::pair<AnfNodePtr, size_t> &y) {
  423. return x.second < y.second;
  424. }
  425. void DebuggerProtoExporter::ExportValueNodes(const std::map<AnfNodePtr, size_t> &const_map,
  426. debugger::GraphProto *graph_proto) {
  427. std::vector<std::pair<AnfNodePtr, size_t>> nodes;
  428. (void)std::transform(const_map.cbegin(), const_map.cend(), std::back_inserter(nodes),
  429. [](const std::pair<AnfNodePtr, size_t> &item) { return item; });
  430. sort(nodes.begin(), nodes.end(), CompareValue);
  431. for (auto &item : nodes) {
  432. if (graph_proto == nullptr) {
  433. MS_LOG(EXCEPTION) << "graph_proto is nullptr";
  434. }
  435. debugger::NamedValueProto *named_value = graph_proto->add_const_vals();
  436. MS_EXCEPTION_IF_NULL(named_value);
  437. named_value->set_key(GetConstNodeId(item.second));
  438. SetValueToProto(GetValueNode(item.first), named_value->mutable_value());
  439. }
  440. }
  441. void DebuggerProtoExporter::InitModelInfo() { model_.set_ir_version(debugger::IR_VERSION); }
  442. std::string GetDebuggerFuncGraphProtoString(const FuncGraphPtr &func_graph) {
  443. DebuggerProtoExporter exporter;
  444. return exporter.GetFuncGraphProtoString(func_graph);
  445. }
  446. debugger::ModelProto GetDebuggerFuncGraphProto(const FuncGraphPtr &func_graph) {
  447. DebuggerProtoExporter exporter;
  448. return exporter.GetFuncGraphProto(func_graph);
  449. }
  450. debugger::DataType GetDebuggerNumberDataType(const TypePtr &type) {
  451. switch (type->type_id()) {
  452. case kNumberTypeBool:
  453. return debugger::DT_BOOL;
  454. case kNumberTypeInt8:
  455. return debugger::DT_INT8;
  456. case kNumberTypeInt16:
  457. return debugger::DT_INT16;
  458. case kNumberTypeInt32:
  459. return debugger::DT_INT32;
  460. case kNumberTypeInt64:
  461. return debugger::DT_INT64;
  462. case kNumberTypeUInt8:
  463. return debugger::DT_UINT8;
  464. case kNumberTypeUInt16:
  465. return debugger::DT_UINT16;
  466. case kNumberTypeUInt32:
  467. return debugger::DT_UINT32;
  468. case kNumberTypeUInt64:
  469. return debugger::DT_UINT64;
  470. case kNumberTypeFloat16:
  471. return debugger::DT_FLOAT16;
  472. case kNumberTypeFloat32:
  473. return debugger::DT_FLOAT32;
  474. case kNumberTypeFloat64:
  475. return debugger::DT_FLOAT64;
  476. case kNumberTypeInt:
  477. return debugger::DT_BASE_INT;
  478. case kNumberTypeUInt:
  479. return debugger::DT_BASE_UINT;
  480. case kNumberTypeFloat:
  481. return debugger::DT_BASE_FLOAT;
  482. default:
  483. MS_LOG(EXCEPTION) << "Unexpected type " << type->type_name();
  484. }
  485. }
  486. } // namespace mindspore