/** * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include #include #include #include #include #include #include #include "debug/debugger/debugger.h" #include "proto/debug_graph.pb.h" #include "ir/graph_utils.h" #include "utils/symbolic.h" namespace mindspore { class DebuggerProtoExporter { public: DebuggerProtoExporter() {} ~DebuggerProtoExporter() {} std::string GetFuncGraphProtoString(const FuncGraphPtr &func_graph); debugger::ModelProto GetFuncGraphProto(const FuncGraphPtr &func_graph); private: void InitModelInfo(); void GetOpNodeTypeAndAttrs(const FuncGraphPtr &func_graph, const AnfNodePtr &node, debugger::NodeProto *node_proto); std::string GetOpNodeInputId(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const std::map &apply_map, std::map *const_map_ptr); void SetValueToProto(const ValuePtr &attr_value, debugger::ValueProto *value_proto); void SetScalarToProto(const ScalarPtr &val, debugger::ValueProto *value_proto); void SetSequenceToProto(const ValueSequeuePtr &val, debugger::ValueProto *value_proto); void SetDictionaryToProto(const ValueDictionaryPtr &val, debugger::ValueProto *value_proto); void SetNodeOutputType(const AnfNodePtr &node, debugger::TypeProto *type_proto); void SetNodeOutputType(const TypePtr &node, const BaseShapePtr &shape, debugger::TypeProto *type_proto); void ExportFuncGraph(const FuncGraphPtr &func_graph, debugger::GraphProto *graph_proto); void ExportParameters(const FuncGraphPtr &func_graph, debugger::GraphProto *graph_proto); void ExportCNodes(const FuncGraphPtr &func_graph, debugger::GraphProto *graph_proto, std::map *const_map_ptr); void ExportCNode(const FuncGraphPtr &func_graph, const CNodePtr &node, std::map *apply_map_ptr, std::map *const_map_ptr, debugger::GraphProto *graph_proto); void ExportFuncGraphOutput(const FuncGraphPtr &func_graph, const CNodePtr &ret_node, const std::map &apply_map, std::map *const_map_ptr, debugger::GraphProto *graph_proto); void ExportValueNodes(const std::map &const_map, debugger::GraphProto *graph_proto); static std::string GetConstNodeId(size_t idx) { return std::string("cst") + std::to_string(idx); } debugger::ModelProto model_; }; void DebuggerProtoExporter::SetNodeOutputType(const TypePtr &type, const BaseShapePtr &shape, debugger::TypeProto *type_proto) { if (type_proto == nullptr) { return; } if (type == nullptr) { type_proto->set_data_type(debugger::DT_UNDEFINED); } else if (type->isa()) { type_proto->set_data_type(GetDebuggerNumberDataType(type)); } else if (type->isa()) { TypePtr elem_type = dyn_cast(type)->element(); type_proto->mutable_tensor_type()->set_elem_type(GetDebuggerNumberDataType(elem_type)); type_proto->set_data_type(debugger::DT_TENSOR); if (shape != nullptr && shape->isa()) { abstract::ShapePtr shape_info = dyn_cast(shape); for (const auto &elem : shape_info->shape()) { type_proto->mutable_tensor_type()->mutable_shape()->add_dim()->set_size(elem); } } } else if (type->isa()) { TuplePtr tuple_type = dyn_cast(type); type_proto->set_data_type(debugger::DT_TUPLE); for (const auto &elem_type : tuple_type->elements()) { SetNodeOutputType(elem_type, nullptr, type_proto->mutable_sequence_type()->add_elem_types()); } } else if (type->isa()) { type_proto->set_data_type(debugger::DT_TYPE); } else if (type->isa()) { ListPtr list_type = dyn_cast(type); type_proto->set_data_type(debugger::DT_LIST); for (const auto &elem_type : list_type->elements()) { SetNodeOutputType(elem_type, nullptr, type_proto->mutable_sequence_type()->add_elem_types()); } } else if (type->isa()) { type_proto->set_data_type(debugger::DT_ANYTHING); } else if (type->isa()) { type_proto->set_data_type(debugger::DT_REFKEY); } else if (type->isa()) { type_proto->set_data_type(debugger::DT_REF); } else if (type->isa()) { type_proto->set_data_type(debugger::DT_GRAPH); } else if (type->isa()) { type_proto->set_data_type(debugger::DT_NONE); } else if (type->isa()) { type_proto->set_data_type(debugger::DT_STRING); } else if (type->isa()) { // Do Nothing. } else { MS_LOG(EXCEPTION) << "Unknown type: " << type->type_name(); } } void DebuggerProtoExporter::SetNodeOutputType(const AnfNodePtr &node, debugger::TypeProto *type_proto) { if (node == nullptr || type_proto == nullptr) { return; } SetNodeOutputType(node->Type(), node->Shape(), type_proto); } void DebuggerProtoExporter::SetValueToProto(const ValuePtr &val, debugger::ValueProto *value_proto) { if (val == nullptr || value_proto == nullptr) { return; } if (val->isa()) { const StringImmPtr &value = dyn_cast(val); value_proto->set_dtype(debugger::DT_STRING); value_proto->set_str_val(value->value()); } else if (val->isa()) { SetScalarToProto(dyn_cast(val), value_proto); } else if (val->isa()) { value_proto->set_dtype(debugger::DT_TYPE); value_proto->mutable_type_val()->set_data_type(debugger::DT_BOOL); } else if (val->isa()) { value_proto->set_dtype(debugger::DT_TYPE); value_proto->mutable_type_val()->set_data_type(debugger::DT_BASE_INT); } else if (val->isa()) { value_proto->set_dtype(debugger::DT_TYPE); value_proto->mutable_type_val()->set_data_type(debugger::DT_BASE_FLOAT); } else if (val->isa()) { SetSequenceToProto(dyn_cast(val), value_proto); } else if (val->isa()) { value_proto->set_dtype(debugger::DT_NONE); value_proto->set_str_val("None"); } else if (val->isa()) { SymbolicKeyInstancePtr sym_inst = dyn_cast(val); ParameterPtr sym_node = dyn_cast(sym_inst->node()); value_proto->set_dtype(debugger::DT_SYM_INST); value_proto->set_str_val(sym_node == nullptr ? std::string("nullptr") : sym_node->ToString()); } else if (val->isa()) { SetDictionaryToProto(dyn_cast(val), value_proto); } else if (val->isa()) { tensor::TensorPtr tensor_ptr = dyn_cast(val); value_proto->set_dtype(debugger::DT_TENSOR); debugger::TensorProto *tensor_proto = value_proto->mutable_tensor_val(); tensor_proto->set_data_type(GetDebuggerNumberDataType(tensor_ptr->Dtype())); for (auto &elem : tensor_ptr->shape()) { tensor_proto->add_dims(elem); } tensor_proto->set_tensor_content(tensor_ptr->data_c(), tensor_ptr->data().nbytes()); } else if (val->isa()) { value_proto->set_dtype(debugger::DT_TYPE); debugger::TypeProto *type_proto = value_proto->mutable_type_val(); type_proto->set_data_type(debugger::DT_TENSOR); TypePtr elem_type = dyn_cast(val)->element(); type_proto->mutable_tensor_type()->set_elem_type(GetDebuggerNumberDataType(elem_type)); } else { MS_LOG(WARNING) << "Unsupported type " << val->type_name(); } } void DebuggerProtoExporter::SetScalarToProto(const ScalarPtr &val, debugger::ValueProto *value_proto) { if (val == nullptr || value_proto == nullptr) { return; } if (val->isa()) { const BoolImmPtr &value = dyn_cast(val); value_proto->set_dtype(debugger::DT_BOOL); value_proto->set_bool_val(value->value()); } else if (val->isa()) { const Int8ImmPtr &value = dyn_cast(val); value_proto->set_dtype(debugger::DT_INT8); value_proto->set_int_val(value->value()); } else if (val->isa()) { const Int16ImmPtr &value = dyn_cast(val); value_proto->set_dtype(debugger::DT_INT16); value_proto->set_int_val(value->value()); } else if (val->isa()) { const Int32ImmPtr &value = dyn_cast(val); value_proto->set_dtype(debugger::DT_INT32); value_proto->set_int_val(value->value()); } else if (val->isa()) { const Int64ImmPtr &value = dyn_cast(val); value_proto->set_dtype(debugger::DT_INT64); value_proto->set_int_val(value->value()); } else if (val->isa()) { const UInt8ImmPtr &value = dyn_cast(val); value_proto->set_dtype(debugger::DT_UINT8); value_proto->set_uint_val(value->value()); } else if (val->isa()) { const UInt16ImmPtr &value = dyn_cast(val); value_proto->set_dtype(debugger::DT_UINT16); value_proto->set_uint_val(value->value()); } else if (val->isa()) { const UInt32ImmPtr &value = dyn_cast(val); value_proto->set_dtype(debugger::DT_UINT32); value_proto->set_uint_val(value->value()); } else if (val->isa()) { const UInt64ImmPtr &value = dyn_cast(val); value_proto->set_dtype(debugger::DT_UINT64); value_proto->set_uint_val(value->value()); } else if (val->isa()) { const FP32ImmPtr &value = dyn_cast(val); value_proto->set_dtype(debugger::DT_FLOAT32); value_proto->set_float_val(value->value()); } else if (val->isa()) { const FP64ImmPtr &value = dyn_cast(val); value_proto->set_dtype(debugger::DT_FLOAT64); value_proto->set_double_val(value->value()); } else { MS_LOG(EXCEPTION) << "Unknown scalar type " << val->ToString(); } } void DebuggerProtoExporter::SetSequenceToProto(const ValueSequeuePtr &val, debugger::ValueProto *value_proto) { if (val == nullptr || value_proto == nullptr) { return; } if (val->isa()) { const ValueTuplePtr &value = dyn_cast(val); value_proto->set_dtype(debugger::DT_TUPLE); for (const auto &item : value->value()) { SetValueToProto(item, value_proto->add_values()); } } else if (val->isa()) { const ValueListPtr &value = dyn_cast(val); value_proto->set_dtype(debugger::DT_LIST); for (const auto &item : value->value()) { SetValueToProto(item, value_proto->add_values()); } } } void DebuggerProtoExporter::SetDictionaryToProto(const ValueDictionaryPtr &val, debugger::ValueProto *value_proto) { if (val == nullptr || value_proto == nullptr) { return; } value_proto->set_dtype(debugger::DT_DICT); for (const auto &item : val->value()) { debugger::NamedValueProto *named_val = value_proto->add_dict_val(); named_val->set_key(item.first); SetValueToProto(item.second, named_val->mutable_value()); } } void DebuggerProtoExporter::GetOpNodeTypeAndAttrs(const FuncGraphPtr &, const AnfNodePtr &node, debugger::NodeProto *node_proto) { if (node == nullptr || node_proto == nullptr) { return; } if (node->isa() || node->isa() || IsValueNode(node)) { MS_LOG(EXCEPTION) << "Op node can not be CNode, Parameter or ValueNode Graph. But got " << node->ToString(); } if (!IsValueNode(node)) { MS_LOG(EXCEPTION) << "Op node is not primitive: " << node->ToString(); } const PrimitivePtr &prim = GetValueNode(node); node_proto->set_op_type(prim->name()); for (const auto &attr : prim->attrs()) { debugger::AttributeProto *attr_proto = node_proto->add_attribute(); attr_proto->set_name(attr.first); SetValueToProto(attr.second, attr_proto->mutable_value()); } node_proto->set_scope(node->scope()->name()); } std::string DebuggerProtoExporter::GetOpNodeInputId(const FuncGraphPtr &, const AnfNodePtr &node, const std::map &apply_map, std::map *const_map_ptr) { if (node == nullptr || const_map_ptr == nullptr) { return ""; } if (node->isa()) { auto iter = apply_map.find(node); if (iter == apply_map.end()) { MS_LOG(EXCEPTION) << "Can not find node '" << node->ToString() << "' in apply_map"; } return std::to_string(iter->second); } if (node->isa()) { return node->ToString(); } if (node->isa()) { auto iter = const_map_ptr->find(node); if (iter == const_map_ptr->end()) { // Start index number from 1 auto const_idx = const_map_ptr->size() + 1; (*const_map_ptr)[node] = const_idx; } return GetConstNodeId((*const_map_ptr)[node]); } MS_LOG(EXCEPTION) << "Unknown node type. node is '" << node->ToString() << "'"; } std::string DebuggerProtoExporter::GetFuncGraphProtoString(const FuncGraphPtr &func_graph) { if (func_graph == nullptr) { return ""; } InitModelInfo(); debugger::GraphProto *graph_proto = model_.mutable_graph(); ExportFuncGraph(func_graph, graph_proto); return model_.SerializeAsString(); } debugger::ModelProto DebuggerProtoExporter::GetFuncGraphProto(const FuncGraphPtr &func_graph) { if (func_graph == nullptr) { return ModelProto(); } InitModelInfo(); debugger::GraphProto *graph_proto = model_.mutable_graph(); ExportFuncGraph(func_graph, graph_proto); return model_; } void DebuggerProtoExporter::ExportFuncGraph(const FuncGraphPtr &func_graph, debugger::GraphProto *graph_proto) { if (func_graph == nullptr || graph_proto == nullptr) { return; } // map for store ValueNodes of this graph std::map const_map; // set graph name graph_proto->set_name(func_graph->ToString()); ExportParameters(func_graph, graph_proto); ExportCNodes(func_graph, graph_proto, &const_map); ExportValueNodes(const_map, graph_proto); } void DebuggerProtoExporter::ExportParameters(const FuncGraphPtr &func_graph, debugger::GraphProto *graph_proto) { if (func_graph == nullptr || graph_proto == nullptr) { return; } // cast FuncGraph to KernelGraph to access inputs() std::vector parameters = static_cast(func_graph.get())->inputs(); for (auto ¶m : parameters) { debugger::ParameterProto *param_proto = graph_proto->add_parameters(); param_proto->set_name(param->ToString()); SetNodeOutputType(param, param_proto->mutable_type()); const ParameterPtr param_ptr = dyn_cast(param); if (param_ptr == nullptr) { MS_LOG(EXCEPTION) << "Parameter '" << param->ToString() << "' could not cast to parameter."; } } } void DebuggerProtoExporter::ExportCNodes(const FuncGraphPtr &func_graph, debugger::GraphProto *graph_proto, std::map *const_map_ptr) { if (func_graph == nullptr || graph_proto == nullptr || const_map_ptr == nullptr) { return; } // topo sort nodes std::vector nodes = TopoSort(func_graph->get_return(), SuccIncoming, AlwaysInclude); std::map apply_map; for (const AnfNodePtr &node : nodes) { MS_EXCEPTION_IF_NULL(node); if (!node->isa()) { continue; } auto cnode = node->cast(); if (cnode != func_graph->get_return()) { ExportCNode(func_graph, cnode, &apply_map, const_map_ptr, graph_proto); } else { ExportFuncGraphOutput(func_graph, cnode, apply_map, const_map_ptr, graph_proto); } } } void DebuggerProtoExporter::ExportCNode(const FuncGraphPtr &func_graph, const CNodePtr &node, std::map *apply_map_ptr, std::map *const_map_ptr, debugger::GraphProto *graph_proto) { if (func_graph == nullptr || node == nullptr || apply_map_ptr == nullptr || const_map_ptr == nullptr || graph_proto == nullptr) { return; } auto apply_idx = apply_map_ptr->size() + 1; (*apply_map_ptr)[node] = apply_idx; auto &inputs = node->inputs(); if (inputs.size() < 1) { MS_LOG(EXCEPTION) << "Inputs of apply node is empty"; } AnfNodePtr op = inputs[0]; debugger::NodeProto *node_proto = graph_proto->add_node(); // CNode/ConstGraph/Const/Parameter if (op->isa() || IsValueNode(op) || op->isa()) { MS_LOG(WARNING) << "Operator must be a primitive"; } else { GetOpNodeTypeAndAttrs(func_graph, op, node_proto); node_proto->set_name(std::to_string(apply_idx)); node_proto->set_scope(node->scope()->name()); // add full_name for debugger node_proto->set_full_name(node->fullname_with_scope()); // process OP inputs for (size_t i = 1; i < inputs.size(); ++i) { debugger::InputProto *input_proto = node_proto->add_input(); input_proto->set_type(debugger::InputProto_EdgeType_DATA_EDGE); std::string id = GetOpNodeInputId(func_graph, inputs[i], *apply_map_ptr, const_map_ptr); input_proto->set_name(id); } // set node output type SetNodeOutputType(node, node_proto->mutable_output_type()); } } void DebuggerProtoExporter::ExportFuncGraphOutput(const FuncGraphPtr &func_graph, const CNodePtr &ret_node, const std::map &apply_map, std::map *const_map_ptr, debugger::GraphProto *graph_proto) { if (ret_node == nullptr || !ret_node->isa()) { MS_LOG(EXCEPTION) << "Graph return node is illegal"; } AnfNodePtr arg = ret_node->input(1); if (graph_proto == nullptr) { MS_LOG(EXCEPTION) << "graph_proto is nullptr"; } debugger::OutputProto *output_proto = graph_proto->add_outputs(); if (output_proto == nullptr) { MS_LOG(EXCEPTION) << "output_proto is nullptr"; } std::string id = GetOpNodeInputId(func_graph, arg, apply_map, const_map_ptr); output_proto->set_name(id); SetNodeOutputType(arg, output_proto->mutable_type()); } static bool CompareValue(const std::pair &x, const std::pair &y) { return x.second < y.second; } void DebuggerProtoExporter::ExportValueNodes(const std::map &const_map, debugger::GraphProto *graph_proto) { std::vector> nodes; (void)std::transform(const_map.cbegin(), const_map.cend(), std::back_inserter(nodes), [](const std::pair &item) { return item; }); sort(nodes.begin(), nodes.end(), CompareValue); for (auto &item : nodes) { if (graph_proto == nullptr) { MS_LOG(EXCEPTION) << "graph_proto is nullptr"; } debugger::NamedValueProto *named_value = graph_proto->add_const_vals(); MS_EXCEPTION_IF_NULL(named_value); named_value->set_key(GetConstNodeId(item.second)); SetValueToProto(GetValueNode(item.first), named_value->mutable_value()); } } void DebuggerProtoExporter::InitModelInfo() { model_.set_ir_version(debugger::IR_VERSION); } std::string GetDebuggerFuncGraphProtoString(const FuncGraphPtr &func_graph) { DebuggerProtoExporter exporter; return exporter.GetFuncGraphProtoString(func_graph); } debugger::ModelProto GetDebuggerFuncGraphProto(const FuncGraphPtr &func_graph) { DebuggerProtoExporter exporter; return exporter.GetFuncGraphProto(func_graph); } debugger::DataType GetDebuggerNumberDataType(const TypePtr &type) { switch (type->type_id()) { case kNumberTypeBool: return debugger::DT_BOOL; case kNumberTypeInt8: return debugger::DT_INT8; case kNumberTypeInt16: return debugger::DT_INT16; case kNumberTypeInt32: return debugger::DT_INT32; case kNumberTypeInt64: return debugger::DT_INT64; case kNumberTypeUInt8: return debugger::DT_UINT8; case kNumberTypeUInt16: return debugger::DT_UINT16; case kNumberTypeUInt32: return debugger::DT_UINT32; case kNumberTypeUInt64: return debugger::DT_UINT64; case kNumberTypeFloat16: return debugger::DT_FLOAT16; case kNumberTypeFloat32: return debugger::DT_FLOAT32; case kNumberTypeFloat64: return debugger::DT_FLOAT64; case kNumberTypeInt: return debugger::DT_BASE_INT; case kNumberTypeUInt: return debugger::DT_BASE_UINT; case kNumberTypeFloat: return debugger::DT_BASE_FLOAT; default: MS_LOG(EXCEPTION) << "Unexpected type " << type->type_name(); } } } // namespace mindspore