Browse Source

!12864 Data dump support 3d format

From: @jojobugfree
Reviewed-by: 
Signed-off-by:
tags/v1.2.0-rc1
mindspore-ci-bot Gitee 4 years ago
parent
commit
0789c487ad
3 changed files with 35 additions and 114 deletions
  1. +2
    -2
      mindspore/ccsrc/runtime/device/ascend/dump/data_dumper.cc
  2. +31
    -65
      mindspore/ccsrc/runtime/device/ascend/ge_types_convert.cc
  3. +2
    -47
      mindspore/ccsrc/runtime/device/ascend/ge_types_convert.h

+ 2
- 2
mindspore/ccsrc/runtime/device/ascend/dump/data_dumper.cc View File

@@ -260,13 +260,13 @@ void DataDumper::SetOpDebugMappingInfo(const NotNull<aicpu::dump::OpMappingInfo

aicpu::dump::Output output;
output.set_data_type(ge::proto::DataType::DT_UINT8);
output.set_format(GeFormat::kFormat_ND);
output.set_format(ge::Format::FORMAT_ND);

output.mutable_shape()->add_dim(kOpDebugShape);

output.set_original_name(kNodeNameOpDebug);
output.set_original_output_index(0);
output.set_original_output_format(GeFormat::kFormat_ND);
output.set_original_output_format(ge::Format::FORMAT_ND);
output.set_original_output_data_type(ge::proto::DataType::DT_UINT8);
// due to lhisi virtual addr bug, cannot use args now
output.set_address(static_cast<uint64_t>(reinterpret_cast<uintptr_t>(op_debug_dump_args_)));


+ 31
- 65
mindspore/ccsrc/runtime/device/ascend/ge_types_convert.cc View File

@@ -15,7 +15,11 @@
*/

#include "runtime/device/ascend/ge_types_convert.h"
#include "graph/utils/type_utils.h"

namespace {
constexpr auto kInvalidFormat = "RESERVED";
}
namespace mindspore {
namespace device {
namespace ascend {
@@ -54,26 +58,31 @@ ge::DataType GeTypesConvert::TransTypeIdToGeDataType(TypeId type_id) {
return iter->second;
}

GeFormat GeTypesConvert::GetGeFormat(const std::string &format, size_t shape_size) {
static const std::map<std::string, GeFormat> format_map = {
ge::Format GeTypesConvert::GetGeFormat(const std::string &format, size_t shape_size) {
static const std::map<std::string, ge::Format> format_map = {
// default format: nchw, fractal_nz?
{kOpFormat_DEFAULT, kFormat_NCHW},
{kOpFormat_NC1KHKWHWC0, kFormat_NC1KHKWHWC0},
{kOpFormat_ND, kFormat_ND},
{kOpFormat_NCHW, kFormat_NCHW},
{kOpFormat_NHWC, kFormat_NHWC},
{kOpFormat_HWCN, kFormat_HWCN},
{kOpFormat_NC1HWC0, kFormat_NC1HWC0},
{kOpFormat_FRAC_Z, kFormat_FRACTAL_Z},
{kOpFormat_FRAC_NZ, kFormat_FRACTAL_NZ},
{kOpFormat_C1HWNCoC0, kFormat_C1HWNCoC0},
{kOpFormat_NC1HWC0_C04, kFormat_NC1HWC0_C04},
{kOpFormat_FRACTAL_Z_C04, kFormat_FRACTAL_Z_C04},
{kOpFormat_NDHWC, kFormat_NDHWC},
};
{kOpFormat_DEFAULT, ge::Format::FORMAT_NCHW},
{kOpFormat_NC1KHKWHWC0, ge::Format::FORMAT_NC1KHKWHWC0},
{kOpFormat_ND, ge::Format::FORMAT_ND},
{kOpFormat_NCHW, ge::Format::FORMAT_NCHW},
{kOpFormat_NHWC, ge::Format::FORMAT_NHWC},
{kOpFormat_HWCN, ge::Format::FORMAT_HWCN},
{kOpFormat_NC1HWC0, ge::Format::FORMAT_NC1HWC0},
{kOpFormat_FRAC_Z, ge::Format::FORMAT_FRACTAL_Z},
{kOpFormat_FRAC_NZ, ge::Format::FORMAT_FRACTAL_NZ},
{kOpFormat_C1HWNCoC0, ge::Format::FORMAT_C1HWNCoC0},
{kOpFormat_NC1HWC0_C04, ge::Format::FORMAT_NC1HWC0_C04},
{kOpFormat_FRACTAL_Z_C04, ge::Format::FORMAT_FRACTAL_Z_C04},
{kOpFormat_NDHWC, ge::Format::FORMAT_NDHWC},
{kOpFormat_NCDHW, ge::Format::FORMAT_NCDHW},
{kOpFormat_DHWNC, ge::Format::FORMAT_DHWNC},
{kOpFormat_DHWCN, ge::Format::FORMAT_DHWCN},
{kOpFormat_NDC1HWC0, ge::Format::FORMAT_NDC1HWC0},
{kOpFormat_FRACTAL_Z_3D, ge::Format::FORMAT_FRACTAL_Z_3D},
{kOpFormat_FRACTAL_ZN_LSTM, ge::Format::FORMAT_FRACTAL_ZN_LSTM}};
MS_LOG(INFO) << "GetGeFormat format:" << format << " shape_size:" << shape_size;
if (format == kOpFormat_DEFAULT) {
return shape_size == 4 ? kFormat_NCHW : kFormat_ND;
return shape_size == 4 ? ge::Format::FORMAT_NCHW : ge::Format::FORMAT_ND;
}
auto iter = format_map.find(format);
if (iter == format_map.end()) {
@@ -82,55 +91,12 @@ GeFormat GeTypesConvert::GetGeFormat(const std::string &format, size_t shape_siz
return iter->second;
}

std::string GeTypesConvert::GetGeTilingFormat(GeFormat ge_format) {
static const std::map<GeFormat, std::string> kFormatToStringMap = {
{kFormat_NCHW, "NCHW"},
{kFormat_NHWC, "NHWC"},
{kFormat_ND, "ND"},
{kFormat_NC1HWC0, "NC1HWC0"},
{kFormat_FRACTAL_Z, "FRACTAL_Z"},
{kFormat_NC1C0HWPAD, "NC1C0HWPAD"},
{kFormat_NHWC1C0, "NHWC1C0"},
{kFormat_FSR_NCHW, "FSR_NCHW"},
{kFormat_FRACTAL_DECONV, "FRACTAL_DECONV"},
{kFormat_C1HWNC0, "C1HWNC0"},
{kFormat_FRACTAL_DECONV_TRANSPOSE, "FRACTAL_DECONV_TRANSPOSE"},
{kFormat_FRACTAL_DECONV_SP_STRIDE_TRANS, "FRACTAL_DECONV_SP_STRIDE_TRANS"},
{kFormat_NC1HWC0_C04, "NC1HWC0_C04"},
{kFormat_FRACTAL_Z_C04, "FRACTAL_Z_C04"},
{kFormat_CHWN, "CHWN"},
{kFormat_FRACTAL_DECONV_SP_STRIDE8_TRANS, "DECONV_SP_STRIDE8_TRANS"},
{kFormat_NC1KHKWHWC0, "NC1KHKWHWC0"},
{kFormat_BN_WEIGHT, "BN_WEIGHT"},
{kFormat_FILTER_HWCK, "FILTER_HWCK"},
{kFormat_HWCN, "HWCN"},
{kFormat_HASHTABLE_LOOKUP_LOOKUPS, "LOOKUP_LOOKUPS"},
{kFormat_HASHTABLE_LOOKUP_KEYS, "LOOKUP_KEYS"},
{kFormat_HASHTABLE_LOOKUP_VALUE, "LOOKUP_VALUE"},
{kFormat_HASHTABLE_LOOKUP_OUTPUT, "LOOKUP_OUTPUT"},
{kFormat_HASHTABLE_LOOKUP_HITS, "LOOKUP_HITS"},
{kFormat_MD, "MD"},
{kFormat_NDHWC, "NDHWC"},
{kFormat_NCDHW, "NCDHW"},
{kFormat_DHWCN, "DHWCN"},
{kFormat_DHWNC, "DHWNC"},
{kFormat_NDC1HWC0, "NDC1HWC0"},
{kFormat_FRACTAL_Z_3D, "FRACTAL_Z_3D"},
{kFormat_FRACTAL_Z_3D_TRANSPOSE, "FRACTAL_Z_3D_TRANSPOSE"},
{kFormat_C1HWNCoC0, "C1HWNCoC0"},
{kFormat_FRACTAL_NZ, "FRACTAL_NZ"},
{kFormat_CN, "CN"},
{kFormat_NC, "NC"},
{kFormat_FRACTAL_ZN_LSTM, "FRACTAL_ZN_LSTM"},
{kFormat_FRACTAL_Z_G, "FRACTAL_Z_G"},
{kFormat_RESERVED, "FORMAT_RESERVED"},
{kFormat_ALL, "ALL"}};

auto iter = kFormatToStringMap.find(ge_format);
if (iter == kFormatToStringMap.end()) {
MS_LOG(EXCEPTION) << "Invalid ge_format:" << ge_format;
std::string GeTypesConvert::GetGeTilingFormat(ge::Format ge_format) {
auto format_str = ge::TypeUtils::FormatToSerialString(ge_format);
if (format_str == kInvalidFormat) {
MS_LOG(EXCEPTION) << "Not support format:" << ge_format;
}
return iter->second;
return format_str;
}
} // namespace ascend
} // namespace device


+ 2
- 47
mindspore/ccsrc/runtime/device/ascend/ge_types_convert.h View File

@@ -27,58 +27,13 @@
namespace mindspore {
namespace device {
namespace ascend {
enum GeFormat {
kFormat_NCHW = 0, // NCHW
kFormat_NHWC, // NHWC
kFormat_ND, // Nd Tensor
kFormat_NC1HWC0, // NC1HWC0
kFormat_FRACTAL_Z, // FRACTAL_Z
kFormat_NC1C0HWPAD,
kFormat_NHWC1C0,
kFormat_FSR_NCHW,
kFormat_FRACTAL_DECONV,
kFormat_C1HWNC0,
kFormat_FRACTAL_DECONV_TRANSPOSE,
kFormat_FRACTAL_DECONV_SP_STRIDE_TRANS,
kFormat_NC1HWC0_C04, // NC1HWC0, C0 =4
kFormat_FRACTAL_Z_C04, // FRACZ, C0 =4
kFormat_CHWN,
kFormat_FRACTAL_DECONV_SP_STRIDE8_TRANS,
kFormat_HWCN,
kFormat_NC1KHKWHWC0, // KH,KW kernel h& kernel w maxpooling max output format
kFormat_BN_WEIGHT,
kFormat_FILTER_HWCK, // filter input tensor format
kFormat_HASHTABLE_LOOKUP_LOOKUPS = 20,
kFormat_HASHTABLE_LOOKUP_KEYS,
kFormat_HASHTABLE_LOOKUP_VALUE,
kFormat_HASHTABLE_LOOKUP_OUTPUT,
kFormat_HASHTABLE_LOOKUP_HITS = 24,
kFormat_C1HWNCoC0,
kFormat_MD,
kFormat_NDHWC,
kFormat_FRACTAL_ZZ,
kFormat_FRACTAL_NZ,
kFormat_NCDHW,
kFormat_DHWCN, // 3D filter input tensor format
kFormat_NDC1HWC0,
kFormat_FRACTAL_Z_3D,
kFormat_CN,
kFormat_NC,
kFormat_DHWNC,
kFormat_FRACTAL_Z_3D_TRANSPOSE, // 3D filter(transpose) input tensor format
kFormat_FRACTAL_ZN_LSTM,
kFormat_FRACTAL_Z_G,
kFormat_RESERVED,
kFormat_ALL
};

class GeTypesConvert {
public:
GeTypesConvert() = default;
~GeTypesConvert() = default;
static ge::proto::DataType GetGeDataType(TypeId type_id);
static GeFormat GetGeFormat(const std::string &format, size_t shape_size);
static std::string GetGeTilingFormat(GeFormat ge_format);
static ge::Format GetGeFormat(const std::string &format, size_t shape_size);
static std::string GetGeTilingFormat(ge::Format ge_format);
static ge::DataType TransTypeIdToGeDataType(TypeId type_id);
};
} // namespace ascend


Loading…
Cancel
Save