| @@ -31,7 +31,7 @@ | |||
| namespace mindspore { | |||
| using FloatPtr = std::shared_ptr<Float>; | |||
| using IntPtr = std::shared_ptr<Int>; | |||
| using UIntPtr = std::shared_ptr<UInt>; | |||
| // anf type to mindir type map | |||
| static std::unordered_map<int, mind_ir::TensorProto_DataType> g_data_type_map = { | |||
| {kNumberTypeBool, mind_ir::TensorProto_DataType_BOOL}, | |||
| @@ -56,6 +56,13 @@ static std::unordered_map<int, mind_ir::TensorProto_DataType> g_data_bits_int_ma | |||
| {64, mind_ir::TensorProto_DataType_INT64}, | |||
| }; | |||
| static std::unordered_map<int, mind_ir::TensorProto_DataType> g_data_bits_uint_map = { | |||
| {8, mind_ir::TensorProto_DataType_UINT8}, | |||
| {16, mind_ir::TensorProto_DataType_UINT16}, | |||
| {32, mind_ir::TensorProto_DataType_UINT32}, | |||
| {64, mind_ir::TensorProto_DataType_UINT64}, | |||
| }; | |||
| static std::unordered_map<int, mind_ir::TensorProto_DataType> g_data_bits_float_map = { | |||
| {16, mind_ir::TensorProto_DataType_FLOAT16}, | |||
| {32, mind_ir::TensorProto_DataType_FLOAT}, | |||
| @@ -117,6 +124,7 @@ class IrExportBuilder { | |||
| mind_ir::TensorProto_DataType GetMindirDataType(TypeId type_id); | |||
| mind_ir::TensorProto_DataType GetMindirDataBitsIntType(int bits); | |||
| mind_ir::TensorProto_DataType GetMindirDataBitsFloatType(int bits); | |||
| mind_ir::TensorProto_DataType GetMindirDataBitsUIntType(int bits); | |||
| std::string GetNodeName(const AnfNodePtr &node); | |||
| std::string GetUniqueNodeName(const AnfNodePtr &node); | |||
| std::string GetOpTypeName(const AnfNodePtr &node); | |||
| @@ -243,6 +251,14 @@ mind_ir::TensorProto_DataType IrExportBuilder::GetMindirDataBitsIntType(int bits | |||
| return iter->second; | |||
| } | |||
| mind_ir::TensorProto_DataType IrExportBuilder::GetMindirDataBitsUIntType(int bits) { | |||
| auto iter = g_data_bits_uint_map.find(bits); | |||
| if (iter == g_data_bits_uint_map.end()) { | |||
| MS_LOG(EXCEPTION) << "Convert bits uint error, unsupported bits! " << bits; | |||
| } | |||
| return iter->second; | |||
| } | |||
| mind_ir::TensorProto_DataType IrExportBuilder::GetMindirDataBitsFloatType(int bits) { | |||
| auto iter = g_data_bits_float_map.find(bits); | |||
| if (iter == g_data_bits_float_map.end()) { | |||
| @@ -551,6 +567,11 @@ void IrExportBuilder::SetTypeToAttributeProto(const ValuePtr &value, mind_ir::At | |||
| tensor_proto->set_name("value0"); | |||
| auto int_value = value->cast<IntPtr>(); | |||
| tensor_proto->set_data_type(GetMindirDataBitsIntType(int_value->nbits())); | |||
| } else if (value->isa<UInt>()) { | |||
| attr_proto->set_ref_attr_name("type:value0"); | |||
| tensor_proto->set_name("value0"); | |||
| auto float_value = value->cast<UIntPtr>(); | |||
| tensor_proto->set_data_type(GetMindirDataBitsUIntType(float_value->nbits())); | |||
| } else if (value->isa<Float>()) { | |||
| attr_proto->set_ref_attr_name("type:value0"); | |||
| tensor_proto->set_name("value0"); | |||