You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

proto_exporter.cc 22 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include <fstream>
  17. #include <map>
  18. #include <memory>
  19. #include <unordered_map>
  20. #include <unordered_set>
  21. #include <utility>
  22. #include <algorithm>
  23. #include "debug/debugger/debugger.h"
  24. #include "proto/debug_graph.pb.h"
  25. #include "ir/graph_utils.h"
  26. #include "utils/symbolic.h"
  27. namespace mindspore {
  28. class DebuggerProtoExporter {
  29. public:
  30. DebuggerProtoExporter() {}
  31. ~DebuggerProtoExporter() {}
  32. std::string GetFuncGraphProtoString(const FuncGraphPtr &func_graph);
  33. debugger::ModelProto GetFuncGraphProto(const FuncGraphPtr &func_graph);
  34. private:
  35. void InitModelInfo();
  36. void GetOpNodeTypeAndAttrs(const FuncGraphPtr &func_graph, const AnfNodePtr &node, debugger::NodeProto *node_proto);
  37. std::string GetOpNodeInputId(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
  38. const std::map<AnfNodePtr, size_t> &apply_map,
  39. std::map<AnfNodePtr, size_t> *const_map_ptr);
  40. void SetValueToProto(const ValuePtr &attr_value, debugger::ValueProto *value_proto);
  41. void SetScalarToProto(const ScalarPtr &val, debugger::ValueProto *value_proto);
  42. void SetSequenceToProto(const ValueSequeuePtr &val, debugger::ValueProto *value_proto);
  43. void SetDictionaryToProto(const ValueDictionaryPtr &val, debugger::ValueProto *value_proto);
  44. void SetNodeOutputType(const AnfNodePtr &node, debugger::TypeProto *type_proto);
  45. void SetNodeOutputType(const TypePtr &node, const BaseShapePtr &shape, debugger::TypeProto *type_proto);
  46. void ExportFuncGraph(const FuncGraphPtr &func_graph, debugger::GraphProto *graph_proto);
  47. void ExportParameters(const FuncGraphPtr &func_graph, debugger::GraphProto *graph_proto);
  48. void ExportCNodes(const FuncGraphPtr &func_graph, debugger::GraphProto *graph_proto,
  49. std::map<AnfNodePtr, size_t> *const_map_ptr);
  50. void ExportCNode(const FuncGraphPtr &func_graph, const CNodePtr &node, std::map<AnfNodePtr, size_t> *apply_map_ptr,
  51. std::map<AnfNodePtr, size_t> *const_map_ptr, debugger::GraphProto *graph_proto);
  52. void ExportFuncGraphOutput(const FuncGraphPtr &func_graph, const CNodePtr &ret_node,
  53. const std::map<AnfNodePtr, size_t> &apply_map, std::map<AnfNodePtr, size_t> *const_map_ptr,
  54. debugger::GraphProto *graph_proto);
  55. void ExportValueNodes(const std::map<AnfNodePtr, size_t> &const_map, debugger::GraphProto *graph_proto);
  56. static std::string GetConstNodeId(size_t idx) { return std::string("cst") + std::to_string(idx); }
  57. debugger::ModelProto model_;
  58. };
  59. void DebuggerProtoExporter::SetNodeOutputType(const TypePtr &type, const BaseShapePtr &shape,
  60. debugger::TypeProto *type_proto) {
  61. if (type_proto == nullptr) {
  62. return;
  63. }
  64. if (type == nullptr) {
  65. type_proto->set_data_type(debugger::DT_UNDEFINED);
  66. } else if (type->isa<Number>()) {
  67. type_proto->set_data_type(GetDebuggerNumberDataType(type));
  68. } else if (type->isa<TensorType>()) {
  69. TypePtr elem_type = dyn_cast<TensorType>(type)->element();
  70. type_proto->mutable_tensor_type()->set_elem_type(GetDebuggerNumberDataType(elem_type));
  71. type_proto->set_data_type(debugger::DT_TENSOR);
  72. if (shape != nullptr && shape->isa<abstract::Shape>()) {
  73. abstract::ShapePtr shape_info = dyn_cast<abstract::Shape>(shape);
  74. for (const auto &elem : shape_info->shape()) {
  75. type_proto->mutable_tensor_type()->mutable_shape()->add_dim()->set_size(elem);
  76. }
  77. }
  78. } else if (type->isa<Tuple>()) {
  79. TuplePtr tuple_type = dyn_cast<Tuple>(type);
  80. type_proto->set_data_type(debugger::DT_TUPLE);
  81. for (const auto &elem_type : tuple_type->elements()) {
  82. SetNodeOutputType(elem_type, nullptr, type_proto->mutable_sequence_type()->add_elem_types());
  83. }
  84. } else if (type->isa<TypeType>()) {
  85. type_proto->set_data_type(debugger::DT_TYPE);
  86. } else if (type->isa<List>()) {
  87. ListPtr list_type = dyn_cast<List>(type);
  88. type_proto->set_data_type(debugger::DT_LIST);
  89. for (const auto &elem_type : list_type->elements()) {
  90. SetNodeOutputType(elem_type, nullptr, type_proto->mutable_sequence_type()->add_elem_types());
  91. }
  92. } else if (type->isa<TypeAnything>()) {
  93. type_proto->set_data_type(debugger::DT_ANYTHING);
  94. } else if (type->isa<RefKeyType>()) {
  95. type_proto->set_data_type(debugger::DT_REFKEY);
  96. } else if (type->isa<RefType>()) {
  97. type_proto->set_data_type(debugger::DT_REF);
  98. } else if (type->isa<Function>()) {
  99. type_proto->set_data_type(debugger::DT_GRAPH);
  100. } else if (type->isa<TypeNone>()) {
  101. type_proto->set_data_type(debugger::DT_NONE);
  102. } else if (type->isa<String>()) {
  103. type_proto->set_data_type(debugger::DT_STRING);
  104. } else if (type->isa<SymbolicKeyType>()) {
  105. // Do Nothing.
  106. } else {
  107. MS_LOG(EXCEPTION) << "Unknown type: " << type->type_name();
  108. }
  109. }
  110. void DebuggerProtoExporter::SetNodeOutputType(const AnfNodePtr &node, debugger::TypeProto *type_proto) {
  111. if (node == nullptr || type_proto == nullptr) {
  112. return;
  113. }
  114. SetNodeOutputType(node->Type(), node->Shape(), type_proto);
  115. }
  116. void DebuggerProtoExporter::SetValueToProto(const ValuePtr &val, debugger::ValueProto *value_proto) {
  117. if (val == nullptr || value_proto == nullptr) {
  118. return;
  119. }
  120. if (val->isa<StringImm>()) {
  121. const StringImmPtr &value = dyn_cast<StringImm>(val);
  122. value_proto->set_dtype(debugger::DT_STRING);
  123. value_proto->set_str_val(value->value());
  124. } else if (val->isa<Scalar>()) {
  125. SetScalarToProto(dyn_cast<Scalar>(val), value_proto);
  126. } else if (val->isa<Bool>()) {
  127. value_proto->set_dtype(debugger::DT_TYPE);
  128. value_proto->mutable_type_val()->set_data_type(debugger::DT_BOOL);
  129. } else if (val->isa<Int>()) {
  130. value_proto->set_dtype(debugger::DT_TYPE);
  131. value_proto->mutable_type_val()->set_data_type(debugger::DT_BASE_INT);
  132. } else if (val->isa<Float>()) {
  133. value_proto->set_dtype(debugger::DT_TYPE);
  134. value_proto->mutable_type_val()->set_data_type(debugger::DT_BASE_FLOAT);
  135. } else if (val->isa<ValueSequeue>()) {
  136. SetSequenceToProto(dyn_cast<ValueSequeue>(val), value_proto);
  137. } else if (val->isa<None>()) {
  138. value_proto->set_dtype(debugger::DT_NONE);
  139. value_proto->set_str_val("None");
  140. } else if (val->isa<SymbolicKeyInstance>()) {
  141. SymbolicKeyInstancePtr sym_inst = dyn_cast<SymbolicKeyInstance>(val);
  142. ParameterPtr sym_node = dyn_cast<Parameter>(sym_inst->node());
  143. value_proto->set_dtype(debugger::DT_SYM_INST);
  144. value_proto->set_str_val(sym_node == nullptr ? std::string("nullptr") : sym_node->ToString());
  145. } else if (val->isa<ValueDictionary>()) {
  146. SetDictionaryToProto(dyn_cast<ValueDictionary>(val), value_proto);
  147. } else if (val->isa<tensor::Tensor>()) {
  148. tensor::TensorPtr tensor_ptr = dyn_cast<tensor::Tensor>(val);
  149. value_proto->set_dtype(debugger::DT_TENSOR);
  150. debugger::TensorProto *tensor_proto = value_proto->mutable_tensor_val();
  151. tensor_proto->set_data_type(GetDebuggerNumberDataType(tensor_ptr->Dtype()));
  152. for (auto &elem : tensor_ptr->shape()) {
  153. tensor_proto->add_dims(elem);
  154. }
  155. tensor_proto->set_tensor_content(tensor_ptr->data_c(), tensor_ptr->data().nbytes());
  156. } else if (val->isa<TensorType>()) {
  157. value_proto->set_dtype(debugger::DT_TYPE);
  158. debugger::TypeProto *type_proto = value_proto->mutable_type_val();
  159. type_proto->set_data_type(debugger::DT_TENSOR);
  160. TypePtr elem_type = dyn_cast<TensorType>(val)->element();
  161. type_proto->mutable_tensor_type()->set_elem_type(GetDebuggerNumberDataType(elem_type));
  162. } else {
  163. MS_LOG(WARNING) << "Unsupported type " << val->type_name();
  164. }
  165. }
  166. void DebuggerProtoExporter::SetScalarToProto(const ScalarPtr &val, debugger::ValueProto *value_proto) {
  167. if (val == nullptr || value_proto == nullptr) {
  168. return;
  169. }
  170. if (val->isa<BoolImm>()) {
  171. const BoolImmPtr &value = dyn_cast<BoolImm>(val);
  172. value_proto->set_dtype(debugger::DT_BOOL);
  173. value_proto->set_bool_val(value->value());
  174. } else if (val->isa<Int8Imm>()) {
  175. const Int8ImmPtr &value = dyn_cast<Int8Imm>(val);
  176. value_proto->set_dtype(debugger::DT_INT8);
  177. value_proto->set_int_val(value->value());
  178. } else if (val->isa<Int16Imm>()) {
  179. const Int16ImmPtr &value = dyn_cast<Int16Imm>(val);
  180. value_proto->set_dtype(debugger::DT_INT16);
  181. value_proto->set_int_val(value->value());
  182. } else if (val->isa<Int32Imm>()) {
  183. const Int32ImmPtr &value = dyn_cast<Int32Imm>(val);
  184. value_proto->set_dtype(debugger::DT_INT32);
  185. value_proto->set_int_val(value->value());
  186. } else if (val->isa<Int64Imm>()) {
  187. const Int64ImmPtr &value = dyn_cast<Int64Imm>(val);
  188. value_proto->set_dtype(debugger::DT_INT64);
  189. value_proto->set_int_val(value->value());
  190. } else if (val->isa<UInt8Imm>()) {
  191. const UInt8ImmPtr &value = dyn_cast<UInt8Imm>(val);
  192. value_proto->set_dtype(debugger::DT_UINT8);
  193. value_proto->set_uint_val(value->value());
  194. } else if (val->isa<UInt16Imm>()) {
  195. const UInt16ImmPtr &value = dyn_cast<UInt16Imm>(val);
  196. value_proto->set_dtype(debugger::DT_UINT16);
  197. value_proto->set_uint_val(value->value());
  198. } else if (val->isa<UInt32Imm>()) {
  199. const UInt32ImmPtr &value = dyn_cast<UInt32Imm>(val);
  200. value_proto->set_dtype(debugger::DT_UINT32);
  201. value_proto->set_uint_val(value->value());
  202. } else if (val->isa<UInt64Imm>()) {
  203. const UInt64ImmPtr &value = dyn_cast<UInt64Imm>(val);
  204. value_proto->set_dtype(debugger::DT_UINT64);
  205. value_proto->set_uint_val(value->value());
  206. } else if (val->isa<FP32Imm>()) {
  207. const FP32ImmPtr &value = dyn_cast<FP32Imm>(val);
  208. value_proto->set_dtype(debugger::DT_FLOAT32);
  209. value_proto->set_float_val(value->value());
  210. } else if (val->isa<FP64Imm>()) {
  211. const FP64ImmPtr &value = dyn_cast<FP64Imm>(val);
  212. value_proto->set_dtype(debugger::DT_FLOAT64);
  213. value_proto->set_double_val(value->value());
  214. } else {
  215. MS_LOG(EXCEPTION) << "Unknown scalar type " << val->ToString();
  216. }
  217. }
  218. void DebuggerProtoExporter::SetSequenceToProto(const ValueSequeuePtr &val, debugger::ValueProto *value_proto) {
  219. if (val == nullptr || value_proto == nullptr) {
  220. return;
  221. }
  222. if (val->isa<ValueTuple>()) {
  223. const ValueTuplePtr &value = dyn_cast<ValueTuple>(val);
  224. value_proto->set_dtype(debugger::DT_TUPLE);
  225. for (const auto &item : value->value()) {
  226. SetValueToProto(item, value_proto->add_values());
  227. }
  228. } else if (val->isa<ValueList>()) {
  229. const ValueListPtr &value = dyn_cast<ValueList>(val);
  230. value_proto->set_dtype(debugger::DT_LIST);
  231. for (const auto &item : value->value()) {
  232. SetValueToProto(item, value_proto->add_values());
  233. }
  234. }
  235. }
  236. void DebuggerProtoExporter::SetDictionaryToProto(const ValueDictionaryPtr &val, debugger::ValueProto *value_proto) {
  237. if (val == nullptr || value_proto == nullptr) {
  238. return;
  239. }
  240. value_proto->set_dtype(debugger::DT_DICT);
  241. for (const auto &item : val->value()) {
  242. debugger::NamedValueProto *named_val = value_proto->add_dict_val();
  243. named_val->set_key(item.first);
  244. SetValueToProto(item.second, named_val->mutable_value());
  245. }
  246. }
  247. void DebuggerProtoExporter::GetOpNodeTypeAndAttrs(const FuncGraphPtr &, const AnfNodePtr &node,
  248. debugger::NodeProto *node_proto) {
  249. if (node == nullptr || node_proto == nullptr) {
  250. return;
  251. }
  252. if (node->isa<CNode>() || node->isa<Parameter>() || IsValueNode<FuncGraph>(node)) {
  253. MS_LOG(EXCEPTION) << "Op node can not be CNode, Parameter or ValueNode Graph. But got " << node->ToString();
  254. }
  255. if (!IsValueNode<Primitive>(node)) {
  256. MS_LOG(EXCEPTION) << "Op node is not primitive: " << node->ToString();
  257. }
  258. const PrimitivePtr &prim = GetValueNode<PrimitivePtr>(node);
  259. node_proto->set_op_type(prim->name());
  260. for (const auto &attr : prim->attrs()) {
  261. debugger::AttributeProto *attr_proto = node_proto->add_attribute();
  262. attr_proto->set_name(attr.first);
  263. SetValueToProto(attr.second, attr_proto->mutable_value());
  264. }
  265. node_proto->set_scope(node->scope()->name());
  266. }
  267. std::string DebuggerProtoExporter::GetOpNodeInputId(const FuncGraphPtr &, const AnfNodePtr &node,
  268. const std::map<AnfNodePtr, size_t> &apply_map,
  269. std::map<AnfNodePtr, size_t> *const_map_ptr) {
  270. if (node == nullptr || const_map_ptr == nullptr) {
  271. return "";
  272. }
  273. if (node->isa<CNode>()) {
  274. auto iter = apply_map.find(node);
  275. if (iter == apply_map.end()) {
  276. MS_LOG(EXCEPTION) << "Can not find node '" << node->ToString() << "' in apply_map";
  277. }
  278. return std::to_string(iter->second);
  279. }
  280. if (node->isa<Parameter>()) {
  281. return node->ToString();
  282. }
  283. if (node->isa<ValueNode>()) {
  284. auto iter = const_map_ptr->find(node);
  285. if (iter == const_map_ptr->end()) {
  286. // Start index number from 1
  287. auto const_idx = const_map_ptr->size() + 1;
  288. (*const_map_ptr)[node] = const_idx;
  289. }
  290. return GetConstNodeId((*const_map_ptr)[node]);
  291. }
  292. MS_LOG(EXCEPTION) << "Unknown node type. node is '" << node->ToString() << "'";
  293. }
  294. std::string DebuggerProtoExporter::GetFuncGraphProtoString(const FuncGraphPtr &func_graph) {
  295. if (func_graph == nullptr) {
  296. return "";
  297. }
  298. InitModelInfo();
  299. debugger::GraphProto *graph_proto = model_.mutable_graph();
  300. ExportFuncGraph(func_graph, graph_proto);
  301. return model_.SerializeAsString();
  302. }
  303. debugger::ModelProto DebuggerProtoExporter::GetFuncGraphProto(const FuncGraphPtr &func_graph) {
  304. if (func_graph == nullptr) {
  305. return ModelProto();
  306. }
  307. InitModelInfo();
  308. debugger::GraphProto *graph_proto = model_.mutable_graph();
  309. ExportFuncGraph(func_graph, graph_proto);
  310. return model_;
  311. }
  312. void DebuggerProtoExporter::ExportFuncGraph(const FuncGraphPtr &func_graph, debugger::GraphProto *graph_proto) {
  313. if (func_graph == nullptr || graph_proto == nullptr) {
  314. return;
  315. }
  316. // map for store ValueNodes of this graph
  317. std::map<AnfNodePtr, size_t> const_map;
  318. // set graph name
  319. graph_proto->set_name(func_graph->ToString());
  320. ExportParameters(func_graph, graph_proto);
  321. ExportCNodes(func_graph, graph_proto, &const_map);
  322. ExportValueNodes(const_map, graph_proto);
  323. }
  324. void DebuggerProtoExporter::ExportParameters(const FuncGraphPtr &func_graph, debugger::GraphProto *graph_proto) {
  325. if (func_graph == nullptr || graph_proto == nullptr) {
  326. return;
  327. }
  328. // cast FuncGraph to KernelGraph to access inputs()
  329. std::vector<AnfNodePtr> parameters = static_cast<session::KernelGraph *>(func_graph.get())->inputs();
  330. for (auto &param : parameters) {
  331. debugger::ParameterProto *param_proto = graph_proto->add_parameters();
  332. param_proto->set_name(param->ToString());
  333. SetNodeOutputType(param, param_proto->mutable_type());
  334. const ParameterPtr param_ptr = dyn_cast<Parameter>(param);
  335. if (param_ptr == nullptr) {
  336. MS_LOG(EXCEPTION) << "Parameter '" << param->ToString() << "' could not cast to parameter.";
  337. }
  338. }
  339. }
  340. void DebuggerProtoExporter::ExportCNodes(const FuncGraphPtr &func_graph, debugger::GraphProto *graph_proto,
  341. std::map<AnfNodePtr, size_t> *const_map_ptr) {
  342. if (func_graph == nullptr || graph_proto == nullptr || const_map_ptr == nullptr) {
  343. return;
  344. }
  345. // topo sort nodes
  346. std::vector<AnfNodePtr> nodes = TopoSort(func_graph->get_return(), SuccIncoming, AlwaysInclude);
  347. std::map<AnfNodePtr, size_t> apply_map;
  348. for (const AnfNodePtr &node : nodes) {
  349. MS_EXCEPTION_IF_NULL(node);
  350. if (!node->isa<CNode>()) {
  351. continue;
  352. }
  353. auto cnode = node->cast<CNodePtr>();
  354. if (cnode != func_graph->get_return()) {
  355. ExportCNode(func_graph, cnode, &apply_map, const_map_ptr, graph_proto);
  356. } else {
  357. ExportFuncGraphOutput(func_graph, cnode, apply_map, const_map_ptr, graph_proto);
  358. }
  359. }
  360. }
  361. void DebuggerProtoExporter::ExportCNode(const FuncGraphPtr &func_graph, const CNodePtr &node,
  362. std::map<AnfNodePtr, size_t> *apply_map_ptr,
  363. std::map<AnfNodePtr, size_t> *const_map_ptr,
  364. debugger::GraphProto *graph_proto) {
  365. if (func_graph == nullptr || node == nullptr || apply_map_ptr == nullptr || const_map_ptr == nullptr ||
  366. graph_proto == nullptr) {
  367. return;
  368. }
  369. auto apply_idx = apply_map_ptr->size() + 1;
  370. (*apply_map_ptr)[node] = apply_idx;
  371. auto &inputs = node->inputs();
  372. if (inputs.size() < 1) {
  373. MS_LOG(EXCEPTION) << "Inputs of apply node is empty";
  374. }
  375. AnfNodePtr op = inputs[0];
  376. debugger::NodeProto *node_proto = graph_proto->add_node();
  377. // CNode/ConstGraph/Const/Parameter
  378. if (op->isa<CNode>() || IsValueNode<FuncGraph>(op) || op->isa<Parameter>()) {
  379. MS_LOG(WARNING) << "Operator must be a primitive";
  380. } else {
  381. GetOpNodeTypeAndAttrs(func_graph, op, node_proto);
  382. node_proto->set_name(std::to_string(apply_idx));
  383. node_proto->set_scope(node->scope()->name());
  384. // add full_name for debugger
  385. node_proto->set_full_name(node->fullname_with_scope());
  386. // process OP inputs
  387. for (size_t i = 1; i < inputs.size(); ++i) {
  388. debugger::InputProto *input_proto = node_proto->add_input();
  389. input_proto->set_type(debugger::InputProto_EdgeType_DATA_EDGE);
  390. std::string id = GetOpNodeInputId(func_graph, inputs[i], *apply_map_ptr, const_map_ptr);
  391. input_proto->set_name(id);
  392. }
  393. // set node output type
  394. SetNodeOutputType(node, node_proto->mutable_output_type());
  395. }
  396. }
  397. void DebuggerProtoExporter::ExportFuncGraphOutput(const FuncGraphPtr &func_graph, const CNodePtr &ret_node,
  398. const std::map<AnfNodePtr, size_t> &apply_map,
  399. std::map<AnfNodePtr, size_t> *const_map_ptr,
  400. debugger::GraphProto *graph_proto) {
  401. if (ret_node == nullptr || !ret_node->isa<CNode>()) {
  402. MS_LOG(EXCEPTION) << "Graph return node is illegal";
  403. }
  404. AnfNodePtr arg = ret_node->input(1);
  405. if (graph_proto == nullptr) {
  406. MS_LOG(EXCEPTION) << "graph_proto is nullptr";
  407. }
  408. debugger::OutputProto *output_proto = graph_proto->add_outputs();
  409. if (output_proto == nullptr) {
  410. MS_LOG(EXCEPTION) << "output_proto is nullptr";
  411. }
  412. std::string id = GetOpNodeInputId(func_graph, arg, apply_map, const_map_ptr);
  413. output_proto->set_name(id);
  414. SetNodeOutputType(arg, output_proto->mutable_type());
  415. }
  416. static bool CompareValue(const std::pair<AnfNodePtr, size_t> &x, const std::pair<AnfNodePtr, size_t> &y) {
  417. return x.second < y.second;
  418. }
  419. void DebuggerProtoExporter::ExportValueNodes(const std::map<AnfNodePtr, size_t> &const_map,
  420. debugger::GraphProto *graph_proto) {
  421. std::vector<std::pair<AnfNodePtr, size_t>> nodes;
  422. (void)std::transform(const_map.cbegin(), const_map.cend(), std::back_inserter(nodes),
  423. [](const std::pair<AnfNodePtr, size_t> &item) { return item; });
  424. sort(nodes.begin(), nodes.end(), CompareValue);
  425. for (auto &item : nodes) {
  426. if (graph_proto == nullptr) {
  427. MS_LOG(EXCEPTION) << "graph_proto is nullptr";
  428. }
  429. debugger::NamedValueProto *named_value = graph_proto->add_const_vals();
  430. MS_EXCEPTION_IF_NULL(named_value);
  431. named_value->set_key(GetConstNodeId(item.second));
  432. SetValueToProto(GetValueNode(item.first), named_value->mutable_value());
  433. }
  434. }
  435. void DebuggerProtoExporter::InitModelInfo() { model_.set_ir_version(debugger::IR_VERSION); }
  436. std::string GetDebuggerFuncGraphProtoString(const FuncGraphPtr &func_graph) {
  437. DebuggerProtoExporter exporter;
  438. return exporter.GetFuncGraphProtoString(func_graph);
  439. }
  440. debugger::ModelProto GetDebuggerFuncGraphProto(const FuncGraphPtr &func_graph) {
  441. DebuggerProtoExporter exporter;
  442. return exporter.GetFuncGraphProto(func_graph);
  443. }
  444. debugger::DataType GetDebuggerNumberDataType(const TypePtr &type) {
  445. switch (type->type_id()) {
  446. case kNumberTypeBool:
  447. return debugger::DT_BOOL;
  448. case kNumberTypeInt8:
  449. return debugger::DT_INT8;
  450. case kNumberTypeInt16:
  451. return debugger::DT_INT16;
  452. case kNumberTypeInt32:
  453. return debugger::DT_INT32;
  454. case kNumberTypeInt64:
  455. return debugger::DT_INT64;
  456. case kNumberTypeUInt8:
  457. return debugger::DT_UINT8;
  458. case kNumberTypeUInt16:
  459. return debugger::DT_UINT16;
  460. case kNumberTypeUInt32:
  461. return debugger::DT_UINT32;
  462. case kNumberTypeUInt64:
  463. return debugger::DT_UINT64;
  464. case kNumberTypeFloat16:
  465. return debugger::DT_FLOAT16;
  466. case kNumberTypeFloat32:
  467. return debugger::DT_FLOAT32;
  468. case kNumberTypeFloat64:
  469. return debugger::DT_FLOAT64;
  470. case kNumberTypeInt:
  471. return debugger::DT_BASE_INT;
  472. case kNumberTypeUInt:
  473. return debugger::DT_BASE_UINT;
  474. case kNumberTypeFloat:
  475. return debugger::DT_BASE_FLOAT;
  476. default:
  477. MS_LOG(EXCEPTION) << "Unexpected type " << type->type_name();
  478. }
  479. }
  480. } // namespace mindspore