|
|
@@ -15,7 +15,11 @@ |
|
|
*/ |
|
|
*/ |
|
|
|
|
|
|
|
|
#include "runtime/device/ascend/ge_types_convert.h" |
|
|
#include "runtime/device/ascend/ge_types_convert.h" |
|
|
|
|
|
#include "graph/utils/type_utils.h" |
|
|
|
|
|
|
|
|
|
|
|
namespace { |
|
|
|
|
|
constexpr auto kInvalidFormat = "RESERVED"; |
|
|
|
|
|
} |
|
|
namespace mindspore { |
|
|
namespace mindspore { |
|
|
namespace device { |
|
|
namespace device { |
|
|
namespace ascend { |
|
|
namespace ascend { |
|
|
@@ -54,26 +58,31 @@ ge::DataType GeTypesConvert::TransTypeIdToGeDataType(TypeId type_id) { |
|
|
return iter->second; |
|
|
return iter->second; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
GeFormat GeTypesConvert::GetGeFormat(const std::string &format, size_t shape_size) { |
|
|
|
|
|
static const std::map<std::string, GeFormat> format_map = { |
|
|
|
|
|
|
|
|
ge::Format GeTypesConvert::GetGeFormat(const std::string &format, size_t shape_size) { |
|
|
|
|
|
static const std::map<std::string, ge::Format> format_map = { |
|
|
// default format: nchw, fractal_nz? |
|
|
// default format: nchw, fractal_nz? |
|
|
{kOpFormat_DEFAULT, kFormat_NCHW}, |
|
|
|
|
|
{kOpFormat_NC1KHKWHWC0, kFormat_NC1KHKWHWC0}, |
|
|
|
|
|
{kOpFormat_ND, kFormat_ND}, |
|
|
|
|
|
{kOpFormat_NCHW, kFormat_NCHW}, |
|
|
|
|
|
{kOpFormat_NHWC, kFormat_NHWC}, |
|
|
|
|
|
{kOpFormat_HWCN, kFormat_HWCN}, |
|
|
|
|
|
{kOpFormat_NC1HWC0, kFormat_NC1HWC0}, |
|
|
|
|
|
{kOpFormat_FRAC_Z, kFormat_FRACTAL_Z}, |
|
|
|
|
|
{kOpFormat_FRAC_NZ, kFormat_FRACTAL_NZ}, |
|
|
|
|
|
{kOpFormat_C1HWNCoC0, kFormat_C1HWNCoC0}, |
|
|
|
|
|
{kOpFormat_NC1HWC0_C04, kFormat_NC1HWC0_C04}, |
|
|
|
|
|
{kOpFormat_FRACTAL_Z_C04, kFormat_FRACTAL_Z_C04}, |
|
|
|
|
|
{kOpFormat_NDHWC, kFormat_NDHWC}, |
|
|
|
|
|
}; |
|
|
|
|
|
|
|
|
{kOpFormat_DEFAULT, ge::Format::FORMAT_NCHW}, |
|
|
|
|
|
{kOpFormat_NC1KHKWHWC0, ge::Format::FORMAT_NC1KHKWHWC0}, |
|
|
|
|
|
{kOpFormat_ND, ge::Format::FORMAT_ND}, |
|
|
|
|
|
{kOpFormat_NCHW, ge::Format::FORMAT_NCHW}, |
|
|
|
|
|
{kOpFormat_NHWC, ge::Format::FORMAT_NHWC}, |
|
|
|
|
|
{kOpFormat_HWCN, ge::Format::FORMAT_HWCN}, |
|
|
|
|
|
{kOpFormat_NC1HWC0, ge::Format::FORMAT_NC1HWC0}, |
|
|
|
|
|
{kOpFormat_FRAC_Z, ge::Format::FORMAT_FRACTAL_Z}, |
|
|
|
|
|
{kOpFormat_FRAC_NZ, ge::Format::FORMAT_FRACTAL_NZ}, |
|
|
|
|
|
{kOpFormat_C1HWNCoC0, ge::Format::FORMAT_C1HWNCoC0}, |
|
|
|
|
|
{kOpFormat_NC1HWC0_C04, ge::Format::FORMAT_NC1HWC0_C04}, |
|
|
|
|
|
{kOpFormat_FRACTAL_Z_C04, ge::Format::FORMAT_FRACTAL_Z_C04}, |
|
|
|
|
|
{kOpFormat_NDHWC, ge::Format::FORMAT_NDHWC}, |
|
|
|
|
|
{kOpFormat_NCDHW, ge::Format::FORMAT_NCDHW}, |
|
|
|
|
|
{kOpFormat_DHWNC, ge::Format::FORMAT_DHWNC}, |
|
|
|
|
|
{kOpFormat_DHWCN, ge::Format::FORMAT_DHWCN}, |
|
|
|
|
|
{kOpFormat_NDC1HWC0, ge::Format::FORMAT_NDC1HWC0}, |
|
|
|
|
|
{kOpFormat_FRACTAL_Z_3D, ge::Format::FORMAT_FRACTAL_Z_3D}, |
|
|
|
|
|
{kOpFormat_FRACTAL_ZN_LSTM, ge::Format::FORMAT_FRACTAL_ZN_LSTM}}; |
|
|
MS_LOG(INFO) << "GetGeFormat format:" << format << " shape_size:" << shape_size; |
|
|
MS_LOG(INFO) << "GetGeFormat format:" << format << " shape_size:" << shape_size; |
|
|
if (format == kOpFormat_DEFAULT) { |
|
|
if (format == kOpFormat_DEFAULT) { |
|
|
return shape_size == 4 ? kFormat_NCHW : kFormat_ND; |
|
|
|
|
|
|
|
|
return shape_size == 4 ? ge::Format::FORMAT_NCHW : ge::Format::FORMAT_ND; |
|
|
} |
|
|
} |
|
|
auto iter = format_map.find(format); |
|
|
auto iter = format_map.find(format); |
|
|
if (iter == format_map.end()) { |
|
|
if (iter == format_map.end()) { |
|
|
@@ -82,55 +91,12 @@ GeFormat GeTypesConvert::GetGeFormat(const std::string &format, size_t shape_siz |
|
|
return iter->second; |
|
|
return iter->second; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
std::string GeTypesConvert::GetGeTilingFormat(GeFormat ge_format) { |
|
|
|
|
|
static const std::map<GeFormat, std::string> kFormatToStringMap = { |
|
|
|
|
|
{kFormat_NCHW, "NCHW"}, |
|
|
|
|
|
{kFormat_NHWC, "NHWC"}, |
|
|
|
|
|
{kFormat_ND, "ND"}, |
|
|
|
|
|
{kFormat_NC1HWC0, "NC1HWC0"}, |
|
|
|
|
|
{kFormat_FRACTAL_Z, "FRACTAL_Z"}, |
|
|
|
|
|
{kFormat_NC1C0HWPAD, "NC1C0HWPAD"}, |
|
|
|
|
|
{kFormat_NHWC1C0, "NHWC1C0"}, |
|
|
|
|
|
{kFormat_FSR_NCHW, "FSR_NCHW"}, |
|
|
|
|
|
{kFormat_FRACTAL_DECONV, "FRACTAL_DECONV"}, |
|
|
|
|
|
{kFormat_C1HWNC0, "C1HWNC0"}, |
|
|
|
|
|
{kFormat_FRACTAL_DECONV_TRANSPOSE, "FRACTAL_DECONV_TRANSPOSE"}, |
|
|
|
|
|
{kFormat_FRACTAL_DECONV_SP_STRIDE_TRANS, "FRACTAL_DECONV_SP_STRIDE_TRANS"}, |
|
|
|
|
|
{kFormat_NC1HWC0_C04, "NC1HWC0_C04"}, |
|
|
|
|
|
{kFormat_FRACTAL_Z_C04, "FRACTAL_Z_C04"}, |
|
|
|
|
|
{kFormat_CHWN, "CHWN"}, |
|
|
|
|
|
{kFormat_FRACTAL_DECONV_SP_STRIDE8_TRANS, "DECONV_SP_STRIDE8_TRANS"}, |
|
|
|
|
|
{kFormat_NC1KHKWHWC0, "NC1KHKWHWC0"}, |
|
|
|
|
|
{kFormat_BN_WEIGHT, "BN_WEIGHT"}, |
|
|
|
|
|
{kFormat_FILTER_HWCK, "FILTER_HWCK"}, |
|
|
|
|
|
{kFormat_HWCN, "HWCN"}, |
|
|
|
|
|
{kFormat_HASHTABLE_LOOKUP_LOOKUPS, "LOOKUP_LOOKUPS"}, |
|
|
|
|
|
{kFormat_HASHTABLE_LOOKUP_KEYS, "LOOKUP_KEYS"}, |
|
|
|
|
|
{kFormat_HASHTABLE_LOOKUP_VALUE, "LOOKUP_VALUE"}, |
|
|
|
|
|
{kFormat_HASHTABLE_LOOKUP_OUTPUT, "LOOKUP_OUTPUT"}, |
|
|
|
|
|
{kFormat_HASHTABLE_LOOKUP_HITS, "LOOKUP_HITS"}, |
|
|
|
|
|
{kFormat_MD, "MD"}, |
|
|
|
|
|
{kFormat_NDHWC, "NDHWC"}, |
|
|
|
|
|
{kFormat_NCDHW, "NCDHW"}, |
|
|
|
|
|
{kFormat_DHWCN, "DHWCN"}, |
|
|
|
|
|
{kFormat_DHWNC, "DHWNC"}, |
|
|
|
|
|
{kFormat_NDC1HWC0, "NDC1HWC0"}, |
|
|
|
|
|
{kFormat_FRACTAL_Z_3D, "FRACTAL_Z_3D"}, |
|
|
|
|
|
{kFormat_FRACTAL_Z_3D_TRANSPOSE, "FRACTAL_Z_3D_TRANSPOSE"}, |
|
|
|
|
|
{kFormat_C1HWNCoC0, "C1HWNCoC0"}, |
|
|
|
|
|
{kFormat_FRACTAL_NZ, "FRACTAL_NZ"}, |
|
|
|
|
|
{kFormat_CN, "CN"}, |
|
|
|
|
|
{kFormat_NC, "NC"}, |
|
|
|
|
|
{kFormat_FRACTAL_ZN_LSTM, "FRACTAL_ZN_LSTM"}, |
|
|
|
|
|
{kFormat_FRACTAL_Z_G, "FRACTAL_Z_G"}, |
|
|
|
|
|
{kFormat_RESERVED, "FORMAT_RESERVED"}, |
|
|
|
|
|
{kFormat_ALL, "ALL"}}; |
|
|
|
|
|
|
|
|
|
|
|
auto iter = kFormatToStringMap.find(ge_format); |
|
|
|
|
|
if (iter == kFormatToStringMap.end()) { |
|
|
|
|
|
MS_LOG(EXCEPTION) << "Invalid ge_format:" << ge_format; |
|
|
|
|
|
|
|
|
std::string GeTypesConvert::GetGeTilingFormat(ge::Format ge_format) { |
|
|
|
|
|
auto format_str = ge::TypeUtils::FormatToSerialString(ge_format); |
|
|
|
|
|
if (format_str == kInvalidFormat) { |
|
|
|
|
|
MS_LOG(EXCEPTION) << "Not support format:" << ge_format; |
|
|
} |
|
|
} |
|
|
return iter->second; |
|
|
|
|
|
|
|
|
return format_str; |
|
|
} |
|
|
} |
|
|
} // namespace ascend |
|
|
} // namespace ascend |
|
|
} // namespace device |
|
|
} // namespace device |
|
|
|