You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

proto_exporter.cc 24 kB

5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616
  1. /**
  2. * Copyright 2020-2021 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "debug/debugger/proto_exporter.h"
  17. #include <fstream>
  18. #include <map>
  19. #include <memory>
  20. #include <utility>
  21. #include <algorithm>
  22. #include "utils/hash_map.h"
  23. #include "utils/hash_set.h"
  24. #include "debug/anf_ir_utils.h"
  25. #include "debug/data_dump/dump_utils.h"
  26. #include "debug/common.h"
  27. #include "debug/debugger/debugger.h"
  28. #include "debug/data_dump/dump_json_parser.h"
  29. #include "proto/debug_graph.pb.h"
  30. #include "ir/graph_utils.h"
  31. #include "utils/symbolic.h"
  32. #include "utils/trace_base.h"
  33. #include "debug/data_dump/e2e_dump.h"
  34. namespace mindspore {
  35. using TypeInfoToProtoTypeMap = std::vector<std::pair<uint32_t, debugger::DataType>>;
  36. void SetOutputType(const TypePtr &node, const BaseShapePtr &shape, debugger::TypeProto *type_proto);
  37. void CheckIfValidType(const TypePtr &type, debugger::TypeProto *const type_proto) {
  38. if (!(type->isa<Number>() || type->isa<TensorType>() || type->isa<Tuple>() || type->isa<TypeType>() ||
  39. type->isa<List>() || type->isa<TypeAnything>() || type->isa<RefKeyType>() || type->isa<RefType>() ||
  40. type->isa<Function>() || type->isa<TypeNone>() || type->isa<String>() || type->isa<SymbolicKeyType>() ||
  41. type->isa<UMonadType>() || type->isa<IOMonadType>())) {
  42. MS_LOG(EXCEPTION) << "Unknown type: " << type->type_name();
  43. }
  44. if (type->isa<Number>()) {
  45. type_proto->set_data_type(GetDebuggerNumberDataType(type));
  46. }
  47. }
  48. void SetTensorTypeProto(const TypePtr &type, const BaseShapePtr &shape, debugger::TypeProto *type_proto) {
  49. TypePtr elem_type = dyn_cast<TensorType>(type)->element();
  50. type_proto->mutable_tensor_type()->set_elem_type(GetDebuggerNumberDataType(elem_type));
  51. if (shape != nullptr && shape->isa<abstract::Shape>()) {
  52. abstract::ShapePtr shape_info = dyn_cast<abstract::Shape>(shape);
  53. for (const auto &elem : shape_info->shape()) {
  54. type_proto->mutable_tensor_type()->mutable_shape()->add_dim()->set_size(elem);
  55. }
  56. }
  57. }
  58. void SetTupleTypeProto(const TypePtr &type, debugger::TypeProto *type_proto) {
  59. TuplePtr tuple_type = dyn_cast<Tuple>(type);
  60. for (const auto &elem_type : tuple_type->elements()) {
  61. SetOutputType(elem_type, nullptr, type_proto->mutable_sequence_type()->add_elem_types());
  62. }
  63. }
  64. void SetListTypeProto(const TypePtr &type, debugger::TypeProto *type_proto) {
  65. ListPtr list_type = dyn_cast<List>(type);
  66. for (const auto &elem_type : list_type->elements()) {
  67. SetOutputType(elem_type, nullptr, type_proto->mutable_sequence_type()->add_elem_types());
  68. }
  69. }
  70. static TypeInfoToProtoTypeMap type_info_to_proto_type = {
  71. {TensorType::kTypeId, debugger::DT_TENSOR}, {Tuple::kTypeId, debugger::DT_TUPLE},
  72. {TypeType::kTypeId, debugger::DT_TYPE}, {List::kTypeId, debugger::DT_LIST},
  73. {TypeAnything::kTypeId, debugger::DT_ANYTHING}, {RefKeyType::kTypeId, debugger::DT_REFKEY},
  74. {RefType::kTypeId, debugger::DT_REF}, {Function::kTypeId, debugger::DT_GRAPH},
  75. {TypeNone::kTypeId, debugger::DT_NONE}, {String::kTypeId, debugger::DT_STRING},
  76. {UMonadType::kTypeId, debugger::DT_UMONAD}, {IOMonadType::kTypeId, debugger::DT_IOMONAD}};
  77. void SetOutputType(const TypePtr &type, const BaseShapePtr &shape, debugger::TypeProto *type_proto) {
  78. if (type_proto == nullptr) {
  79. return;
  80. }
  81. if (type == nullptr) {
  82. type_proto->set_data_type(debugger::DT_UNDEFINED);
  83. return;
  84. }
  85. CheckIfValidType(type, type_proto);
  86. for (auto &it : type_info_to_proto_type) {
  87. if (type->IsFromTypeId(it.first)) {
  88. type_proto->set_data_type(it.second);
  89. break;
  90. }
  91. }
  92. if (type->isa<TensorType>()) {
  93. SetTensorTypeProto(type, shape, type_proto);
  94. return;
  95. }
  96. if (type->isa<Tuple>()) {
  97. SetTupleTypeProto(type, type_proto);
  98. return;
  99. }
  100. if (type->isa<List>()) {
  101. SetListTypeProto(type, type_proto);
  102. }
  103. }
  104. void DebuggerProtoExporter::SetNodeOutputType(const AnfNodePtr &node, debugger::TypeProto *type_proto) {
  105. if (node == nullptr || type_proto == nullptr) {
  106. return;
  107. }
  108. SetOutputType(node->Type(), node->Shape(), type_proto);
  109. }
  110. void DebuggerProtoExporter::SetValueToProto(const ValuePtr &val, debugger::ValueProto *value_proto) {
  111. if (val == nullptr || value_proto == nullptr) {
  112. return;
  113. }
  114. if (val->isa<StringImm>()) {
  115. const StringImmPtr &value = dyn_cast<StringImm>(val);
  116. value_proto->set_dtype(debugger::DT_STRING);
  117. value_proto->set_str_val(value->value());
  118. } else if (val->isa<Scalar>()) {
  119. SetScalarToProto(dyn_cast<Scalar>(val), value_proto);
  120. } else if (val->isa<Bool>()) {
  121. value_proto->set_dtype(debugger::DT_TYPE);
  122. value_proto->mutable_type_val()->set_data_type(debugger::DT_BOOL);
  123. } else if (val->isa<Int>()) {
  124. value_proto->set_dtype(debugger::DT_TYPE);
  125. value_proto->mutable_type_val()->set_data_type(debugger::DT_BASE_INT);
  126. } else if (val->isa<Float>()) {
  127. value_proto->set_dtype(debugger::DT_TYPE);
  128. value_proto->mutable_type_val()->set_data_type(debugger::DT_BASE_FLOAT);
  129. } else if (val->isa<ValueSequence>()) {
  130. SetSequenceToProto(dyn_cast<ValueSequence>(val), value_proto);
  131. } else if (val->isa<None>()) {
  132. value_proto->set_dtype(debugger::DT_NONE);
  133. value_proto->set_str_val("None");
  134. } else if (val->isa<SymbolicKeyInstance>()) {
  135. SymbolicKeyInstancePtr sym_inst = dyn_cast<SymbolicKeyInstance>(val);
  136. ParameterPtr sym_node = dyn_cast<Parameter>(sym_inst->node());
  137. value_proto->set_dtype(debugger::DT_SYM_INST);
  138. value_proto->set_str_val(sym_node == nullptr ? std::string("nullptr") : sym_node->ToString());
  139. } else if (val->isa<ValueDictionary>()) {
  140. SetDictionaryToProto(dyn_cast<ValueDictionary>(val), value_proto);
  141. } else if (val->isa<tensor::Tensor>()) {
  142. tensor::TensorPtr tensor_ptr = dyn_cast<tensor::Tensor>(val);
  143. value_proto->set_dtype(debugger::DT_TENSOR);
  144. debugger::TensorProto *tensor_proto = value_proto->mutable_tensor_val();
  145. tensor_proto->set_data_type(GetDebuggerNumberDataType(tensor_ptr->Dtype()));
  146. for (auto &elem : tensor_ptr->shape()) {
  147. tensor_proto->add_dims(elem);
  148. }
  149. tensor_proto->set_tensor_content(tensor_ptr->data_c(), tensor_ptr->data().nbytes());
  150. } else if (val->isa<TensorType>()) {
  151. value_proto->set_dtype(debugger::DT_TYPE);
  152. debugger::TypeProto *type_proto = value_proto->mutable_type_val();
  153. type_proto->set_data_type(debugger::DT_TENSOR);
  154. TypePtr elem_type = dyn_cast<TensorType>(val)->element();
  155. type_proto->mutable_tensor_type()->set_elem_type(GetDebuggerNumberDataType(elem_type));
  156. } else {
  157. MS_LOG(INFO) << "Unsupported type " << val->type_name();
  158. }
  159. }
  160. void DebuggerProtoExporter::SetScalarToProto(const ScalarPtr &val, debugger::ValueProto *value_proto) {
  161. if (val == nullptr || value_proto == nullptr) {
  162. return;
  163. }
  164. if (val->isa<BoolImm>()) {
  165. const BoolImmPtr &value = dyn_cast<BoolImm>(val);
  166. value_proto->set_dtype(debugger::DT_BOOL);
  167. value_proto->set_bool_val(value->value());
  168. } else if (val->isa<Int8Imm>()) {
  169. const Int8ImmPtr &value = dyn_cast<Int8Imm>(val);
  170. value_proto->set_dtype(debugger::DT_INT8);
  171. value_proto->set_int_val(value->value());
  172. } else if (val->isa<Int16Imm>()) {
  173. const Int16ImmPtr &value = dyn_cast<Int16Imm>(val);
  174. value_proto->set_dtype(debugger::DT_INT16);
  175. value_proto->set_int_val(value->value());
  176. } else if (val->isa<Int32Imm>()) {
  177. const Int32ImmPtr &value = dyn_cast<Int32Imm>(val);
  178. value_proto->set_dtype(debugger::DT_INT32);
  179. value_proto->set_int_val(value->value());
  180. } else if (val->isa<Int64Imm>()) {
  181. const Int64ImmPtr &value = dyn_cast<Int64Imm>(val);
  182. value_proto->set_dtype(debugger::DT_INT64);
  183. value_proto->set_int_val(value->value());
  184. } else if (val->isa<UInt8Imm>()) {
  185. const UInt8ImmPtr &value = dyn_cast<UInt8Imm>(val);
  186. value_proto->set_dtype(debugger::DT_UINT8);
  187. value_proto->set_uint_val(value->value());
  188. } else if (val->isa<UInt16Imm>()) {
  189. const UInt16ImmPtr &value = dyn_cast<UInt16Imm>(val);
  190. value_proto->set_dtype(debugger::DT_UINT16);
  191. value_proto->set_uint_val(value->value());
  192. } else if (val->isa<UInt32Imm>()) {
  193. const UInt32ImmPtr &value = dyn_cast<UInt32Imm>(val);
  194. value_proto->set_dtype(debugger::DT_UINT32);
  195. value_proto->set_uint_val(value->value());
  196. } else if (val->isa<UInt64Imm>()) {
  197. const UInt64ImmPtr &value = dyn_cast<UInt64Imm>(val);
  198. value_proto->set_dtype(debugger::DT_UINT64);
  199. value_proto->set_uint_val(value->value());
  200. } else if (val->isa<FP32Imm>()) {
  201. const FP32ImmPtr &value = dyn_cast<FP32Imm>(val);
  202. value_proto->set_dtype(debugger::DT_FLOAT32);
  203. value_proto->set_float_val(value->value());
  204. } else if (val->isa<FP64Imm>()) {
  205. const FP64ImmPtr &value = dyn_cast<FP64Imm>(val);
  206. value_proto->set_dtype(debugger::DT_FLOAT64);
  207. value_proto->set_double_val(value->value());
  208. } else {
  209. MS_LOG(EXCEPTION) << "Unknown scalar type " << val->ToString();
  210. }
  211. }
  212. void DebuggerProtoExporter::SetSequenceToProto(const ValueSequencePtr &val, debugger::ValueProto *value_proto) {
  213. if (val == nullptr || value_proto == nullptr) {
  214. return;
  215. }
  216. if (val->isa<ValueTuple>()) {
  217. const ValueTuplePtr &value = dyn_cast<ValueTuple>(val);
  218. value_proto->set_dtype(debugger::DT_TUPLE);
  219. for (const auto &item : value->value()) {
  220. SetValueToProto(item, value_proto->add_values());
  221. }
  222. } else if (val->isa<ValueList>()) {
  223. const ValueListPtr &value = dyn_cast<ValueList>(val);
  224. value_proto->set_dtype(debugger::DT_LIST);
  225. for (const auto &item : value->value()) {
  226. SetValueToProto(item, value_proto->add_values());
  227. }
  228. }
  229. }
  230. void DebuggerProtoExporter::SetDictionaryToProto(const ValueDictionaryPtr &val, debugger::ValueProto *value_proto) {
  231. if (val == nullptr || value_proto == nullptr) {
  232. return;
  233. }
  234. value_proto->set_dtype(debugger::DT_DICT);
  235. for (const auto &item : val->value()) {
  236. debugger::NamedValueProto *named_val = value_proto->add_dict_val();
  237. named_val->set_key(item.first);
  238. SetValueToProto(item.second, named_val->mutable_value());
  239. }
  240. }
  241. void DebuggerProtoExporter::GetOpNodeTypeAndAttrs(const FuncGraphPtr &, const AnfNodePtr &node,
  242. debugger::NodeProto *node_proto) {
  243. if (node == nullptr || node_proto == nullptr) {
  244. return;
  245. }
  246. if (node->isa<CNode>() || node->isa<Parameter>() || IsValueNode<FuncGraph>(node)) {
  247. MS_LOG(EXCEPTION) << "Op node can not be CNode, Parameter or ValueNode Graph. But got " << node->ToString();
  248. }
  249. if (!IsValueNode<Primitive>(node)) {
  250. MS_LOG(EXCEPTION) << "Op node is not primitive: " << node->ToString();
  251. }
  252. const PrimitivePtr &prim = GetValueNode<PrimitivePtr>(node);
  253. node_proto->set_op_type(prim->name());
  254. for (const auto &attr : prim->attrs()) {
  255. debugger::AttributeProto *attr_proto = node_proto->add_attribute();
  256. attr_proto->set_name(attr.first);
  257. SetValueToProto(attr.second, attr_proto->mutable_value());
  258. }
  259. node_proto->set_scope(node->scope()->name());
  260. }
  261. std::string DebuggerProtoExporter::GetOpNodeInputId(const FuncGraphPtr &, const AnfNodePtr &node,
  262. const std::map<AnfNodePtr, size_t> &apply_map,
  263. std::map<AnfNodePtr, size_t> *const_map_ptr) {
  264. if (node == nullptr || const_map_ptr == nullptr) {
  265. return "";
  266. }
  267. if (node->isa<CNode>()) {
  268. auto iter = apply_map.find(node);
  269. if (iter == apply_map.end()) {
  270. MS_LOG(EXCEPTION) << "Can not find node '" << node->ToString() << "' in apply_map";
  271. }
  272. return std::to_string(iter->second);
  273. }
  274. if (node->isa<Parameter>()) {
  275. return node->ToString();
  276. }
  277. if (node->isa<ValueNode>()) {
  278. auto iter = const_map_ptr->find(node);
  279. if (iter == const_map_ptr->end()) {
  280. // Start index number from 1
  281. auto const_idx = const_map_ptr->size() + 1;
  282. (*const_map_ptr)[node] = const_idx;
  283. }
  284. return GetConstNodeId((*const_map_ptr)[node]);
  285. }
  286. MS_LOG(EXCEPTION) << "Unknown node type. node is '" << node->ToString() << "'";
  287. }
  288. std::string DebuggerProtoExporter::GetFuncGraphProtoString(const FuncGraphPtr &func_graph,
  289. LocDebugDumpMode dump_location) {
  290. if (func_graph == nullptr) {
  291. return "";
  292. }
  293. InitModelInfo();
  294. debugger::GraphProto *graph_proto = model_.mutable_graph();
  295. ExportFuncGraph(func_graph, graph_proto, dump_location);
  296. return model_.SerializeAsString();
  297. }
  298. debugger::ModelProto DebuggerProtoExporter::GetFuncGraphProto(const FuncGraphPtr &func_graph) {
  299. if (func_graph == nullptr) {
  300. return ModelProto();
  301. }
  302. InitModelInfo();
  303. debugger::GraphProto *graph_proto = model_.mutable_graph();
  304. ExportFuncGraph(func_graph, graph_proto);
  305. return model_;
  306. }
  307. void DebuggerProtoExporter::ExportFuncGraph(const FuncGraphPtr &func_graph, debugger::GraphProto *const graph_proto,
  308. LocDebugDumpMode dump_location) {
  309. if (func_graph == nullptr || graph_proto == nullptr) {
  310. return;
  311. }
  312. // map for store ValueNodes of this graph
  313. std::map<AnfNodePtr, size_t> const_map;
  314. // set graph name
  315. graph_proto->set_name(func_graph->ToString());
  316. MS_LOG(INFO) << "graph names: " << func_graph->ToString();
  317. // cast FuncGraph to KernelGraph to access root_graph_id()
  318. uint32_t root_graph_id = static_cast<session::KernelGraph *>(func_graph.get())->root_graph_id();
  319. MS_LOG(INFO) << "root graph id: " << root_graph_id;
  320. // set root graph id
  321. graph_proto->set_root_name(std::to_string(root_graph_id));
  322. ExportParameters(func_graph, graph_proto);
  323. ExportCNodes(func_graph, graph_proto, &const_map, dump_location);
  324. ExportValueNodes(const_map, graph_proto);
  325. }
  326. void DebuggerProtoExporter::ExportParameters(const FuncGraphPtr &func_graph, debugger::GraphProto *graph_proto) {
  327. if (func_graph == nullptr || graph_proto == nullptr) {
  328. return;
  329. }
  330. // cast FuncGraph to KernelGraph to access inputs()
  331. std::vector<AnfNodePtr> parameters = static_cast<session::KernelGraph *>(func_graph.get())->inputs();
  332. for (auto &param : parameters) {
  333. debugger::ParameterProto *param_proto = graph_proto->add_parameters();
  334. param_proto->set_name(param->ToString());
  335. SetNodeOutputType(param, param_proto->mutable_type());
  336. const ParameterPtr param_ptr = dyn_cast<Parameter>(param);
  337. if (param_ptr == nullptr) {
  338. MS_LOG(INFO) << "Parameter '" << param->ToString() << "' could not cast to parameter.";
  339. }
  340. }
  341. }
  342. void DebuggerProtoExporter::ExportCNodes(const FuncGraphPtr &func_graph, debugger::GraphProto *const graph_proto,
  343. std::map<AnfNodePtr, size_t> *const_map_ptr, LocDebugDumpMode dump_location) {
  344. if (func_graph == nullptr || graph_proto == nullptr || const_map_ptr == nullptr) {
  345. return;
  346. }
  347. // topo sort nodes
  348. std::vector<AnfNodePtr> nodes = TopoSort(func_graph->get_return(), SuccIncoming, AlwaysInclude);
  349. std::map<AnfNodePtr, size_t> apply_map;
  350. for (const AnfNodePtr &node : nodes) {
  351. MS_EXCEPTION_IF_NULL(node);
  352. if (!node->isa<CNode>()) {
  353. continue;
  354. }
  355. auto cnode = node->cast<CNodePtr>();
  356. if (cnode != func_graph->get_return()) {
  357. ExportCNode(func_graph, cnode, &apply_map, const_map_ptr, graph_proto, dump_location);
  358. } else {
  359. ExportFuncGraphOutput(func_graph, cnode, apply_map, const_map_ptr, graph_proto);
  360. }
  361. }
  362. }
  363. void DebuggerProtoExporter::ExportCNode(const FuncGraphPtr &func_graph, const CNodePtr &node,
  364. std::map<AnfNodePtr, size_t> *apply_map_ptr,
  365. std::map<AnfNodePtr, size_t> *const_map_ptr,
  366. debugger::GraphProto *const graph_proto, LocDebugDumpMode dump_location) {
  367. if (func_graph == nullptr || node == nullptr || apply_map_ptr == nullptr || const_map_ptr == nullptr ||
  368. graph_proto == nullptr) {
  369. return;
  370. }
  371. auto apply_idx = apply_map_ptr->size() + 1;
  372. (*apply_map_ptr)[node] = apply_idx;
  373. auto &inputs = node->inputs();
  374. if (inputs.size() < 1) {
  375. MS_LOG(EXCEPTION) << "Inputs of apply node is empty";
  376. }
  377. AnfNodePtr op = inputs[0];
  378. debugger::NodeProto *node_proto = graph_proto->add_node();
  379. // CNode/ConstGraph/Const/Parameter
  380. if (op->isa<CNode>() || IsValueNode<FuncGraph>(op) || op->isa<Parameter>()) {
  381. MS_LOG(WARNING) << "Operator must be a primitive";
  382. } else {
  383. GetOpNodeTypeAndAttrs(func_graph, op, node_proto);
  384. node_proto->set_name(std::to_string(apply_idx));
  385. node_proto->set_scope(node->scope()->name());
  386. // add full_name for debugger
  387. std::string full_name = GetKernelNodeName(node);
  388. node_proto->set_full_name(full_name);
  389. MS_LOG(INFO) << "full_name: " << full_name;
  390. if (dump_location == kDebugWholeStack) {
  391. std::ostringstream buffer;
  392. auto traces = mindspore::trace::GetSourceLineList(node);
  393. for (auto &trace : traces) {
  394. buffer << " # " << trace;
  395. }
  396. node_proto->set_source_address(buffer.str());
  397. }
  398. // process OP inputs
  399. for (size_t i = 1; i < inputs.size(); ++i) {
  400. debugger::InputProto *input_proto = node_proto->add_input();
  401. input_proto->set_type(debugger::InputProto_EdgeType_DATA_EDGE);
  402. std::string id = GetOpNodeInputId(func_graph, inputs[i], *apply_map_ptr, const_map_ptr);
  403. input_proto->set_name(id);
  404. }
  405. // set node output type
  406. SetNodeOutputType(node, node_proto->mutable_output_type());
  407. }
  408. }
  409. void DebuggerProtoExporter::ExportFuncGraphOutput(const FuncGraphPtr &func_graph, const CNodePtr &ret_node,
  410. const std::map<AnfNodePtr, size_t> &apply_map,
  411. std::map<AnfNodePtr, size_t> *const_map_ptr,
  412. debugger::GraphProto *graph_proto) {
  413. if (ret_node == nullptr || !ret_node->isa<CNode>()) {
  414. MS_LOG(EXCEPTION) << "Graph return node is illegal";
  415. }
  416. AnfNodePtr arg = ret_node->input(1);
  417. if (graph_proto == nullptr) {
  418. MS_LOG(EXCEPTION) << "graph_proto is nullptr";
  419. }
  420. debugger::OutputProto *output_proto = graph_proto->add_outputs();
  421. if (output_proto == nullptr) {
  422. MS_LOG(EXCEPTION) << "output_proto is nullptr";
  423. }
  424. std::string id = GetOpNodeInputId(func_graph, arg, apply_map, const_map_ptr);
  425. output_proto->set_name(id);
  426. SetNodeOutputType(arg, output_proto->mutable_type());
  427. }
  428. static bool CompareValue(const std::pair<AnfNodePtr, size_t> &x, const std::pair<AnfNodePtr, size_t> &y) {
  429. return x.second < y.second;
  430. }
  431. void DebuggerProtoExporter::ExportValueNodes(const std::map<AnfNodePtr, size_t> &const_map,
  432. debugger::GraphProto *graph_proto) {
  433. std::vector<std::pair<AnfNodePtr, size_t>> nodes;
  434. (void)std::transform(const_map.cbegin(), const_map.cend(), std::back_inserter(nodes),
  435. [](const std::pair<AnfNodePtr, size_t> &item) { return item; });
  436. sort(nodes.begin(), nodes.end(), CompareValue);
  437. for (auto &item : nodes) {
  438. if (graph_proto == nullptr) {
  439. MS_LOG(EXCEPTION) << "graph_proto is nullptr";
  440. }
  441. debugger::NamedValueProto *named_value = graph_proto->add_const_vals();
  442. MS_EXCEPTION_IF_NULL(named_value);
  443. named_value->set_key(GetConstNodeId(item.second));
  444. // cst full name: Default--data-x
  445. std::string node_name = GetKernelNodeName(item.first);
  446. GetFileKernelName(NOT_NULL(&node_name));
  447. named_value->set_full_name(node_name);
  448. if (GetValueNode(item.first)->isa<tensor::Tensor>()) {
  449. continue;
  450. }
  451. SetValueToProto(GetValueNode(item.first), named_value->mutable_value());
  452. }
  453. }
  454. void DebuggerProtoExporter::InitModelInfo() { model_.set_ir_version(static_cast<int64_t>(debugger::IR_VERSION)); }
  455. debugger::ModelProto GetDebuggerFuncGraphProto(const FuncGraphPtr &func_graph) {
  456. DebuggerProtoExporter exporter;
  457. return exporter.GetFuncGraphProto(func_graph);
  458. }
  459. debugger::DataType GetDebuggerNumberDataType(const TypePtr &type) {
  460. switch (type->type_id()) {
  461. case kNumberTypeBool:
  462. return debugger::DT_BOOL;
  463. case kNumberTypeInt8:
  464. return debugger::DT_INT8;
  465. case kNumberTypeInt16:
  466. return debugger::DT_INT16;
  467. case kNumberTypeInt32:
  468. return debugger::DT_INT32;
  469. case kNumberTypeInt64:
  470. return debugger::DT_INT64;
  471. case kNumberTypeUInt8:
  472. return debugger::DT_UINT8;
  473. case kNumberTypeUInt16:
  474. return debugger::DT_UINT16;
  475. case kNumberTypeUInt32:
  476. return debugger::DT_UINT32;
  477. case kNumberTypeUInt64:
  478. return debugger::DT_UINT64;
  479. case kNumberTypeFloat16:
  480. return debugger::DT_FLOAT16;
  481. case kNumberTypeFloat32:
  482. return debugger::DT_FLOAT32;
  483. case kNumberTypeFloat64:
  484. return debugger::DT_FLOAT64;
  485. case kNumberTypeInt:
  486. return debugger::DT_BASE_INT;
  487. case kNumberTypeUInt:
  488. return debugger::DT_BASE_UINT;
  489. case kNumberTypeFloat:
  490. return debugger::DT_BASE_FLOAT;
  491. default:
  492. MS_LOG(EXCEPTION) << "Unexpected type " << type->type_name();
  493. }
  494. }
  495. #ifdef ENABLE_DUMP_IR
  496. void DumpIRProtoWithSrcInfo(const FuncGraphPtr &func_graph, const std::string &suffix, const std::string &target_dir,
  497. LocDebugDumpMode dump_location) {
  498. DebuggerProtoExporter exporter;
  499. std::string graph_proto = exporter.GetFuncGraphProtoString(func_graph, dump_location);
  500. if (func_graph == nullptr) {
  501. MS_LOG(ERROR) << "Func graph is nullptr";
  502. return;
  503. }
  504. std::string file_path = target_dir + "/" + "ms_output_" + suffix + ".pb";
  505. auto realpath = Common::CreatePrefixPath(file_path);
  506. if (!realpath.has_value()) {
  507. MS_LOG(ERROR) << "Get real path failed, path=" << file_path;
  508. return;
  509. }
  510. ChangeFileMode(realpath.value(), S_IWUSR);
  511. // write to pb file
  512. std::ofstream ofs(realpath.value());
  513. if (!ofs.is_open()) {
  514. MS_LOG(ERROR) << "Open file '" << realpath.value() << "' failed!" << ErrnoToString(errno);
  515. return;
  516. }
  517. ofs << graph_proto;
  518. ofs.close();
  519. // set file mode to read only by user
  520. ChangeFileMode(file_path, S_IRUSR);
  521. }
  522. void DumpConstantInfo(const KernelGraphPtr &graph, const std::string &target_dir) {
  523. // Dump constant to npy file
  524. MS_LOG(INFO) << "Start e2e dump Const values";
  525. E2eDump::DumpConstantData(graph.get(), target_dir);
  526. }
  527. #else
  528. void DumpIRProtoWithSrcInfo(const FuncGraphPtr &, const std::string &, const std::string &, LocDebugDumpMode) {
  529. static bool already_printed = false;
  530. if (already_printed) {
  531. return;
  532. }
  533. already_printed = true;
  534. MS_LOG(WARNING) << "The functionality of dumping function graph IR in protobuf format is disabled,"
  535. << "because ENABLE_DEBUGGER option is off"
  536. << "please recompile source to enable it. See help of building script.";
  537. }
  538. void DumpConstantInfo(const KernelGraphPtr &, const std::string &) {
  539. static bool already_printed = false;
  540. if (already_printed) {
  541. return;
  542. }
  543. already_printed = true;
  544. MS_LOG(WARNING) << "The functionality of dumping function graph constant is disabled, "
  545. << "please recompile source to enable it. See help of building script.";
  546. }
  547. #endif
  548. } // namespace mindspore