diff --git a/mindspore/ccsrc/runtime/device/ascend/dump/data_dumper.cc b/mindspore/ccsrc/runtime/device/ascend/dump/data_dumper.cc index b108dccbe1..00e45a839c 100644 --- a/mindspore/ccsrc/runtime/device/ascend/dump/data_dumper.cc +++ b/mindspore/ccsrc/runtime/device/ascend/dump/data_dumper.cc @@ -260,13 +260,13 @@ void DataDumper::SetOpDebugMappingInfo(const NotNulladd_dim(kOpDebugShape); output.set_original_name(kNodeNameOpDebug); output.set_original_output_index(0); - output.set_original_output_format(GeFormat::kFormat_ND); + output.set_original_output_format(ge::Format::FORMAT_ND); output.set_original_output_data_type(ge::proto::DataType::DT_UINT8); // due to lhisi virtual addr bug, cannot use args now output.set_address(static_cast(reinterpret_cast(op_debug_dump_args_))); diff --git a/mindspore/ccsrc/runtime/device/ascend/ge_types_convert.cc b/mindspore/ccsrc/runtime/device/ascend/ge_types_convert.cc index 8cd5fa9aed..098cf59cb9 100644 --- a/mindspore/ccsrc/runtime/device/ascend/ge_types_convert.cc +++ b/mindspore/ccsrc/runtime/device/ascend/ge_types_convert.cc @@ -15,7 +15,11 @@ */ #include "runtime/device/ascend/ge_types_convert.h" +#include "graph/utils/type_utils.h" +namespace { +constexpr auto kInvalidFormat = "RESERVED"; +} namespace mindspore { namespace device { namespace ascend { @@ -54,26 +58,31 @@ ge::DataType GeTypesConvert::TransTypeIdToGeDataType(TypeId type_id) { return iter->second; } -GeFormat GeTypesConvert::GetGeFormat(const std::string &format, size_t shape_size) { - static const std::map format_map = { +ge::Format GeTypesConvert::GetGeFormat(const std::string &format, size_t shape_size) { + static const std::map format_map = { // default format: nchw, fractal_nz? - {kOpFormat_DEFAULT, kFormat_NCHW}, - {kOpFormat_NC1KHKWHWC0, kFormat_NC1KHKWHWC0}, - {kOpFormat_ND, kFormat_ND}, - {kOpFormat_NCHW, kFormat_NCHW}, - {kOpFormat_NHWC, kFormat_NHWC}, - {kOpFormat_HWCN, kFormat_HWCN}, - {kOpFormat_NC1HWC0, kFormat_NC1HWC0}, - {kOpFormat_FRAC_Z, kFormat_FRACTAL_Z}, - {kOpFormat_FRAC_NZ, kFormat_FRACTAL_NZ}, - {kOpFormat_C1HWNCoC0, kFormat_C1HWNCoC0}, - {kOpFormat_NC1HWC0_C04, kFormat_NC1HWC0_C04}, - {kOpFormat_FRACTAL_Z_C04, kFormat_FRACTAL_Z_C04}, - {kOpFormat_NDHWC, kFormat_NDHWC}, - }; + {kOpFormat_DEFAULT, ge::Format::FORMAT_NCHW}, + {kOpFormat_NC1KHKWHWC0, ge::Format::FORMAT_NC1KHKWHWC0}, + {kOpFormat_ND, ge::Format::FORMAT_ND}, + {kOpFormat_NCHW, ge::Format::FORMAT_NCHW}, + {kOpFormat_NHWC, ge::Format::FORMAT_NHWC}, + {kOpFormat_HWCN, ge::Format::FORMAT_HWCN}, + {kOpFormat_NC1HWC0, ge::Format::FORMAT_NC1HWC0}, + {kOpFormat_FRAC_Z, ge::Format::FORMAT_FRACTAL_Z}, + {kOpFormat_FRAC_NZ, ge::Format::FORMAT_FRACTAL_NZ}, + {kOpFormat_C1HWNCoC0, ge::Format::FORMAT_C1HWNCoC0}, + {kOpFormat_NC1HWC0_C04, ge::Format::FORMAT_NC1HWC0_C04}, + {kOpFormat_FRACTAL_Z_C04, ge::Format::FORMAT_FRACTAL_Z_C04}, + {kOpFormat_NDHWC, ge::Format::FORMAT_NDHWC}, + {kOpFormat_NCDHW, ge::Format::FORMAT_NCDHW}, + {kOpFormat_DHWNC, ge::Format::FORMAT_DHWNC}, + {kOpFormat_DHWCN, ge::Format::FORMAT_DHWCN}, + {kOpFormat_NDC1HWC0, ge::Format::FORMAT_NDC1HWC0}, + {kOpFormat_FRACTAL_Z_3D, ge::Format::FORMAT_FRACTAL_Z_3D}, + {kOpFormat_FRACTAL_ZN_LSTM, ge::Format::FORMAT_FRACTAL_ZN_LSTM}}; MS_LOG(INFO) << "GetGeFormat format:" << format << " shape_size:" << shape_size; if (format == kOpFormat_DEFAULT) { - return shape_size == 4 ? kFormat_NCHW : kFormat_ND; + return shape_size == 4 ? ge::Format::FORMAT_NCHW : ge::Format::FORMAT_ND; } auto iter = format_map.find(format); if (iter == format_map.end()) { @@ -82,55 +91,12 @@ GeFormat GeTypesConvert::GetGeFormat(const std::string &format, size_t shape_siz return iter->second; } -std::string GeTypesConvert::GetGeTilingFormat(GeFormat ge_format) { - static const std::map kFormatToStringMap = { - {kFormat_NCHW, "NCHW"}, - {kFormat_NHWC, "NHWC"}, - {kFormat_ND, "ND"}, - {kFormat_NC1HWC0, "NC1HWC0"}, - {kFormat_FRACTAL_Z, "FRACTAL_Z"}, - {kFormat_NC1C0HWPAD, "NC1C0HWPAD"}, - {kFormat_NHWC1C0, "NHWC1C0"}, - {kFormat_FSR_NCHW, "FSR_NCHW"}, - {kFormat_FRACTAL_DECONV, "FRACTAL_DECONV"}, - {kFormat_C1HWNC0, "C1HWNC0"}, - {kFormat_FRACTAL_DECONV_TRANSPOSE, "FRACTAL_DECONV_TRANSPOSE"}, - {kFormat_FRACTAL_DECONV_SP_STRIDE_TRANS, "FRACTAL_DECONV_SP_STRIDE_TRANS"}, - {kFormat_NC1HWC0_C04, "NC1HWC0_C04"}, - {kFormat_FRACTAL_Z_C04, "FRACTAL_Z_C04"}, - {kFormat_CHWN, "CHWN"}, - {kFormat_FRACTAL_DECONV_SP_STRIDE8_TRANS, "DECONV_SP_STRIDE8_TRANS"}, - {kFormat_NC1KHKWHWC0, "NC1KHKWHWC0"}, - {kFormat_BN_WEIGHT, "BN_WEIGHT"}, - {kFormat_FILTER_HWCK, "FILTER_HWCK"}, - {kFormat_HWCN, "HWCN"}, - {kFormat_HASHTABLE_LOOKUP_LOOKUPS, "LOOKUP_LOOKUPS"}, - {kFormat_HASHTABLE_LOOKUP_KEYS, "LOOKUP_KEYS"}, - {kFormat_HASHTABLE_LOOKUP_VALUE, "LOOKUP_VALUE"}, - {kFormat_HASHTABLE_LOOKUP_OUTPUT, "LOOKUP_OUTPUT"}, - {kFormat_HASHTABLE_LOOKUP_HITS, "LOOKUP_HITS"}, - {kFormat_MD, "MD"}, - {kFormat_NDHWC, "NDHWC"}, - {kFormat_NCDHW, "NCDHW"}, - {kFormat_DHWCN, "DHWCN"}, - {kFormat_DHWNC, "DHWNC"}, - {kFormat_NDC1HWC0, "NDC1HWC0"}, - {kFormat_FRACTAL_Z_3D, "FRACTAL_Z_3D"}, - {kFormat_FRACTAL_Z_3D_TRANSPOSE, "FRACTAL_Z_3D_TRANSPOSE"}, - {kFormat_C1HWNCoC0, "C1HWNCoC0"}, - {kFormat_FRACTAL_NZ, "FRACTAL_NZ"}, - {kFormat_CN, "CN"}, - {kFormat_NC, "NC"}, - {kFormat_FRACTAL_ZN_LSTM, "FRACTAL_ZN_LSTM"}, - {kFormat_FRACTAL_Z_G, "FRACTAL_Z_G"}, - {kFormat_RESERVED, "FORMAT_RESERVED"}, - {kFormat_ALL, "ALL"}}; - - auto iter = kFormatToStringMap.find(ge_format); - if (iter == kFormatToStringMap.end()) { - MS_LOG(EXCEPTION) << "Invalid ge_format:" << ge_format; +std::string GeTypesConvert::GetGeTilingFormat(ge::Format ge_format) { + auto format_str = ge::TypeUtils::FormatToSerialString(ge_format); + if (format_str == kInvalidFormat) { + MS_LOG(EXCEPTION) << "Not support format:" << ge_format; } - return iter->second; + return format_str; } } // namespace ascend } // namespace device diff --git a/mindspore/ccsrc/runtime/device/ascend/ge_types_convert.h b/mindspore/ccsrc/runtime/device/ascend/ge_types_convert.h index 0881ced928..5a24f85d2b 100644 --- a/mindspore/ccsrc/runtime/device/ascend/ge_types_convert.h +++ b/mindspore/ccsrc/runtime/device/ascend/ge_types_convert.h @@ -27,58 +27,13 @@ namespace mindspore { namespace device { namespace ascend { -enum GeFormat { - kFormat_NCHW = 0, // NCHW - kFormat_NHWC, // NHWC - kFormat_ND, // Nd Tensor - kFormat_NC1HWC0, // NC1HWC0 - kFormat_FRACTAL_Z, // FRACTAL_Z - kFormat_NC1C0HWPAD, - kFormat_NHWC1C0, - kFormat_FSR_NCHW, - kFormat_FRACTAL_DECONV, - kFormat_C1HWNC0, - kFormat_FRACTAL_DECONV_TRANSPOSE, - kFormat_FRACTAL_DECONV_SP_STRIDE_TRANS, - kFormat_NC1HWC0_C04, // NC1HWC0, C0 =4 - kFormat_FRACTAL_Z_C04, // FRACZ, C0 =4 - kFormat_CHWN, - kFormat_FRACTAL_DECONV_SP_STRIDE8_TRANS, - kFormat_HWCN, - kFormat_NC1KHKWHWC0, // KH,KW kernel h& kernel w maxpooling max output format - kFormat_BN_WEIGHT, - kFormat_FILTER_HWCK, // filter input tensor format - kFormat_HASHTABLE_LOOKUP_LOOKUPS = 20, - kFormat_HASHTABLE_LOOKUP_KEYS, - kFormat_HASHTABLE_LOOKUP_VALUE, - kFormat_HASHTABLE_LOOKUP_OUTPUT, - kFormat_HASHTABLE_LOOKUP_HITS = 24, - kFormat_C1HWNCoC0, - kFormat_MD, - kFormat_NDHWC, - kFormat_FRACTAL_ZZ, - kFormat_FRACTAL_NZ, - kFormat_NCDHW, - kFormat_DHWCN, // 3D filter input tensor format - kFormat_NDC1HWC0, - kFormat_FRACTAL_Z_3D, - kFormat_CN, - kFormat_NC, - kFormat_DHWNC, - kFormat_FRACTAL_Z_3D_TRANSPOSE, // 3D filter(transpose) input tensor format - kFormat_FRACTAL_ZN_LSTM, - kFormat_FRACTAL_Z_G, - kFormat_RESERVED, - kFormat_ALL -}; - class GeTypesConvert { public: GeTypesConvert() = default; ~GeTypesConvert() = default; static ge::proto::DataType GetGeDataType(TypeId type_id); - static GeFormat GetGeFormat(const std::string &format, size_t shape_size); - static std::string GetGeTilingFormat(GeFormat ge_format); + static ge::Format GetGeFormat(const std::string &format, size_t shape_size); + static std::string GetGeTilingFormat(ge::Format ge_format); static ge::DataType TransTypeIdToGeDataType(TypeId type_id); }; } // namespace ascend