|
|
|
@@ -33,7 +33,11 @@ |
|
|
|
|
|
|
|
namespace mindspore { |
|
|
|
|
|
|
|
void CheckIfValidType(const TypePtr &type, debugger::TypeProto *type_proto) { |
|
|
|
using TypeInfoToProtoTypeMap = std::vector<std::pair<const char *, debugger::DataType>>; |
|
|
|
|
|
|
|
void SetOutputType(const TypePtr &node, const BaseShapePtr &shape, debugger::TypeProto *type_proto); |
|
|
|
|
|
|
|
void CheckIfValidType(const TypePtr &type, debugger::TypeProto *const type_proto) { |
|
|
|
if (!(type->isa<Number>() || type->isa<TensorType>() || type->isa<Tuple>() || type->isa<TypeType>() || |
|
|
|
type->isa<List>() || type->isa<TypeAnything>() || type->isa<RefKeyType>() || type->isa<RefType>() || |
|
|
|
type->isa<Function>() || type->isa<TypeNone>() || type->isa<String>() || type->isa<SymbolicKeyType>() || |
|
|
|
@@ -45,8 +49,40 @@ void CheckIfValidType(const TypePtr &type, debugger::TypeProto *type_proto) { |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
void DebuggerProtoExporter::SetNodeOutputType(const TypePtr &type, const BaseShapePtr &shape, |
|
|
|
debugger::TypeProto *type_proto) { |
|
|
|
void SetTensorTypeProto(const TypePtr &type, const BaseShapePtr &shape, debugger::TypeProto *type_proto) { |
|
|
|
TypePtr elem_type = dyn_cast<TensorType>(type)->element(); |
|
|
|
type_proto->mutable_tensor_type()->set_elem_type(GetDebuggerNumberDataType(elem_type)); |
|
|
|
if (shape != nullptr && shape->isa<abstract::Shape>()) { |
|
|
|
abstract::ShapePtr shape_info = dyn_cast<abstract::Shape>(shape); |
|
|
|
for (const auto &elem : shape_info->shape()) { |
|
|
|
type_proto->mutable_tensor_type()->mutable_shape()->add_dim()->set_size(elem); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
void SetTupleTypeProto(const TypePtr &type, debugger::TypeProto *type_proto) { |
|
|
|
TuplePtr tuple_type = dyn_cast<Tuple>(type); |
|
|
|
for (const auto &elem_type : tuple_type->elements()) { |
|
|
|
SetOutputType(elem_type, nullptr, type_proto->mutable_sequence_type()->add_elem_types()); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
void SetListTypeProto(const TypePtr &type, debugger::TypeProto *type_proto) { |
|
|
|
ListPtr list_type = dyn_cast<List>(type); |
|
|
|
for (const auto &elem_type : list_type->elements()) { |
|
|
|
SetOutputType(elem_type, nullptr, type_proto->mutable_sequence_type()->add_elem_types()); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
static TypeInfoToProtoTypeMap type_info_to_proto_type = { |
|
|
|
{typeid(TensorType).name(), debugger::DT_TENSOR}, {typeid(Tuple).name(), debugger::DT_TUPLE}, |
|
|
|
{typeid(TypeType).name(), debugger::DT_TYPE}, {typeid(List).name(), debugger::DT_LIST}, |
|
|
|
{typeid(TypeAnything).name(), debugger::DT_ANYTHING}, {typeid(RefKeyType).name(), debugger::DT_REFKEY}, |
|
|
|
{typeid(RefType).name(), debugger::DT_REF}, {typeid(Function).name(), debugger::DT_GRAPH}, |
|
|
|
{typeid(TypeNone).name(), debugger::DT_NONE}, {typeid(String).name(), debugger::DT_STRING}, |
|
|
|
{typeid(UMonadType).name(), debugger::DT_UMONAD}, {typeid(IOMonadType).name(), debugger::DT_IOMONAD}}; |
|
|
|
|
|
|
|
void SetOutputType(const TypePtr &type, const BaseShapePtr &shape, debugger::TypeProto *type_proto) { |
|
|
|
if (type_proto == nullptr) { |
|
|
|
return; |
|
|
|
} |
|
|
|
@@ -55,46 +91,22 @@ void DebuggerProtoExporter::SetNodeOutputType(const TypePtr &type, const BaseSha |
|
|
|
return; |
|
|
|
} |
|
|
|
CheckIfValidType(type, type_proto); |
|
|
|
if (type->isa<TensorType>()) { |
|
|
|
TypePtr elem_type = dyn_cast<TensorType>(type)->element(); |
|
|
|
type_proto->mutable_tensor_type()->set_elem_type(GetDebuggerNumberDataType(elem_type)); |
|
|
|
type_proto->set_data_type(debugger::DT_TENSOR); |
|
|
|
if (shape != nullptr && shape->isa<abstract::Shape>()) { |
|
|
|
abstract::ShapePtr shape_info = dyn_cast<abstract::Shape>(shape); |
|
|
|
for (const auto &elem : shape_info->shape()) { |
|
|
|
type_proto->mutable_tensor_type()->mutable_shape()->add_dim()->set_size(elem); |
|
|
|
} |
|
|
|
} |
|
|
|
} else if (type->isa<Tuple>()) { |
|
|
|
TuplePtr tuple_type = dyn_cast<Tuple>(type); |
|
|
|
type_proto->set_data_type(debugger::DT_TUPLE); |
|
|
|
for (const auto &elem_type : tuple_type->elements()) { |
|
|
|
SetNodeOutputType(elem_type, nullptr, type_proto->mutable_sequence_type()->add_elem_types()); |
|
|
|
} |
|
|
|
} else if (type->isa<TypeType>()) { |
|
|
|
type_proto->set_data_type(debugger::DT_TYPE); |
|
|
|
} else if (type->isa<List>()) { |
|
|
|
ListPtr list_type = dyn_cast<List>(type); |
|
|
|
type_proto->set_data_type(debugger::DT_LIST); |
|
|
|
for (const auto &elem_type : list_type->elements()) { |
|
|
|
SetNodeOutputType(elem_type, nullptr, type_proto->mutable_sequence_type()->add_elem_types()); |
|
|
|
for (auto &it : type_info_to_proto_type) { |
|
|
|
if (type->IsFromTypeId(Base::GetTypeId(it.first))) { |
|
|
|
type_proto->set_data_type(it.second); |
|
|
|
break; |
|
|
|
} |
|
|
|
} else if (type->isa<TypeAnything>()) { |
|
|
|
type_proto->set_data_type(debugger::DT_ANYTHING); |
|
|
|
} else if (type->isa<RefKeyType>()) { |
|
|
|
type_proto->set_data_type(debugger::DT_REFKEY); |
|
|
|
} else if (type->isa<RefType>()) { |
|
|
|
type_proto->set_data_type(debugger::DT_REF); |
|
|
|
} else if (type->isa<Function>()) { |
|
|
|
type_proto->set_data_type(debugger::DT_GRAPH); |
|
|
|
} else if (type->isa<TypeNone>()) { |
|
|
|
type_proto->set_data_type(debugger::DT_NONE); |
|
|
|
} else if (type->isa<String>()) { |
|
|
|
type_proto->set_data_type(debugger::DT_STRING); |
|
|
|
} else if (type->isa<UMonadType>()) { |
|
|
|
type_proto->set_data_type(debugger::DT_UMONAD); |
|
|
|
} else if (type->isa<IOMonadType>()) { |
|
|
|
type_proto->set_data_type(debugger::DT_IOMONAD); |
|
|
|
} |
|
|
|
if (type->isa<TensorType>()) { |
|
|
|
SetTensorTypeProto(type, shape, type_proto); |
|
|
|
return; |
|
|
|
} |
|
|
|
if (type->isa<Tuple>()) { |
|
|
|
SetTupleTypeProto(type, type_proto); |
|
|
|
return; |
|
|
|
} |
|
|
|
if (type->isa<List>()) { |
|
|
|
SetListTypeProto(type, type_proto); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
@@ -102,7 +114,7 @@ void DebuggerProtoExporter::SetNodeOutputType(const AnfNodePtr &node, debugger:: |
|
|
|
if (node == nullptr || type_proto == nullptr) { |
|
|
|
return; |
|
|
|
} |
|
|
|
SetNodeOutputType(node->Type(), node->Shape(), type_proto); |
|
|
|
SetOutputType(node->Type(), node->Shape(), type_proto); |
|
|
|
} |
|
|
|
|
|
|
|
void DebuggerProtoExporter::SetValueToProto(const ValuePtr &val, debugger::ValueProto *value_proto) { |
|
|
|
|