Browse Source

Pre Merge pull request !1258 from lichun/master

pull/1258/MERGE
lichun Gitee 4 years ago
parent
commit
04c811fcff
4 changed files with 53 additions and 49 deletions
  1. +1
    -1
      ge/common/formats/format_transfers/format_transfer_fracz_nchw.cc
  2. +25
    -23
      ge/common/formats/format_transfers/format_transfer_fracz_nhwc.cc
  3. +26
    -24
      ge/common/formats/format_transfers/format_transfer_nc1hwc0_nchw.cc
  4. +1
    -1
      ge/common/formats/format_transfers/format_transfer_nchw_fz_c04.cc

+ 1
- 1
ge/common/formats/format_transfers/format_transfer_fracz_nchw.cc View File

@@ -37,7 +37,7 @@ Status CheckArgsForFracZToNchw(const TransArgs &args) {
std::string error = "Dose not support trans format from " +
FmtToStr(TypeUtils::FormatToSerialString(args.src_format)) + " to " +
FmtToStr(TypeUtils::FormatToSerialString(args.dst_format));
GE_ERRORLOG_AND_ERRORMSG(UNSUPPORTED, error.c_str());
GE_ERRORLOG_AND_ERRORMSG(ACL_ERROR_GE_FORMAT_INVALID, error.c_str());
return ACL_ERROR_GE_FORMAT_INVALID;
}
if (!CheckDataTypeSupported(args.src_data_type)) {


+ 25
- 23
ge/common/formats/format_transfers/format_transfer_fracz_nhwc.cc View File

@@ -37,33 +37,33 @@ Status CheckArgsForFracZToNhwc(const TransArgs &args) {
std::string error = "Dose not support trans format from " +
FmtToStr(TypeUtils::FormatToSerialString(args.src_format)) + " to " +
FmtToStr(TypeUtils::FormatToSerialString(args.dst_format));
GE_ERRORLOG_AND_ERRORMSG(UNSUPPORTED, error.c_str());
return UNSUPPORTED;
GE_ERRORLOG_AND_ERRORMSG(ACL_ERROR_GE_FORMAT_INVALID, error.c_str());
return ACL_ERROR_GE_FORMAT_INVALID;
}
if (!CheckDataTypeSupported(args.src_data_type)) {
GELOGE(UNSUPPORTED, "Failed to trans shape from FORMAT_FRACTAL_Z to NHWC, invalid data type %s",
GELOGE(ACL_ERROR_GE_DATATYPE_INVALID, "Failed to trans shape from FORMAT_FRACTAL_Z to NHWC, invalid data type %s",
TypeUtils::DataTypeToSerialString(args.src_data_type).c_str());
return UNSUPPORTED;
return ACL_ERROR_GE_DATATYPE_INVALID;
}
if (!CheckShapeValid(src_shape, kFracZDimsNum)) {
GELOGE(PARAM_INVALID, "Failed to check src shape %s", ShapeToString(src_shape).c_str());
return PARAM_INVALID;
GELOGE(ACL_ERROR_GE_SHAPE_INVALID, "Failed to check src shape %s", ShapeToString(src_shape).c_str());
return ACL_ERROR_GE_SHAPE_INVALID;
}
if (!CheckShapeValid(dst_shape, kNhwcDimsNum)) {
GELOGE(PARAM_INVALID, "Failed to check dst shape %s", ShapeToString(dst_shape).c_str());
return PARAM_INVALID;
GELOGE(ACL_ERROR_GE_SHAPE_INVALID, "Failed to check dst shape %s", ShapeToString(dst_shape).c_str());
return ACL_ERROR_GE_SHAPE_INVALID;
}
int64_t c0 = GetCubeSizeByDataType(args.src_data_type);
if (c0 < 0) {
return PARAM_INVALID;
return ACL_ERROR_GE_DATATYPE_INVALID;
}
int64_t c1 = Ceil(dst_shape.at(kNhwcC), c0);
int64_t n0 = Ceil(dst_shape.at(kNhwcN), static_cast<int64_t>(kNiSize));
if (src_shape.at(kFracZHWC1) != dst_shape.at(kNhwcH) * dst_shape.at(kNhwcW) * c1 || src_shape.at(kFracZC0) != c0 ||
src_shape.at(kFracZNi) != kNiSize || src_shape.at(kFracZN0) != n0) {
GELOGE(PARAM_INVALID, "Failed to check relationship between src and dst shape, src shape %s, dst shape %s",
GELOGE(ACL_ERROR_GE_SHAPE_INVALID, "Failed to check relationship between src and dst shape, src shape %s, dst shape %s",
ShapeToString(src_shape).c_str(), ShapeToString(dst_shape).c_str());
return PARAM_INVALID;
return ACL_ERROR_GE_SHAPE_INVALID;
}

return SUCCESS;
@@ -72,10 +72,10 @@ Status CheckArgsForFracZToNhwc(const TransArgs &args) {
Status GetDstDataAfterTrans(const TransArgs &args, TransResult &result, int size, int64_t total_size) {
std::shared_ptr<uint8_t> dst(new (std::nothrow) uint8_t[total_size], std::default_delete<uint8_t[]>());
if (dst == nullptr) {
GELOGE(OUT_OF_MEMORY, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld, shape %s",
GELOGE(ACL_ERROR_GE_MEMORY_ALLOCATION, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld, shape %s",
TypeUtils::FormatToSerialString(args.src_format).c_str(),
TypeUtils::FormatToSerialString(args.dst_format).c_str(), total_size, ShapeToString(args.dst_shape).c_str());
return OUT_OF_MEMORY;
return ACL_ERROR_GE_MEMORY_ALLOCATION;
}

auto n0 = args.src_shape.at(kFracZN0);
@@ -111,10 +111,10 @@ Status GetDstDataAfterTrans(const TransArgs &args, TransResult &result, int size
auto ret = memcpy_s(dst.get() + dst_offset, static_cast<size_t>(protected_size), args.data + src_offset,
static_cast<size_t>(size));
if (ret != EOK) {
GELOGE(INTERNAL_ERROR,
GELOGE(ACL_ERROR_GE_MEMORY_OPERATE_FAILED,
"Failed to copy data from FracZ offset %ld to HHWC[%ld, %ld, %ld, %ld] offset %ld, err-code %d",
src_offset, n_idx, h_idx, w_idx, c_idx, dst_offset, ret);
return INTERNAL_ERROR;
return ACL_ERROR_GE_MEMORY_OPERATE_FAILED;
}
}
}
@@ -127,8 +127,9 @@ Status GetDstDataAfterTrans(const TransArgs &args, TransResult &result, int size
} // namespace

Status FormatTransferFracZNhwc::TransFormat(const TransArgs &args, TransResult &result) {
if (CheckArgsForFracZToNhwc(args) != SUCCESS) {
return PARAM_INVALID;
Status ret = CheckArgsForFracZToNhwc(args);
if (ret != SUCCESS) {
return ret;
}
int size = GetSizeByDataType(args.src_data_type);
auto total_size = GetItemNumByShape(args.dst_shape) * size;
@@ -139,18 +140,19 @@ Status FormatTransferFracZNhwc::TransFormat(const TransArgs &args, TransResult &
return SUCCESS;
}

GELOGE(INTERNAL_ERROR, "Get %ld total size from dst shape %s, src shape %s", total_size,
GELOGE(ACL_ERROR_GE_PARAM_INVALID, "Get %ld total size from dst shape %s, src shape %s", total_size,
ShapeToString(args.dst_shape).c_str(), ShapeToString(args.src_shape).c_str());
return PARAM_INVALID;
return ACL_ERROR_GE_PARAM_INVALID;
}
GELOGD("Begin to trans format from FracZ to NHWC, src shape %s, data type %s, dst shape %s, memory size %ld",
ShapeToString(args.src_shape).c_str(), TypeUtils::DataTypeToSerialString(args.src_data_type).c_str(),
ShapeToString(args.dst_shape).c_str(), total_size);
if (GetDstDataAfterTrans(args, result, size, total_size) != SUCCESS) {
GELOGE(INTERNAL_ERROR, "Failed to get data after trans, src shape %s, data type %s, dst shape %s, memory size %ld",
ret = GetDstDataAfterTrans(args, result, size, total_size);
if (ret != SUCCESS) {
GELOGE(ret, "Failed to get data after trans, src shape %s, data type %s, dst shape %s, memory size %ld",
ShapeToString(args.src_shape).c_str(), TypeUtils::DataTypeToSerialString(args.src_data_type).c_str(),
ShapeToString(args.dst_shape).c_str(), total_size);
return INTERNAL_ERROR;
return ret;
}
return SUCCESS;
}
@@ -158,7 +160,7 @@ Status FormatTransferFracZNhwc::TransFormat(const TransArgs &args, TransResult &
Status FormatTransferFracZNhwc::TransShape(Format src_format, const std::vector<int64_t> &src_shape, DataType data_type,
Format dst_format, std::vector<int64_t> &dst_shape) {
GELOGD("The shape derivation from FracZ to NHWC is not unique. Trans shape in this direction is not supported");
return UNSUPPORTED;
return ACL_ERROR_GE_FORMAT_INVALID;
}

REGISTER_FORMAT_TRANSFER(FormatTransferFracZNhwc, FORMAT_FRACTAL_Z, FORMAT_NHWC)


+ 26
- 24
ge/common/formats/format_transfers/format_transfer_nc1hwc0_nchw.cc View File

@@ -37,33 +37,33 @@ Status CheckArgsForNc1hwc0ToNchw(const TransArgs &args) {
std::string error = "Dose not support trans format from " +
FmtToStr(TypeUtils::FormatToSerialString(args.src_format)) + " to " +
FmtToStr(TypeUtils::FormatToSerialString(args.dst_format));
GE_ERRORLOG_AND_ERRORMSG(UNSUPPORTED, error.c_str());
return UNSUPPORTED;
GE_ERRORLOG_AND_ERRORMSG(ACL_ERROR_GE_FORMAT_INVALID, error.c_str());
return ACL_ERROR_GE_FORMAT_INVALID;
}
if (!CheckDataTypeSupported(args.src_data_type)) {
GELOGE(UNSUPPORTED, "Failed to trans shape from NC1HWC0 to NCHW, invalid data type %s",
GELOGE(ACL_ERROR_GE_DATATYPE_INVALID, "Failed to trans shape from NC1HWC0 to NCHW, invalid data type %s",
TypeUtils::DataTypeToSerialString(args.src_data_type).c_str());
return UNSUPPORTED;
return ACL_ERROR_GE_DATATYPE_INVALID;
}
if (!CheckShapeValid(args.src_shape, kNc1hwc0DimsNum)) {
GELOGE(PARAM_INVALID, "Failed to check src shape %s", ShapeToString(args.src_shape).c_str());
return PARAM_INVALID;
GELOGE(ACL_ERROR_GE_SHAPE_INVALID, "Failed to check src shape %s", ShapeToString(args.src_shape).c_str());
return ACL_ERROR_GE_SHAPE_INVALID;
}
if (!CheckShapeValid(args.dst_shape, kNchwDimsNum)) {
GELOGE(PARAM_INVALID, "Failed to check dst shape %s", ShapeToString(args.dst_shape).c_str());
return PARAM_INVALID;
GELOGE(ACL_ERROR_GE_SHAPE_INVALID, "Failed to check dst shape %s", ShapeToString(args.dst_shape).c_str());
return ACL_ERROR_GE_SHAPE_INVALID;
}
int64_t c0 = GetCubeSizeByDataType(args.src_data_type);
if (c0 <= 0) {
GELOGE(PARAM_INVALID, "Failed to get cube size, the data type is invalid");
return PARAM_INVALID;
GELOGE(ACL_ERROR_GE_DATATYPE_INVALID, "Failed to get cube size, the data type is invalid");
return ACL_ERROR_GE_DATATYPE_INVALID;
}
if (src_shape.at(kNc1hwc0H) != dst_shape.at(kNchwH) || src_shape.at(kNc1hwc0W) != dst_shape.at(kNchwW) ||
src_shape.at(kNc1hwc0N) != dst_shape.at(kNchwN) || src_shape.at(kNc1hwc0C0) != c0 ||
src_shape.at(kNc1hwc0C1) != (Ceil(dst_shape.at(kNchwC), c0))) {
GELOGE(PARAM_INVALID, "Failed to check relationship between src and dst shape, src shape %s, dst shape %s",
GELOGE(ACL_ERROR_GE_SHAPE_INVALID, "Failed to check relationship between src and dst shape, src shape %s, dst shape %s",
ShapeToString(src_shape).c_str(), ShapeToString(dst_shape).c_str());
return PARAM_INVALID;
return ACL_ERROR_GE_SHAPE_INVALID;
}

return SUCCESS;
@@ -72,10 +72,10 @@ Status CheckArgsForNc1hwc0ToNchw(const TransArgs &args) {
Status GetDstDataAfterTrans(const TransArgs &args, TransResult &result, const int size, const int64_t total_size) {
std::shared_ptr<uint8_t> dst(new (std::nothrow) uint8_t[total_size], std::default_delete<uint8_t[]>());
if (dst == nullptr) {
GELOGE(OUT_OF_MEMORY, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld, shape %s",
GELOGE(ACL_ERROR_GE_MEMORY_ALLOCATION, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld, shape %s",
TypeUtils::FormatToSerialString(args.src_format).c_str(),
TypeUtils::FormatToSerialString(args.dst_format).c_str(), total_size, ShapeToString(args.dst_shape).c_str());
return OUT_OF_MEMORY;
return ACL_ERROR_GE_MEMORY_ALLOCATION;
}

auto h = args.src_shape.at(kNc1hwc0H);
@@ -109,11 +109,11 @@ Status GetDstDataAfterTrans(const TransArgs &args, TransResult &result, const in
auto ret = memcpy_s(dst.get() + dst_offset, static_cast<size_t>(protected_size), args.data + src_offset,
static_cast<size_t>(size));
if (ret != EOK) {
GELOGE(INTERNAL_ERROR,
GELOGE(ACL_ERROR_GE_MEMORY_OPERATE_FAILED,
"Failed to copy data from NC1HWC0[%ld, %ld, %ld, %ld, %ld] offset %ld to NCHW[%ld, %ld, %ld, %ld]"
" offset %ld, err-code %d",
n_idx, c1_idx, h_idx, w_idx, c0_idx, src_offset, n_idx, c_idx, h_idx, w_idx, dst_offset, ret);
return INTERNAL_ERROR;
return ACL_ERROR_GE_MEMORY_OPERATE_FAILED;
}
}
}
@@ -126,8 +126,9 @@ Status GetDstDataAfterTrans(const TransArgs &args, TransResult &result, const in
} // namespace

Status FormatTransferNc1hwc0Nchw::TransFormat(const TransArgs &args, TransResult &result) {
if (CheckArgsForNc1hwc0ToNchw(args) != SUCCESS) {
return PARAM_INVALID;
Status ret = CheckArgsForNc1hwc0ToNchw(args);
if (ret != SUCCESS) {
return ret;
}
int size = GetSizeByDataType(args.src_data_type);
auto total_size = GetItemNumByShape(args.dst_shape) * size;
@@ -138,18 +139,19 @@ Status FormatTransferNc1hwc0Nchw::TransFormat(const TransArgs &args, TransResult
return SUCCESS;
}

GELOGE(INTERNAL_ERROR, "Get %ld total size from dst shape %s, src shape %s", total_size,
GELOGE(ACL_ERROR_GE_PARAM_INVALID, "Get %ld total size from dst shape %s, src shape %s", total_size,
ShapeToString(args.dst_shape).c_str(), ShapeToString(args.src_shape).c_str());
return PARAM_INVALID;
return ACL_ERROR_GE_PARAM_INVALID;
}
GELOGD("Begin to trans format from NC1HWC0 to NCHW, src shape %s, data type %s, dst shape %s, memory size %ld",
ShapeToString(args.src_shape).c_str(), TypeUtils::DataTypeToSerialString(args.src_data_type).c_str(),
ShapeToString(args.dst_shape).c_str(), total_size);
if (GetDstDataAfterTrans(args, result, size, total_size) != SUCCESS) {
GELOGE(INTERNAL_ERROR, "Failed to get data after trans, src shape %s, data type %s, dst shape %s, memory size %ld",
ret = GetDstDataAfterTrans(args, result, size, total_size);
if (ret != SUCCESS) {
GELOGE(ret, "Failed to get data after trans, src shape %s, data type %s, dst shape %s, memory size %ld",
ShapeToString(args.src_shape).c_str(), TypeUtils::DataTypeToSerialString(args.src_data_type).c_str(),
ShapeToString(args.dst_shape).c_str(), total_size);
return INTERNAL_ERROR;
return ret;
}
return SUCCESS;
}
@@ -157,7 +159,7 @@ Status FormatTransferNc1hwc0Nchw::TransFormat(const TransArgs &args, TransResult
Status FormatTransferNc1hwc0Nchw::TransShape(Format src_format, const std::vector<int64_t> &src_shape,
DataType data_type, Format dst_format, std::vector<int64_t> &dst_shape) {
GELOGD("The shape derivation from NC1HWC0 to NCHW is not unique. Trans shape in this direction is not supported");
return UNSUPPORTED;
return ACL_ERROR_GE_FORMAT_INVALID;
}

REGISTER_FORMAT_TRANSFER(FormatTransferNc1hwc0Nchw, FORMAT_NC1HWC0, FORMAT_NCHW)


+ 1
- 1
ge/common/formats/format_transfers/format_transfer_nchw_fz_c04.cc View File

@@ -125,7 +125,7 @@ Status TransFormatFromNchwToFzC04(const TransArgs &args, TransResult &result) {
return ACL_ERROR_GE_INTERNAL_ERROR);
auto t1 = h_o * w_o;
auto t2 = n_o * c_o;
GE_IF_BOOL_EXEC(!CheckInt64MulOverflow(t1, t2), GELOGE(INTERNAL_ERROR, "int64 mul overflow.A[%ld], B[%ld]", t1, t2);
GE_IF_BOOL_EXEC(!CheckInt64MulOverflow(t1, t2), GELOGE(ACL_ERROR_GE_INTERNAL_ERROR, "int64 mul overflow.A[%ld], B[%ld]", t1, t2);
return ACL_ERROR_GE_INTERNAL_ERROR);

int64_t total_ele_cnt = n_o * c_o * h_o * w_o;


Loading…
Cancel
Save