Browse Source

support uint for mindir datatype

tags/v1.4.0
lianliguang 4 years ago
parent
commit
96d656aa00
1 changed files with 22 additions and 1 deletions
  1. +22
    -1
      mindspore/ccsrc/transform/express_ir/mindir_exporter.cc

+ 22
- 1
mindspore/ccsrc/transform/express_ir/mindir_exporter.cc View File

@@ -31,7 +31,7 @@
namespace mindspore {
using FloatPtr = std::shared_ptr<Float>;
using IntPtr = std::shared_ptr<Int>;
using UIntPtr = std::shared_ptr<UInt>;
// anf type to mindir type map
static std::unordered_map<int, mind_ir::TensorProto_DataType> g_data_type_map = {
{kNumberTypeBool, mind_ir::TensorProto_DataType_BOOL},
@@ -56,6 +56,13 @@ static std::unordered_map<int, mind_ir::TensorProto_DataType> g_data_bits_int_ma
{64, mind_ir::TensorProto_DataType_INT64},
};

static std::unordered_map<int, mind_ir::TensorProto_DataType> g_data_bits_uint_map = {
{8, mind_ir::TensorProto_DataType_UINT8},
{16, mind_ir::TensorProto_DataType_UINT16},
{32, mind_ir::TensorProto_DataType_UINT32},
{64, mind_ir::TensorProto_DataType_UINT64},
};

static std::unordered_map<int, mind_ir::TensorProto_DataType> g_data_bits_float_map = {
{16, mind_ir::TensorProto_DataType_FLOAT16},
{32, mind_ir::TensorProto_DataType_FLOAT},
@@ -117,6 +124,7 @@ class IrExportBuilder {
mind_ir::TensorProto_DataType GetMindirDataType(TypeId type_id);
mind_ir::TensorProto_DataType GetMindirDataBitsIntType(int bits);
mind_ir::TensorProto_DataType GetMindirDataBitsFloatType(int bits);
mind_ir::TensorProto_DataType GetMindirDataBitsUIntType(int bits);
std::string GetNodeName(const AnfNodePtr &node);
std::string GetUniqueNodeName(const AnfNodePtr &node);
std::string GetOpTypeName(const AnfNodePtr &node);
@@ -243,6 +251,14 @@ mind_ir::TensorProto_DataType IrExportBuilder::GetMindirDataBitsIntType(int bits
return iter->second;
}

mind_ir::TensorProto_DataType IrExportBuilder::GetMindirDataBitsUIntType(int bits) {
auto iter = g_data_bits_uint_map.find(bits);
if (iter == g_data_bits_uint_map.end()) {
MS_LOG(EXCEPTION) << "Convert bits uint error, unsupported bits! " << bits;
}
return iter->second;
}

mind_ir::TensorProto_DataType IrExportBuilder::GetMindirDataBitsFloatType(int bits) {
auto iter = g_data_bits_float_map.find(bits);
if (iter == g_data_bits_float_map.end()) {
@@ -551,6 +567,11 @@ void IrExportBuilder::SetTypeToAttributeProto(const ValuePtr &value, mind_ir::At
tensor_proto->set_name("value0");
auto int_value = value->cast<IntPtr>();
tensor_proto->set_data_type(GetMindirDataBitsIntType(int_value->nbits()));
} else if (value->isa<UInt>()) {
attr_proto->set_ref_attr_name("type:value0");
tensor_proto->set_name("value0");
auto float_value = value->cast<UIntPtr>();
tensor_proto->set_data_type(GetMindirDataBitsUIntType(float_value->nbits()));
} else if (value->isa<Float>()) {
attr_proto->set_ref_attr_name("type:value0");
tensor_proto->set_name("value0");


Loading…
Cancel
Save