| @@ -111,7 +111,7 @@ Status CastKernel(const CastArgs &args, uint8_t *dst, const size_t data_size, co | |||
| }; | |||
| auto it = transfer_handle.find(trans_mode); | |||
| if (it == transfer_handle.end()) { | |||
| return UNSUPPORTED; | |||
| return ACL_ERROR_GE_DATATYPE_INVALID; | |||
| } else { | |||
| return (it->second)(args, dst, data_size); | |||
| } | |||
| @@ -127,8 +127,8 @@ Status DataTypeTransfer::TransDataType(const CastArgs &args, TransResult &result | |||
| std::string error = "Failed to trans data from datatype " + | |||
| FmtToStr(TypeUtils::DataTypeToSerialString(args.src_data_type)) + " to " + | |||
| FmtToStr(TypeUtils::DataTypeToSerialString(args.dst_data_type)) + " , it is not supported."; | |||
| GE_ERRORLOG_AND_ERRORMSG(UNSUPPORTED, error.c_str()); | |||
| return UNSUPPORTED; | |||
| GE_ERRORLOG_AND_ERRORMSG(ACL_ERROR_GE_DATATYPE_INVALID, error.c_str()); | |||
| return ACL_ERROR_GE_DATATYPE_INVALID; | |||
| } | |||
| auto trans_mode = iter->second; | |||
| @@ -136,14 +136,14 @@ Status DataTypeTransfer::TransDataType(const CastArgs &args, TransResult &result | |||
| if (size <= 0) { | |||
| std::string error = "Failed to calc size from data type" + | |||
| FmtToStr(TypeUtils::DataTypeToSerialString(args.dst_data_type)) + ", it is not supported."; | |||
| GE_ERRORLOG_AND_ERRORMSG(PARAM_INVALID, error.c_str()); | |||
| return PARAM_INVALID; | |||
| GE_ERRORLOG_AND_ERRORMSG(ACL_ERROR_GE_DATATYPE_INVALID, error.c_str()); | |||
| return ACL_ERROR_GE_DATATYPE_INVALID; | |||
| } | |||
| if (args.src_data_size > static_cast<size_t>(SIZE_MAX / size)) { | |||
| std::string error = "args.src_data_size" + FmtToStr(args.src_data_size) + | |||
| " or data type size" + FmtToStr(size) + " is too big"; | |||
| GE_ERRORLOG_AND_ERRORMSG(PARAM_INVALID, error.c_str()); | |||
| return PARAM_INVALID; | |||
| GE_ERRORLOG_AND_ERRORMSG(ACL_ERROR_GE_PARAM_INVALID, error.c_str()); | |||
| return ACL_ERROR_GE_PARAM_INVALID; | |||
| } | |||
| size_t total_size = static_cast<size_t>(args.src_data_size * size); | |||
| result.length = total_size; | |||
| @@ -154,8 +154,8 @@ Status DataTypeTransfer::TransDataType(const CastArgs &args, TransResult &result | |||
| std::shared_ptr<uint8_t> dst(new (std::nothrow) uint8_t[total_size], std::default_delete<uint8_t[]>()); | |||
| if (dst == nullptr) { | |||
| GELOGE(OUT_OF_MEMORY, "Failed to alloc the memory for dst buf %zu, data size %zu", total_size, args.src_data_size); | |||
| return OUT_OF_MEMORY; | |||
| GELOGE(ACL_ERROR_GE_MEMORY_ALLOCATION, "Failed to alloc the memory for dst buf %zu, data size %zu", total_size, args.src_data_size); | |||
| return ACL_ERROR_GE_MEMORY_ALLOCATION; | |||
| } | |||
| if (CastKernel(args, dst.get(), args.src_data_size, trans_mode) != SUCCESS) { | |||
| @@ -163,8 +163,8 @@ Status DataTypeTransfer::TransDataType(const CastArgs &args, TransResult &result | |||
| FmtToStr(TypeUtils::DataTypeToSerialString(args.src_data_type)) + " to " + | |||
| FmtToStr(TypeUtils::DataTypeToSerialString(args.dst_data_type)) + ", data size is " + | |||
| FmtToStr(std::to_string(args.src_data_size)); | |||
| GE_ERRORLOG_AND_ERRORMSG(INTERNAL_ERROR, error.c_str()); | |||
| return INTERNAL_ERROR; | |||
| GE_ERRORLOG_AND_ERRORMSG(ACL_ERROR_GE_INTERNAL_ERROR, error.c_str()); | |||
| return ACL_ERROR_GE_INTERNAL_ERROR; | |||
| } | |||
| result.data = dst; | |||
| return SUCCESS; | |||
| @@ -39,22 +39,22 @@ Status CheckArgsForC1hwncoc0ToHwcn(const TransArgs &args) { | |||
| std::string error = "Dose not support trans format from " + | |||
| FmtToStr(TypeUtils::FormatToSerialString(args.src_format)) + " to " + | |||
| FmtToStr(TypeUtils::FormatToSerialString(args.dst_format)); | |||
| GE_ERRORLOG_AND_ERRORMSG(UNSUPPORTED, error.c_str()); | |||
| return UNSUPPORTED; | |||
| GE_ERRORLOG_AND_ERRORMSG(ACL_ERROR_GE_FORMAT_INVALID, error.c_str()); | |||
| return ACL_ERROR_GE_FORMAT_INVALID; | |||
| } | |||
| if (!CheckDataTypeSupported(args.src_data_type)) { | |||
| std::string error = "Failed to trans shape from NC1HWNCoC0 to HWCN, invalid data type" + | |||
| FmtToStr(TypeUtils::DataTypeToSerialString(args.src_data_type)); | |||
| GE_ERRORLOG_AND_ERRORMSG(UNSUPPORTED, error.c_str()); | |||
| return UNSUPPORTED; | |||
| GE_ERRORLOG_AND_ERRORMSG(ACL_ERROR_GE_DATATYPE_INVALID, error.c_str()); | |||
| return ACL_ERROR_GE_DATATYPE_INVALID; | |||
| } | |||
| if (!CheckShapeValid(src_shape, kC1hwncoc0DimsNum)) { | |||
| GELOGE(PARAM_INVALID, "Failed to check src shape %s", ShapeToString(src_shape).c_str()); | |||
| return PARAM_INVALID; | |||
| GELOGE(ACL_ERROR_GE_SHAPE_INVALID, "Failed to check src shape %s", ShapeToString(src_shape).c_str()); | |||
| return ACL_ERROR_GE_SHAPE_INVALID; | |||
| } | |||
| if (!CheckShapeValid(dst_shape, kHwcnDimsNum)) { | |||
| GELOGE(PARAM_INVALID, "Failed to check dst shape %s", ShapeToString(dst_shape).c_str()); | |||
| return PARAM_INVALID; | |||
| GELOGE(ACL_ERROR_GE_SHAPE_INVALID, "Failed to check dst shape %s", ShapeToString(dst_shape).c_str()); | |||
| return ACL_ERROR_GE_SHAPE_INVALID; | |||
| } | |||
| auto cube_size = GetCubeSizeByDataType(args.src_data_type); | |||
| if (src_shape.at(kC1hwncoc0C1) != (dst_shape.at(kHwcnC) - 1) / cube_size + 1 || | |||
| @@ -63,8 +63,8 @@ Status CheckArgsForC1hwncoc0ToHwcn(const TransArgs &args) { | |||
| src_shape.at(kC1hwncoc0C0) != cube_size) { | |||
| std::string error = "Failed to check relationship between src and dst shape, src shape" + | |||
| FmtToStr(ShapeToString(src_shape)) + ", dst shape" + FmtToStr(ShapeToString(dst_shape)); | |||
| GE_ERRORLOG_AND_ERRORMSG(PARAM_INVALID, error.c_str()); | |||
| return PARAM_INVALID; | |||
| GE_ERRORLOG_AND_ERRORMSG(ACL_ERROR_GE_SHAPE_INVALID, error.c_str()); | |||
| return ACL_ERROR_GE_SHAPE_INVALID; | |||
| } | |||
| return SUCCESS; | |||
| @@ -73,10 +73,10 @@ Status CheckArgsForC1hwncoc0ToHwcn(const TransArgs &args) { | |||
| Status GetDstDataAfterTrans(const TransArgs &args, TransResult &result, int size, int64_t total_size) { | |||
| std::shared_ptr<uint8_t> dst(new (std::nothrow) uint8_t[total_size], std::default_delete<uint8_t[]>()); | |||
| if (dst == nullptr) { | |||
| GELOGE(OUT_OF_MEMORY, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld, shape %s", | |||
| GELOGE(ACL_ERROR_GE_MEMORY_ALLOCATION, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld, shape %s", | |||
| TypeUtils::FormatToSerialString(args.src_format).c_str(), | |||
| TypeUtils::FormatToSerialString(args.dst_format).c_str(), total_size, ShapeToString(args.dst_shape).c_str()); | |||
| return OUT_OF_MEMORY; | |||
| return ACL_ERROR_GE_MEMORY_ALLOCATION; | |||
| } | |||
| auto h = args.src_shape.at(kC1hwncoc0H); | |||
| @@ -114,12 +114,12 @@ Status GetDstDataAfterTrans(const TransArgs &args, TransResult &result, int size | |||
| auto ret = memcpy_s(dst.get() + dst_offset, static_cast<size_t>(protected_size), args.data + src_offset, | |||
| static_cast<size_t>(size)); | |||
| if (ret != EOK) { | |||
| GELOGE(INTERNAL_ERROR, | |||
| GELOGE(ACL_ERROR_GE_MEMORY_OPERATE_FAILED, | |||
| "Failed to copy data from C1HWNCoC0[%ld, %ld, %ld, %ld, %ld, %ld] offset %ld to " | |||
| "HWCN[%ld, %ld, %ld, %ld] offset %ld, err-code %d", | |||
| c1_idx, h_idx, w_idx, n_idx, co_idx, c0_idx, src_offset, h_idx, w_idx, c_idx, n_idx, dst_offset, | |||
| ret); | |||
| return INTERNAL_ERROR; | |||
| return ACL_ERROR_GE_MEMORY_OPERATE_FAILED; | |||
| } | |||
| } | |||
| } | |||
| @@ -132,8 +132,9 @@ Status GetDstDataAfterTrans(const TransArgs &args, TransResult &result, int size | |||
| } // namespace | |||
| Status FormatTransferC1hwncoc0Hwcn::TransFormat(const TransArgs &args, TransResult &result) { | |||
| if (CheckArgsForC1hwncoc0ToHwcn(args) != SUCCESS) { | |||
| return PARAM_INVALID; | |||
| Status ret = CheckArgsForC1hwncoc0ToHwcn(args); | |||
| if (ret != SUCCESS) { | |||
| return ret; | |||
| } | |||
| int size = GetSizeByDataType(args.src_data_type); | |||
| int64_t total_size = GetItemNumByShape(args.dst_shape) * size; | |||
| @@ -143,18 +144,19 @@ Status FormatTransferC1hwncoc0Hwcn::TransFormat(const TransArgs &args, TransResu | |||
| result.length = static_cast<size_t>(total_size); | |||
| return SUCCESS; | |||
| } | |||
| GELOGE(INTERNAL_ERROR, "Get %ld total size from dst shape %s, src shape %s", total_size, | |||
| GELOGE(ACL_ERROR_GE_SHAPE_INVALID, "Get %ld total size from dst shape %s, src shape %s", total_size, | |||
| ShapeToString(args.dst_shape).c_str(), ShapeToString(args.src_shape).c_str()); | |||
| return PARAM_INVALID; | |||
| return ACL_ERROR_GE_SHAPE_INVALID; | |||
| } | |||
| GELOGD("Begin to trans format from C1HWNCoC0 to HWCN, src shape %s, data type %s, dst shape %s, memory size %ld", | |||
| ShapeToString(args.src_shape).c_str(), TypeUtils::DataTypeToSerialString(args.src_data_type).c_str(), | |||
| ShapeToString(args.dst_shape).c_str(), total_size); | |||
| if (GetDstDataAfterTrans(args, result, size, total_size) != SUCCESS) { | |||
| GELOGE(INTERNAL_ERROR, "Failed to get data after trans, src shape %s, data type %s, dst shape %s, memory size %ld", | |||
| ret = GetDstDataAfterTrans(args, result, size, total_size); | |||
| if (ret != SUCCESS) { | |||
| GELOGE(ret, "Failed to get data after trans, src shape %s, data type %s, dst shape %s, memory size %ld", | |||
| ShapeToString(args.src_shape).c_str(), TypeUtils::DataTypeToSerialString(args.src_data_type).c_str(), | |||
| ShapeToString(args.dst_shape).c_str(), total_size); | |||
| return INTERNAL_ERROR; | |||
| return ret; | |||
| } | |||
| return SUCCESS; | |||
| } | |||
| @@ -162,7 +164,7 @@ Status FormatTransferC1hwncoc0Hwcn::TransFormat(const TransArgs &args, TransResu | |||
| Status FormatTransferC1hwncoc0Hwcn::TransShape(Format src_format, const std::vector<int64_t> &src_shape, | |||
| DataType data_type, Format dst_format, std::vector<int64_t> &dst_shape) { | |||
| GELOGD("The shape derivation from C1HWNCoC0 to HWCN is not unique. Trans shape in this direction is not supported"); | |||
| return ACL_ERROR_GE_TRANSSHAPE_FORMAT_INVALID; | |||
| return ACL_ERROR_GE_FORMAT_INVALID; | |||
| } | |||
| REGISTER_FORMAT_TRANSFER(FormatTransferC1hwncoc0Hwcn, FORMAT_C1HWNCoC0, FORMAT_HWCN) | |||
| @@ -32,7 +32,7 @@ Status TransShapeToFz(int64_t d, int64_t n, int64_t c, int64_t h, int64_t w, Dat | |||
| std::vector<int64_t> &dst_shape) { | |||
| auto c0 = GetCubeSizeByDataType(data_type); | |||
| if (c0 < 0) { | |||
| return ACL_ERROR_GE_TRANSSHAPE_DATATYPE_INVALID; | |||
| return ACL_ERROR_GE_DATATYPE_INVALID; | |||
| } | |||
| auto c1 = Ceil(c, c0); | |||
| @@ -50,7 +50,7 @@ Status TransShapeToFz(int64_t d, int64_t n, int64_t c, int64_t h, int64_t w, Dat | |||
| Status TransShapeDhwckToFz3D(const std::vector<int64_t> &src_shape, DataType data_type, | |||
| std::vector<int64_t> &dst_shape) { | |||
| if (!CheckShapeValid(src_shape, kDhwcnDimsNum)) { | |||
| return ACL_ERROR_GE_TRANSSHAPE_SHAPE_INVALID; | |||
| return ACL_ERROR_GE_SHAPE_INVALID; | |||
| } | |||
| auto d = src_shape.at(kDhwcnD); | |||
| auto h = src_shape.at(kDhwcnH); | |||
| @@ -62,7 +62,7 @@ Status TransShapeDhwckToFz3D(const std::vector<int64_t> &src_shape, DataType dat | |||
| } | |||
| Status TransFormatDhwckToFz3D(const TransArgs &args, TransResult &result) { | |||
| if (!CheckShapeValid(args.src_shape, kDhwcnDimsNum)) { | |||
| return PARAM_INVALID; | |||
| return ACL_ERROR_GE_SHAPE_INVALID; | |||
| } | |||
| int64_t d = args.src_shape[kDhwcnD]; | |||
| int64_t h = args.src_shape[kDhwcnH]; | |||
| @@ -94,10 +94,10 @@ Status TransFormatDhwckToFz3D(const TransArgs &args, TransResult &result) { | |||
| std::shared_ptr<uint8_t> dst(new (std::nothrow) uint8_t[dst_size], std::default_delete<uint8_t[]>()); | |||
| if (dst == nullptr) { | |||
| GELOGE(OUT_OF_MEMORY, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld", | |||
| GELOGE(ACL_ERROR_GE_MEMORY_ALLOCATION, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld", | |||
| TypeUtils::FormatToSerialString(args.src_format).c_str(), | |||
| TypeUtils::FormatToSerialString(args.dst_format).c_str(), dst_size); | |||
| return OUT_OF_MEMORY; | |||
| return ACL_ERROR_GE_MEMORY_ALLOCATION; | |||
| } | |||
| for (int64_t di = 0; di < d; di++) { | |||
| @@ -122,9 +122,9 @@ Status TransFormatDhwckToFz3D(const TransArgs &args, TransResult &result) { | |||
| args.data + src_idx * data_size, static_cast<size_t>(data_size)); | |||
| } | |||
| if (ret != EOK) { | |||
| GELOGE(INTERNAL_ERROR, "Failed to operate the dst memory at offset %ld, error-code %d, pad mode %d", | |||
| GELOGE(ACL_ERROR_GE_MEMORY_OPERATE_FAILED, "Failed to operate the dst memory at offset %ld, error-code %d, pad mode %d", | |||
| dst_offset, ret, pad_zero); | |||
| return INTERNAL_ERROR; | |||
| return ACL_ERROR_GE_MEMORY_OPERATE_FAILED; | |||
| } | |||
| } | |||
| } | |||
| @@ -149,28 +149,28 @@ Status FormatTransferDhwcnFractalZ3D::TransFormat(const TransArgs &args, TransRe | |||
| return ret; | |||
| } | |||
| if (!IsTransShapeDstCorrect(args, expect_shape)) { | |||
| return PARAM_INVALID; | |||
| return ACL_ERROR_GE_SHAPE_INVALID; | |||
| } | |||
| if (args.src_format == FORMAT_DHWCN && args.dst_format == FORMAT_FRACTAL_Z_3D) { | |||
| return TransFormatDhwckToFz3D(args, result); | |||
| } | |||
| return UNSUPPORTED; | |||
| return ACL_ERROR_GE_FORMAT_INVALID; | |||
| } | |||
| Status FormatTransferDhwcnFractalZ3D::TransShape(Format src_format, const std::vector<int64_t> &src_shape, | |||
| DataType data_type, Format dst_format, | |||
| std::vector<int64_t> &dst_shape) { | |||
| if (CheckDataTypeSupport(data_type) != SUCCESS) { | |||
| return ACL_ERROR_GE_TRANSSHAPE_DATATYPE_INVALID; | |||
| return ACL_ERROR_GE_DATATYPE_INVALID; | |||
| } | |||
| if (src_format == FORMAT_DHWCN && dst_format == FORMAT_FRACTAL_Z_3D) { | |||
| return TransShapeDhwckToFz3D(src_shape, data_type, dst_shape); | |||
| } | |||
| return ACL_ERROR_GE_TRANSSHAPE_FORMAT_INVALID; | |||
| return ACL_ERROR_GE_FORMAT_INVALID; | |||
| } | |||
| REGISTER_FORMAT_TRANSFER(FormatTransferDhwcnFractalZ3D, FORMAT_DHWCN, FORMAT_FRACTAL_Z_3D) | |||
| @@ -32,7 +32,7 @@ Status TransShapeToFz(int64_t d, int64_t n, int64_t c, int64_t h, int64_t w, Dat | |||
| std::vector<int64_t> &dst_shape) { | |||
| auto c0 = GetCubeSizeByDataType(data_type); | |||
| if (c0 < 0) { | |||
| return ACL_ERROR_GE_TRANSSHAPE_DATATYPE_INVALID; | |||
| return ACL_ERROR_GE_DATATYPE_INVALID; | |||
| } | |||
| auto c1 = Ceil(c, c0); | |||
| @@ -50,7 +50,7 @@ Status TransShapeToFz(int64_t d, int64_t n, int64_t c, int64_t h, int64_t w, Dat | |||
| Status TransShapeDhwncToFz3DTranspose(const std::vector<int64_t> &src_shape, DataType data_type, | |||
| std::vector<int64_t> &dst_shape) { | |||
| if (!CheckShapeValid(src_shape, kDhwncDimsNum)) { | |||
| return ACL_ERROR_GE_TRANSSHAPE_SHAPE_INVALID; | |||
| return ACL_ERROR_GE_SHAPE_INVALID; | |||
| } | |||
| auto d = src_shape.at(kDhwncD); | |||
| auto h = src_shape.at(kDhwncH); | |||
| @@ -62,7 +62,7 @@ Status TransShapeDhwncToFz3DTranspose(const std::vector<int64_t> &src_shape, Dat | |||
| } | |||
| Status TransFormatDhwncToFz3DTranspose(const TransArgs &args, TransResult &result) { | |||
| if (!CheckShapeValid(args.src_shape, kDhwncDimsNum)) { | |||
| return PARAM_INVALID; | |||
| return ACL_ERROR_GE_SHAPE_INVALID; | |||
| } | |||
| int64_t d = args.src_shape[kDhwncD]; | |||
| int64_t h = args.src_shape[kDhwncH]; | |||
| @@ -95,10 +95,10 @@ Status TransFormatDhwncToFz3DTranspose(const TransArgs &args, TransResult &resul | |||
| std::shared_ptr<uint8_t> dst(new (std::nothrow) uint8_t[dst_size], std::default_delete<uint8_t[]>()); | |||
| if (dst == nullptr) { | |||
| GELOGE(OUT_OF_MEMORY, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld", | |||
| GELOGE(ACL_ERROR_GE_MEMORY_ALLOCATION, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld", | |||
| TypeUtils::FormatToSerialString(args.src_format).c_str(), | |||
| TypeUtils::FormatToSerialString(args.dst_format).c_str(), dst_size); | |||
| return OUT_OF_MEMORY; | |||
| return ACL_ERROR_GE_MEMORY_ALLOCATION; | |||
| } | |||
| for (int64_t di = 0; di < d; di++) { | |||
| @@ -123,9 +123,9 @@ Status TransFormatDhwncToFz3DTranspose(const TransArgs &args, TransResult &resul | |||
| args.data + src_idx * data_size, static_cast<size_t>(data_size)); | |||
| } | |||
| if (ret != EOK) { | |||
| GELOGE(INTERNAL_ERROR, "Failed to operate the dst memory at offset %ld, error-code %d, pad mode %d", | |||
| GELOGE(ACL_ERROR_GE_MEMORY_OPERATE_FAILED, "Failed to operate the dst memory at offset %ld, error-code %d, pad mode %d", | |||
| dst_offset, ret, pad_zero); | |||
| return INTERNAL_ERROR; | |||
| return ACL_ERROR_GE_MEMORY_OPERATE_FAILED; | |||
| } | |||
| } | |||
| } | |||
| @@ -150,28 +150,28 @@ Status FormatTransferDhwncFractalZ3DTranspose::TransFormat(const TransArgs &args | |||
| return ret; | |||
| } | |||
| if (!IsTransShapeDstCorrect(args, expect_shape)) { | |||
| return PARAM_INVALID; | |||
| return ACL_ERROR_GE_SHAPE_INVALID; | |||
| } | |||
| if (args.src_format == ge::FORMAT_DHWNC && args.dst_format == ge::FORMAT_FRACTAL_Z_3D_TRANSPOSE) { | |||
| return TransFormatDhwncToFz3DTranspose(args, result); | |||
| } | |||
| return UNSUPPORTED; | |||
| return ACL_ERROR_GE_FORMAT_INVALID; | |||
| } | |||
| Status FormatTransferDhwncFractalZ3DTranspose::TransShape(Format src_format, const std::vector<int64_t> &src_shape, | |||
| DataType data_type, Format dst_format, | |||
| std::vector<int64_t> &dst_shape) { | |||
| if (CheckDataTypeSupport(data_type) != SUCCESS) { | |||
| return ACL_ERROR_GE_TRANSSHAPE_DATATYPE_INVALID; | |||
| return ACL_ERROR_GE_DATATYPE_INVALID; | |||
| } | |||
| if (src_format == FORMAT_DHWNC && dst_format == FORMAT_FRACTAL_Z_3D_TRANSPOSE) { | |||
| return TransShapeDhwncToFz3DTranspose(src_shape, data_type, dst_shape); | |||
| } | |||
| return ACL_ERROR_GE_TRANSSHAPE_FORMAT_INVALID; | |||
| return ACL_ERROR_GE_FORMAT_INVALID; | |||
| } | |||
| REGISTER_FORMAT_TRANSFER(FormatTransferDhwncFractalZ3DTranspose, FORMAT_DHWNC, FORMAT_FRACTAL_Z_3D_TRANSPOSE) | |||
| @@ -87,8 +87,8 @@ Status TransShapeToFracNz(const ShapeVector &src_shape, DataType data_type, Shap | |||
| hw_shape.push_back(DIM_DEFAULT_VALUE); | |||
| hw_shape.push_back(src_shape[kNdDimIndexN]); | |||
| if (!IsShapeValid(dst_shape)) { | |||
| GELOGE(ACL_ERROR_GE_TRANSSHAPE_SHAPE_INVALID, "Failed to check dst shape %s", ShapeToString(dst_shape).c_str()); | |||
| return ACL_ERROR_GE_TRANSSHAPE_SHAPE_INVALID; | |||
| GELOGE(ACL_ERROR_GE_SHAPE_INVALID, "Failed to check dst shape %s", ShapeToString(dst_shape).c_str()); | |||
| return ACL_ERROR_GE_SHAPE_INVALID; | |||
| } | |||
| return SUCCESS; | |||
| default: | |||
| @@ -106,8 +106,8 @@ Status TransShapeToFracNz(const ShapeVector &src_shape, DataType data_type, Shap | |||
| hw_shape.push_back(src_shape[size - kNdDimCountBackwardsWH]); | |||
| hw_shape.push_back(src_shape[size - kNdDimCountBackwardsW]); | |||
| if (!IsShapeValid(dst_shape)) { | |||
| GELOGE(ACL_ERROR_GE_TRANSSHAPE_SHAPE_INVALID, "Failed to check dst shape %s", ShapeToString(dst_shape).c_str()); | |||
| return ACL_ERROR_GE_TRANSSHAPE_SHAPE_INVALID; | |||
| GELOGE(ACL_ERROR_GE_SHAPE_INVALID, "Failed to check dst shape %s", ShapeToString(dst_shape).c_str()); | |||
| return ACL_ERROR_GE_SHAPE_INVALID; | |||
| } | |||
| return SUCCESS; | |||
| } | |||
| @@ -117,14 +117,14 @@ Status CheckShapeRelation(const TransArgs &args, ShapeVector &hw_shape) { | |||
| ShapeVector expect_src_shape; | |||
| auto ret = TransShapeToFracNz(args.dst_shape, args.src_data_type, expect_src_shape, hw_shape); | |||
| if (ret != SUCCESS) { | |||
| GELOGE(INTERNAL_ERROR, "Trans shape from %s to %s, shape %s to %s, data type %s failed", | |||
| GELOGE(ret, "Trans shape from %s to %s, shape %s to %s, data type %s failed", | |||
| TypeUtils::FormatToSerialString(args.dst_format).c_str(), | |||
| TypeUtils::FormatToSerialString(args.src_format).c_str(), ShapeToString(args.dst_shape).c_str(), | |||
| ShapeToString(args.src_shape).c_str(), TypeUtils::DataTypeToSerialString(args.src_data_type).c_str()); | |||
| return INTERNAL_ERROR; | |||
| return ret; | |||
| } | |||
| if (!IsTransShapeSrcCorrect(args, expect_src_shape)) { | |||
| return PARAM_INVALID; | |||
| return ACL_ERROR_GE_SHAPE_INVALID; | |||
| } | |||
| return SUCCESS; | |||
| } | |||
| @@ -139,10 +139,10 @@ Status TransFormatFromNdToFracNz(const TransArgs &args, TransResult &result, con | |||
| std::shared_ptr<uint8_t> dst(new (std::nothrow) uint8_t[dst_size](), std::default_delete<uint8_t[]>()); | |||
| if (dst == nullptr) { | |||
| GELOGE(OUT_OF_MEMORY, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld", | |||
| GELOGE(ACL_ERROR_GE_MEMORY_ALLOCATION, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld", | |||
| TypeUtils::FormatToSerialString(args.src_format).c_str(), | |||
| TypeUtils::FormatToSerialString(args.dst_format).c_str(), dst_size); | |||
| return OUT_OF_MEMORY; | |||
| return ACL_ERROR_GE_MEMORY_ALLOCATION; | |||
| } | |||
| // src&dst_shape can be written as times*H*W & times*W1*H1*H0*W0, respectively. dst_shape_size >= kDimNum4D | |||
| @@ -175,8 +175,8 @@ Status TransFormatFromNdToFracNz(const TransArgs &args, TransResult &result, con | |||
| auto ret = memcpy_s(dst.get() + dst_offset, static_cast<size_t>(protected_size), args.data + src_offset, | |||
| static_cast<size_t>(size * w0)); | |||
| if (ret != EOK) { | |||
| GELOGE(INTERNAL_ERROR, "Failed to operate the dst memory at offset %ld, error-code %d", dst_offset, ret); | |||
| return INTERNAL_ERROR; | |||
| GELOGE(ACL_ERROR_GE_MEMORY_OPERATE_FAILED, "Failed to operate the dst memory at offset %ld, error-code %d", dst_offset, ret); | |||
| return ACL_ERROR_GE_MEMORY_OPERATE_FAILED; | |||
| } | |||
| } | |||
| auto w1_head = num_w1 * w0; | |||
| @@ -189,8 +189,8 @@ Status TransFormatFromNdToFracNz(const TransArgs &args, TransResult &result, con | |||
| auto ret = memcpy_s(dst.get() + dst_offset, static_cast<size_t>(protected_size), args.data + src_offset, | |||
| static_cast<size_t>(size)); | |||
| if (ret != EOK) { | |||
| GELOGE(INTERNAL_ERROR, "Failed to operate the dst memory at offset %ld, error-code %d", dst_offset, ret); | |||
| return INTERNAL_ERROR; | |||
| GELOGE(ACL_ERROR_GE_MEMORY_OPERATE_FAILED, "Failed to operate the dst memory at offset %ld, error-code %d", dst_offset, ret); | |||
| return ACL_ERROR_GE_MEMORY_OPERATE_FAILED; | |||
| } | |||
| } | |||
| } | |||
| @@ -210,10 +210,10 @@ Status TransFormatFromFracNzToNd(const TransArgs &args, TransResult &result, con | |||
| std::shared_ptr<uint8_t> dst(new (std::nothrow) uint8_t[dst_size], std::default_delete<uint8_t[]>()); | |||
| if (dst == nullptr) { | |||
| GELOGE(OUT_OF_MEMORY, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld", | |||
| GELOGE(ACL_ERROR_GE_MEMORY_ALLOCATION, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld", | |||
| TypeUtils::FormatToSerialString(args.src_format).c_str(), | |||
| TypeUtils::FormatToSerialString(args.dst_format).c_str(), dst_size); | |||
| return OUT_OF_MEMORY; | |||
| return ACL_ERROR_GE_MEMORY_ALLOCATION; | |||
| } | |||
| auto times = dst_hw_shape.at(kNdDimIndexN); | |||
| @@ -246,8 +246,8 @@ Status TransFormatFromFracNzToNd(const TransArgs &args, TransResult &result, con | |||
| ret = memcpy_s(dst.get() + dst_offset, static_cast<size_t>(protected_size), args.data + src_offset, | |||
| static_cast<size_t>(size * w0)); | |||
| if (ret != EOK) { | |||
| GELOGE(INTERNAL_ERROR, "Failed to operate the dst memory at offset %ld, error-code %d", dst_offset, ret); | |||
| return INTERNAL_ERROR; | |||
| GELOGE(ACL_ERROR_GE_MEMORY_OPERATE_FAILED, "Failed to operate the dst memory at offset %ld, error-code %d", dst_offset, ret); | |||
| return ACL_ERROR_GE_MEMORY_OPERATE_FAILED; | |||
| } | |||
| } | |||
| auto w1_head = num_w1 * w0; | |||
| @@ -260,8 +260,8 @@ Status TransFormatFromFracNzToNd(const TransArgs &args, TransResult &result, con | |||
| ret = memcpy_s(dst.get() + dst_offset, static_cast<size_t>(protected_size), args.data + src_offset, | |||
| static_cast<size_t>(size)); | |||
| if (ret != EOK) { | |||
| GELOGE(INTERNAL_ERROR, "Failed to operate the dst memory at offset %ld, error-code %d", dst_offset, ret); | |||
| return INTERNAL_ERROR; | |||
| GELOGE(ACL_ERROR_GE_MEMORY_OPERATE_FAILED, "Failed to operate the dst memory at offset %ld, error-code %d", dst_offset, ret); | |||
| return ACL_ERROR_GE_MEMORY_OPERATE_FAILED; | |||
| } | |||
| } | |||
| } | |||
| @@ -273,13 +273,19 @@ Status TransFormatFromFracNzToNd(const TransArgs &args, TransResult &result, con | |||
| } // namespace | |||
| Status FormatTransferFractalNz::TransFormat(const TransArgs &args, TransResult &result) { | |||
| if (!IsDataTypeSupport(args.src_data_type) || !CheckShape(args.src_format, args.src_shape) || | |||
| !IsShapeValid(args.dst_shape)) { | |||
| GELOGE(PARAM_INVALID, "Trans format from %s to %s, src shape %s, dst shape %s, data type %s is not supported", | |||
| if (!IsDataTypeSupport(args.src_data_type)) { | |||
| GELOGE(ACL_ERROR_GE_DATATYPE_INVALID, "Trans format from %s to %s, src shape %s, dst shape %s, data type %s is not supported", | |||
| TypeUtils::FormatToSerialString(args.src_format).c_str(), | |||
| TypeUtils::FormatToSerialString(args.dst_format).c_str(), ShapeToString(args.src_shape).c_str(), | |||
| ShapeToString(args.dst_shape).c_str(), TypeUtils::DataTypeToSerialString(args.src_data_type).c_str()); | |||
| return ACL_ERROR_GE_DATATYPE_INVALID; | |||
| } | |||
| if (!CheckShape(args.src_format, args.src_shape) || !IsShapeValid(args.dst_shape)) { | |||
| GELOGE(ACL_ERROR_GE_SHAPE_INVALID, "Trans format from %s to %s, src shape %s, dst shape %s, data type %s is not supported", | |||
| TypeUtils::FormatToSerialString(args.src_format).c_str(), | |||
| TypeUtils::FormatToSerialString(args.dst_format).c_str(), ShapeToString(args.src_shape).c_str(), | |||
| ShapeToString(args.dst_shape).c_str(), TypeUtils::DataTypeToSerialString(args.src_data_type).c_str()); | |||
| return PARAM_INVALID; | |||
| return ACL_ERROR_GE_SHAPE_INVALID; | |||
| } | |||
| GELOGD("Begin to trans format from %s to %s, src shape %s, dst shape %s, data type %s", | |||
| TypeUtils::FormatToSerialString(args.src_format).c_str(), | |||
| @@ -292,7 +298,7 @@ Status FormatTransferFractalNz::TransFormat(const TransArgs &args, TransResult & | |||
| return ret; | |||
| } | |||
| if (!IsTransShapeDstCorrect(args, expect_shape)) { | |||
| return PARAM_INVALID; | |||
| return ACL_ERROR_GE_SHAPE_INVALID; | |||
| } | |||
| return TransFormatFromNdToFracNz(args, result, hw_shape); | |||
| } | |||
| @@ -300,31 +306,38 @@ Status FormatTransferFractalNz::TransFormat(const TransArgs &args, TransResult & | |||
| Status FormatTransferFractalNz::TransShape(Format src_format, const ShapeVector &src_shape, DataType data_type, | |||
| Format dst_format, ShapeVector &dst_shape) { | |||
| if (!IsDataTypeSupport(data_type)) { | |||
| GELOGE(ACL_ERROR_GE_TRANSSHAPE_DATATYPE_INVALID, | |||
| GELOGE(ACL_ERROR_GE_DATATYPE_INVALID, | |||
| "Trans format from %s to %s, src shape %s, data type %s is not supported", | |||
| TypeUtils::FormatToSerialString(src_format).c_str(), TypeUtils::FormatToSerialString(dst_format).c_str(), | |||
| ShapeToString(src_shape).c_str(), TypeUtils::DataTypeToSerialString(data_type).c_str()); | |||
| return ACL_ERROR_GE_TRANSSHAPE_DATATYPE_INVALID; | |||
| return ACL_ERROR_GE_DATATYPE_INVALID; | |||
| } | |||
| if (!CheckShape(src_format, src_shape)) { | |||
| GELOGE(ACL_ERROR_GE_TRANSSHAPE_SHAPE_INVALID, | |||
| GELOGE(ACL_ERROR_GE_SHAPE_INVALID, | |||
| "Trans format from %s to %s, src shape %s, data type %s is not supported", | |||
| TypeUtils::FormatToSerialString(src_format).c_str(), TypeUtils::FormatToSerialString(dst_format).c_str(), | |||
| ShapeToString(src_shape).c_str(), TypeUtils::DataTypeToSerialString(data_type).c_str()); | |||
| return ACL_ERROR_GE_TRANSSHAPE_SHAPE_INVALID; | |||
| return ACL_ERROR_GE_SHAPE_INVALID; | |||
| } | |||
| ShapeVector hw_shape; | |||
| return TransShapeToFracNz(src_shape, data_type, dst_shape, hw_shape); | |||
| } | |||
| Status FormatTransferFractalNzND::TransFormat(const TransArgs &args, TransResult &result) { | |||
| if (!IsDataTypeSupport(args.src_data_type) || !IsShapeValid(args.src_shape) || | |||
| !CheckShape(args.dst_format, args.dst_shape)) { | |||
| GELOGE(PARAM_INVALID, "Trans format from %s to %s, src shape %s, dst shape %s, data type %s is not supported", | |||
| if (!IsDataTypeSupport(args.src_data_type)) { | |||
| GELOGE(ACL_ERROR_GE_DATATYPE_INVALID, "Trans format from %s to %s, src shape %s, dst shape %s, data type %s is not supported", | |||
| TypeUtils::FormatToSerialString(args.src_format).c_str(), | |||
| TypeUtils::FormatToSerialString(args.dst_format).c_str(), ShapeToString(args.src_shape).c_str(), | |||
| ShapeToString(args.dst_shape).c_str(), TypeUtils::DataTypeToSerialString(args.src_data_type).c_str()); | |||
| return ACL_ERROR_GE_DATATYPE_INVALID; | |||
| } | |||
| if (!IsShapeValid(args.src_shape) || !CheckShape(args.dst_format, args.dst_shape)) { | |||
| GELOGE(ACL_ERROR_GE_SHAPE_INVALID, "Trans format from %s to %s, src shape %s, dst shape %s, data type %s is not supported", | |||
| TypeUtils::FormatToSerialString(args.src_format).c_str(), | |||
| TypeUtils::FormatToSerialString(args.dst_format).c_str(), ShapeToString(args.src_shape).c_str(), | |||
| ShapeToString(args.dst_shape).c_str(), TypeUtils::DataTypeToSerialString(args.src_data_type).c_str()); | |||
| return PARAM_INVALID; | |||
| return ACL_ERROR_GE_SHAPE_INVALID; | |||
| } | |||
| GELOGD("Begin to trans format from %s to %s, src shape %s, dst shape %s, data type %s", | |||
| TypeUtils::FormatToSerialString(args.src_format).c_str(), | |||
| @@ -332,8 +345,9 @@ Status FormatTransferFractalNzND::TransFormat(const TransArgs &args, TransResult | |||
| ShapeToString(args.dst_shape).c_str(), TypeUtils::DataTypeToSerialString(args.src_data_type).c_str()); | |||
| ShapeVector hw_shape; | |||
| if (CheckShapeRelation(args, hw_shape) != SUCCESS) { | |||
| return PARAM_INVALID; | |||
| Status ret = CheckShapeRelation(args, hw_shape); | |||
| if (ret != SUCCESS) { | |||
| return ret; | |||
| } | |||
| return TransFormatFromFracNzToNd(args, result, hw_shape); | |||
| } | |||
| @@ -342,7 +356,7 @@ Status FormatTransferFractalNzND::TransShape(Format src_format, const ShapeVecto | |||
| Format dst_format, ShapeVector &dst_shape) { | |||
| GELOGD("The shape derivation from %s to %s is not unique. Trans shape is not supported", | |||
| TypeUtils::FormatToSerialString(src_format).c_str(), TypeUtils::FormatToSerialString(dst_format).c_str()); | |||
| return ACL_ERROR_GE_TRANSSHAPE_FORMAT_INVALID; | |||
| return ACL_ERROR_GE_FORMAT_INVALID; | |||
| } | |||
| REGISTER_FORMAT_TRANSFER(FormatTransferFractalNz, FORMAT_ND, FORMAT_FRACTAL_NZ) | |||
| @@ -84,38 +84,9 @@ Status TransShapeToFz(int64_t n, int64_t c, int64_t h, int64_t w, DataType data_ | |||
| return SUCCESS; | |||
| } | |||
| Status TransShapeToFzWithGroups(int64_t n, int64_t c, int64_t h, int64_t w, DataType data_type, std::vector<int64_t> &dst_shape | |||
| , int64_t groups) { | |||
| auto c0 = GetCubeSizeByDataType(data_type); | |||
| if (c0 < 0) { | |||
| return ACL_ERROR_GE_DATATYPE_INVALID; | |||
| } | |||
| int64_t cin_ori = c; | |||
| int64_t cout_ori = n / groups; | |||
| int64_t cube_k = data_type == DT_INT8 ? 32 : 16; | |||
| int64_t e_mult = std::min( | |||
| Lcm(Lcm(cin_ori, cube_k) / (cin_ori), Lcm(cout_ori, kCubeN) / (cout_ori)), | |||
| groups); | |||
| int64_t cin_opt = Ceil(e_mult * cin_ori, cube_k) * cube_k; | |||
| int64_t c1_dim = cin_opt / cube_k; | |||
| int64_t g_dim = Ceil(groups, e_mult); | |||
| auto n1 = Ceil(n , kCubeN); | |||
| dst_shape.clear(); | |||
| dst_shape.push_back(g_dim * c1_dim * h * w); | |||
| dst_shape.push_back(n1); | |||
| dst_shape.push_back(16); | |||
| dst_shape.push_back(cube_k); | |||
| if (!IsShapeValid(dst_shape)) { | |||
| GELOGE(ACL_ERROR_GE_SHAPE_INVALID, "Failed to check dst shape %s", | |||
| ShapeToString(dst_shape).c_str()); | |||
| return ACL_ERROR_GE_SHAPE_INVALID; | |||
| } | |||
| return SUCCESS; | |||
| } | |||
| Status TransShapeNchwToFz(const std::vector<int64_t> &src_shape, DataType data_type, std::vector<int64_t> &dst_shape) { | |||
| if (!CheckShapeValid(src_shape, kNchwDimsNum)) { | |||
| return ACL_ERROR_GE_TRANSSHAPE_SHAPE_INVALID; | |||
| return ACL_ERROR_GE_SHAPE_INVALID; | |||
| } | |||
| auto n = src_shape.at(kNchwN); | |||
| @@ -373,9 +344,9 @@ Status TransFormatHwcnToFz(const TransArgs &args, TransResult &result) { | |||
| static_cast<size_t>(data_size)); | |||
| } else { | |||
| if (protected_size < data_size) { | |||
| GELOGE(INTERNAL_ERROR, "Failed to operate the dst memory, protected_size is %ld and size is %ld", | |||
| GELOGE(ACL_ERROR_GE_PARAM_INVALID, "Failed to operate the dst memory, protected_size is %ld and size is %ld", | |||
| protected_size, data_size); | |||
| return INTERNAL_ERROR; | |||
| return ACL_ERROR_GE_PARAM_INVALID; | |||
| } | |||
| int64_t src_idx = hi * wcn + wi * cn + (c1i * c0 + c0i) * n + n1n0i; | |||
| char *dst_data = reinterpret_cast<char *>(dst.get() + dst_offset); | |||
| @@ -526,13 +497,14 @@ Status FormatTransferFractalZ::TransFormat(const TransArgs &args, TransResult &r | |||
| if (args.src_format == FORMAT_NCHW && args.dst_format == FORMAT_FRACTAL_Z) { | |||
| return TransFormatFromNchwToFz(args, result); | |||
| } | |||
| return ACL_ERROR_GE_FORMAT_INVALID; | |||
| } | |||
| Status FormatTransferFractalZ::TransShape(Format src_format, const std::vector<int64_t> &src_shape, DataType data_type, | |||
| Format dst_format, std::vector<int64_t> &dst_shape) { | |||
| if (CheckDataTypeSupport(data_type) != SUCCESS) { | |||
| return ACL_ERROR_GE_TRANSSHAPE_DATATYPE_INVALID; | |||
| return ACL_ERROR_GE_DATATYPE_INVALID; | |||
| } | |||
| if (src_format == FORMAT_NHWC && dst_format == FORMAT_FRACTAL_Z) { | |||
| @@ -545,7 +517,7 @@ Status FormatTransferFractalZ::TransShape(Format src_format, const std::vector<i | |||
| return TransShapeNchwToFz(src_shape, data_type, dst_shape); | |||
| } | |||
| return ACL_ERROR_GE_TRANSSHAPE_FORMAT_INVALID; | |||
| return ACL_ERROR_GE_FORMAT_INVALID; | |||
| } | |||
| REGISTER_FORMAT_TRANSFER(FormatTransferFractalZ, FORMAT_NCHW, FORMAT_FRACTAL_Z) | |||
| @@ -86,9 +86,9 @@ Status TransShapeToFracZz(const ShapeVector &src_shape, DataType data_type, Shap | |||
| hw_shape.push_back(DIM_DEFAULT_VALUE); | |||
| hw_shape.push_back(src_shape[kNdDimIndexN]); | |||
| if (!IsShapeValid(dst_shape)) { | |||
| GELOGE(ACL_ERROR_GE_TRANSSHAPE_SHAPE_INVALID, "Failed to check dst shape %s", | |||
| GELOGE(ACL_ERROR_GE_SHAPE_INVALID, "Failed to check dst shape %s", | |||
| ShapeToString(dst_shape).c_str()); | |||
| return ACL_ERROR_GE_TRANSSHAPE_SHAPE_INVALID; | |||
| return ACL_ERROR_GE_SHAPE_INVALID; | |||
| } | |||
| return SUCCESS; | |||
| default: | |||
| @@ -106,9 +106,9 @@ Status TransShapeToFracZz(const ShapeVector &src_shape, DataType data_type, Shap | |||
| hw_shape.push_back(src_shape[size - kNdDimCountBackwardsWH]); | |||
| hw_shape.push_back(src_shape[size - kNdDimCountBackwardsW]); | |||
| if (!IsShapeValid(dst_shape)) { | |||
| GELOGE(ACL_ERROR_GE_TRANSSHAPE_SHAPE_INVALID, "Failed to check dst shape %s", | |||
| GELOGE(ACL_ERROR_GE_SHAPE_INVALID, "Failed to check dst shape %s", | |||
| ShapeToString(dst_shape).c_str()); | |||
| return ACL_ERROR_GE_TRANSSHAPE_SHAPE_INVALID; | |||
| return ACL_ERROR_GE_SHAPE_INVALID; | |||
| } | |||
| return SUCCESS; | |||
| } | |||
| @@ -118,14 +118,14 @@ Status CheckShapeRelation(const TransArgs &args, ShapeVector &hw_shape) { | |||
| ShapeVector expect_src_shape; | |||
| auto ret = TransShapeToFracZz(args.dst_shape, args.src_data_type, expect_src_shape, hw_shape); | |||
| if (ret != SUCCESS) { | |||
| GELOGE(INTERNAL_ERROR, "Trans shape from %s to %s, shape %s to %s, data type %s failed", | |||
| GELOGE(ret, "Trans shape from %s to %s, shape %s to %s, data type %s failed", | |||
| TypeUtils::FormatToSerialString(args.dst_format).c_str(), | |||
| TypeUtils::FormatToSerialString(args.src_format).c_str(), ShapeToString(args.dst_shape).c_str(), | |||
| ShapeToString(args.src_shape).c_str(), TypeUtils::DataTypeToSerialString(args.src_data_type).c_str()); | |||
| return INTERNAL_ERROR; | |||
| return ret; | |||
| } | |||
| if (!IsTransShapeSrcCorrect(args, expect_src_shape)) { | |||
| return PARAM_INVALID; | |||
| return ACL_ERROR_GE_SHAPE_INVALID; | |||
| } | |||
| return SUCCESS; | |||
| } | |||
| @@ -140,10 +140,10 @@ Status TransFormatFromNdToFracZz(const TransArgs &args, TransResult &result, con | |||
| std::shared_ptr<uint8_t> dst(new (std::nothrow) uint8_t[dst_size](), std::default_delete<uint8_t[]>()); | |||
| if (dst == nullptr) { | |||
| GELOGE(OUT_OF_MEMORY, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld", | |||
| GELOGE(ACL_ERROR_GE_MEMORY_ALLOCATION, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld", | |||
| TypeUtils::FormatToSerialString(args.src_format).c_str(), | |||
| TypeUtils::FormatToSerialString(args.dst_format).c_str(), dst_size); | |||
| return OUT_OF_MEMORY; | |||
| return ACL_ERROR_GE_MEMORY_ALLOCATION; | |||
| } | |||
| // The src&dst_shape can be written as times*H*W & times*H1*W1*H0*W0, respectively. dst_shape_size >= kDimNum4D | |||
| auto times = hw_shape.at(kNdDimIndexN); | |||
| @@ -179,8 +179,8 @@ Status TransFormatFromNdToFracZz(const TransArgs &args, TransResult &result, con | |||
| auto ret = memcpy_s(dst.get() + dst_offset, static_cast<size_t>(protected_size), args.data + src_offset, | |||
| static_cast<size_t>(size * w0)); | |||
| if (ret != EOK) { | |||
| GELOGE(INTERNAL_ERROR, "Failed to operate the dst memory at offset %ld, error-code %d", dst_offset, ret); | |||
| return INTERNAL_ERROR; | |||
| GELOGE(ACL_ERROR_GE_MEMORY_OPERATE_FAILED, "Failed to operate the dst memory at offset %ld, error-code %d", dst_offset, ret); | |||
| return ACL_ERROR_GE_MEMORY_OPERATE_FAILED; | |||
| } | |||
| } | |||
| auto w1_head = num_w1 * w0; | |||
| @@ -195,8 +195,8 @@ Status TransFormatFromNdToFracZz(const TransArgs &args, TransResult &result, con | |||
| auto ret = memcpy_s(dst.get() + dst_offset, static_cast<size_t>(protected_size), args.data + src_offset, | |||
| static_cast<size_t>(size)); | |||
| if (ret != EOK) { | |||
| GELOGE(INTERNAL_ERROR, "Failed to operate the dst memory at offset %ld, error-code %d", dst_offset, ret); | |||
| return INTERNAL_ERROR; | |||
| GELOGE(ACL_ERROR_GE_MEMORY_OPERATE_FAILED, "Failed to operate the dst memory at offset %ld, error-code %d", dst_offset, ret); | |||
| return ACL_ERROR_GE_MEMORY_OPERATE_FAILED; | |||
| } | |||
| } | |||
| } | |||
| @@ -217,10 +217,10 @@ Status TransFormatFromFracZzToNd(const TransArgs &args, TransResult &result, con | |||
| std::shared_ptr<uint8_t> dst(new (std::nothrow) uint8_t[dst_size](), std::default_delete<uint8_t[]>()); | |||
| if (dst == nullptr) { | |||
| GELOGE(OUT_OF_MEMORY, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld", | |||
| GELOGE(ACL_ERROR_GE_MEMORY_ALLOCATION, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld", | |||
| TypeUtils::FormatToSerialString(args.src_format).c_str(), | |||
| TypeUtils::FormatToSerialString(args.dst_format).c_str(), dst_size); | |||
| return OUT_OF_MEMORY; | |||
| return ACL_ERROR_GE_MEMORY_ALLOCATION; | |||
| } | |||
| // The src&dst_shape can be written as times*H*W & times*H1*W1*H0*W0, respectively. dst_shape_size >= kDimNum4D | |||
| @@ -257,8 +257,8 @@ Status TransFormatFromFracZzToNd(const TransArgs &args, TransResult &result, con | |||
| auto ret = memcpy_s(dst.get() + dst_offset, static_cast<size_t>(protected_size), args.data + src_offset, | |||
| static_cast<size_t>(size * w0)); | |||
| if (ret != EOK) { | |||
| GELOGE(INTERNAL_ERROR, "Failed to operate the dst memory at offset %ld, error-code %d", dst_offset, ret); | |||
| return INTERNAL_ERROR; | |||
| GELOGE(ACL_ERROR_GE_MEMORY_OPERATE_FAILED, "Failed to operate the dst memory at offset %ld, error-code %d", dst_offset, ret); | |||
| return ACL_ERROR_GE_MEMORY_OPERATE_FAILED; | |||
| } | |||
| } | |||
| auto w1_head = num_w1 * w0; | |||
| @@ -273,8 +273,8 @@ Status TransFormatFromFracZzToNd(const TransArgs &args, TransResult &result, con | |||
| auto ret = memcpy_s(dst.get() + dst_offset, static_cast<size_t>(protected_size), args.data + src_offset, | |||
| static_cast<size_t>(size)); | |||
| if (ret != EOK) { | |||
| GELOGE(INTERNAL_ERROR, "Failed to operate the dst memory at offset %ld, error-code %d", dst_offset, ret); | |||
| return INTERNAL_ERROR; | |||
| GELOGE(ACL_ERROR_GE_MEMORY_OPERATE_FAILED, "Failed to operate the dst memory at offset %ld, error-code %d", dst_offset, ret); | |||
| return ACL_ERROR_GE_MEMORY_OPERATE_FAILED; | |||
| } | |||
| } | |||
| } | |||
| @@ -287,13 +287,19 @@ Status TransFormatFromFracZzToNd(const TransArgs &args, TransResult &result, con | |||
| } // namespace | |||
| Status FormatTransferFractalZz::TransFormat(const TransArgs &args, TransResult &result) { | |||
| if (!IsDataTypeSupport(args.src_data_type) || !CheckShape(args.src_format, args.src_shape) || | |||
| !IsShapeValid(args.dst_shape)) { | |||
| GELOGE(PARAM_INVALID, "Not support trans format from %s to %s, src shape %s, dst shape %s, data type %s", | |||
| if (!IsDataTypeSupport(args.src_data_type)) { | |||
| GELOGE(ACL_ERROR_GE_DATATYPE_INVALID, "Not support trans format from %s to %s, src shape %s, dst shape %s, data type %s", | |||
| TypeUtils::FormatToSerialString(args.src_format).c_str(), | |||
| TypeUtils::FormatToSerialString(args.dst_format).c_str(), ShapeToString(args.src_shape).c_str(), | |||
| ShapeToString(args.dst_shape).c_str(), TypeUtils::DataTypeToSerialString(args.src_data_type).c_str()); | |||
| return ACL_ERROR_GE_DATATYPE_INVALID; | |||
| } | |||
| if (!CheckShape(args.src_format, args.src_shape) || !IsShapeValid(args.dst_shape)) { | |||
| GELOGE(ACL_ERROR_GE_SHAPE_INVALID, "Not support trans format from %s to %s, src shape %s, dst shape %s, data type %s", | |||
| TypeUtils::FormatToSerialString(args.src_format).c_str(), | |||
| TypeUtils::FormatToSerialString(args.dst_format).c_str(), ShapeToString(args.src_shape).c_str(), | |||
| ShapeToString(args.dst_shape).c_str(), TypeUtils::DataTypeToSerialString(args.src_data_type).c_str()); | |||
| return PARAM_INVALID; | |||
| return ACL_ERROR_GE_SHAPE_INVALID; | |||
| } | |||
| GELOGD("Begin to trans format from %s to %s, src shape %s, dst shape %s, data type %s", | |||
| TypeUtils::FormatToSerialString(args.src_format).c_str(), | |||
| @@ -306,7 +312,7 @@ Status FormatTransferFractalZz::TransFormat(const TransArgs &args, TransResult & | |||
| return ret; | |||
| } | |||
| if (!IsTransShapeDstCorrect(args, expect_shape)) { | |||
| return PARAM_INVALID; | |||
| return ACL_ERROR_GE_SHAPE_INVALID; | |||
| } | |||
| return TransFormatFromNdToFracZz(args, result, hw_shape); | |||
| } | |||
| @@ -314,31 +320,38 @@ Status FormatTransferFractalZz::TransFormat(const TransArgs &args, TransResult & | |||
| Status FormatTransferFractalZz::TransShape(Format src_format, const ShapeVector &src_shape, DataType data_type, | |||
| Format dst_format, ShapeVector &dst_shape) { | |||
| if (!IsDataTypeSupport(data_type)) { | |||
| GELOGE(ACL_ERROR_GE_TRANSSHAPE_DATATYPE_INVALID, | |||
| GELOGE(ACL_ERROR_GE_DATATYPE_INVALID, | |||
| "Not support trans format from %s to %s, src shape %s, data type %s", | |||
| TypeUtils::FormatToSerialString(src_format).c_str(), TypeUtils::FormatToSerialString(dst_format).c_str(), | |||
| ShapeToString(src_shape).c_str(), TypeUtils::DataTypeToSerialString(data_type).c_str()); | |||
| return ACL_ERROR_GE_TRANSSHAPE_DATATYPE_INVALID; | |||
| return ACL_ERROR_GE_DATATYPE_INVALID; | |||
| } | |||
| if (!CheckShape(src_format, src_shape)) { | |||
| GELOGE(ACL_ERROR_GE_TRANSSHAPE_SHAPE_INVALID, | |||
| GELOGE(ACL_ERROR_GE_SHAPE_INVALID, | |||
| "Not support trans format from %s to %s, src shape %s, data type %s", | |||
| TypeUtils::FormatToSerialString(src_format).c_str(), TypeUtils::FormatToSerialString(dst_format).c_str(), | |||
| ShapeToString(src_shape).c_str(), TypeUtils::DataTypeToSerialString(data_type).c_str()); | |||
| return ACL_ERROR_GE_TRANSSHAPE_SHAPE_INVALID; | |||
| return ACL_ERROR_GE_SHAPE_INVALID; | |||
| } | |||
| ShapeVector hw_shape; | |||
| return TransShapeToFracZz(src_shape, data_type, dst_shape, hw_shape); | |||
| } | |||
| Status FormatTransferFractalZzND::TransFormat(const TransArgs &args, TransResult &result) { | |||
| if (!IsDataTypeSupport(args.src_data_type) || !IsShapeValid(args.src_shape) || | |||
| !CheckShape(args.dst_format, args.dst_shape)) { | |||
| GELOGE(PARAM_INVALID, "Not support trans format from %s to %s, src shape %s, dst shape %s, data type %s", | |||
| if (!IsDataTypeSupport(args.src_data_type)) { | |||
| GELOGE(ACL_ERROR_GE_DATATYPE_INVALID, "Not support trans format from %s to %s, src shape %s, dst shape %s, data type %s", | |||
| TypeUtils::FormatToSerialString(args.src_format).c_str(), | |||
| TypeUtils::FormatToSerialString(args.dst_format).c_str(), ShapeToString(args.src_shape).c_str(), | |||
| ShapeToString(args.dst_shape).c_str(), TypeUtils::DataTypeToSerialString(args.src_data_type).c_str()); | |||
| return ACL_ERROR_GE_DATATYPE_INVALID; | |||
| } | |||
| if (!IsShapeValid(args.src_shape) || !CheckShape(args.dst_format, args.dst_shape)) { | |||
| GELOGE(ACL_ERROR_GE_SHAPE_INVALID, "Not support trans format from %s to %s, src shape %s, dst shape %s, data type %s", | |||
| TypeUtils::FormatToSerialString(args.src_format).c_str(), | |||
| TypeUtils::FormatToSerialString(args.dst_format).c_str(), ShapeToString(args.src_shape).c_str(), | |||
| ShapeToString(args.dst_shape).c_str(), TypeUtils::DataTypeToSerialString(args.src_data_type).c_str()); | |||
| return PARAM_INVALID; | |||
| return ACL_ERROR_GE_SHAPE_INVALID; | |||
| } | |||
| GELOGD("Begin to trans format from %s to %s, src shape %s, dst shape %s, data type %s", | |||
| TypeUtils::FormatToSerialString(args.src_format).c_str(), | |||
| @@ -346,8 +359,9 @@ Status FormatTransferFractalZzND::TransFormat(const TransArgs &args, TransResult | |||
| ShapeToString(args.dst_shape).c_str(), TypeUtils::DataTypeToSerialString(args.src_data_type).c_str()); | |||
| ShapeVector hw_shape; | |||
| if (CheckShapeRelation(args, hw_shape) != SUCCESS) { | |||
| return PARAM_INVALID; | |||
| Status ret = CheckShapeRelation(args, hw_shape); | |||
| if (ret != SUCCESS) { | |||
| return ret; | |||
| } | |||
| return TransFormatFromFracZzToNd(args, result, hw_shape); | |||
| } | |||
| @@ -356,7 +370,7 @@ Status FormatTransferFractalZzND::TransShape(Format src_format, const ShapeVecto | |||
| Format dst_format, ShapeVector &dst_shape) { | |||
| GELOGD("The shape derivation from %s to %s is not unique. Trans shape is not supported", | |||
| TypeUtils::FormatToSerialString(src_format).c_str(), TypeUtils::FormatToSerialString(dst_format).c_str()); | |||
| return ACL_ERROR_GE_TRANSSHAPE_FORMAT_INVALID; | |||
| return ACL_ERROR_GE_FORMAT_INVALID; | |||
| } | |||
| REGISTER_FORMAT_TRANSFER(FormatTransferFractalZz, FORMAT_ND, FORMAT_FRACTAL_ZZ) | |||
| @@ -37,25 +37,25 @@ Status CheckArgsForFracZToHwcn(const TransArgs &args) { | |||
| std::string error = "Dose not support trans format from " + | |||
| FmtToStr(TypeUtils::FormatToSerialString(args.src_format)) + " to " + | |||
| FmtToStr(TypeUtils::FormatToSerialString(args.dst_format)); | |||
| GE_ERRORLOG_AND_ERRORMSG(UNSUPPORTED, error.c_str()); | |||
| return UNSUPPORTED; | |||
| GE_ERRORLOG_AND_ERRORMSG(ACL_ERROR_GE_FORMAT_INVALID, error.c_str()); | |||
| return ACL_ERROR_GE_FORMAT_INVALID; | |||
| } | |||
| if (!CheckDataTypeSupported(args.src_data_type)) { | |||
| GELOGE(UNSUPPORTED, "Failed to trans shape from FORMAT_FRACTAL_Z to HWCN, invalid data type %s", | |||
| GELOGE(ACL_ERROR_GE_DATATYPE_INVALID, "Failed to trans shape from FORMAT_FRACTAL_Z to HWCN, invalid data type %s", | |||
| TypeUtils::DataTypeToSerialString(args.src_data_type).c_str()); | |||
| return UNSUPPORTED; | |||
| return ACL_ERROR_GE_DATATYPE_INVALID; | |||
| } | |||
| if (!CheckShapeValid(src_shape, kFracZDimsNum)) { | |||
| GELOGE(PARAM_INVALID, "Failed to check src shape %s", ShapeToString(src_shape).c_str()); | |||
| return PARAM_INVALID; | |||
| GELOGE(ACL_ERROR_GE_SHAPE_INVALID, "Failed to check src shape %s", ShapeToString(src_shape).c_str()); | |||
| return ACL_ERROR_GE_SHAPE_INVALID; | |||
| } | |||
| if (!CheckShapeValid(dst_shape, kHwcnDimsNum)) { | |||
| GELOGE(PARAM_INVALID, "Failed to check dst shape %s", ShapeToString(dst_shape).c_str()); | |||
| return PARAM_INVALID; | |||
| GELOGE(ACL_ERROR_GE_SHAPE_INVALID, "Failed to check dst shape %s", ShapeToString(dst_shape).c_str()); | |||
| return ACL_ERROR_GE_SHAPE_INVALID; | |||
| } | |||
| int64_t c0 = GetCubeSizeByDataType(args.src_data_type); | |||
| if (c0 < 0) { | |||
| return PARAM_INVALID; | |||
| return ACL_ERROR_GE_DATATYPE_INVALID; | |||
| } | |||
| int64_t c1 = Ceil(dst_shape.at(kHwcnC), c0); | |||
| int64_t n0 = Ceil(dst_shape.at(kHwcnN), static_cast<int64_t>(kNiSize)); | |||
| @@ -64,8 +64,8 @@ Status CheckArgsForFracZToHwcn(const TransArgs &args) { | |||
| std::string error = "Failed to check relationship between src shape" + | |||
| FmtToStr(ShapeToString(src_shape)) + " and dst shape" + | |||
| FmtToStr(ShapeToString(dst_shape)); | |||
| GE_ERRORLOG_AND_ERRORMSG(UNSUPPORTED, error.c_str()); | |||
| return PARAM_INVALID; | |||
| GE_ERRORLOG_AND_ERRORMSG(ACL_ERROR_GE_SHAPE_INVALID, error.c_str()); | |||
| return ACL_ERROR_GE_SHAPE_INVALID; | |||
| } | |||
| return SUCCESS; | |||
| @@ -74,10 +74,10 @@ Status CheckArgsForFracZToHwcn(const TransArgs &args) { | |||
| Status GetDstDataAfterTrans(const TransArgs &args, TransResult &result, const int size, const int64_t total_size) { | |||
| std::shared_ptr<uint8_t> dst(new (std::nothrow) uint8_t[total_size], std::default_delete<uint8_t[]>()); | |||
| if (dst == nullptr) { | |||
| GELOGE(OUT_OF_MEMORY, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld, shape %s", | |||
| GELOGE(ACL_ERROR_GE_MEMORY_ALLOCATION, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld, shape %s", | |||
| TypeUtils::FormatToSerialString(args.src_format).c_str(), | |||
| TypeUtils::FormatToSerialString(args.dst_format).c_str(), total_size, ShapeToString(args.dst_shape).c_str()); | |||
| return OUT_OF_MEMORY; | |||
| return ACL_ERROR_GE_MEMORY_ALLOCATION; | |||
| } | |||
| auto n0 = args.src_shape.at(kFracZN0); | |||
| @@ -113,11 +113,11 @@ Status GetDstDataAfterTrans(const TransArgs &args, TransResult &result, const in | |||
| auto ret = memcpy_s(dst.get() + dst_offset, static_cast<size_t>(protected_size), args.data + src_offset, | |||
| static_cast<size_t>(size)); | |||
| if (ret != EOK) { | |||
| GELOGE(INTERNAL_ERROR, | |||
| GELOGE(ACL_ERROR_GE_MEMORY_OPERATE_FAILED, | |||
| "Failed to copy data from FracZ offset %ld to HWCN[%ld, %ld, %ld, %ld] " | |||
| "offset %ld, err-code %d", | |||
| src_offset, h_idx, w_idx, c_idx, n_idx, dst_offset, ret); | |||
| return INTERNAL_ERROR; | |||
| return ACL_ERROR_GE_MEMORY_OPERATE_FAILED; | |||
| } | |||
| } | |||
| } | |||
| @@ -130,8 +130,9 @@ Status GetDstDataAfterTrans(const TransArgs &args, TransResult &result, const in | |||
| } // namespace | |||
| Status FormatTransferFracZHwcn::TransFormat(const TransArgs &args, TransResult &result) { | |||
| if (CheckArgsForFracZToHwcn(args) != SUCCESS) { | |||
| return PARAM_INVALID; | |||
| Status ret = CheckArgsForFracZToHwcn(args); | |||
| if (ret != SUCCESS) { | |||
| return ret; | |||
| } | |||
| int size = GetSizeByDataType(args.src_data_type); | |||
| auto total_size = GetItemNumByShape(args.dst_shape) * size; | |||
| @@ -142,18 +143,19 @@ Status FormatTransferFracZHwcn::TransFormat(const TransArgs &args, TransResult & | |||
| return SUCCESS; | |||
| } | |||
| GELOGE(INTERNAL_ERROR, "Get %ld total size from dst shape %s, src shape %s", total_size, | |||
| GELOGE(ACL_ERROR_GE_SHAPE_INVALID, "Get %ld total size from dst shape %s, src shape %s", total_size, | |||
| ShapeToString(args.dst_shape).c_str(), ShapeToString(args.src_shape).c_str()); | |||
| return PARAM_INVALID; | |||
| return ACL_ERROR_GE_SHAPE_INVALID; | |||
| } | |||
| GELOGD("Begin to trans format from FracZ to HWCN, src shape %s, data type %s, dst shape %s, memory size %ld", | |||
| ShapeToString(args.src_shape).c_str(), TypeUtils::DataTypeToSerialString(args.src_data_type).c_str(), | |||
| ShapeToString(args.dst_shape).c_str(), total_size); | |||
| if (GetDstDataAfterTrans(args, result, size, total_size) != SUCCESS) { | |||
| GELOGE(INTERNAL_ERROR, "Failed to get data after trans, src shape %s, data type %s, dst shape %s, memory size %ld", | |||
| ret = GetDstDataAfterTrans(args, result, size, total_size); | |||
| if (ret != SUCCESS) { | |||
| GELOGE(ret, "Failed to get data after trans, src shape %s, data type %s, dst shape %s, memory size %ld", | |||
| ShapeToString(args.src_shape).c_str(), TypeUtils::DataTypeToSerialString(args.src_data_type).c_str(), | |||
| ShapeToString(args.dst_shape).c_str(), total_size); | |||
| return INTERNAL_ERROR; | |||
| return ret; | |||
| } | |||
| return SUCCESS; | |||
| } | |||
| @@ -161,7 +163,7 @@ Status FormatTransferFracZHwcn::TransFormat(const TransArgs &args, TransResult & | |||
| Status FormatTransferFracZHwcn::TransShape(Format src_format, const std::vector<int64_t> &src_shape, DataType data_type, | |||
| Format dst_format, std::vector<int64_t> &dst_shape) { | |||
| GELOGD("The shape derivation from FracZ to HWCN is not unique. Trans shape in this direction is not supported"); | |||
| return ACL_ERROR_GE_TRANSSHAPE_FORMAT_INVALID; | |||
| return ACL_ERROR_GE_FORMAT_INVALID; | |||
| } | |||
| REGISTER_FORMAT_TRANSFER(FormatTransferFracZHwcn, FORMAT_FRACTAL_Z, FORMAT_HWCN) | |||
| @@ -38,32 +38,32 @@ Status CheckArgsForFracZToNchw(const TransArgs &args) { | |||
| FmtToStr(TypeUtils::FormatToSerialString(args.src_format)) + " to " + | |||
| FmtToStr(TypeUtils::FormatToSerialString(args.dst_format)); | |||
| GE_ERRORLOG_AND_ERRORMSG(UNSUPPORTED, error.c_str()); | |||
| return UNSUPPORTED; | |||
| return ACL_ERROR_GE_FORMAT_INVALID; | |||
| } | |||
| if (!CheckDataTypeSupported(args.src_data_type)) { | |||
| GELOGE(UNSUPPORTED, "Failed to trans shape from FORMAT_FRACTAL_Z to NCHW, invalid data type %s", | |||
| GELOGE(ACL_ERROR_GE_DATATYPE_INVALID, "Failed to trans shape from FORMAT_FRACTAL_Z to NCHW, invalid data type %s", | |||
| TypeUtils::DataTypeToSerialString(args.src_data_type).c_str()); | |||
| return UNSUPPORTED; | |||
| return ACL_ERROR_GE_DATATYPE_INVALID; | |||
| } | |||
| if (!CheckShapeValid(src_shape, kFracZDimsNum)) { | |||
| GELOGE(PARAM_INVALID, "Failed to check src shape %s", ShapeToString(src_shape).c_str()); | |||
| return PARAM_INVALID; | |||
| GELOGE(ACL_ERROR_GE_SHAPE_INVALID, "Failed to check src shape %s", ShapeToString(src_shape).c_str()); | |||
| return ACL_ERROR_GE_SHAPE_INVALID; | |||
| } | |||
| if (!CheckShapeValid(dst_shape, kNchwDimsNum)) { | |||
| GELOGE(PARAM_INVALID, "Failed to check dst shape %s", ShapeToString(dst_shape).c_str()); | |||
| return PARAM_INVALID; | |||
| GELOGE(ACL_ERROR_GE_SHAPE_INVALID, "Failed to check dst shape %s", ShapeToString(dst_shape).c_str()); | |||
| return ACL_ERROR_GE_SHAPE_INVALID; | |||
| } | |||
| int64_t c0 = GetCubeSizeByDataType(args.src_data_type); | |||
| if (c0 < 0) { | |||
| return PARAM_INVALID; | |||
| return ACL_ERROR_GE_DATATYPE_INVALID; | |||
| } | |||
| int64_t c1 = Ceil(dst_shape.at(kNchwC), c0); | |||
| int64_t n0 = Ceil(dst_shape.at(kNchwN), static_cast<int64_t>(kNiSize)); | |||
| if (src_shape.at(kFracZHWC1) != dst_shape.at(kNchwH) * dst_shape.at(kNchwW) * c1 || src_shape.at(kFracZC0) != c0 || | |||
| src_shape.at(kFracZNi) != kNiSize || src_shape.at(kFracZN0) != n0) { | |||
| GELOGE(PARAM_INVALID, "Failed to check relationship between src and dst shape, src shape %s, dst shape %s", | |||
| GELOGE(ACL_ERROR_GE_SHAPE_INVALID, "Failed to check relationship between src and dst shape, src shape %s, dst shape %s", | |||
| ShapeToString(src_shape).c_str(), ShapeToString(dst_shape).c_str()); | |||
| return PARAM_INVALID; | |||
| return ACL_ERROR_GE_SHAPE_INVALID; | |||
| } | |||
| return SUCCESS; | |||
| @@ -72,10 +72,10 @@ Status CheckArgsForFracZToNchw(const TransArgs &args) { | |||
| Status GetDstDataAfterTrans(const TransArgs &args, TransResult &result, const int size, const int64_t total_size) { | |||
| std::shared_ptr<uint8_t> dst(new (std::nothrow) uint8_t[total_size], std::default_delete<uint8_t[]>()); | |||
| if (dst == nullptr) { | |||
| GELOGE(OUT_OF_MEMORY, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld, shape %s", | |||
| GELOGE(ACL_ERROR_GE_MEMORY_ALLOCATION, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld, shape %s", | |||
| TypeUtils::FormatToSerialString(args.src_format).c_str(), | |||
| TypeUtils::FormatToSerialString(args.dst_format).c_str(), total_size, ShapeToString(args.dst_shape).c_str()); | |||
| return OUT_OF_MEMORY; | |||
| return ACL_ERROR_GE_MEMORY_ALLOCATION; | |||
| } | |||
| auto n0 = args.src_shape.at(kFracZN0); | |||
| @@ -111,11 +111,11 @@ Status GetDstDataAfterTrans(const TransArgs &args, TransResult &result, const in | |||
| auto ret = memcpy_s(dst.get() + dst_offset, static_cast<size_t>(protected_size), args.data + src_offset, | |||
| static_cast<size_t>(size)); | |||
| if (ret != EOK) { | |||
| GELOGE(INTERNAL_ERROR, | |||
| GELOGE(ACL_ERROR_GE_MEMORY_OPERATE_FAILED, | |||
| "Failed to copy data from FracZ offset %ld to NCHW[%ld, %ld, %ld, %ld] offset %ld, " | |||
| "err-code %d", | |||
| src_offset, n_idx, c_idx, h_idx, w_idx, dst_offset, ret); | |||
| return INTERNAL_ERROR; | |||
| return ACL_ERROR_GE_MEMORY_OPERATE_FAILED; | |||
| } | |||
| } | |||
| } | |||
| @@ -128,8 +128,9 @@ Status GetDstDataAfterTrans(const TransArgs &args, TransResult &result, const in | |||
| } // namespace | |||
| Status FormatTransferFracZNchw::TransFormat(const TransArgs &args, TransResult &result) { | |||
| if (CheckArgsForFracZToNchw(args) != SUCCESS) { | |||
| return PARAM_INVALID; | |||
| Status ret = CheckArgsForFracZToNchw(args); | |||
| if (ret != SUCCESS) { | |||
| return ret; | |||
| } | |||
| int size = GetSizeByDataType(args.src_data_type); | |||
| auto total_size = GetItemNumByShape(args.dst_shape) * size; | |||
| @@ -140,19 +141,20 @@ Status FormatTransferFracZNchw::TransFormat(const TransArgs &args, TransResult & | |||
| return SUCCESS; | |||
| } | |||
| GELOGE(INTERNAL_ERROR, "Get %ld total size from dst shape %s, src shape %s", total_size, | |||
| GELOGE(ACL_ERROR_GE_SHAPE_INVALID, "Get %ld total size from dst shape %s, src shape %s", total_size, | |||
| ShapeToString(args.dst_shape).c_str(), ShapeToString(args.src_shape).c_str()); | |||
| return PARAM_INVALID; | |||
| return ACL_ERROR_GE_SHAPE_INVALID; | |||
| } | |||
| GELOGD("Begin to trans format from FracZ to NCHW, src shape %s, data type %s, dst shape %s, memory size %ld", | |||
| ShapeToString(args.src_shape).c_str(), TypeUtils::DataTypeToSerialString(args.src_data_type).c_str(), | |||
| ShapeToString(args.dst_shape).c_str(), total_size); | |||
| if (GetDstDataAfterTrans(args, result, size, total_size) != SUCCESS) { | |||
| GELOGE(INTERNAL_ERROR, "Failed to get data after trans, src shape %s, data type %s, dst shape %s, memory size %ld", | |||
| ret = GetDstDataAfterTrans(args, result, size, total_size); | |||
| if (ret != SUCCESS) { | |||
| GELOGE(ret, "Failed to get data after trans, src shape %s, data type %s, dst shape %s, memory size %ld", | |||
| ShapeToString(args.src_shape).c_str(), TypeUtils::DataTypeToSerialString(args.src_data_type).c_str(), | |||
| ShapeToString(args.dst_shape).c_str(), total_size); | |||
| return INTERNAL_ERROR; | |||
| return ret; | |||
| } | |||
| return SUCCESS; | |||
| } | |||
| @@ -160,7 +162,7 @@ Status FormatTransferFracZNchw::TransFormat(const TransArgs &args, TransResult & | |||
| Status FormatTransferFracZNchw::TransShape(Format src_format, const std::vector<int64_t> &src_shape, DataType data_type, | |||
| Format dst_format, std::vector<int64_t> &dst_shape) { | |||
| GELOGD("The shape derivation from FracZ to NCHW is not unique. Trans shape in this direction is not supported"); | |||
| return ACL_ERROR_GE_TRANSSHAPE_FORMAT_INVALID; | |||
| return ACL_ERROR_GE_FORMAT_INVALID; | |||
| } | |||
| REGISTER_FORMAT_TRANSFER(FormatTransferFracZNchw, FORMAT_FRACTAL_Z, FORMAT_NCHW) | |||
| @@ -43,9 +43,9 @@ Status TransShapeHwcnToC1hwncoc0(const DataType &data_type, const std::vector<in | |||
| dst_shape.push_back(cube_size); | |||
| dst_shape.push_back(cube_size); | |||
| if (!CheckShapeValid(dst_shape, kC1hwncoc0DimsNum)) { | |||
| GELOGE(ACL_ERROR_GE_TRANSSHAPE_SHAPE_INVALID, "Failed to check dst shape %s", | |||
| GELOGE(ACL_ERROR_GE_SHAPE_INVALID, "Failed to check dst shape %s", | |||
| ShapeToString(dst_shape).c_str()); | |||
| return ACL_ERROR_GE_TRANSSHAPE_SHAPE_INVALID; | |||
| return ACL_ERROR_GE_SHAPE_INVALID; | |||
| } | |||
| return SUCCESS; | |||
| } | |||
| @@ -55,21 +55,21 @@ Status CheckArgsForHwcnToC1hwncoc0(const TransArgs &args) { | |||
| std::string error = "Dose not support trans format from " + | |||
| FmtToStr(TypeUtils::FormatToSerialString(args.src_format)) + " to " + | |||
| FmtToStr(TypeUtils::FormatToSerialString(args.dst_format)); | |||
| GE_ERRORLOG_AND_ERRORMSG(UNSUPPORTED, error.c_str()); | |||
| return UNSUPPORTED; | |||
| GE_ERRORLOG_AND_ERRORMSG(ACL_ERROR_GE_FORMAT_INVALID, error.c_str()); | |||
| return ACL_ERROR_GE_FORMAT_INVALID; | |||
| } | |||
| if (!CheckDataTypeSupported(args.src_data_type)) { | |||
| GELOGE(UNSUPPORTED, "Failed to trans shape from HWCN to C1HWNCoC0, invalid data type %s", | |||
| GELOGE(ACL_ERROR_GE_DATATYPE_INVALID, "Failed to trans shape from HWCN to C1HWNCoC0, invalid data type %s", | |||
| TypeUtils::DataTypeToSerialString(args.src_data_type).c_str()); | |||
| return UNSUPPORTED; | |||
| return ACL_ERROR_GE_DATATYPE_INVALID; | |||
| } | |||
| if (!CheckShapeValid(args.src_shape, kHwcnDimsNum)) { | |||
| GELOGE(PARAM_INVALID, "Failed to check src shape %s", ShapeToString(args.src_shape).c_str()); | |||
| return PARAM_INVALID; | |||
| GELOGE(ACL_ERROR_GE_SHAPE_INVALID, "Failed to check src shape %s", ShapeToString(args.src_shape).c_str()); | |||
| return ACL_ERROR_GE_SHAPE_INVALID; | |||
| } | |||
| if (!CheckShapeValid(args.dst_shape, kC1hwncoc0DimsNum)) { | |||
| GELOGE(PARAM_INVALID, "Failed to check dst shape %s", ShapeToString(args.dst_shape).c_str()); | |||
| return PARAM_INVALID; | |||
| GELOGE(ACL_ERROR_GE_SHAPE_INVALID, "Failed to check dst shape %s", ShapeToString(args.dst_shape).c_str()); | |||
| return ACL_ERROR_GE_SHAPE_INVALID; | |||
| } | |||
| std::vector<int64_t> expect_dst_shape; | |||
| auto ret = TransShapeHwcnToC1hwncoc0(args.src_data_type, args.src_shape, expect_dst_shape); | |||
| @@ -77,12 +77,12 @@ Status CheckArgsForHwcnToC1hwncoc0(const TransArgs &args) { | |||
| return ret; | |||
| } | |||
| if (args.dst_shape != expect_dst_shape) { | |||
| GELOGE(PARAM_INVALID, | |||
| GELOGE(ACL_ERROR_GE_SHAPE_INVALID, | |||
| "Failed to trans format, src and dst shape are not compatible. src shape %s, dst shape %s, " | |||
| "expect dst shape %s", | |||
| ShapeToString(args.src_shape).c_str(), ShapeToString(args.dst_shape).c_str(), | |||
| ShapeToString(expect_dst_shape).c_str()); | |||
| return PARAM_INVALID; | |||
| return ACL_ERROR_GE_SHAPE_INVALID; | |||
| } | |||
| return SUCCESS; | |||
| @@ -91,10 +91,10 @@ Status CheckArgsForHwcnToC1hwncoc0(const TransArgs &args) { | |||
| Status GetDstDataAfterTrans(const TransArgs &args, TransResult &result, const int size, const int64_t total_size) { | |||
| std::shared_ptr<uint8_t> dst(new (std::nothrow) uint8_t[total_size], std::default_delete<uint8_t[]>()); | |||
| if (dst == nullptr) { | |||
| GELOGE(OUT_OF_MEMORY, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld, shape %s", | |||
| GELOGE(ACL_ERROR_GE_MEMORY_ALLOCATION, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld, shape %s", | |||
| TypeUtils::FormatToSerialString(args.src_format).c_str(), | |||
| TypeUtils::FormatToSerialString(args.dst_format).c_str(), total_size, ShapeToString(args.dst_shape).c_str()); | |||
| return OUT_OF_MEMORY; | |||
| return ACL_ERROR_GE_MEMORY_ALLOCATION; | |||
| } | |||
| auto h = args.src_shape.at(kHwcnH); | |||
| @@ -135,22 +135,22 @@ Status GetDstDataAfterTrans(const TransArgs &args, TransResult &result, const in | |||
| auto ret = memcpy_s(dst.get() + dst_offset, static_cast<size_t>(protected_size), args.data + src_offset, | |||
| static_cast<size_t>(size)); | |||
| if (ret != EOK) { | |||
| GELOGE(INTERNAL_ERROR, | |||
| GELOGE(ACL_ERROR_GE_MEMORY_OPERATE_FAILED, | |||
| "Failed to copy data from HWCN[%ld, %ld, %ld, %ld] offset %ld to " | |||
| "C1HWNCoC0[%ld, %ld, %ld, %ld, %ld, %ld] offset %ld, err-code %d", | |||
| h_idx, w_idx, c_idx, n_idx, src_offset, c1_idx, h_idx, w_idx, n_idx, co_idx, c0_idx, | |||
| dst_offset, ret); | |||
| return INTERNAL_ERROR; | |||
| return ACL_ERROR_GE_MEMORY_OPERATE_FAILED; | |||
| } | |||
| } else { | |||
| auto ret = | |||
| memset_s(dst.get() + dst_offset, static_cast<size_t>(protected_size), 0, static_cast<size_t>(size)); | |||
| if (ret != EOK) { | |||
| GELOGE(INTERNAL_ERROR, | |||
| GELOGE(ACL_ERROR_GE_MEMORY_OPERATE_FAILED, | |||
| "Failed to set to 0 to C1HWNCoC0[%ld, %ld, %ld, %ld, %ld, %ld] offset %ld, " | |||
| "err-code %d", | |||
| c1_idx, h_idx, w_idx, n_idx, co_idx, c0_idx, dst_offset, ret); | |||
| return INTERNAL_ERROR; | |||
| return ACL_ERROR_GE_MEMORY_OPERATE_FAILED; | |||
| } | |||
| } | |||
| } | |||
| @@ -166,8 +166,9 @@ Status GetDstDataAfterTrans(const TransArgs &args, TransResult &result, const in | |||
| } // namespace | |||
| Status FormatTransferHwcnC1hwncoc0::TransFormat(const TransArgs &args, TransResult &result) { | |||
| if (CheckArgsForHwcnToC1hwncoc0(args) != SUCCESS) { | |||
| return PARAM_INVALID; | |||
| Status ret = CheckArgsForHwcnToC1hwncoc0(args); | |||
| if (ret != SUCCESS) { | |||
| return ret; | |||
| } | |||
| int size = GetSizeByDataType(args.src_data_type); | |||
| auto total_size = GetItemNumByShape(args.dst_shape) * size; | |||
| @@ -178,18 +179,20 @@ Status FormatTransferHwcnC1hwncoc0::TransFormat(const TransArgs &args, TransResu | |||
| return SUCCESS; | |||
| } | |||
| GELOGE(INTERNAL_ERROR, "Get %ld total size from dst shape %s, src shape %s", total_size, | |||
| GELOGE(ACL_ERROR_GE_SHAPE_INVALID, "Get %ld total size from dst shape %s, src shape %s", total_size, | |||
| ShapeToString(args.dst_shape).c_str(), ShapeToString(args.src_shape).c_str()); | |||
| return PARAM_INVALID; | |||
| return ACL_ERROR_GE_SHAPE_INVALID; | |||
| } | |||
| GELOGD("Begin to trans format from HWCN to C1HWNCoC0, src shape %s, data type %s, dst shape %s, memory size %ld", | |||
| ShapeToString(args.src_shape).c_str(), TypeUtils::DataTypeToSerialString(args.src_data_type).c_str(), | |||
| ShapeToString(args.dst_shape).c_str(), total_size); | |||
| if (GetDstDataAfterTrans(args, result, size, total_size) != SUCCESS) { | |||
| GELOGE(INTERNAL_ERROR, "Failed to get data after trans, src shape %s, data type %s, dst shape %s, memory size %ld", | |||
| ret = GetDstDataAfterTrans(args, result, size, total_size); | |||
| if (ret != SUCCESS) { | |||
| GELOGE(ret, "Failed to get data after trans, src shape %s, data type %s, dst shape %s, memory size %ld", | |||
| ShapeToString(args.src_shape).c_str(), TypeUtils::DataTypeToSerialString(args.src_data_type).c_str(), | |||
| ShapeToString(args.dst_shape).c_str(), total_size); | |||
| return INTERNAL_ERROR; | |||
| return ret; | |||
| } | |||
| return SUCCESS; | |||
| } | |||
| @@ -198,15 +201,15 @@ Status FormatTransferHwcnC1hwncoc0::TransShape(Format src_format, const std::vec | |||
| DataType data_type, Format dst_format, std::vector<int64_t> &dst_shape) { | |||
| if (src_format == FORMAT_HWCN && CheckDataTypeSupported(data_type)) { | |||
| if (!CheckShapeValid(src_shape, kHwcnDimsNum)) { | |||
| GELOGE(ACL_ERROR_GE_TRANSSHAPE_SHAPE_INVALID, "Failed to check src shape %s", | |||
| GELOGE(ACL_ERROR_GE_SHAPE_INVALID, "Failed to check src shape %s", | |||
| ShapeToString(src_shape).c_str()); | |||
| return ACL_ERROR_GE_TRANSSHAPE_SHAPE_INVALID; | |||
| return ACL_ERROR_GE_SHAPE_INVALID; | |||
| } | |||
| return TransShapeHwcnToC1hwncoc0(data_type, src_shape, dst_shape); | |||
| } else if (src_format != FORMAT_HWCN) { | |||
| return ACL_ERROR_GE_TRANSSHAPE_FORMAT_INVALID; | |||
| return ACL_ERROR_GE_FORMAT_INVALID; | |||
| } else { | |||
| return ACL_ERROR_GE_TRANSSHAPE_DATATYPE_INVALID; | |||
| return ACL_ERROR_GE_DATATYPE_INVALID; | |||
| } | |||
| } | |||
| @@ -37,33 +37,33 @@ Status CheckArgsForNc1hwc0ToNhwc(const TransArgs &args) { | |||
| std::string error = "Dose not support trans format from " + | |||
| FmtToStr(TypeUtils::FormatToSerialString(args.src_format)) + " to " + | |||
| FmtToStr(TypeUtils::FormatToSerialString(args.dst_format)); | |||
| GE_ERRORLOG_AND_ERRORMSG(UNSUPPORTED, error.c_str()); | |||
| return UNSUPPORTED; | |||
| GE_ERRORLOG_AND_ERRORMSG(ACL_ERROR_GE_FORMAT_INVALID, error.c_str()); | |||
| return ACL_ERROR_GE_FORMAT_INVALID; | |||
| } | |||
| if (!CheckDataTypeSupported(args.src_data_type)) { | |||
| GELOGE(UNSUPPORTED, "Failed to trans shape from NC1HWC0 to NHWC, invalid data type %s", | |||
| GELOGE(ACL_ERROR_GE_DATATYPE_INVALID, "Failed to trans shape from NC1HWC0 to NHWC, invalid data type %s", | |||
| TypeUtils::DataTypeToSerialString(args.src_data_type).c_str()); | |||
| return UNSUPPORTED; | |||
| return ACL_ERROR_GE_DATATYPE_INVALID; | |||
| } | |||
| if (!CheckShapeValid(args.src_shape, kNc1hwc0DimsNum)) { | |||
| GELOGE(PARAM_INVALID, "Failed to check src shape %s", ShapeToString(args.src_shape).c_str()); | |||
| return PARAM_INVALID; | |||
| GELOGE(ACL_ERROR_GE_SHAPE_INVALID, "Failed to check src shape %s", ShapeToString(args.src_shape).c_str()); | |||
| return ACL_ERROR_GE_SHAPE_INVALID; | |||
| } | |||
| if (!CheckShapeValid(args.dst_shape, kNhwcDimsNum)) { | |||
| GELOGE(PARAM_INVALID, "Failed to check dst shape %s", ShapeToString(args.dst_shape).c_str()); | |||
| return PARAM_INVALID; | |||
| GELOGE(ACL_ERROR_GE_SHAPE_INVALID, "Failed to check dst shape %s", ShapeToString(args.dst_shape).c_str()); | |||
| return ACL_ERROR_GE_SHAPE_INVALID; | |||
| } | |||
| int64_t c0 = GetCubeSizeByDataType(args.src_data_type); | |||
| if (c0 <= 0) { | |||
| GELOGE(PARAM_INVALID, "Failed to get cube size, the data type is invalid"); | |||
| return PARAM_INVALID; | |||
| GELOGE(ACL_ERROR_GE_DATATYPE_INVALID, "Failed to get cube size, the data type is invalid"); | |||
| return ACL_ERROR_GE_DATATYPE_INVALID; | |||
| } | |||
| if (src_shape.at(kNc1hwc0H) != dst_shape.at(kNhwcH) || src_shape.at(kNc1hwc0W) != dst_shape.at(kNhwcW) || | |||
| src_shape.at(kNc1hwc0N) != dst_shape.at(kNhwcN) || src_shape.at(kNc1hwc0C0) != c0 || | |||
| src_shape.at(kNc1hwc0C1) != (Ceil(dst_shape.at(kNhwcC), c0))) { | |||
| GELOGE(PARAM_INVALID, "Failed to check relationship between src and dst shape, src shape %s, dst shape %s", | |||
| GELOGE(ACL_ERROR_GE_SHAPE_INVALID, "Failed to check relationship between src and dst shape, src shape %s, dst shape %s", | |||
| ShapeToString(src_shape).c_str(), ShapeToString(dst_shape).c_str()); | |||
| return PARAM_INVALID; | |||
| return ACL_ERROR_GE_SHAPE_INVALID; | |||
| } | |||
| return SUCCESS; | |||
| @@ -72,10 +72,10 @@ Status CheckArgsForNc1hwc0ToNhwc(const TransArgs &args) { | |||
| Status GetDstDataAfterTrans(const TransArgs &args, TransResult &result, const int size, const int64_t total_size) { | |||
| std::shared_ptr<uint8_t> dst(new (std::nothrow) uint8_t[total_size], std::default_delete<uint8_t[]>()); | |||
| if (dst == nullptr) { | |||
| GELOGE(OUT_OF_MEMORY, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld, shape %s", | |||
| GELOGE(ACL_ERROR_GE_MEMORY_ALLOCATION, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld, shape %s", | |||
| TypeUtils::FormatToSerialString(args.src_format).c_str(), | |||
| TypeUtils::FormatToSerialString(args.dst_format).c_str(), total_size, ShapeToString(args.dst_shape).c_str()); | |||
| return OUT_OF_MEMORY; | |||
| return ACL_ERROR_GE_MEMORY_ALLOCATION; | |||
| } | |||
| auto h = args.src_shape.at(kNc1hwc0H); | |||
| @@ -109,11 +109,11 @@ Status GetDstDataAfterTrans(const TransArgs &args, TransResult &result, const in | |||
| auto ret = memcpy_s(dst.get() + dst_offset, static_cast<size_t>(protected_size), args.data + src_offset, | |||
| static_cast<size_t>(size)); | |||
| if (ret != EOK) { | |||
| GELOGE(INTERNAL_ERROR, | |||
| GELOGE(ACL_ERROR_GE_MEMORY_OPERATE_FAILED, | |||
| "Failed to copy data from NC1HWC0[%ld, %ld, %ld, %ld, %ld] offset %ld to NHWC[%ld, %ld, %ld, %ld]" | |||
| " offset %ld, err-code %d", | |||
| n_idx, c1_idx, h_idx, w_idx, c0_idx, src_offset, n_idx, c_idx, h_idx, w_idx, dst_offset, ret); | |||
| return INTERNAL_ERROR; | |||
| return ACL_ERROR_GE_MEMORY_OPERATE_FAILED; | |||
| } | |||
| } | |||
| } | |||
| @@ -126,8 +126,9 @@ Status GetDstDataAfterTrans(const TransArgs &args, TransResult &result, const in | |||
| } // namespace | |||
| Status FormatTransferNc1hwc0Nhwc::TransFormat(const TransArgs &args, TransResult &result) { | |||
| if (CheckArgsForNc1hwc0ToNhwc(args) != SUCCESS) { | |||
| return PARAM_INVALID; | |||
| Status ret = CheckArgsForNc1hwc0ToNhwc(args); | |||
| if (ret != SUCCESS) { | |||
| return ret; | |||
| } | |||
| int size = GetSizeByDataType(args.src_data_type); | |||
| auto total_size = GetItemNumByShape(args.dst_shape) * size; | |||
| @@ -138,18 +139,20 @@ Status FormatTransferNc1hwc0Nhwc::TransFormat(const TransArgs &args, TransResult | |||
| return SUCCESS; | |||
| } | |||
| GELOGE(INTERNAL_ERROR, "Get %ld total size from dst shape %s, src shape %s", total_size, | |||
| GELOGE(ACL_ERROR_GE_SHAPE_INVALID, "Get %ld total size from dst shape %s, src shape %s", total_size, | |||
| ShapeToString(args.dst_shape).c_str(), ShapeToString(args.src_shape).c_str()); | |||
| return PARAM_INVALID; | |||
| return ACL_ERROR_GE_SHAPE_INVALID; | |||
| } | |||
| GELOGD("Begin to trans format from NC1HWC0 to NCHW, src shape %s, data type %s, dst shape %s, memory size %ld", | |||
| ShapeToString(args.src_shape).c_str(), TypeUtils::DataTypeToSerialString(args.src_data_type).c_str(), | |||
| ShapeToString(args.dst_shape).c_str(), total_size); | |||
| if (GetDstDataAfterTrans(args, result, size, total_size) != SUCCESS) { | |||
| GELOGE(INTERNAL_ERROR, "Failed to get data after trans, src shape %s, data type %s, dst shape %s, memory size %ld", | |||
| ret = GetDstDataAfterTrans(args, result, size, total_size); | |||
| if (ret != SUCCESS) { | |||
| GELOGE(ret, "Failed to get data after trans, src shape %s, data type %s, dst shape %s, memory size %ld", | |||
| ShapeToString(args.src_shape).c_str(), TypeUtils::DataTypeToSerialString(args.src_data_type).c_str(), | |||
| ShapeToString(args.dst_shape).c_str(), total_size); | |||
| return INTERNAL_ERROR; | |||
| return ret; | |||
| } | |||
| return SUCCESS; | |||
| } | |||
| @@ -157,7 +160,7 @@ Status FormatTransferNc1hwc0Nhwc::TransFormat(const TransArgs &args, TransResult | |||
| Status FormatTransferNc1hwc0Nhwc::TransShape(Format src_format, const std::vector<int64_t> &src_shape, | |||
| DataType data_type, Format dst_format, std::vector<int64_t> &dst_shape) { | |||
| GELOGD("The shape derivation from NC1HWC0 to NHWC is not unique. Trans shape in this direction is not supported"); | |||
| return ACL_ERROR_GE_TRANSSHAPE_FORMAT_INVALID; | |||
| return ACL_ERROR_GE_FORMAT_INVALID; | |||
| } | |||
| REGISTER_FORMAT_TRANSFER(FormatTransferNc1hwc0Nhwc, FORMAT_NC1HWC0, FORMAT_NHWC) | |||
| @@ -45,7 +45,7 @@ Status CheckDataTypeSupport(DataType data_type) { return GetSizeByDataType(data_ | |||
| Status TransShape(int64_t n, int64_t c, int64_t h, int64_t w, DataType data_type, std::vector<int64_t> &dst_shape) { | |||
| auto c0 = GetCubeSizeByDataType(data_type); | |||
| if (c0 < 0) { | |||
| return ACL_ERROR_GE_TRANSSHAPE_DATATYPE_INVALID; | |||
| return ACL_ERROR_GE_DATATYPE_INVALID; | |||
| } | |||
| auto chw = c * h * w; | |||
| @@ -59,9 +59,9 @@ Status TransShape(int64_t n, int64_t c, int64_t h, int64_t w, DataType data_type | |||
| dst_shape.push_back(c0); | |||
| if (!IsShapeValid(dst_shape)) { | |||
| GELOGE(ACL_ERROR_GE_TRANSSHAPE_SHAPE_INVALID, "Failed to check dst shape %s", | |||
| GELOGE(ACL_ERROR_GE_SHAPE_INVALID, "Failed to check dst shape %s", | |||
| ShapeToString(dst_shape).c_str()); | |||
| return ACL_ERROR_GE_TRANSSHAPE_SHAPE_INVALID; | |||
| return ACL_ERROR_GE_SHAPE_INVALID; | |||
| } | |||
| return SUCCESS; | |||
| } | |||
| @@ -69,7 +69,7 @@ Status TransShape(int64_t n, int64_t c, int64_t h, int64_t w, DataType data_type | |||
| Status TransShapeNchwToFzC04(const std::vector<int64_t> &src_shape, DataType data_type, | |||
| std::vector<int64_t> &dst_shape) { | |||
| if (!CheckShapeValid(src_shape, kNchwDimsNum)) { | |||
| return ACL_ERROR_GE_TRANSSHAPE_SHAPE_INVALID; | |||
| return ACL_ERROR_GE_SHAPE_INVALID; | |||
| } | |||
| auto n = src_shape.at(kNchwN); | |||
| @@ -94,8 +94,8 @@ Status TransFormatFromNchwToFzC04(const TransArgs &args, TransResult &result) { | |||
| std::vector<int64_t> expect_shape = {n, h, w, c}; | |||
| auto ret = ge::formats::Transpose(data, args.src_shape, args.src_data_type, perm_arg_1, trans_result_1); | |||
| if (ret != SUCCESS) { | |||
| GELOGE(INTERNAL_ERROR, "Failed to Transpose from NCHW to HWCN"); | |||
| return NOT_CHANGED; | |||
| GELOGE(ret, "Failed to Transpose from NCHW to HWCN"); | |||
| return ret; | |||
| } | |||
| TransArgs args_tmp = args; | |||
| @@ -104,8 +104,8 @@ Status TransFormatFromNchwToFzC04(const TransArgs &args, TransResult &result) { | |||
| // check size it should be same with original | |||
| size_t expect_size = n * c * h * w * size; // before has do check about mul | |||
| if (trans_result_1.length != expect_size) { | |||
| GELOGE(INTERNAL_ERROR, "size is not match after transpose!"); | |||
| return NOT_CHANGED; | |||
| GELOGE(ACL_ERROR_GE_PARAM_INVALID, "size is not match after transpose!"); | |||
| return ACL_ERROR_GE_PARAM_INVALID; | |||
| } | |||
| // prepare for padding in chw | |||
| @@ -118,20 +118,20 @@ Status TransFormatFromNchwToFzC04(const TransArgs &args, TransResult &result) { | |||
| // data overflow check totally | |||
| GE_IF_BOOL_EXEC(!CheckInt64MulOverflow(h_o, w_o), | |||
| GELOGE(INTERNAL_ERROR, "int64 mul overflow.A[%ld], B[%ld]", h_o, w_o); | |||
| return INTERNAL_ERROR); | |||
| GELOGE(ACL_ERROR_GE_INTERNAL_ERROR, "int64 mul overflow.A[%ld], B[%ld]", h_o, w_o); | |||
| return ACL_ERROR_GE_INTERNAL_ERROR); | |||
| GE_IF_BOOL_EXEC(!CheckInt64MulOverflow(n_o, c_o), | |||
| GELOGE(INTERNAL_ERROR, "int64 mul overflow.A[%ld], B[%ld]", n_o, c_o); | |||
| return INTERNAL_ERROR); | |||
| GELOGE(ACL_ERROR_GE_INTERNAL_ERROR, "int64 mul overflow.A[%ld], B[%ld]", n_o, c_o); | |||
| return ACL_ERROR_GE_INTERNAL_ERROR); | |||
| auto t1 = h_o * w_o; | |||
| auto t2 = n_o * c_o; | |||
| GE_IF_BOOL_EXEC(!CheckInt64MulOverflow(t1, t2), GELOGE(INTERNAL_ERROR, "int64 mul overflow.A[%ld], B[%ld]", t1, t2); | |||
| return INTERNAL_ERROR); | |||
| return ACL_ERROR_GE_INTERNAL_ERROR); | |||
| int64_t total_ele_cnt = n_o * c_o * h_o * w_o; | |||
| GE_IF_BOOL_EXEC(!CheckInt64MulOverflow(total_ele_cnt, size), | |||
| GELOGE(INTERNAL_ERROR, "int64 mul overflow.A[%ld], B[%d]", total_ele_cnt, size); | |||
| return INTERNAL_ERROR); | |||
| GELOGE(ACL_ERROR_GE_INTERNAL_ERROR, "int64 mul overflow.A[%ld], B[%d]", total_ele_cnt, size); | |||
| return ACL_ERROR_GE_INTERNAL_ERROR); | |||
| int64_t dst_size = total_ele_cnt * size; | |||
| if (dst_size == 0) { | |||
| result.length = 0; | |||
| @@ -140,15 +140,15 @@ Status TransFormatFromNchwToFzC04(const TransArgs &args, TransResult &result) { | |||
| std::shared_ptr<uint8_t> dst(new (std::nothrow) uint8_t[dst_size], std::default_delete<uint8_t[]>()); | |||
| if (dst == nullptr) { | |||
| GELOGE(OUT_OF_MEMORY, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld", | |||
| GELOGE(ACL_ERROR_GE_MEMORY_ALLOCATION, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld", | |||
| TypeUtils::FormatToSerialString(args.src_format).c_str(), | |||
| TypeUtils::FormatToSerialString(args.dst_format).c_str(), dst_size); | |||
| return OUT_OF_MEMORY; | |||
| return ACL_ERROR_GE_MEMORY_ALLOCATION; | |||
| } | |||
| auto retMem = memset_s(dst.get(), dst_size, 0, dst_size); | |||
| if (retMem != EOK) { | |||
| GELOGE(INTERNAL_ERROR, "memst failed!"); | |||
| return INTERNAL_ERROR; | |||
| GELOGE(ACL_ERROR_GE_MEMORY_OPERATE_FAILED, "memst failed!"); | |||
| return ACL_ERROR_GE_MEMORY_OPERATE_FAILED; | |||
| } | |||
| // copy data | |||
| auto block = c * h * w * size; | |||
| @@ -159,8 +159,8 @@ Status TransFormatFromNchwToFzC04(const TransArgs &args, TransResult &result) { | |||
| for (auto k = 0; k < n; k++) { | |||
| ret = memcpy_s(p_d + k * stride, protectSize, p_s + k * block, block); | |||
| if (ret != EOK) { | |||
| GELOGE(INTERNAL_ERROR, "memcpy_s failed!"); | |||
| return INTERNAL_ERROR; | |||
| GELOGE(ACL_ERROR_GE_MEMORY_OPERATE_FAILED, "memcpy_s failed!"); | |||
| return ACL_ERROR_GE_MEMORY_OPERATE_FAILED; | |||
| } | |||
| protectSize = protectSize - block; | |||
| } | |||
| @@ -169,8 +169,8 @@ Status TransFormatFromNchwToFzC04(const TransArgs &args, TransResult &result) { | |||
| std::vector<int64_t> perm_arg_2 = {2, 0, 1, 3}; | |||
| ret = ge::formats::Transpose(dst.get(), shape_o, args.src_data_type, perm_arg_2, result); | |||
| if (ret != SUCCESS) { | |||
| GELOGE(INTERNAL_ERROR, "Failed to Transpose from NCHW to HWCN"); | |||
| return NOT_CHANGED; | |||
| GELOGE(ret, "Failed to Transpose from NCHW to HWCN"); | |||
| return ret; | |||
| } | |||
| return SUCCESS; | |||
| @@ -180,7 +180,7 @@ Status PaddingNC(const TransArgs &args, TransArgs &args_tmp, std::shared_ptr<uin | |||
| args_tmp = args; | |||
| auto src_shape = args_tmp.src_shape; | |||
| if (!CheckShapeValid(src_shape, kNchwDimsNum)) { | |||
| return PARAM_INVALID; | |||
| return ACL_ERROR_GE_SHAPE_INVALID; | |||
| } | |||
| int64_t c0 = GetCubeSizeByDataType(args.src_data_type); | |||
| @@ -190,8 +190,8 @@ Status PaddingNC(const TransArgs &args, TransArgs &args_tmp, std::shared_ptr<uin | |||
| auto w = src_shape.at(kNchwW); | |||
| if (c > kMaxDimsNumC) { | |||
| GELOGE(PARAM_INVALID, "Invalie dim c num[%lu].It should be in (0,4]", c); | |||
| return PARAM_INVALID; | |||
| GELOGE(ACL_ERROR_GE_SHAPE_INVALID, "Invalie dim c num[%lu].It should be in (0,4]", c); | |||
| return ACL_ERROR_GE_SHAPE_INVALID; | |||
| } | |||
| auto n_o = Ceil(n, c0) * c0; | |||
| @@ -205,21 +205,21 @@ Status PaddingNC(const TransArgs &args, TransArgs &args_tmp, std::shared_ptr<uin | |||
| // data overflow check | |||
| GE_IF_BOOL_EXEC(!CheckInt64MulOverflow(h_o, w_o), | |||
| GELOGE(INTERNAL_ERROR, "int64 mul overflow.A[%ld], B[%ld]", h_o, w_o); | |||
| return INTERNAL_ERROR); | |||
| GELOGE(ACL_ERROR_GE_INTERNAL_ERROR, "int64 mul overflow.A[%ld], B[%ld]", h_o, w_o); | |||
| return ACL_ERROR_GE_INTERNAL_ERROR); | |||
| GE_IF_BOOL_EXEC(!CheckInt64MulOverflow(n_o, c_o), | |||
| GELOGE(INTERNAL_ERROR, "int64 mul overflow.A[%ld], B[%ld]", n_o, c_o); | |||
| return INTERNAL_ERROR); | |||
| GELOGE(ACL_ERROR_GE_INTERNAL_ERROR, "int64 mul overflow.A[%ld], B[%ld]", n_o, c_o); | |||
| return ACL_ERROR_GE_INTERNAL_ERROR); | |||
| auto t1 = h_o * w_o; | |||
| auto t2 = n_o * c_o; | |||
| GE_IF_BOOL_EXEC(!CheckInt64MulOverflow(t1, t2), GELOGE(INTERNAL_ERROR, "int64 mul overflow.A[%ld], B[%ld]", t1, t2); | |||
| return INTERNAL_ERROR); | |||
| GE_IF_BOOL_EXEC(!CheckInt64MulOverflow(t1, t2), GELOGE(ACL_ERROR_GE_INTERNAL_ERROR, "int64 mul overflow.A[%ld], B[%ld]", t1, t2); | |||
| return ACL_ERROR_GE_INTERNAL_ERROR); | |||
| int64_t total_ele_cnt = n_o * c_o * h_o * w_o; | |||
| int size = GetSizeByDataType(args.src_data_type); | |||
| GE_IF_BOOL_EXEC(!CheckInt64MulOverflow(total_ele_cnt, size), | |||
| GELOGE(INTERNAL_ERROR, "int64 mul overflow.A[%ld], B[%d]", total_ele_cnt, size); | |||
| return INTERNAL_ERROR); | |||
| GELOGE(ACL_ERROR_GE_INTERNAL_ERROR, "int64 mul overflow.A[%ld], B[%d]", total_ele_cnt, size); | |||
| return ACL_ERROR_GE_INTERNAL_ERROR); | |||
| int64_t dst_size = total_ele_cnt * size; | |||
| if (dst_size == 0) { | |||
| @@ -228,15 +228,15 @@ Status PaddingNC(const TransArgs &args, TransArgs &args_tmp, std::shared_ptr<uin | |||
| dst.reset(new (std::nothrow) uint8_t[dst_size], std::default_delete<uint8_t[]>()); | |||
| if (dst == nullptr) { | |||
| GELOGE(OUT_OF_MEMORY, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld", | |||
| GELOGE(ACL_ERROR_GE_MEMORY_ALLOCATION, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld", | |||
| TypeUtils::FormatToSerialString(args.src_format).c_str(), | |||
| TypeUtils::FormatToSerialString(args.dst_format).c_str(), dst_size); | |||
| return OUT_OF_MEMORY; | |||
| return ACL_ERROR_GE_MEMORY_ALLOCATION; | |||
| } | |||
| auto ret = memset_s(dst.get(), dst_size, 0, dst_size); | |||
| if (ret != EOK) { | |||
| GELOGE(INTERNAL_ERROR, "memst failed!"); | |||
| return INTERNAL_ERROR; | |||
| GELOGE(ACL_ERROR_GE_MEMORY_OPERATE_FAILED, "memst failed!"); | |||
| return ACL_ERROR_GE_MEMORY_OPERATE_FAILED; | |||
| } | |||
| auto p_s = args.data; | |||
| @@ -249,8 +249,8 @@ Status PaddingNC(const TransArgs &args, TransArgs &args_tmp, std::shared_ptr<uin | |||
| ret = memcpy_s(p_d + (i * c_o * h_o * w_o + j * h_o * w_o) * size, protectSize, | |||
| p_s + (i * c * h * w + j * h * w) * size, block); | |||
| if (ret != EOK) { | |||
| GELOGE(INTERNAL_ERROR, "memcpy_s failed!"); | |||
| return INTERNAL_ERROR; | |||
| GELOGE(ACL_ERROR_GE_MEMORY_OPERATE_FAILED, "memcpy_s failed!"); | |||
| return ACL_ERROR_GE_MEMORY_OPERATE_FAILED; | |||
| } | |||
| protectSize = protectSize - block; | |||
| } | |||
| @@ -270,7 +270,7 @@ Status FormatTransferNchwToFZC04::TransFormat(const TransArgs &args, TransResult | |||
| std::shared_ptr<uint8_t> dst = nullptr; | |||
| auto ret = PaddingNC(args, args_tmp, dst); | |||
| if (ret != SUCCESS) { | |||
| GELOGE(INTERNAL_ERROR, "Padding in NC axis failed!"); | |||
| GELOGE(ret, "Padding in NC axis failed!"); | |||
| return ret; | |||
| } | |||
| @@ -281,26 +281,26 @@ Status FormatTransferNchwToFZC04::TransFormat(const TransArgs &args, TransResult | |||
| } | |||
| if (!IsTransShapeDstCorrect(args_tmp, expect_shape)) { | |||
| return PARAM_INVALID; | |||
| return ACL_ERROR_GE_SHAPE_INVALID; | |||
| } | |||
| if (args_tmp.src_format == FORMAT_NCHW && args_tmp.dst_format == FORMAT_FRACTAL_Z_C04) { | |||
| return TransFormatFromNchwToFzC04(args_tmp, result); | |||
| } | |||
| return UNSUPPORTED; | |||
| return ACL_ERROR_GE_FORMAT_INVALID; | |||
| } | |||
| Status FormatTransferNchwToFZC04::TransShape(Format src_format, const std::vector<int64_t> &src_shape, | |||
| DataType data_type, Format dst_format, std::vector<int64_t> &dst_shape) { | |||
| if (CheckDataTypeSupport(data_type) != SUCCESS) { | |||
| return ACL_ERROR_GE_TRANSSHAPE_DATATYPE_INVALID; | |||
| return ACL_ERROR_GE_DATATYPE_INVALID; | |||
| } | |||
| if (src_format == FORMAT_NCHW && dst_format == FORMAT_FRACTAL_Z_C04) { | |||
| return TransShapeNchwToFzC04(src_shape, data_type, dst_shape); | |||
| } | |||
| return ACL_ERROR_GE_TRANSSHAPE_FORMAT_INVALID; | |||
| return ACL_ERROR_GE_FORMAT_INVALID; | |||
| } | |||
| REGISTER_FORMAT_TRANSFER(FormatTransferNchwToFZC04, FORMAT_NCHW, FORMAT_FRACTAL_Z_C04) | |||
| @@ -32,13 +32,13 @@ Status TransShapeNchwToNc1hwc0(const std::vector<int64_t> &src_shape, DataType d | |||
| std::vector<int64_t> &dst_shape) { | |||
| int64_t c0 = GetCubeSizeByDataType(data_type); | |||
| if (c0 <= 0) { | |||
| GELOGE(ACL_ERROR_GE_TRANSSHAPE_DATATYPE_INVALID, "Failed to get cube size, the data type is invalid"); | |||
| return ACL_ERROR_GE_TRANSSHAPE_DATATYPE_INVALID; | |||
| GELOGE(ACL_ERROR_GE_DATATYPE_INVALID, "Failed to get cube size, the data type is invalid"); | |||
| return ACL_ERROR_GE_DATATYPE_INVALID; | |||
| } | |||
| if (!CheckShapeValid(src_shape, kNchwDimsNum)) { | |||
| GELOGE(ACL_ERROR_GE_TRANSSHAPE_SHAPE_INVALID, "Failed to check src shape %s", | |||
| GELOGE(ACL_ERROR_GE_SHAPE_INVALID, "Failed to check src shape %s", | |||
| ShapeToString(src_shape).c_str()); | |||
| return ACL_ERROR_GE_TRANSSHAPE_SHAPE_INVALID; | |||
| return ACL_ERROR_GE_SHAPE_INVALID; | |||
| } | |||
| dst_shape.clear(); | |||
| dst_shape.push_back(src_shape.at(kNchwN)); | |||
| @@ -47,9 +47,9 @@ Status TransShapeNchwToNc1hwc0(const std::vector<int64_t> &src_shape, DataType d | |||
| dst_shape.push_back(src_shape.at(kNchwW)); | |||
| dst_shape.push_back(c0); | |||
| if (!CheckShapeValid(dst_shape, kNc1hwc0DimsNum)) { | |||
| GELOGE(ACL_ERROR_GE_TRANSSHAPE_SHAPE_INVALID, "Failed to check dst shape %s", | |||
| GELOGE(ACL_ERROR_GE_SHAPE_INVALID, "Failed to check dst shape %s", | |||
| ShapeToString(dst_shape).c_str()); | |||
| return ACL_ERROR_GE_TRANSSHAPE_SHAPE_INVALID; | |||
| return ACL_ERROR_GE_SHAPE_INVALID; | |||
| } | |||
| return SUCCESS; | |||
| } | |||
| @@ -59,8 +59,8 @@ Status CheckArgsForNchwToNc1hwc0(const TransArgs &args) { | |||
| std::string error = "Dose not support trans format from " + | |||
| FmtToStr(TypeUtils::FormatToSerialString(args.src_format)) + " to " + | |||
| FmtToStr(TypeUtils::FormatToSerialString(args.dst_format)); | |||
| GE_ERRORLOG_AND_ERRORMSG(UNSUPPORTED, error.c_str()); | |||
| return UNSUPPORTED; | |||
| GE_ERRORLOG_AND_ERRORMSG(ACL_ERROR_GE_FORMAT_INVALID, error.c_str()); | |||
| return ACL_ERROR_GE_FORMAT_INVALID; | |||
| } | |||
| std::vector<int64_t> expect_5d_shape; | |||
| auto ret = TransShapeNchwToNc1hwc0(args.src_shape, args.src_data_type, expect_5d_shape); | |||
| @@ -68,12 +68,12 @@ Status CheckArgsForNchwToNc1hwc0(const TransArgs &args) { | |||
| return ret; | |||
| } | |||
| if (expect_5d_shape != args.dst_shape) { | |||
| GELOGE(PARAM_INVALID, | |||
| GELOGE(ACL_ERROR_GE_SHAPE_INVALID, | |||
| "Failed to trans format, the src and dst shape are not compatible. data" | |||
| " type %s, src shape %s, dst shape %s, expect dst shape %s", | |||
| TypeUtils::DataTypeToSerialString(args.src_data_type).c_str(), ShapeToString(args.src_shape).c_str(), | |||
| ShapeToString(args.dst_shape).c_str(), ShapeToString(expect_5d_shape).c_str()); | |||
| return PARAM_INVALID; | |||
| return ACL_ERROR_GE_SHAPE_INVALID; | |||
| } | |||
| return SUCCESS; | |||
| @@ -82,12 +82,12 @@ Status CheckArgsForNchwToNc1hwc0(const TransArgs &args) { | |||
| Status GetDstDataAfterTrans(const TransArgs &args, TransResult &result, const int size, const int64_t total_size) { | |||
| std::shared_ptr<uint8_t> dst(new (std::nothrow) uint8_t[total_size], std::default_delete<uint8_t[]>()); | |||
| if (dst == nullptr) { | |||
| GELOGE(OUT_OF_MEMORY, | |||
| GELOGE(ACL_ERROR_GE_MEMORY_ALLOCATION, | |||
| "Failed to trans format from %s to %s, can not alloc the memory for" | |||
| " dst buf %ld, shape %s", | |||
| TypeUtils::FormatToSerialString(args.src_format).c_str(), | |||
| TypeUtils::FormatToSerialString(args.dst_format).c_str(), total_size, ShapeToString(args.dst_shape).c_str()); | |||
| return OUT_OF_MEMORY; | |||
| return ACL_ERROR_GE_MEMORY_ALLOCATION; | |||
| } | |||
| auto n = args.src_shape.at(kNchwN); | |||
| @@ -97,8 +97,8 @@ Status GetDstDataAfterTrans(const TransArgs &args, TransResult &result, const in | |||
| int64_t c0 = GetCubeSizeByDataType(args.src_data_type); | |||
| if (c0 <= 0) { | |||
| GELOGE(INTERNAL_ERROR, "The c0 is invalid %ld", c0); | |||
| return PARAM_INVALID; | |||
| GELOGE(ACL_ERROR_GE_DATATYPE_INVALID, "The c0 is invalid %ld", c0); | |||
| return ACL_ERROR_GE_DATATYPE_INVALID; | |||
| } | |||
| int64_t c1 = (c - 1) / c0 + 1; | |||
| int64_t hw = h * w; | |||
| @@ -129,21 +129,21 @@ Status GetDstDataAfterTrans(const TransArgs &args, TransResult &result, const in | |||
| auto ret = memcpy_s(dst.get() + dst_offset, static_cast<size_t>(protected_size), args.data + src_offset, | |||
| static_cast<size_t>(size)); | |||
| if (ret != EOK) { | |||
| GELOGE(INTERNAL_ERROR, | |||
| GELOGE(ACL_ERROR_GE_MEMORY_OPERATE_FAILED, | |||
| "Failed to copy data from NCHW[%ld] offset %ld to " | |||
| "NC1HWC0[%ld, %ld, %ld, %ld, %ld] offset %ld, err-code %d", | |||
| srcIdx, src_offset, n_idx, c1_idx, h_idx, w_idx, c0_idx, dst_offset, ret); | |||
| return INTERNAL_ERROR; | |||
| return ACL_ERROR_GE_MEMORY_OPERATE_FAILED; | |||
| } | |||
| } else { | |||
| auto ret = | |||
| memset_s(dst.get() + dst_offset, static_cast<size_t>(protected_size), 0, static_cast<size_t>(size)); | |||
| if (ret != EOK) { | |||
| GELOGE(INTERNAL_ERROR, | |||
| GELOGE(ACL_ERROR_GE_MEMORY_OPERATE_FAILED, | |||
| "Failed to set to 0 to " | |||
| "NC1HWC0[%ld, %ld, %ld, %ld, %ld] offset %ld, err-code %d", | |||
| n_idx, c1_idx, h_idx, w_idx, c0_idx, dst_offset, ret); | |||
| return INTERNAL_ERROR; | |||
| return ACL_ERROR_GE_MEMORY_OPERATE_FAILED; | |||
| } | |||
| } | |||
| } | |||
| @@ -159,8 +159,9 @@ Status GetDstDataAfterTrans(const TransArgs &args, TransResult &result, const in | |||
| } // namespace | |||
| Status FormatTransferNchwNc1hwc0::TransFormat(const TransArgs &args, TransResult &result) { | |||
| if (CheckArgsForNchwToNc1hwc0(args) != SUCCESS) { | |||
| return PARAM_INVALID; | |||
| Status ret = CheckArgsForNchwToNc1hwc0(args); | |||
| if (ret != SUCCESS) { | |||
| return ret; | |||
| } | |||
| // Guarantee the validity of parameters in check function | |||
| int size = GetSizeByDataType(args.src_data_type); | |||
| @@ -172,20 +173,21 @@ Status FormatTransferNchwNc1hwc0::TransFormat(const TransArgs &args, TransResult | |||
| return SUCCESS; | |||
| } | |||
| GELOGE(INTERNAL_ERROR, "Get %ld total size from dst shape %s, src shape %s", total_size, | |||
| GELOGE(ACL_ERROR_GE_SHAPE_INVALID, "Get %ld total size from dst shape %s, src shape %s", total_size, | |||
| ShapeToString(args.dst_shape).c_str(), ShapeToString(args.src_shape).c_str()); | |||
| return PARAM_INVALID; | |||
| return ACL_ERROR_GE_SHAPE_INVALID; | |||
| } | |||
| GELOGD( | |||
| "Begin to trans format from NCHW to NC1HWC0, src shape %s, data type " | |||
| "%s, dst shape %s memory size %ld", | |||
| ShapeToString(args.src_shape).c_str(), TypeUtils::DataTypeToSerialString(args.src_data_type).c_str(), | |||
| ShapeToString(args.dst_shape).c_str(), total_size); | |||
| if (GetDstDataAfterTrans(args, result, size, total_size) != SUCCESS) { | |||
| GELOGE(INTERNAL_ERROR, "Failed to get data after trans, src shape %s, data type %s, dst shape %s, memory size %ld", | |||
| ret = GetDstDataAfterTrans(args, result, size, total_size); | |||
| if (ret != SUCCESS) { | |||
| GELOGE(ret, "Failed to get data after trans, src shape %s, data type %s, dst shape %s, memory size %ld", | |||
| ShapeToString(args.src_shape).c_str(), TypeUtils::DataTypeToSerialString(args.src_data_type).c_str(), | |||
| ShapeToString(args.dst_shape).c_str(), total_size); | |||
| return INTERNAL_ERROR; | |||
| return ret; | |||
| } | |||
| return SUCCESS; | |||
| } | |||
| @@ -195,7 +197,7 @@ Status FormatTransferNchwNc1hwc0::TransShape(Format src_format, const std::vecto | |||
| if (src_format == FORMAT_NCHW) { | |||
| return TransShapeNchwToNc1hwc0(src_shape, data_type, dst_shape); | |||
| } else { | |||
| return ACL_ERROR_GE_TRANSSHAPE_FORMAT_INVALID; | |||
| return ACL_ERROR_GE_FORMAT_INVALID; | |||
| } | |||
| } | |||
| @@ -34,8 +34,8 @@ Status TransShapeNhwcToNc1hwc0(const std::vector<int64_t> &src_shape, DataType d | |||
| std::vector<int64_t> &dst_shape) { | |||
| int64_t c0 = GetCubeSizeByDataType(data_type); | |||
| if (c0 <= 0) { | |||
| GELOGE(ACL_ERROR_GE_TRANSSHAPE_DATATYPE_INVALID, "Failed to get cube size, the data type is invalid"); | |||
| return ACL_ERROR_GE_TRANSSHAPE_DATATYPE_INVALID; | |||
| GELOGE(ACL_ERROR_GE_DATATYPE_INVALID, "Failed to get cube size, the data type is invalid"); | |||
| return ACL_ERROR_GE_DATATYPE_INVALID; | |||
| } | |||
| dst_shape.clear(); | |||
| dst_shape.push_back(src_shape.at(kNhwcN)); | |||
| @@ -44,9 +44,9 @@ Status TransShapeNhwcToNc1hwc0(const std::vector<int64_t> &src_shape, DataType d | |||
| dst_shape.push_back(src_shape.at(kNhwcW)); | |||
| dst_shape.push_back(c0); | |||
| if (!CheckShapeValid(dst_shape, kNc1hwc0DimsNum)) { | |||
| GELOGE(ACL_ERROR_GE_TRANSSHAPE_SHAPE_INVALID, "Failed to check dst shape %s", | |||
| GELOGE(ACL_ERROR_GE_SHAPE_INVALID, "Failed to check dst shape %s", | |||
| ShapeToString(dst_shape).c_str()); | |||
| return ACL_ERROR_GE_TRANSSHAPE_SHAPE_INVALID; | |||
| return ACL_ERROR_GE_SHAPE_INVALID; | |||
| } | |||
| return SUCCESS; | |||
| } | |||
| @@ -56,21 +56,21 @@ Status CheckArgsForNhwcToNc1hwc0(const TransArgs &args) { | |||
| std::string error = "Dose not support trans format from " + | |||
| FmtToStr(TypeUtils::FormatToSerialString(args.src_format)) + " to " + | |||
| FmtToStr(TypeUtils::FormatToSerialString(args.dst_format)); | |||
| GE_ERRORLOG_AND_ERRORMSG(UNSUPPORTED, error.c_str()); | |||
| return UNSUPPORTED; | |||
| GE_ERRORLOG_AND_ERRORMSG(ACL_ERROR_GE_FORMAT_INVALID, error.c_str()); | |||
| return ACL_ERROR_GE_FORMAT_INVALID; | |||
| } | |||
| if (!CheckDataTypeSupported(args.src_data_type)) { | |||
| GELOGE(UNSUPPORTED, "Failed to trans shape from NHWC to NC1HWC0, invalid data type %s", | |||
| GELOGE(ACL_ERROR_GE_DATATYPE_INVALID, "Failed to trans shape from NHWC to NC1HWC0, invalid data type %s", | |||
| TypeUtils::DataTypeToSerialString(args.src_data_type).c_str()); | |||
| return UNSUPPORTED; | |||
| return ACL_ERROR_GE_DATATYPE_INVALID; | |||
| } | |||
| if (!CheckShapeValid(args.src_shape, kNhwcDimsNum)) { | |||
| GELOGE(PARAM_INVALID, "Failed to check src shape %s", ShapeToString(args.src_shape).c_str()); | |||
| return PARAM_INVALID; | |||
| GELOGE(ACL_ERROR_GE_SHAPE_INVALID, "Failed to check src shape %s", ShapeToString(args.src_shape).c_str()); | |||
| return ACL_ERROR_GE_SHAPE_INVALID; | |||
| } | |||
| if (!CheckShapeValid(args.dst_shape, kNc1hwc0DimsNum)) { | |||
| GELOGE(PARAM_INVALID, "Failed to check dst shape %s", ShapeToString(args.dst_shape).c_str()); | |||
| return PARAM_INVALID; | |||
| GELOGE(ACL_ERROR_GE_SHAPE_INVALID, "Failed to check dst shape %s", ShapeToString(args.dst_shape).c_str()); | |||
| return ACL_ERROR_GE_SHAPE_INVALID; | |||
| } | |||
| std::vector<int64_t> expect_dst_shape; | |||
| auto ret = TransShapeNhwcToNc1hwc0(args.src_shape, args.src_data_type, expect_dst_shape); | |||
| @@ -78,12 +78,12 @@ Status CheckArgsForNhwcToNc1hwc0(const TransArgs &args) { | |||
| return ret; | |||
| } | |||
| if (args.dst_shape != expect_dst_shape) { | |||
| GELOGE(PARAM_INVALID, | |||
| GELOGE(ACL_ERROR_GE_SHAPE_INVALID, | |||
| "Failed to trans format, the src and dst shape are not compatible. src shape %s, dst shape %s, " | |||
| "expect dst shape %s", | |||
| ShapeToString(args.src_shape).c_str(), ShapeToString(args.dst_shape).c_str(), | |||
| ShapeToString(expect_dst_shape).c_str()); | |||
| return PARAM_INVALID; | |||
| return ACL_ERROR_GE_SHAPE_INVALID; | |||
| } | |||
| return SUCCESS; | |||
| @@ -92,10 +92,10 @@ Status CheckArgsForNhwcToNc1hwc0(const TransArgs &args) { | |||
| Status GetDstDataAfterTrans(const TransArgs &args, TransResult &result, const int size, const int64_t total_size) { | |||
| std::shared_ptr<uint8_t> dst(new (std::nothrow) uint8_t[total_size], std::default_delete<uint8_t[]>()); | |||
| if (dst == nullptr) { | |||
| GELOGE(OUT_OF_MEMORY, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld, shape %s", | |||
| GELOGE(ACL_ERROR_GE_MEMORY_ALLOCATION, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld, shape %s", | |||
| TypeUtils::FormatToSerialString(args.src_format).c_str(), | |||
| TypeUtils::FormatToSerialString(args.dst_format).c_str(), total_size, ShapeToString(args.dst_shape).c_str()); | |||
| return OUT_OF_MEMORY; | |||
| return ACL_ERROR_GE_MEMORY_ALLOCATION; | |||
| } | |||
| auto n = args.src_shape.at(kNhwcN); | |||
| @@ -131,19 +131,19 @@ Status GetDstDataAfterTrans(const TransArgs &args, TransResult &result, const in | |||
| if (c_idx < c) { | |||
| auto ret = memcpy_s(dst.get() + dst_offset, protected_size, args.data + src_offset, size); | |||
| if (ret != EOK) { | |||
| GELOGE(INTERNAL_ERROR, | |||
| GELOGE(ACL_ERROR_GE_MEMORY_OPERATE_FAILED, | |||
| "Failed to copy data from NHWC[%ld, %ld, %ld, %ld] offset %ld to " | |||
| "NC1HWC0[%ld, %ld, %ld, %ld, %ld] offset %ld err-code %d", | |||
| n_idx, h_idx, w_idx, c_idx, src_offset, n_idx, c1_idx, h_idx, w_idx, c0_idx, dst_offset, ret); | |||
| return INTERNAL_ERROR; | |||
| return ACL_ERROR_GE_MEMORY_OPERATE_FAILED; | |||
| } | |||
| } else { | |||
| auto ret = memset_s(dst.get() + dst_offset, protected_size, 0, size); | |||
| if (ret != EOK) { | |||
| GELOGE(INTERNAL_ERROR, | |||
| GELOGE(ACL_ERROR_GE_MEMORY_OPERATE_FAILED, | |||
| "Failed to set 0 to NC1HWC0[%ld, %ld, %ld, %ld, %ld] offset %ld base err-code %d", n_idx, c1_idx, | |||
| h_idx, w_idx, c0_idx, dst_offset, ret); | |||
| return INTERNAL_ERROR; | |||
| return ACL_ERROR_GE_MEMORY_OPERATE_FAILED; | |||
| } | |||
| } | |||
| } | |||
| @@ -158,8 +158,9 @@ Status GetDstDataAfterTrans(const TransArgs &args, TransResult &result, const in | |||
| } // namespace | |||
| Status FormatTransferNhwcNc1hwc0::TransFormat(const TransArgs &args, TransResult &result) { | |||
| if (CheckArgsForNhwcToNc1hwc0(args) != SUCCESS) { | |||
| return PARAM_INVALID; | |||
| Status ret = CheckArgsForNhwcToNc1hwc0(args); | |||
| if (ret != SUCCESS) { | |||
| return ret; | |||
| } | |||
| int size = GetSizeByDataType(args.src_data_type); | |||
| auto total_size = GetItemNumByShape(args.dst_shape) * size; | |||
| @@ -170,18 +171,20 @@ Status FormatTransferNhwcNc1hwc0::TransFormat(const TransArgs &args, TransResult | |||
| return SUCCESS; | |||
| } | |||
| GELOGE(INTERNAL_ERROR, "Get %ld total size from dst shape %s, src shape %s", total_size, | |||
| GELOGE(ACL_ERROR_GE_DATATYPE_INVALID, "Get %ld total size from dst shape %s, src shape %s", total_size, | |||
| ShapeToString(args.dst_shape).c_str(), ShapeToString(args.src_shape).c_str()); | |||
| return PARAM_INVALID; | |||
| return ACL_ERROR_GE_DATATYPE_INVALID; | |||
| } | |||
| GELOGD("Begin to trans format from NHWC to NC1HWC0, src shape %s, data type %s, dst shape %s, memory size %ld", | |||
| ShapeToString(args.src_shape).c_str(), TypeUtils::DataTypeToSerialString(args.src_data_type).c_str(), | |||
| ShapeToString(args.dst_shape).c_str(), total_size); | |||
| if (GetDstDataAfterTrans(args, result, size, total_size) != SUCCESS) { | |||
| GELOGE(INTERNAL_ERROR, "Failed to get data after trans, src shape %s, data type %s, dst shape %s, memory size %ld", | |||
| ret = GetDstDataAfterTrans(args, result, size, total_size); | |||
| if (ret != SUCCESS) { | |||
| GELOGE(ret, "Failed to get data after trans, src shape %s, data type %s, dst shape %s, memory size %ld", | |||
| ShapeToString(args.src_shape).c_str(), TypeUtils::DataTypeToSerialString(args.src_data_type).c_str(), | |||
| ShapeToString(args.dst_shape).c_str(), total_size); | |||
| return INTERNAL_ERROR; | |||
| return ret; | |||
| } | |||
| return SUCCESS; | |||
| } | |||
| @@ -190,15 +193,15 @@ Status FormatTransferNhwcNc1hwc0::TransShape(Format src_format, const std::vecto | |||
| DataType data_type, Format dst_format, std::vector<int64_t> &dst_shape) { | |||
| if (src_format == FORMAT_NHWC && CheckDataTypeSupported(data_type)) { | |||
| if (!CheckShapeValid(src_shape, kNhwcDimsNum)) { | |||
| GELOGE(ACL_ERROR_GE_TRANSSHAPE_SHAPE_INVALID, "Failed to check src shape %s", | |||
| GELOGE(ACL_ERROR_GE_SHAPE_INVALID, "Failed to check src shape %s", | |||
| ShapeToString(src_shape).c_str()); | |||
| return ACL_ERROR_GE_TRANSSHAPE_SHAPE_INVALID; | |||
| return ACL_ERROR_GE_SHAPE_INVALID; | |||
| } | |||
| return TransShapeNhwcToNc1hwc0(src_shape, data_type, dst_shape); | |||
| } else if (src_format != FORMAT_NHWC) { | |||
| return ACL_ERROR_GE_TRANSSHAPE_FORMAT_INVALID; | |||
| return ACL_ERROR_GE_FORMAT_INVALID; | |||
| } else { | |||
| return ACL_ERROR_GE_TRANSSHAPE_DATATYPE_INVALID; | |||
| return ACL_ERROR_GE_DATATYPE_INVALID; | |||
| } | |||
| } | |||
| @@ -141,7 +141,7 @@ std::vector<int64_t> TransShapeByPerm(const std::vector<int64_t> &src_shape, con | |||
| Status Transpose(const uint8_t *src, const std::vector<int64_t> &src_shape, DataType src_data_type, | |||
| const std::vector<int64_t> &perm_arg, TransResult &result) { | |||
| if (!IsTransposeArgValid(src, src_shape, src_data_type, perm_arg)) { | |||
| return PARAM_INVALID; | |||
| return ACL_ERROR_GE_PARAM_INVALID; | |||
| } | |||
| auto dst_shape = TransShapeByPerm(src_shape, perm_arg); | |||
| @@ -172,12 +172,12 @@ Status Transpose(const uint8_t *src, const std::vector<int64_t> &src_shape, Data | |||
| auto ret = memcpy_s(dst.get() + dst_offset_bytes, static_cast<size_t>(protected_size), src + src_offset, | |||
| static_cast<size_t>(data_size)); | |||
| if (ret != EOK) { | |||
| GELOGE(INTERNAL_ERROR, | |||
| GELOGE(ACL_ERROR_GE_MEMORY_OPERATE_FAILED, | |||
| "Failed to transpose, src shape %s, perm arg %s, dst shape %s, " | |||
| "failed to write to dst offset %ld, current dim offset %s", | |||
| ShapeToString(src_shape).c_str(), ShapeToString(perm_arg).c_str(), ShapeToString(dst_shape).c_str(), | |||
| dst_offset_bytes, ShapeToString(dst_indexes).c_str()); | |||
| return INTERNAL_ERROR; | |||
| return ACL_ERROR_GE_MEMORY_OPERATE_FAILED; | |||
| } | |||
| AddOne(dst_shape, dst_indexes); | |||
| ++dst_index; | |||
| @@ -192,14 +192,14 @@ Status TransposeWithShapeCheck(const uint8_t *data, const std::vector<int64_t> & | |||
| const std::vector<int64_t> &dst_shape, DataType src_data_type, | |||
| const std::vector<int64_t> &perm_arg, TransResult &result) { | |||
| if (!IsTransposeArgValid(data, src_shape, src_data_type, perm_arg)) { | |||
| return PARAM_INVALID; | |||
| return ACL_ERROR_GE_PARAM_INVALID; | |||
| } | |||
| auto expected_shape = TransShapeByPerm(src_shape, perm_arg); | |||
| if (dst_shape != expected_shape) { | |||
| std::string error = "Failed to trans axis for perm_arg" + | |||
| FmtToStr(ShapeToString(perm_arg)) + ", invalid dst shape" + | |||
| FmtToStr(ShapeToString(dst_shape)) + ", expect" + FmtToStr(ShapeToString(expected_shape)); | |||
| GE_ERRORLOG_AND_ERRORMSG(PARAM_INVALID, error.c_str()); | |||
| GE_ERRORLOG_AND_ERRORMSG(ACL_ERROR_GE_SHAPE_INVALID, error.c_str()); | |||
| } | |||
| return Transpose(data, src_shape, src_data_type, perm_arg, result); | |||
| @@ -211,16 +211,16 @@ Status GetPermByForamt(Format src_format, Format dst_format, std::vector<int64_t | |||
| std::string error = "Failed to trans shape, do not support transpose from format " + | |||
| FmtToStr(TypeUtils::FormatToSerialString(src_format)) + " to " + | |||
| FmtToStr(TypeUtils::FormatToSerialString(dst_format)); | |||
| GE_ERRORLOG_AND_ERRORMSG(ACL_ERROR_GE_TRANSSHAPE_FORMAT_INVALID, error.c_str()); | |||
| return ACL_ERROR_GE_TRANSSHAPE_FORMAT_INVALID; | |||
| GE_ERRORLOG_AND_ERRORMSG(ACL_ERROR_GE_FORMAT_INVALID, error.c_str()); | |||
| return ACL_ERROR_GE_FORMAT_INVALID; | |||
| } | |||
| auto iter = dst_iter->second.find(dst_format); | |||
| if (iter == dst_iter->second.end()) { | |||
| std::string error = "Failed to trans shape, do not support transpose from format " + | |||
| FmtToStr(TypeUtils::FormatToSerialString(src_format)) + " to " + | |||
| FmtToStr(TypeUtils::FormatToSerialString(dst_format)); | |||
| GE_ERRORLOG_AND_ERRORMSG(ACL_ERROR_GE_TRANSSHAPE_FORMAT_INVALID, error.c_str()); | |||
| return ACL_ERROR_GE_TRANSSHAPE_FORMAT_INVALID; | |||
| GE_ERRORLOG_AND_ERRORMSG(ACL_ERROR_GE_FORMAT_INVALID, error.c_str()); | |||
| return ACL_ERROR_GE_FORMAT_INVALID; | |||
| } | |||
| perm = iter->second; | |||
| return SUCCESS; | |||
| @@ -233,7 +233,7 @@ Status FormatTransferTranspose::TransFormat(const TransArgs &args, TransResult & | |||
| return ret; | |||
| } | |||
| if (!IsTransShapeDstCorrect(args, expected_shape)) { | |||
| return PARAM_INVALID; | |||
| return ACL_ERROR_GE_SHAPE_INVALID; | |||
| } | |||
| return Transpose(args.data, args.src_shape, args.src_data_type, perm_args[args.src_format][args.dst_format], result); | |||
| @@ -244,7 +244,7 @@ Status FormatTransferTranspose::TransShape(Format src_format, const std::vector< | |||
| std::vector<int64_t> perm_arg; | |||
| GE_CHK_STATUS_RET_NOLOG(GetPermByForamt(src_format, dst_format, perm_arg)); | |||
| if (!IsShapeArgValid(src_shape, perm_arg)) { | |||
| return ACL_ERROR_GE_TRANSSHAPE_SHAPE_INVALID; | |||
| return ACL_ERROR_GE_SHAPE_INVALID; | |||
| } | |||
| dst_shape = TransShapeByPerm(src_shape, perm_arg); | |||
| return SUCCESS; | |||
| @@ -38,14 +38,14 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Status TransFormat(const TransArg | |||
| std::string error = "Failed to trans data from format " + | |||
| FmtToStr(TypeUtils::FormatToSerialString(args.src_format)) + " to " + | |||
| FmtToStr(TypeUtils::FormatToSerialString(args.dst_format)); | |||
| GE_ERRORLOG_AND_ERRORMSG(UNSUPPORTED, error.c_str()); | |||
| return UNSUPPORTED; | |||
| GE_ERRORLOG_AND_ERRORMSG(ACL_ERROR_GE_FORMAT_INVALID, error.c_str()); | |||
| return ACL_ERROR_GE_FORMAT_INVALID; | |||
| } | |||
| auto src_shape_size = GetItemNumByShape(args.src_shape); | |||
| if (args.data == nullptr && src_shape_size != 0) { | |||
| GELOGE(PARAM_INVALID, "Invalid input null data"); | |||
| return PARAM_INVALID; | |||
| GELOGE(ACL_ERROR_GE_PARAM_INVALID, "Invalid input null data"); | |||
| return ACL_ERROR_GE_PARAM_INVALID; | |||
| } | |||
| return transfer->TransFormat(args, result); | |||
| @@ -64,8 +64,8 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Status TransShape(Format src_form | |||
| std::string error = "Failed to trans data from format " + | |||
| FmtToStr(TypeUtils::FormatToSerialString(args.src_format)) + " to " + | |||
| FmtToStr(TypeUtils::FormatToSerialString(args.dst_format)); | |||
| GE_ERRORLOG_AND_ERRORMSG(ACL_ERROR_GE_TRANSSHAPE_FORMAT_INVALID, error.c_str()); | |||
| return ACL_ERROR_GE_TRANSSHAPE_FORMAT_INVALID; | |||
| GE_ERRORLOG_AND_ERRORMSG(ACL_ERROR_GE_FORMAT_INVALID, error.c_str()); | |||
| return ACL_ERROR_GE_FORMAT_INVALID; | |||
| } | |||
| return transfer->TransShape(src_format, src_shape, data_type, dst_format, dst_shape); | |||
| @@ -77,13 +77,13 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Status TransDataType(const CastAr | |||
| std::string error = "Failed to trans data from datatype " + | |||
| FmtToStr(TypeUtils::DataTypeToSerialString(args.src_data_type)) + " to " + | |||
| FmtToStr(TypeUtils::DataTypeToSerialString(args.dst_data_type)); | |||
| GE_ERRORLOG_AND_ERRORMSG(UNSUPPORTED, error.c_str()); | |||
| return UNSUPPORTED; | |||
| GE_ERRORLOG_AND_ERRORMSG(ACL_ERROR_GE_DATATYPE_INVALID, error.c_str()); | |||
| return ACL_ERROR_GE_DATATYPE_INVALID; | |||
| } | |||
| if (args.data == nullptr && args.src_data_size != 0) { | |||
| GELOGE(PARAM_INVALID, "Invalid input null data"); | |||
| return PARAM_INVALID; | |||
| GELOGE(ACL_ERROR_GE_PARAM_INVALID, "Invalid input null data"); | |||
| return ACL_ERROR_GE_PARAM_INVALID; | |||
| } | |||
| return transfer->TransDataType(args, result); | |||
| @@ -217,7 +217,7 @@ std::string DNNEngineManager::GetDNNEngineName(const ge::NodePtr &node_ptr) { | |||
| std::string unsupported_reason; | |||
| // It will be replaced by engine' checksupport | |||
| uint64_t start_time = GetCurrentTimestamp(); | |||
| if (kernel_info_store->second->CheckSupported(op_desc, unsupported_reason)) { | |||
| if (kernel_info_store->second->CheckSupported(node_ptr, unsupported_reason)) { | |||
| checksupport_cost_[kernel_name] += GetCurrentTimestamp() - start_time; | |||
| op_desc->SetOpEngineName(it.engine); | |||
| op_desc->SetOpKernelLibName(kernel_name); | |||
| @@ -47,7 +47,7 @@ void GetGeTensorDescFromDomiInfo(std::vector<ge::TensorDesc> &ge_descs, | |||
| uint32_t idx = 0; | |||
| for (auto desc_item : domi_descs) { | |||
| ge::TensorDesc ge_desc; | |||
| ge_desc.SetName(desc_item.name); | |||
| ge_desc.SetName(desc_item.name.c_str()); | |||
| ge_desc.SetDataType(static_cast<ge::DataType>(desc_item.data_type)); | |||
| ge_desc.SetFormat(static_cast<ge::Format>(formats[idx])); | |||
| std::vector<int64_t> shape_dims; | |||
| @@ -66,7 +66,8 @@ bool ContainsDynamicInpus(const ge::OpDesc &op_desc) { | |||
| } // namespace | |||
| namespace ge { | |||
| static Status CheckEngineTypeSupport(const OpDescPtr &op_desc, OpEngineType engine_type) { | |||
| static Status CheckEngineTypeSupport(const NodePtr &node, OpEngineType engine_type) { | |||
| const OpDescPtr &op_desc = node->GetOpDesc(); | |||
| GE_CHECK_NOTNULL_EXEC(op_desc, return PARAM_INVALID); | |||
| if (engine_type == ENGINE_SYS) { | |||
| GELOGI("CheckEngineType: use default engine."); | |||
| @@ -123,7 +124,7 @@ static Status CheckEngineTypeSupport(const OpDescPtr &op_desc, OpEngineType engi | |||
| auto kernel_info_store = kernel_map.find(kernel_name); | |||
| if (kernel_info_store != kernel_map.end()) { | |||
| std::string unsupported_reason; | |||
| if (kernel_info_store->second->CheckSupported(op_desc, unsupported_reason)) { | |||
| if (kernel_info_store->second->CheckSupported(node, unsupported_reason)) { | |||
| op_desc->SetOpEngineName(op_engine_name); | |||
| op_desc->SetOpKernelLibName(kernel_name); | |||
| GELOGI("CheckEngineType:Set OpKernelLibName %s and engine name %s into op_desc %s", kernel_name.c_str(), | |||
| @@ -697,22 +698,23 @@ Status GeGenerator::BuildSingleOp(OpDescPtr &op_desc, const vector<GeTensor> &in | |||
| OpDescPtr op_desc_tmp = AttrUtils::CloneOpDesc(op_desc); | |||
| GE_CHECK_NOTNULL(op_desc_tmp); | |||
| // 1. check engine type when compile online | |||
| // 1. Create ComputeGraph. | |||
| string name = ge::CurrentTimeInStr() + "_" + model_file_name; | |||
| Graph graph; | |||
| GE_CHK_STATUS(BuildSingleOpGraph(op_desc, inputs, outputs, name, graph), "make graph fail."); | |||
| // 2. check engine type when compile online | |||
| if (model_file_name == kFileNameSuffix) { | |||
| Status ret = CheckEngineTypeSupport(op_desc, engine_type); | |||
| auto comp_graph = GraphUtils::GetComputeGraph(graph); | |||
| GE_CHECK_NOTNULL(comp_graph); | |||
| auto node = comp_graph->FindNode(op_desc->GetName()); | |||
| Status ret = CheckEngineTypeSupport(node, engine_type); | |||
| if (ret != SUCCESS) { | |||
| GELOGE(ret, "check engine type failed."); | |||
| return ret; | |||
| } | |||
| } | |||
| // 2. Create ComputeGraph. | |||
| string name = ge::CurrentTimeInStr() + "_" + model_file_name; | |||
| Graph graph; | |||
| if (BuildSingleOpGraph(op_desc, inputs, outputs, name, graph) != ge::SUCCESS) { | |||
| GELOGE(GRAPH_FAILED, "make graph fail."); | |||
| return GRAPH_FAILED; | |||
| } | |||
| GELOGI("ATC parser success in single op build."); | |||
| GeRootModelPtr ge_root_model = nullptr; | |||
| @@ -131,7 +131,7 @@ class GraphMemoryAssigner { | |||
| std::map<NodePtr, uint32_t> &node_2_continuous_type); | |||
| ge::Status AssignContinuousInputMemoryWithAtomicProcess(const NodePtr &input_continuous_node, | |||
| uint32_t continuous_type, bool reverse_refresh=false); | |||
| uint32_t continuous_type, bool reverse_refresh = false); | |||
| ge::Status FilterAtomicNodesForMemoryAssign(map<string, map<NodePtr, vector<NodePtr>>> &normal_atomic_nodes_map, | |||
| map<string, vector<NodePtr>> &connecting_output_atomic_nodes); | |||
| @@ -261,7 +261,9 @@ Status ModelBuilder::SetInputOutputDesc() { | |||
| GE_IF_BOOL_EXEC(n->GetInAllNodes().empty() && n->GetOutAllNodes().empty(), continue;); | |||
| SetInputIsConst(n); | |||
| if (IsGeLocalOp(n->GetOpDesc())) { | |||
| bool is_unknow = false; | |||
| (void)NodeUtils::GetNodeUnknownShapeStatus(*n, is_unknow); | |||
| if ((IsGeLocalOp(n->GetOpDesc())) && (!is_unknow)) { | |||
| GE_CHK_STATUS_RET(CalcOutputSize(n), "Calculate output size failed"); | |||
| } | |||
| ret = AdjustConstWeightSize(n, weight_offset_); | |||
| @@ -124,7 +124,7 @@ inline bool IsDataOp(const std::string &node_type) { | |||
| return (node_type == DATA_TYPE) || (node_type == AIPP_DATA_TYPE) || (node_type == ANN_DATA_TYPE); | |||
| } | |||
| inline bool IsTbeTask(const OpDescPtr &op_desc) { | |||
| bool IsTbeTask(const OpDescPtr &op_desc) { | |||
| uint32_t run_mode = static_cast<uint32_t>(domi::ImplyType::INVALID); | |||
| if (!AttrUtils::GetInt(op_desc, ATTR_NAME_IMPLY_TYPE, run_mode)) { | |||
| return false; | |||
| @@ -527,25 +527,25 @@ Status DavinciModel::DoTaskSink() { | |||
| } | |||
| GE_CHK_RT_RET(rtGetAicpuDeploy(&deploy_type_)); | |||
| GELOGI("do task_sink. AiCpu deploy type is: %x.", deploy_type_); | |||
| GELOGI("Do task_sink. AiCpu deploy type is: %x.", deploy_type_); | |||
| GE_CHK_STATUS_RET(BindModelStream(), "Bind model stream failed"); | |||
| GE_CHK_STATUS_RET(BindModelStream(), "Bind model stream failed."); | |||
| if (known_node_) { | |||
| GE_CHK_STATUS_RET(MallocKnownArgs(), "Mallloc known node args failed"); | |||
| GE_CHK_STATUS_RET(MallocKnownArgs(), "Mallloc known node args failed."); | |||
| } | |||
| GE_CHK_STATUS_RET(InitTaskInfo(*model_task_def.get()), "InitTaskInfo failed"); | |||
| GE_CHK_STATUS_RET(InitTaskInfo(*model_task_def.get()), "InitTaskInfo failed."); | |||
| GE_CHK_STATUS_RET(ModelManager::GetInstance()->LaunchCustAicpuSo(), "Launch cust aicpu so failed"); | |||
| GE_CHK_STATUS_RET(ModelManager::GetInstance()->LaunchCustAicpuSo(), "Launch cust aicpu so failed."); | |||
| GE_CHK_STATUS_RET(ModelManager::GetInstance()->CheckAicpuOpList(ge_model_), "Check aicpu op type failed"); | |||
| GE_CHK_STATUS_RET(ModelManager::GetInstance()->CheckAicpuOpList(ge_model_), "Check aicpu op type failed."); | |||
| GE_CHK_STATUS_RET(InitEntryTask(), "InitEntryTask failed"); | |||
| GE_CHK_STATUS_RET(InitEntryTask(), "InitEntryTask failed."); | |||
| GE_CHK_STATUS_RET(InitL1DataDumperArgs(), "InitL1DataDumperArgs failed"); | |||
| GE_CHK_STATUS_RET(InitL1DataDumperArgs(), "InitL1DataDumperArgs failed."); | |||
| GE_CHK_STATUS_RET(DistributeTask(), "Distribute failed"); | |||
| GE_CHK_STATUS_RET(DistributeTask(), "Distribute failed."); | |||
| GE_CHK_RT_RET(rtModelLoadComplete(rt_model_handle_)); | |||
| @@ -558,7 +558,7 @@ Status DavinciModel::SetTSDevice() { | |||
| int64_t value = 0; | |||
| bool ret = ge::AttrUtils::GetInt(ge_model_, ATTR_MODEL_CORE_TYPE, value); | |||
| uint32_t core_type = ret ? static_cast<uint32_t>(value) : 0; | |||
| GELOGD("SetTSDevice: %u", core_type); | |||
| GELOGD("SetTSDevice: %u.", core_type); | |||
| rtError_t rt_ret = rtSetTSDevice(core_type); | |||
| if (rt_ret != RT_ERROR_NONE) { | |||
| GELOGE(RT_FAILED, "SetTSDevice failed, ret: 0x%X", rt_ret); | |||
| @@ -570,7 +570,7 @@ Status DavinciModel::SetTSDevice() { | |||
| Status DavinciModel::OpDebugRegister() { | |||
| bool is_op_debug = false; | |||
| (void)ge::AttrUtils::GetBool(ge_model_, ATTR_OP_DEBUG_FLAG, is_op_debug); | |||
| GELOGD("The value of op debug in ge_model is %d", is_op_debug); | |||
| GELOGD("The value of op debug in ge_model is %d.", is_op_debug); | |||
| if (is_op_debug) { | |||
| debug_reg_mutex_.lock(); | |||
| rtError_t rt_ret = rtMalloc(&op_debug_addr_, kOpDebugMemorySize, RT_MEMORY_DDR); | |||
| @@ -1214,7 +1214,7 @@ void DavinciModel::GetAllGearsInfo(const NodePtr &node) { | |||
| } | |||
| if (!gear_info.empty()) { | |||
| all_gears_info_.emplace_back(gear_info); | |||
| GELOGD("Init all gears info from %s, gaer info is %s.", node->GetName().c_str(), | |||
| GELOGD("Init all gears info from %s, gaer info is %s", node->GetName().c_str(), | |||
| formats::JoinToString(gear_info).c_str()); | |||
| } | |||
| } | |||
| @@ -1283,7 +1283,7 @@ Status DavinciModel::GetGearAndRealOutSizeInfo(const ComputeGraphPtr &graph, con | |||
| Status DavinciModel::GetRealOutputSizeOfCase(const ComputeGraphPtr &graph, size_t input_index, | |||
| const NodePtr &case_node) { | |||
| GELOGD("Start get output size of %s, which is %zu input to netoutput.", case_node->GetName().c_str(), input_index); | |||
| GELOGD("Start get output size of %s, which is %zu input to netoutput", case_node->GetName().c_str(), input_index); | |||
| const auto &func_desc = case_node->GetOpDesc(); | |||
| GE_CHECK_NOTNULL(func_desc); | |||
| std::map<vector<int32_t>, int64_t> gear_and_real_out_size_info; | |||
| @@ -1328,7 +1328,7 @@ Status DavinciModel::GetRealOutputSizeOfCase(const ComputeGraphPtr &graph, size_ | |||
| } | |||
| Status DavinciModel::GetGearAndRealOutShapeInfo(const ComputeGraphPtr &graph, const NodePtr &node) { | |||
| GELOGD("Start to get dynamic output dims of %s.", node->GetName().c_str()); | |||
| GELOGD("Start to get dynamic output dims of %s", node->GetName().c_str()); | |||
| merge_nodes_gear_and_real_out_shape_info_.clear(); | |||
| size_t idx = 0; | |||
| for (const auto &in_anchor : node->GetAllInDataAnchors()) { | |||
| @@ -1342,7 +1342,7 @@ Status DavinciModel::GetGearAndRealOutShapeInfo(const ComputeGraphPtr &graph, co | |||
| if ((peer_node->GetType() == CASE) && (op_desc->HasAttr(ATTR_INSERT_BY_MBATCH))) { | |||
| std::vector<std::string> dynamic_output_shape_info; | |||
| if (!AttrUtils::GetListStr(node->GetOpDesc(), ATTR_NAME_DYNAMIC_OUTPUT_DIMS, dynamic_output_shape_info)) { | |||
| GELOGD("Can not get dynamic output dims attr from %s.", node->GetName().c_str()); | |||
| GELOGD("Can not get dynamic output dims attr from %s", node->GetName().c_str()); | |||
| return SUCCESS; | |||
| } | |||
| GELOGI("Dynamic output shape info is %s", formats::JoinToString(dynamic_output_shape_info).c_str()); | |||
| @@ -1362,7 +1362,7 @@ Status DavinciModel::GetGearAndRealOutShapeInfo(const ComputeGraphPtr &graph, co | |||
| output_shape.emplace_back(it[i]); | |||
| } | |||
| gear_and_real_out_shape_info[all_gears_info_[gear_index]] = output_shape; | |||
| GELOGD("Get real gear index is: %zu, gear info is %s, output shape is %s.", | |||
| GELOGD("Get real gear index is: %zu, gear info is %s, output shape is %s", | |||
| gear_index, formats::JoinToString(all_gears_info_[gear_index]).c_str(), | |||
| formats::JoinToString(output_shape).c_str()); | |||
| } | |||
| @@ -1385,7 +1385,7 @@ void DavinciModel::ParseDynamicOutShape(const std::vector<std::string> &str_info | |||
| } | |||
| shape.emplace_back(std::strtol(dim.c_str(), nullptr, kDecimal)); | |||
| } | |||
| GELOGI("Shape from attr is %s.", formats::JoinToString(shape).c_str()); | |||
| GELOGI("Shape from attr is %s", formats::JoinToString(shape).c_str()); | |||
| vec_info.emplace_back(shape); | |||
| } | |||
| } | |||
| @@ -1428,7 +1428,7 @@ Status DavinciModel::InitLabelSet(const OpDescPtr &op_desc) { | |||
| return INTERNAL_ERROR; | |||
| } | |||
| GELOGI("InitLabelSet: label[%u]=%p stream[%u]=%p.", label_index, rt_label, stream_id, stream); | |||
| GELOGI("InitLabelSet: label[%u]=%p stream[%u]=%p", label_index, rt_label, stream_id, stream); | |||
| label_id_indication_.insert(label_index); | |||
| label_list_[label_index] = rt_label; | |||
| return SUCCESS; | |||
| @@ -1831,7 +1831,7 @@ void DavinciModel::GetUserDesignateShapeOrder(std::vector<std::string> &user_inp | |||
| /// | |||
| Status DavinciModel::InitAippInfo(uint32_t index, const OpDescPtr &op_desc) { | |||
| if (!op_desc->HasAttr(ATTR_NAME_AIPP)) { | |||
| GELOGW("There is not AIPP related with index %u.", index); | |||
| GELOGW("There is not AIPP related with index %u", index); | |||
| return SUCCESS; | |||
| } | |||
| @@ -1861,7 +1861,7 @@ Status DavinciModel::InitAippInfo(uint32_t index, const OpDescPtr &op_desc) { | |||
| Status DavinciModel::GetAippInfo(uint32_t index, AippConfigInfo &aipp_info) const { | |||
| const auto it = aipp_info_list_.find(index); | |||
| if (it == aipp_info_list_.end()) { | |||
| GELOGW("there is not AIPP related with index %u.", index); | |||
| GELOGW("there is not AIPP related with index %u", index); | |||
| return ACL_ERROR_GE_AIPP_NOT_EXIST; | |||
| } | |||
| @@ -1871,7 +1871,7 @@ Status DavinciModel::GetAippInfo(uint32_t index, AippConfigInfo &aipp_info) cons | |||
| Status DavinciModel::InitAippType(uint32_t index, const OpDescPtr &op_desc, const map<uint32_t, OpDescPtr> &data_list) { | |||
| if (!op_desc->HasAttr(ATTR_DATA_RELATED_AIPP_MODE)) { | |||
| GELOGW("There is no aipp releated info with index %u.", index); | |||
| GELOGW("There is no aipp releated info with index %u", index); | |||
| return SUCCESS; | |||
| } | |||
| @@ -1916,7 +1916,7 @@ Status DavinciModel::GetAippType(uint32_t index, InputAippType &aipp_type, size_ | |||
| GE_CHK_BOOL_RET_STATUS(index < input_addrs_list_.size(), PARAM_INVALID, "Index %u is invalid", index); | |||
| const auto it = aipp_type_list_.find(index); | |||
| if (it == aipp_type_list_.end()) { | |||
| GELOGW("There is no aipp releated info with index %u.", index); | |||
| GELOGW("There is no aipp releated info with index %u", index); | |||
| aipp_type = DATA_WITHOUT_AIPP; | |||
| aipp_index = 0xFFFFFFFF; | |||
| return SUCCESS; | |||
| @@ -271,7 +271,8 @@ ge::Status ModelManager::SetDynamicSize(uint32_t model_id, const std::vector<uin | |||
| return SUCCESS; | |||
| } | |||
| ge::Status ModelManager::DoLoadHybridModelOnline(uint32_t model_id, const string &model_name, const shared_ptr<ge::GeRootModel> &ge_root_model, | |||
| ge::Status ModelManager::DoLoadHybridModelOnline(uint32_t model_id, const string &model_name, | |||
| const shared_ptr<ge::GeRootModel> &ge_root_model, | |||
| const shared_ptr<ModelListener> &listener) { | |||
| auto hybrid_model = hybrid::HybridDavinciModel::Create(ge_root_model); | |||
| GE_CHECK_NOTNULL(hybrid_model); | |||
| @@ -73,7 +73,8 @@ class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY ModelManager { | |||
| ge::Status LoadModelOnline(uint32_t &model_id, const std::shared_ptr<ge::GeRootModel> &ge_root_model, | |||
| std::shared_ptr<ModelListener> listener); | |||
| ge::Status DoLoadHybridModelOnline(uint32_t model_id, const string &model_name, const shared_ptr<ge::GeRootModel> &ge_root_model, | |||
| ge::Status DoLoadHybridModelOnline(uint32_t model_id, const string &model_name, | |||
| const shared_ptr<ge::GeRootModel> &ge_root_model, | |||
| const std::shared_ptr<ModelListener> &listener); | |||
| /// | |||
| @@ -387,7 +387,7 @@ Status ModelUtils::GetVarAddr(const RuntimeParam &model_param, const ConstOpDesc | |||
| GELOGE(PARAM_INVALID, "rdma var addr is invalid, addr=%p", reinterpret_cast<uint8_t *>(offset)); | |||
| return PARAM_INVALID; | |||
| } | |||
| var_addr = reinterpret_cast<uint8_t *>(offset); | |||
| var_addr = reinterpret_cast<uint8_t *>(static_cast<uintptr_t>(offset)); | |||
| break; | |||
| case RT_MEMORY_HBM: | |||
| VALIDATE_MEM_RANGE(op_desc, model_param.var_size, offset - model_param.logic_var_base); | |||
| @@ -3196,7 +3196,7 @@ Status GraphManager::SaveVariables(const Graph &graph, const std::vector<std::st | |||
| return FAILED; | |||
| } else { | |||
| auto var_tensor = var_results[var_name].GetTensorDesc(); | |||
| var_tensor.SetName(var_name); | |||
| var_tensor.SetName(var_name.c_str()); | |||
| var_results[var_name].SetTensorDesc(var_tensor); | |||
| var_values.emplace_back(var_results[var_name]); | |||
| } | |||
| @@ -3205,7 +3205,7 @@ Status GraphManager::SaveVariables(const Graph &graph, const std::vector<std::st | |||
| for (auto iter = var_results.begin(); iter != var_results.end(); ++iter) { | |||
| string var_name = iter->first; | |||
| auto var_tensor = iter->second.GetTensorDesc(); | |||
| var_tensor.SetName(var_name); | |||
| var_tensor.SetName(var_name.c_str()); | |||
| iter->second.SetTensorDesc(var_tensor); | |||
| var_values.emplace_back(iter->second); | |||
| } | |||
| @@ -601,6 +601,8 @@ std::string Cluster::DebugString() const { | |||
| case KNOWN_SHAPE: | |||
| ss << "KNOW"; | |||
| break; | |||
| default: | |||
| break; | |||
| } | |||
| ss << "[" << id_ << "](size:" << nodes_.size() << ")"; | |||
| ss << "(" << min_ << "," << max_ << ")("; | |||
| @@ -167,7 +167,7 @@ bool CastTranslatePass::IsOpSupportedOptimize(NodePtr &cast_node, NodePtr &trans | |||
| trans_op_outdesc->SetDataType(cast_out_datatype); | |||
| } | |||
| if (!TranslateCheckAccuracySupported(trans_op_desc)) { | |||
| if (!TranslateCheckAccuracySupported(trans_node)) { | |||
| if (is_src_cast) { | |||
| trans_op_desc->MutableInputDesc(0)->SetDataType(trans_in_datatype); | |||
| } else { | |||
| @@ -271,7 +271,8 @@ Status CastTranslatePass::FuseDstNTranslates(NodePtr &node) { | |||
| return SUCCESS; | |||
| } | |||
| bool CastTranslatePass::TranslateCheckAccuracySupported(const OpDescPtr &op_desc) { | |||
| bool CastTranslatePass::TranslateCheckAccuracySupported(NodePtr &node) { | |||
| const OpDescPtr &op_desc = node->GetOpDesc(); | |||
| std::shared_ptr<GELib> instance_ptr = ge::GELib::GetInstance(); | |||
| if ((instance_ptr == nullptr) || (!instance_ptr->InitFlag())) { | |||
| GELOGW("GE is not initialized or is finalized."); | |||
| @@ -293,7 +294,7 @@ bool CastTranslatePass::TranslateCheckAccuracySupported(const OpDescPtr &op_desc | |||
| auto kernel_info_store = kernel_map.find(kernel_name); | |||
| if (kernel_info_store != kernel_map.end()) { | |||
| if (kernel_info_store->second != nullptr && | |||
| kernel_info_store->second->CheckAccuracySupported(op_desc, unsupported_reason)) { | |||
| kernel_info_store->second->CheckAccuracySupported(node, unsupported_reason)) { | |||
| return true; | |||
| } | |||
| } | |||
| @@ -35,7 +35,7 @@ class CastTranslatePass : public BaseNodePass { | |||
| bool IsOpSupportedOptimize(NodePtr &cast_node, NodePtr &trans_node, bool &is_src_cast); | |||
| bool CheckOpSupportOptimize(NodePtr &node, bool &is_src_cast); | |||
| Status FuseDstNTranslates(NodePtr &node); | |||
| bool TranslateCheckAccuracySupported(const OpDescPtr &op_desc); | |||
| bool TranslateCheckAccuracySupported(NodePtr &node); | |||
| }; | |||
| } // namespace ge | |||
| #endif // GE_GRAPH_PASSES_CAST_TRANSLATE_PASS_H_ | |||
| @@ -110,7 +110,7 @@ graphStatus CompileNodesPass::GetSupportedKernel(const NodePtr &node, const std: | |||
| return ge::GE_GRAPH_PARAM_NULLPTR; | |||
| } | |||
| // begin accuracy supported check | |||
| if (!CheckAccuracySupport(kernel_info, instance, op_desc)) { | |||
| if (!CheckAccuracySupport(kernel_info, instance, node)) { | |||
| // if check accuracy support failed , try to go to other engine. | |||
| GELOGD("Check Accuracy Supported return not support, node name is %s. Try to go to other engine.", | |||
| op_desc->GetName().c_str()); | |||
| @@ -123,7 +123,7 @@ graphStatus CompileNodesPass::GetSupportedKernel(const NodePtr &node, const std: | |||
| continue; | |||
| } | |||
| OpsKernelInfoStorePtr tmp_kernel_info = it->second; | |||
| if (CheckAccuracySupport(tmp_kernel_info, instance, op_desc)) { | |||
| if (CheckAccuracySupport(tmp_kernel_info, instance, node)) { | |||
| kernel_lib_name = tmp_kernel_name; | |||
| GELOGD("Find kernel lib %s support node:%s, type:%s , get kernel lib success.", tmp_kernel_name.c_str(), | |||
| node->GetName().c_str(), op_desc->GetType().c_str()); | |||
| @@ -138,14 +138,9 @@ graphStatus CompileNodesPass::GetSupportedKernel(const NodePtr &node, const std: | |||
| } | |||
| bool CompileNodesPass::CheckAccuracySupport(const OpsKernelInfoStorePtr &kernel_info, | |||
| const std::shared_ptr<GELib> instance, OpDescPtr &op_desc) { | |||
| auto ge_desc = MakeShared<ge::OpDescPtr>(op_desc); | |||
| if (ge_desc == nullptr) { | |||
| GELOGE(GE_GRAPH_MEMORY_ALLOC_FAILED, "Fail to malloc op desc."); | |||
| return false; | |||
| } | |||
| const std::shared_ptr<GELib> instance, const NodePtr &node) { | |||
| string reason; | |||
| if (!(kernel_info->CheckAccuracySupported(*ge_desc, reason, true))) { | |||
| if (!(kernel_info->CheckAccuracySupported(node, reason, true))) { | |||
| return false; | |||
| } | |||
| return true; | |||
| @@ -39,7 +39,7 @@ class CompileNodesPass : public GraphPass { | |||
| private: | |||
| graphStatus GetSupportedKernel(const NodePtr &node, const std::shared_ptr<GELib> instance, string &kernel_lib_name); | |||
| bool CheckAccuracySupport(const OpsKernelInfoStorePtr &kernel_info, const std::shared_ptr<GELib> instance, | |||
| OpDescPtr &op_desc); | |||
| const NodePtr &node); | |||
| graphStatus CompileNodes(const std::shared_ptr<GELib> instance, | |||
| std::unordered_map<string, vector<NodePtr>> &kernel_to_compile_nodes); | |||
| }; | |||
| @@ -29,13 +29,13 @@ const int kRemoveInputIndex = 1; | |||
| Status DimensionAdjustPass::Run(ge::NodePtr &node) { | |||
| if (node == nullptr) { | |||
| GELOGE(PARAM_INVALID, "node is nullptr"); | |||
| GELOGE(PARAM_INVALID, "node is nullptr."); | |||
| return PARAM_INVALID; | |||
| } | |||
| OpDescPtr op_desc_ptr = node->GetOpDesc(); | |||
| if (op_desc_ptr == nullptr) { | |||
| GELOGE(PARAM_INVALID, "GetOpDesc return nullptr"); | |||
| GELOGE(PARAM_INVALID, "GetOpDesc return nullptr."); | |||
| return PARAM_INVALID; | |||
| } | |||
| @@ -33,11 +33,11 @@ Status FlowCtrlPass::Run(ComputeGraphPtr compute_graph) { | |||
| GE_CHECK_NOTNULL(compute_graph); | |||
| if (!PassUtils::IsNeedTrainIteFlowCtrl(compute_graph)) { | |||
| GELOGI("No need FlowCtrl for graph %u", compute_graph->GetGraphID()); | |||
| GELOGI("No need FlowCtrl for graph %u.", compute_graph->GetGraphID()); | |||
| return NOT_CHANGED; | |||
| } | |||
| GELOGI("FlowCtrl pass begin.graph is [%s]", compute_graph->GetName().c_str()); | |||
| GELOGI("FlowCtrl pass begin.graph is [%s].", compute_graph->GetName().c_str()); | |||
| bool graph_change = false; | |||
| // 1. Add FP/BP flow ctrl (big cycle) | |||
| for (auto &node : compute_graph->GetDirectNode()) { | |||
| @@ -347,11 +347,11 @@ Status FlowCtrlPass::CreateIterCtrlFalseBranch(ComputeGraphPtr &compute_graph, c | |||
| NodePtr assign_node = | |||
| InsertAssignOp(compute_graph, ASSIGN, NODE_NAME_FLOWCTRL_LOOP_ASSIGN, loop_cond_node, loop_reset_node); | |||
| if (assign_node == nullptr || switch_node == nullptr) { | |||
| GELOGE(PARAM_INVALID, "assign_node or switch node is null"); | |||
| GELOGE(PARAM_INVALID, "assign_node or switch node is null."); | |||
| return FAILED; | |||
| } | |||
| GE_CHK_STATUS_RET(SetStreamLabel(assign_node, switch_node->GetName()), "set stream label failed"); | |||
| GE_CHK_STATUS_RET(SetStreamLabel(assign_node, switch_node->GetName()), "set stream label failed."); | |||
| graphStatus add_ret = GraphUtils::AddEdge(switch_node->GetOutControlAnchor(), assign_node->GetInControlAnchor()); | |||
| if (add_ret != GRAPH_SUCCESS) { | |||
| @@ -370,7 +370,7 @@ Status FlowCtrlPass::CreateIterCtrlFalseBranch(ComputeGraphPtr &compute_graph, c | |||
| } | |||
| GE_CHK_STATUS_RET(SetStreamLabel(active_node, switch_node->GetName()), "set stream label failed"); | |||
| GE_CHK_STATUS_RET(SetSwitchBranchNodeLabel(active_node, switch_node->GetName()), | |||
| "set switch branch node label failed"); | |||
| "set switch branch node label failed."); | |||
| string model_exit_name = switch_node->GetName() + "_ModelExit"; | |||
| GE_CHK_STATUS_RET(SetActiveLabelList(active_node, { model_exit_name }), "set active label list failed"); | |||
| @@ -401,7 +401,7 @@ Status FlowCtrlPass::CreateIterCtrlFalseBranch(ComputeGraphPtr &compute_graph, c | |||
| } | |||
| Status FlowCtrlPass::AddFpBpIteratorCtrl(ComputeGraphPtr &compute_graph, NodePtr &pre_node) { | |||
| GE_IF_BOOL_EXEC(pre_node == nullptr, DOMI_LOGE("pre_node is nullptr"); return FAILED); | |||
| GE_IF_BOOL_EXEC(pre_node == nullptr, DOMI_LOGE("pre_node is nullptr."); return FAILED); | |||
| string pre_node_name = pre_node->GetName(); | |||
| GELOGI("Add FpBp Iterator ctrl, pre node:%s.", pre_node_name.c_str()); | |||
| // 1. Get or add variables | |||
| @@ -477,7 +477,7 @@ Status FlowCtrlPass::AddSpecialNodeIteratorCtrl(ComputeGraphPtr &compute_graph, | |||
| * itersPerLoop loopCond | |||
| */ | |||
| GE_IF_BOOL_EXEC(loop_after_node == nullptr || compute_graph == nullptr, | |||
| DOMI_LOGE("loop after node or compute graph is null"); return FAILED); | |||
| DOMI_LOGE("loop after node or compute graph is null."); return FAILED); | |||
| InDataAnchorPtr in_anchor = loop_after_node->GetInDataAnchor(0); | |||
| if (in_anchor == nullptr || in_anchor->GetPeerOutAnchor() == nullptr) { | |||
| GELOGE(FAILED, "Find %s in data anchor failed.", loop_after_node->GetName().c_str()); | |||
| @@ -498,7 +498,7 @@ Status FlowCtrlPass::AddSpecialNodeIteratorCtrl(ComputeGraphPtr &compute_graph, | |||
| } | |||
| // 2. Add StreamSwitch and edges to switch_node. | |||
| GE_IF_BOOL_EXEC(loop_pre_node == nullptr, DOMI_LOGE("loop pre node is null"); return FAILED); | |||
| GE_IF_BOOL_EXEC(loop_pre_node == nullptr, DOMI_LOGE("loop pre node is null."); return FAILED); | |||
| string switch_name = loop_pre_node->GetName() + "_" + NODE_NAME_STREAM_SWITCH; | |||
| NodePtr switch_node = InsertStreamSwitchOp(compute_graph, switch_name, loop_cond_node, iter_per_loop_node); | |||
| if (switch_node == nullptr) { | |||
| @@ -506,7 +506,7 @@ Status FlowCtrlPass::AddSpecialNodeIteratorCtrl(ComputeGraphPtr &compute_graph, | |||
| return FAILED; | |||
| } | |||
| GE_CHK_STATUS_RET(SetStreamLabel(switch_node, switch_name), "set stream label failed"); | |||
| GE_CHK_STATUS_RET(SetStreamLabel(switch_node, switch_name), "set stream label failed."); | |||
| graphStatus add_ret = GraphUtils::AddEdge(loop_pre_node->GetOutControlAnchor(), switch_node->GetInControlAnchor()); | |||
| if (add_ret != GRAPH_SUCCESS) { | |||
| @@ -529,7 +529,7 @@ Status FlowCtrlPass::AddSpecialNodeIteratorCtrl(ComputeGraphPtr &compute_graph, | |||
| return FAILED; | |||
| } | |||
| GE_CHK_STATUS_RET(SetStreamLabel(active_node, active_name), "set stream label failed"); | |||
| GE_CHK_STATUS_RET(SetStreamLabel(active_node, active_name), "set stream label failed."); | |||
| GE_IF_BOOL_EXEC(!AttrUtils::SetBool(active_node->GetOpDesc(), ATTR_NAME_IS_LOOP_ACTIVE, true), | |||
| DOMI_LOGE("set ATTR_NAME_IS_LOOP_ACTIVE failed"); return FAILED); | |||
| @@ -542,7 +542,7 @@ Status FlowCtrlPass::AddSpecialNodeIteratorCtrl(ComputeGraphPtr &compute_graph, | |||
| } | |||
| // used for stream assign to find true branch | |||
| GE_CHK_STATUS_RET(SetActiveLabelList(switch_node, { active_name }), "set active label list failed"); | |||
| GE_CHK_STATUS_RET(SetActiveLabelList(switch_node, { active_name }), "set active label list failed."); | |||
| // used for stream assign to find active stream | |||
| GE_CHK_STATUS_RET(SetActiveLabelList(active_node, { loop_pre_node->GetName() }), "set active label list failed"); | |||
| active_nodes_in_iter_loop_.push_back(active_node); | |||
| @@ -63,16 +63,17 @@ Status ResourcePairAddControlPass::Run(ComputeGraphPtr graph) { | |||
| NodePtr from_node = prefix_2_node.second; | |||
| GE_CHECK_NOTNULL(from_node); | |||
| auto to_item_prefix_2_node = prefix_2_node_per_type.find(resource_type_pair.second); | |||
| // stackpush and stackpop may exist in two subgraphs, no necessary to report error | |||
| if (to_item_prefix_2_node == prefix_2_node_per_type.end()) { | |||
| GELOGE(PARAM_INVALID, "find peer type node fail, suffix:%s, from_type:%s, to_type:%s", prefix.c_str(), | |||
| GELOGW("find peer type node fail, suffix:%s, from_type:%s, to_type:%s", prefix.c_str(), | |||
| resource_type_pair.first.c_str(), resource_type_pair.second.c_str()); | |||
| return PARAM_INVALID; | |||
| continue; | |||
| } | |||
| auto to_prefix_2_node = to_item_prefix_2_node->second.find(prefix); | |||
| if (to_prefix_2_node == to_item_prefix_2_node->second.end()) { | |||
| GELOGE(PARAM_INVALID, "find peer prefix node fail, suffix:%s, from_type:%s, to_type:%s", prefix.c_str(), | |||
| GELOGW("find peer prefix node fail, suffix:%s, from_type:%s, to_type:%s", prefix.c_str(), | |||
| resource_type_pair.first.c_str(), resource_type_pair.second.c_str()); | |||
| return PARAM_INVALID; | |||
| continue; | |||
| } | |||
| NodePtr to_node = to_prefix_2_node->second; | |||
| GE_CHECK_NOTNULL(to_node); | |||
| @@ -63,16 +63,17 @@ Status ResourcePairRemoveControlPass::Run(ComputeGraphPtr graph) { | |||
| NodePtr from_node = prefix_2_node.second; | |||
| GE_CHECK_NOTNULL(from_node); | |||
| auto to_item_prefix_2_node = prefix_2_node_per_type.find(resource_type_pair.second); | |||
| // stackpush and stackpop may exist in two subgraphs, no necessary to report error | |||
| if (to_item_prefix_2_node == prefix_2_node_per_type.end()) { | |||
| GELOGE(INTERNAL_ERROR, "find peer type node fail, suffix:%s, from_type:%s, to_type:%s", prefix.c_str(), | |||
| GELOGW("find peer type node fail, suffix:%s, from_type:%s, to_type:%s", prefix.c_str(), | |||
| resource_type_pair.first.c_str(), resource_type_pair.second.c_str()); | |||
| return domi::PARAM_INVALID; | |||
| continue; | |||
| } | |||
| auto to_prefix_2_node = to_item_prefix_2_node->second.find(prefix); | |||
| if (to_prefix_2_node == to_item_prefix_2_node->second.end()) { | |||
| GELOGE(INTERNAL_ERROR, "find peer prefix node fail, suffix:%s, from_type:%s, to_type:%s", prefix.c_str(), | |||
| GELOGW("find peer prefix node fail, suffix:%s, from_type:%s, to_type:%s", prefix.c_str(), | |||
| resource_type_pair.first.c_str(), resource_type_pair.second.c_str()); | |||
| return domi::PARAM_INVALID; | |||
| continue; | |||
| } | |||
| NodePtr to_node = to_prefix_2_node->second; | |||
| GE_CHECK_NOTNULL(to_node); | |||
| @@ -67,7 +67,7 @@ OpDescPtr SameTransdataBreadthFusionPass::GetCastOp(const GeTensorDesc &in_desc, | |||
| auto fusion_cast_op_count = atomic_fusion_cast_op_count.fetch_add(1); | |||
| std::stringstream cast_op_name; | |||
| cast_op_name << "fusion_cast_" << fusion_cast_op_count; | |||
| auto node_op = ge::OperatorFactory::CreateOperator(cast_op_name.str(), CAST); | |||
| auto node_op = ge::OperatorFactory::CreateOperator(cast_op_name.str().c_str(), CAST); | |||
| auto cast_op = ge::OpDescUtils::GetOpDescFromOperator(node_op); | |||
| node_op.BreakConnect(); | |||
| if (cast_op == nullptr) { | |||
| @@ -86,7 +86,7 @@ Status TransposeTransDataPass::Run(NodePtr &node) { | |||
| if (CheckOneInAndOneOutDataAnchor(out_node)) { | |||
| return FAILED; | |||
| } | |||
| if (!FusionIfNeed(op_desc, out_op_desc)) { | |||
| if (!FusionIfNeed(op_desc, out_node)) { | |||
| continue; | |||
| } | |||
| CopyInputEdges(node, out_node); | |||
| @@ -152,7 +152,8 @@ Status TransposeTransDataPass::RemoveTranspose(NodePtr &node) { | |||
| return SUCCESS; | |||
| } | |||
| bool TransposeTransDataPass::FusionIfNeed(OpDescPtr &op_desc, OpDescPtr &transdata_op_desc) { | |||
| bool TransposeTransDataPass::FusionIfNeed(OpDescPtr &op_desc, NodePtr &node) { | |||
| auto transdata_op_desc = node->GetOpDesc(); | |||
| GE_CHECK_NOTNULL(op_desc); | |||
| GE_CHECK_NOTNULL(transdata_op_desc); | |||
| auto out_input_desc = transdata_op_desc->MutableInputDesc(0); | |||
| @@ -187,7 +188,7 @@ bool TransposeTransDataPass::FusionIfNeed(OpDescPtr &op_desc, OpDescPtr &transda | |||
| out_input_desc->SetFormat(src_format); | |||
| out_input_desc->SetShape(src_shape); | |||
| if (!TransDataCheckAccuracySupported(transdata_op_desc)) { | |||
| if (!TransDataCheckAccuracySupported(node)) { | |||
| out_input_desc->SetFormat(out_input_format); | |||
| out_input_desc->SetShape(out_input_shape); | |||
| return false; | |||
| @@ -224,7 +225,8 @@ void TransposeTransDataPass::CopyInputEdges(NodePtr &origin_node, NodePtr &new_n | |||
| GraphUtils::CopyInCtrlEdges(origin_node, new_node) != GRAPH_SUCCESS, GELOGW("Copy in ctrl edges failed"); return); | |||
| } | |||
| bool TransposeTransDataPass::TransDataCheckAccuracySupported(const OpDescPtr &op_desc) { | |||
| bool TransposeTransDataPass::TransDataCheckAccuracySupported(NodePtr &node) { | |||
| const OpDescPtr &op_desc = node->GetOpDesc(); | |||
| std::shared_ptr<GELib> instance_ptr = ge::GELib::GetInstance(); | |||
| if ((instance_ptr == nullptr) || (!instance_ptr->InitFlag())) { | |||
| GELOGW("GELib not initialized"); | |||
| @@ -244,7 +246,7 @@ bool TransposeTransDataPass::TransDataCheckAccuracySupported(const OpDescPtr &op | |||
| auto &kernel_name = it.opKernelLib; | |||
| auto kernel_info_store = kernel_map.find(kernel_name); | |||
| if (kernel_info_store != kernel_map.end()) { | |||
| if (kernel_info_store->second->CheckAccuracySupported(op_desc, unsupported_reason, true)) { | |||
| if (kernel_info_store->second->CheckAccuracySupported(node, unsupported_reason, true)) { | |||
| return true; | |||
| } | |||
| } | |||
| @@ -26,9 +26,9 @@ class TransposeTransDataPass : public BaseNodePass { | |||
| private: | |||
| Status CheckOneInAndOneOutDataAnchor(NodePtr &node) const; | |||
| Status RemoveTranspose(NodePtr &node); | |||
| bool FusionIfNeed(OpDescPtr &op_desc, OpDescPtr &transdata_op_desc); | |||
| bool FusionIfNeed(OpDescPtr &op_desc, NodePtr &node); | |||
| void CopyInputEdges(NodePtr &origin_node, NodePtr &new_node); | |||
| bool TransDataCheckAccuracySupported(const OpDescPtr &op_desc); | |||
| bool TransDataCheckAccuracySupported(NodePtr &node); | |||
| }; | |||
| } // namespace ge | |||
| #endif // GE_GRAPH_PASSES_TRANSPOSE_TRANSDATA_PASS_H_ | |||
| @@ -600,7 +600,7 @@ Status MultiBatchGraphCopyer::LabelInBatchBranchStatus() { | |||
| for (auto &in_node : node->GetInDataNodes()) { | |||
| if (origin_nodes_status_.find(in_node.get()) != origin_nodes_status_.end()) { | |||
| if (origin_nodes_status_.find(node.get()) == origin_nodes_status_.end()) { | |||
| origin_nodes_status_[node.get()] == kNodeInBatchBranch; | |||
| origin_nodes_status_[node.get()] = kNodeInBatchBranch; | |||
| ResetEnterStatus(frame_enters, node); | |||
| changed = true; | |||
| } | |||
| @@ -458,8 +458,8 @@ Status HybridModelAsyncExecutor::Execute(const std::vector<DataBuffer> &inputs, | |||
| i, outputs[i].length, output_real_size); | |||
| return FAILED; | |||
| } | |||
| GE_CHK_RT_RET(rtMemcpy(outputs[i].data, outputs[i].length, args.outputs[i].GetData(), output_real_size, | |||
| RT_MEMCPY_DEVICE_TO_DEVICE)); | |||
| GE_CHK_RT_RET(rtMemcpy(outputs[i].data, outputs[i].length, args.outputs[i].GetData(), output_real_size, | |||
| RT_MEMCPY_DEVICE_TO_DEVICE)); | |||
| } | |||
| outputs[i].length = output_real_size; | |||
| } | |||
| @@ -60,7 +60,7 @@ class StageExecutor { | |||
| BlockingQueue<StageTask> task_queue_; | |||
| std::unique_ptr<SubgraphExecutor> root_graph_executor_; | |||
| GraphExecutionContext context_; | |||
| StageExecutor *next_executor_; | |||
| StageExecutor *next_executor_ = nullptr; | |||
| rtStream_t stream_ = nullptr; | |||
| }; | |||
| @@ -30,7 +30,7 @@ namespace ge { | |||
| namespace hybrid { | |||
| class TbeHandleHolder { | |||
| public: | |||
| TbeHandleHolder(void *bin_handle); | |||
| explicit TbeHandleHolder(void *bin_handle); | |||
| ~TbeHandleHolder(); | |||
| void SetBinHandle(void *bin_handle) { bin_handle_ = bin_handle; } | |||
| @@ -360,6 +360,7 @@ Status AicpuTfNodeTask::Init(const HybridModel &model) { | |||
| need_sync_ = true; | |||
| } | |||
| auto task_defs = model.GetTaskDefs(node_item_->node); | |||
| GE_CHECK_NOTNULL(task_defs); | |||
| if (unknown_type_ == DEPEND_COMPUTE) { | |||
| GE_CHK_STATUS_RET_NOLOG(SetMemCopyTask((*task_defs)[1])); | |||
| } | |||
| @@ -669,7 +670,7 @@ Status AicpuNodeTask::Init(const HybridModel &model) { | |||
| auto kernel_type = static_cast<ccKernelType>(context.kernel_type()); | |||
| if (kernel_type == ccKernelType::CUST_AI_CPU) { | |||
| bool loaded = false; | |||
| GE_CHK_STATUS_RET(ModelManager::GetInstance()->LoadCustAicpuSo(op_desc, so_name, loaded), | |||
| GE_CHK_STATUS_RET(ModelManager::GetInstance()->LoadCustAicpuSo(op_desc, so_name, loaded), | |||
| "load cust aicpu so failed."); | |||
| if (!loaded) { | |||
| GE_CHK_STATUS_RET(ModelManager::GetInstance()->LaunchCustAicpuSo(), "Launch cust aicpu so failed."); | |||
| @@ -554,6 +554,7 @@ Status GEInit::Finalize() { | |||
| if (instance_ptr != nullptr) { | |||
| return instance_ptr->Finalize(); | |||
| } | |||
| return SUCCESS; | |||
| } | |||
| string GEInit::GetPath() { | |||
| @@ -64,7 +64,7 @@ const std::string kInputFormat = "input_format"; | |||
| * @param cfg_path [IN] the config file path | |||
| * @return graphStatus | |||
| */ | |||
| typedef graphStatus (*SetOpAttrFun)(ComputeGraphPtr &graph, const std::string &cfg_path); | |||
| using SetOpAttrFun = graphStatus (*)(ComputeGraphPtr &graph, const std::string &cfg_path); | |||
| const std::map<aclgrphAttrType, SetOpAttrFun> kAttrTypeFuncMap = { | |||
| {ATTR_TYPE_KEEP_DTYPE, KeepDtypeFunc}, | |||
| @@ -798,11 +798,17 @@ void SaveCustomCaffeProtoPath() { | |||
| Status CreateInputsForInference(const ge::Graph &graph, vector<ge::GeTensor> &inputs) { | |||
| auto compute_graph = ge::GraphUtils::GetComputeGraph(graph); | |||
| GE_CHECK_NOTNULL(compute_graph); | |||
| int64_t index = 0; | |||
| for (ge::NodePtr &input_node : compute_graph->GetAllNodes()) { | |||
| GE_CHECK_NOTNULL(input_node); | |||
| ge::OpDescPtr op = input_node->GetOpDesc(); | |||
| GE_CHECK_NOTNULL(op); | |||
| if (op->GetType() == ge::DATA) { | |||
| if (!op->HasAttr(ge::ATTR_NAME_INDEX)) { | |||
| (void)ge::AttrUtils::SetInt(op, ge::ATTR_NAME_INDEX, index); | |||
| GELOGD("Set attr index:%ld for data op:%s", index, op->GetName().c_str()); | |||
| } | |||
| index++; | |||
| GELOGI("Data op inputDesc size is: %zu", op->GetAllInputsDesc().size()); | |||
| ge::GeTensorDesc tensor = op->GetInputDesc(0); | |||
| string data_op_name = op->GetName(); | |||
| @@ -70,7 +70,8 @@ Status OpTask::OpenDump(rtStream_t stream) { | |||
| uint64_t output_addr = arg_base[input_size + j]; | |||
| output_adds.emplace_back(output_addr); | |||
| } | |||
| dump_op_.SetDumpInfo(DumpManager::GetInstance().GetDumpProperties(kInferSessionId), op_desc_, input_addrs, output_adds, stream); | |||
| dump_op_.SetDumpInfo(DumpManager::GetInstance().GetDumpProperties(kInferSessionId), | |||
| op_desc_, input_addrs, output_adds, stream); | |||
| auto status = dump_op_.LaunchDumpOp(); | |||
| if (status != SUCCESS) { | |||
| GELOGE(status, "Launch dump op failed in single op"); | |||
| @@ -504,7 +505,7 @@ Status AiCpuBaseTask::UpdateOutputShape(vector<GeTensorDesc> &output_desc) { | |||
| "AiCpuCCTask Update [%zu]th output shape failed.", i); | |||
| if (DumpManager::GetInstance().GetDumpProperties(kInferSessionId).IsSingleOpNeedDump()) { | |||
| GE_CHK_STATUS_RET(op_desc_->UpdateOutputDesc(i, output_desc[i]), | |||
| "AiCpuCCTask Update [%zu]th output desc failed.", i); | |||
| "AiCpuCCTask Update [%zu]th output desc failed.", i); | |||
| } | |||
| } | |||
| GELOGD("Update DEPEND_SHAPE_RANGE AiCpuBaseTask outputshape finished."); | |||
| @@ -711,7 +712,7 @@ Status AiCpuTask::UpdateShapeByHbmBuffer(vector<GeTensorDesc> &output_desc) { | |||
| "AiCpuTask update [%zu]th output shape failed.", i); | |||
| if (DumpManager::GetInstance().GetDumpProperties(kInferSessionId).IsSingleOpNeedDump()) { | |||
| GE_CHK_STATUS_RET(op_desc_->UpdateOutputDesc(i, output_desc[i]), | |||
| "AiCpuTask update [%zu]th output desc failed.", i); | |||
| "AiCpuTask update [%zu]th output desc failed.", i); | |||
| } | |||
| } | |||
| return SUCCESS; | |||
| @@ -110,9 +110,9 @@ GE_ERRORNO_EXTERNAL(ACL_ERROR_GE_AIPP_MODE_INVALID, "AIPP mode invalid."); | |||
| GE_ERRORNO_EXTERNAL(ACL_ERROR_GE_OP_TASK_TYPE_INVALID, "Task type invalid."); | |||
| GE_ERRORNO_EXTERNAL(ACL_ERROR_GE_OP_KERNEL_TYPE_INVALID, "Kernel type invalid."); | |||
| GE_ERRORNO_EXTERNAL(ACL_ERROR_GE_PLGMGR_PATH_INVALID, "Plugin path is invalid."); | |||
| GE_ERRORNO_EXTERNAL(ACL_ERROR_GE_TRANSSHAPE_FORMAT_INVALID, "Format is invalid when transferring shape."); | |||
| GE_ERRORNO_EXTERNAL(ACL_ERROR_GE_TRANSSHAPE_SHAPE_INVALID, "Shape is invalid when transferring shape."); | |||
| GE_ERRORNO_EXTERNAL(ACL_ERROR_GE_TRANSSHAPE_DATATYPE_INVALID, "Datatype is invalid when transferring shape."); | |||
| GE_ERRORNO_EXTERNAL(ACL_ERROR_GE_FORMAT_INVALID, "Format is invalid."); | |||
| GE_ERRORNO_EXTERNAL(ACL_ERROR_GE_SHAPE_INVALID, "Shape is invalid."); | |||
| GE_ERRORNO_EXTERNAL(ACL_ERROR_GE_DATATYPE_INVALID, "Datatype is invalid."); | |||
| GE_ERRORNO_EXTERNAL(ACL_ERROR_GE_MEMORY_ALLOCATION, "Memory allocation error."); | |||
| GE_ERRORNO_EXTERNAL(ACL_ERROR_GE_MEMORY_OPERATE_FAILED, "Failed to operate memory."); | |||
| @@ -53,9 +53,9 @@ static const uint32_t ACL_ERROR_GE_AIPP_MODE_INVALID = 145016; | |||
| static const uint32_t ACL_ERROR_GE_OP_TASK_TYPE_INVALID = 145017; | |||
| static const uint32_t ACL_ERROR_GE_OP_KERNEL_TYPE_INVALID = 145018; | |||
| static const uint32_t ACL_ERROR_GE_PLGMGR_PATH_INVALID = 145019; | |||
| static const uint32_t ACL_ERROR_GE_TRANSSHAPE_FORMAT_INVALID = 145020; | |||
| static const uint32_t ACL_ERROR_GE_TRANSSHAPE_SHAPE_INVALID = 145021; | |||
| static const uint32_t ACL_ERROR_GE_TRANSSHAPE_DATATYPE_INVALID = 145022; | |||
| static const uint32_t ACL_ERROR_GE_FORMAT_INVALID = 145020; | |||
| static const uint32_t ACL_ERROR_GE_SHAPE_INVALID = 145021; | |||
| static const uint32_t ACL_ERROR_GE_DATATYPE_INVALID = 145022; | |||
| static const uint32_t ACL_ERROR_GE_MEMORY_ALLOCATION = 245000; | |||
| static const uint32_t ACL_ERROR_GE_MEMORY_OPERATE_FAILED = 245001; | |||
| static const uint32_t ACL_ERROR_GE_INTERNAL_ERROR = 545000; | |||
| @@ -1 +1 @@ | |||
| Subproject commit 7a51997cbd34e1869b9fb4ea5597a021e6427272 | |||
| Subproject commit 6b802ec3cf711e9942a7e2a74f04a53647aae473 | |||
| @@ -1 +1 @@ | |||
| Subproject commit 227b10355427038785e95c81a41cda99893eba08 | |||
| Subproject commit 6a07f1a8b9b8b4630a5b60d9d8d02ec4a6314d68 | |||
| @@ -690,6 +690,7 @@ set(PASS_TEST_FILES | |||
| "graph/passes/infershape_pass_unittest.cc" | |||
| "graph/passes/multi_batch_clone_pass_unittest.cc" | |||
| "graph/passes/replace_with_empty_const_pass_unittest.cc" | |||
| "graph/passes/transpose_transdata_pass_unittest.cc" | |||
| ) | |||
| set(KERNEL_TEST_FILES | |||
| @@ -365,7 +365,7 @@ TEST_F(UtestDataTypeTransfer, invalid_src_data_type) { | |||
| TransResult result; | |||
| DataTypeTransfer transfer; | |||
| EXPECT_EQ(transfer.TransDataType(args, result), UNSUPPORTED); | |||
| EXPECT_EQ(transfer.TransDataType(args, result), ACL_ERROR_GE_DATATYPE_INVALID); | |||
| } | |||
| /* | |||
| @@ -386,8 +386,8 @@ TEST_F(UtestDataTypeTransfer, unsupprot_trans) { | |||
| TransResult result; | |||
| DataTypeTransfer transfer; | |||
| EXPECT_EQ(transfer.TransDataType(args, result), UNSUPPORTED); | |||
| EXPECT_EQ(TransDataType(args, result), UNSUPPORTED); | |||
| EXPECT_EQ(transfer.TransDataType(args, result), ACL_ERROR_GE_DATATYPE_INVALID); | |||
| EXPECT_EQ(TransDataType(args, result), ACL_ERROR_GE_DATATYPE_INVALID); | |||
| } | |||
| TEST_F(UtestDataTypeTransfer, unsupprot_trans2) { | |||
| @@ -396,8 +396,8 @@ TEST_F(UtestDataTypeTransfer, unsupprot_trans2) { | |||
| TransResult result; | |||
| DataTypeTransfer transfer; | |||
| EXPECT_EQ(transfer.TransDataType(args, result), UNSUPPORTED); | |||
| EXPECT_EQ(TransDataType(args, result), UNSUPPORTED); | |||
| EXPECT_EQ(transfer.TransDataType(args, result), ACL_ERROR_GE_DATATYPE_INVALID); | |||
| EXPECT_EQ(TransDataType(args, result), ACL_ERROR_GE_DATATYPE_INVALID); | |||
| } | |||
| } // namespace formats | |||
| } // namespace ge | |||
| @@ -679,7 +679,7 @@ TEST_F(UtestFormatTransfer5dNhwc, nc1hwc0_to_nhwc_float2) { | |||
| } | |||
| Status status = | |||
| transfer.TransShape(args.src_format, args.src_shape, args.src_data_type, args.dst_format, args.dst_shape); | |||
| EXPECT_EQ(status, ACL_ERROR_GE_TRANSSHAPE_FORMAT_INVALID); | |||
| EXPECT_EQ(status, ACL_ERROR_GE_FORMAT_INVALID); | |||
| } | |||
| TEST_F(UtestFormatTransfer5dNhwc, invalid_src_format) { | |||
| @@ -689,7 +689,7 @@ TEST_F(UtestFormatTransfer5dNhwc, invalid_src_format) { | |||
| TransResult result; | |||
| FormatTransferNc1hwc0Nhwc transfer; | |||
| EXPECT_EQ(transfer.TransFormat(args, result), PARAM_INVALID); | |||
| EXPECT_EQ(transfer.TransFormat(args, result), ACL_ERROR_GE_FORMAT_INVALID); | |||
| } | |||
| TEST_F(UtestFormatTransfer5dNhwc, invalid_src_shape1) { | |||
| @@ -699,7 +699,7 @@ TEST_F(UtestFormatTransfer5dNhwc, invalid_src_shape1) { | |||
| TransResult result; | |||
| FormatTransferNc1hwc0Nhwc transfer; | |||
| EXPECT_EQ(transfer.TransFormat(args, result), PARAM_INVALID); | |||
| EXPECT_EQ(transfer.TransFormat(args, result), ACL_ERROR_GE_SHAPE_INVALID); | |||
| } | |||
| TEST_F(UtestFormatTransfer5dNhwc, InvalidSrcShape2) { | |||
| @@ -709,7 +709,7 @@ TEST_F(UtestFormatTransfer5dNhwc, InvalidSrcShape2) { | |||
| TransResult result; | |||
| FormatTransferNc1hwc0Nhwc transfer; | |||
| EXPECT_EQ(transfer.TransFormat(args, result), PARAM_INVALID); | |||
| EXPECT_EQ(transfer.TransFormat(args, result), ACL_ERROR_GE_SHAPE_INVALID); | |||
| } | |||
| TEST_F(UtestFormatTransfer5dNhwc, invalid_src_data_type) { | |||
| @@ -719,7 +719,7 @@ TEST_F(UtestFormatTransfer5dNhwc, invalid_src_data_type) { | |||
| TransResult result; | |||
| FormatTransferNc1hwc0Nhwc transfer; | |||
| EXPECT_EQ(transfer.TransFormat(args, result), PARAM_INVALID); | |||
| EXPECT_EQ(transfer.TransFormat(args, result), ACL_ERROR_GE_DATATYPE_INVALID); | |||
| } | |||
| TEST_F(UtestFormatTransfer5dNhwc, invalid_dst_format) { | |||
| @@ -729,7 +729,7 @@ TEST_F(UtestFormatTransfer5dNhwc, invalid_dst_format) { | |||
| TransResult result; | |||
| FormatTransferNc1hwc0Nhwc transfer; | |||
| EXPECT_EQ(transfer.TransFormat(args, result), PARAM_INVALID); | |||
| EXPECT_EQ(transfer.TransFormat(args, result), ACL_ERROR_GE_FORMAT_INVALID); | |||
| } | |||
| TEST_F(UtestFormatTransfer5dNhwc, invalid_dst_shape1) { | |||
| @@ -739,7 +739,7 @@ TEST_F(UtestFormatTransfer5dNhwc, invalid_dst_shape1) { | |||
| TransResult result; | |||
| FormatTransferNc1hwc0Nhwc transfer; | |||
| EXPECT_EQ(transfer.TransFormat(args, result), PARAM_INVALID); | |||
| EXPECT_EQ(transfer.TransFormat(args, result), ACL_ERROR_GE_SHAPE_INVALID); | |||
| } | |||
| TEST_F(UtestFormatTransfer5dNhwc, invalid_dst_shape2) { | |||
| @@ -749,7 +749,7 @@ TEST_F(UtestFormatTransfer5dNhwc, invalid_dst_shape2) { | |||
| TransResult result; | |||
| FormatTransferNc1hwc0Nhwc transfer; | |||
| EXPECT_EQ(transfer.TransFormat(args, result), PARAM_INVALID); | |||
| EXPECT_EQ(transfer.TransFormat(args, result), ACL_ERROR_GE_SHAPE_INVALID); | |||
| } | |||
| TEST_F(UtestFormatTransfer5dNhwc, invalid_src_dst_shape_relation) { | |||
| @@ -759,7 +759,7 @@ TEST_F(UtestFormatTransfer5dNhwc, invalid_src_dst_shape_relation) { | |||
| TransResult result; | |||
| FormatTransferNc1hwc0Nhwc transfer; | |||
| EXPECT_EQ(transfer.TransFormat(args, result), PARAM_INVALID); | |||
| EXPECT_EQ(transfer.TransFormat(args, result), ACL_ERROR_GE_SHAPE_INVALID); | |||
| } | |||
| } // namespace formats | |||
| } // namespace ge | |||
| @@ -39,7 +39,7 @@ TEST_F(UtestFormatTransferC1hwncoc0Hwcn, sixd_to_hwcn_invalid_data_type_uint8) { | |||
| TransResult result; | |||
| FormatTransferC1hwncoc0Hwcn transfer; | |||
| EXPECT_EQ(transfer.TransFormat(args, result), PARAM_INVALID); | |||
| EXPECT_EQ(transfer.TransFormat(args, result), ACL_ERROR_GE_DATATYPE_INVALID); | |||
| } | |||
| TEST_F(UtestFormatTransferC1hwncoc0Hwcn, sixd_to_hwcn_invalid_data_type_int32) { | |||
| @@ -50,7 +50,7 @@ TEST_F(UtestFormatTransferC1hwncoc0Hwcn, sixd_to_hwcn_invalid_data_type_int32) { | |||
| reinterpret_cast<uint8_t *>(data), FORMAT_C1HWNCoC0, FORMAT_HWCN, {1, 4, 4, 1, 16, 16}, {4, 4, 3, 1}, DT_INT32}; | |||
| TransResult result; | |||
| EXPECT_EQ(transfer.TransFormat(args, result), PARAM_INVALID); | |||
| EXPECT_EQ(transfer.TransFormat(args, result), ACL_ERROR_GE_DATATYPE_INVALID); | |||
| } | |||
| TEST_F(UtestFormatTransferC1hwncoc0Hwcn, sixd_to_hwcn_invalid_src_format_nc1khkwhwc0) { | |||
| @@ -61,7 +61,7 @@ TEST_F(UtestFormatTransferC1hwncoc0Hwcn, sixd_to_hwcn_invalid_src_format_nc1khkw | |||
| reinterpret_cast<uint8_t *>(data), FORMAT_NC1KHKWHWC0, FORMAT_HWCN, {1, 4, 4, 1, 16, 16}, {4, 4, 3, 1}, DT_FLOAT}; | |||
| TransResult result; | |||
| EXPECT_EQ(transfer.TransFormat(args, result), PARAM_INVALID); | |||
| EXPECT_EQ(transfer.TransFormat(args, result), ACL_ERROR_GE_FORMAT_INVALID); | |||
| } | |||
| TEST_F(UtestFormatTransferC1hwncoc0Hwcn, sixd_to_hwcn_invalid_dst_format_nchw) { | |||
| @@ -72,7 +72,7 @@ TEST_F(UtestFormatTransferC1hwncoc0Hwcn, sixd_to_hwcn_invalid_dst_format_nchw) { | |||
| reinterpret_cast<uint8_t *>(data), FORMAT_C1HWNCoC0, FORMAT_NCHW, {1, 4, 4, 1, 16, 16}, {4, 4, 3, 1}, DT_FLOAT}; | |||
| TransResult result; | |||
| EXPECT_EQ(transfer.TransFormat(args, result), PARAM_INVALID); | |||
| EXPECT_EQ(transfer.TransFormat(args, result), ACL_ERROR_GE_FORMAT_INVALID); | |||
| } | |||
| TEST_F(UtestFormatTransferC1hwncoc0Hwcn, sixd_to_hwcn_invalid_src_shape) { | |||
| @@ -83,7 +83,7 @@ TEST_F(UtestFormatTransferC1hwncoc0Hwcn, sixd_to_hwcn_invalid_src_shape) { | |||
| reinterpret_cast<uint8_t *>(data), FORMAT_C1HWNCoC0, FORMAT_HWCN, {1, 4, 4, 1, 16}, {4, 4, 3, 1}, DT_FLOAT}; | |||
| TransResult result; | |||
| EXPECT_EQ(transfer.TransFormat(args, result), PARAM_INVALID); | |||
| EXPECT_EQ(transfer.TransFormat(args, result), ACL_ERROR_GE_SHAPE_INVALID); | |||
| } | |||
| TEST_F(UtestFormatTransferC1hwncoc0Hwcn, sixd_to_hwcn_invalid_src_shape2) { | |||
| @@ -94,7 +94,7 @@ TEST_F(UtestFormatTransferC1hwncoc0Hwcn, sixd_to_hwcn_invalid_src_shape2) { | |||
| reinterpret_cast<uint8_t *>(data), FORMAT_C1HWNCoC0, FORMAT_HWCN, {1, 4, 4, 1, 16, -16}, {4, 4, 3, 1}, DT_FLOAT}; | |||
| TransResult result; | |||
| EXPECT_EQ(transfer.TransFormat(args, result), PARAM_INVALID); | |||
| EXPECT_EQ(transfer.TransFormat(args, result), ACL_ERROR_GE_SHAPE_INVALID); | |||
| } | |||
| TEST_F(UtestFormatTransferC1hwncoc0Hwcn, sixd_to_hwcn_invali_dst_shape) { | |||
| @@ -105,7 +105,7 @@ TEST_F(UtestFormatTransferC1hwncoc0Hwcn, sixd_to_hwcn_invali_dst_shape) { | |||
| reinterpret_cast<uint8_t *>(data), FORMAT_C1HWNCoC0, FORMAT_HWCN, {1, 4, 4, 1, 16, 16}, {4, 4, 3}, DT_FLOAT}; | |||
| TransResult result; | |||
| EXPECT_EQ(transfer.TransFormat(args, result), PARAM_INVALID); | |||
| EXPECT_EQ(transfer.TransFormat(args, result), ACL_ERROR_GE_SHAPE_INVALID); | |||
| } | |||
| TEST_F(UtestFormatTransferC1hwncoc0Hwcn, sixd_to_hwcn_invalid_dst_shape2) { | |||
| @@ -116,7 +116,7 @@ TEST_F(UtestFormatTransferC1hwncoc0Hwcn, sixd_to_hwcn_invalid_dst_shape2) { | |||
| reinterpret_cast<uint8_t *>(data), FORMAT_C1HWNCoC0, FORMAT_HWCN, {1, 4, 4, 1, 16, 16}, {4, 4, 3, -1}, DT_FLOAT}; | |||
| TransResult result; | |||
| EXPECT_EQ(transfer.TransFormat(args, result), PARAM_INVALID); | |||
| EXPECT_EQ(transfer.TransFormat(args, result), ACL_ERROR_GE_SHAPE_INVALID); | |||
| } | |||
| TEST_F(UtestFormatTransferC1hwncoc0Hwcn, sixd_to_hwcn_invalid_src_dst_shape_relation) { | |||
| @@ -127,7 +127,7 @@ TEST_F(UtestFormatTransferC1hwncoc0Hwcn, sixd_to_hwcn_invalid_src_dst_shape_rela | |||
| reinterpret_cast<uint8_t *>(data), FORMAT_C1HWNCoC0, FORMAT_HWCN, {1, 4, 4, 1, 16, 16}, {4, 4, 17, 1}, DT_FLOAT}; | |||
| TransResult result; | |||
| EXPECT_EQ(transfer.TransFormat(args, result), PARAM_INVALID); | |||
| EXPECT_EQ(transfer.TransFormat(args, result), ACL_ERROR_GE_SHAPE_INVALID); | |||
| } | |||
| TEST_F(UtestFormatTransferC1hwncoc0Hwcn, sixd_to_hwcn_fp16_success_lt_cube) { | |||
| @@ -158,7 +158,7 @@ TEST_F(UtestFormatTransferC1hwncoc0Hwcn, sixd_to_hwcn_fp16_success_lt_cube) { | |||
| } | |||
| Status status = | |||
| transfer.TransShape(args.src_format, args.src_shape, args.src_data_type, args.dst_format, args.dst_shape); | |||
| EXPECT_EQ(status, ACL_ERROR_GE_TRANSSHAPE_FORMAT_INVALID); | |||
| EXPECT_EQ(status, ACL_ERROR_GE_FORMAT_INVALID); | |||
| } | |||
| TEST_F(UtestFormatTransferC1hwncoc0Hwcn, sixd_to_hwcn_gp16_success_eq_cube) { | |||
| @@ -2332,7 +2332,7 @@ TEST_F(UtestFormatTransferNdFractNz, nd_shape4_fp16) { | |||
| } | |||
| EXPECT_EQ( | |||
| transfer2.TransShape(args2.src_format, args2.src_shape, args2.src_data_type, args2.dst_format, args2.dst_shape), | |||
| ACL_ERROR_GE_TRANSSHAPE_FORMAT_INVALID); | |||
| ACL_ERROR_GE_FORMAT_INVALID); | |||
| } | |||
| TEST_F(UtestFormatTransferNdFractNz, nd_shape5_fp16) { | |||
| @@ -4785,7 +4785,7 @@ TEST_F(UtestFormatTransferNdFractNz, nd_shape4_fp32) { | |||
| EXPECT_EQ((reinterpret_cast<float *>(result2.data.get()))[i], data[i]); | |||
| } | |||
| EXPECT_EQ(transfer2.TransShape(args.src_format, args.src_shape, args.src_data_type, args.dst_format, args.dst_shape), | |||
| ACL_ERROR_GE_TRANSSHAPE_FORMAT_INVALID); | |||
| ACL_ERROR_GE_FORMAT_INVALID); | |||
| } | |||
| TEST_F(UtestFormatTransferNdFractNz, nchw_shape4_fp32) { | |||
| @@ -9058,9 +9058,9 @@ TEST_F(UtestFormatTransferNdFractNz, invalid_src_shape) { | |||
| reinterpret_cast<uint8_t *>(data), FORMAT_NHWC, FORMAT_FRACTAL_NZ, {1, 4, 4}, {1, 1, 1, 16, 16}, DT_FLOAT16}; | |||
| TransResult result; | |||
| FormatTransferFractalNz transfer; | |||
| EXPECT_EQ(transfer.TransFormat(args, result), PARAM_INVALID); | |||
| EXPECT_EQ(transfer.TransFormat(args, result), ACL_ERROR_GE_SHAPE_INVALID); | |||
| EXPECT_EQ(transfer.TransShape(args.src_format, args.src_shape, args.src_data_type, args.dst_format, args.dst_shape), | |||
| ACL_ERROR_GE_TRANSSHAPE_SHAPE_INVALID); | |||
| ACL_ERROR_GE_SHAPE_INVALID); | |||
| } | |||
| TEST_F(UtestFormatTransferNdFractNz, invalid_src_data_type) { | |||
| @@ -9078,9 +9078,9 @@ TEST_F(UtestFormatTransferNdFractNz, invalid_src_data_type) { | |||
| DT_UNDEFINED}; | |||
| TransResult result; | |||
| FormatTransferFractalNz transfer; | |||
| EXPECT_EQ(transfer.TransFormat(args, result), PARAM_INVALID); | |||
| EXPECT_EQ(transfer.TransFormat(args, result), ACL_ERROR_GE_DATATYPE_INVALID); | |||
| EXPECT_EQ(transfer.TransShape(args.src_format, args.src_shape, args.src_data_type, args.dst_format, args.dst_shape), | |||
| ACL_ERROR_GE_TRANSSHAPE_DATATYPE_INVALID); | |||
| ACL_ERROR_GE_DATATYPE_INVALID); | |||
| } | |||
| TEST_F(UtestFormatTransferNdFractNz, invalid_src_format) { | |||
| @@ -9093,9 +9093,9 @@ TEST_F(UtestFormatTransferNdFractNz, invalid_src_format) { | |||
| DT_FLOAT16}; | |||
| TransResult result; | |||
| FormatTransferFractalNz transfer; | |||
| EXPECT_EQ(transfer.TransFormat(args, result), PARAM_INVALID); | |||
| EXPECT_EQ(transfer.TransFormat(args, result), ACL_ERROR_GE_SHAPE_INVALID); | |||
| EXPECT_EQ(transfer.TransShape(args.src_format, args.src_shape, args.src_data_type, args.dst_format, args.dst_shape), | |||
| ACL_ERROR_GE_TRANSSHAPE_SHAPE_INVALID); | |||
| ACL_ERROR_GE_SHAPE_INVALID); | |||
| } | |||
| TEST_F(UtestFormatTransferNdFractNz, invalid_dst_shape) { | |||
| @@ -9104,7 +9104,7 @@ TEST_F(UtestFormatTransferNdFractNz, invalid_dst_shape) { | |||
| reinterpret_cast<uint8_t *>(data), FORMAT_NHWC, FORMAT_FRACTAL_NZ, {1, 1, 4, 4}, {1, 1, 16, 16}, DT_FLOAT16}; | |||
| TransResult result; | |||
| FormatTransferFractalNz transfer; | |||
| EXPECT_EQ(transfer.TransFormat(args, result), PARAM_INVALID); | |||
| EXPECT_EQ(transfer.TransFormat(args, result), ACL_ERROR_GE_SHAPE_INVALID); | |||
| EXPECT_EQ(transfer.TransShape(args.src_format, args.src_shape, args.src_data_type, args.dst_format, args.dst_shape), | |||
| SUCCESS); | |||
| } | |||
| @@ -9115,7 +9115,7 @@ TEST_F(UtestFormatTransferNdFractNz, invalid_dst_shape2) { | |||
| reinterpret_cast<uint8_t *>(data), FORMAT_FRACTAL_NZ, FORMAT_NHWC, {1, 1, 1, 1, 16, 16}, {1, 4, 4}, DT_FLOAT16}; | |||
| TransResult result; | |||
| FormatTransferFractalNzND transfer; | |||
| EXPECT_EQ(transfer.TransFormat(args, result), PARAM_INVALID); | |||
| EXPECT_EQ(transfer.TransFormat(args, result), ACL_ERROR_GE_SHAPE_INVALID); | |||
| } | |||
| TEST_F(UtestFormatTransferNdFractNz, invalid_src_data_type2) { | |||
| @@ -9133,7 +9133,7 @@ TEST_F(UtestFormatTransferNdFractNz, invalid_src_data_type2) { | |||
| DT_UNDEFINED}; | |||
| TransResult result; | |||
| FormatTransferFractalNzND transfer; | |||
| EXPECT_EQ(transfer.TransFormat(args, result), PARAM_INVALID); | |||
| EXPECT_EQ(transfer.TransFormat(args, result), ACL_ERROR_GE_DATATYPE_INVALID); | |||
| } | |||
| TEST_F(UtestFormatTransferNdFractNz, invalid_src_data_type3) { | |||
| @@ -9151,7 +9151,7 @@ TEST_F(UtestFormatTransferNdFractNz, invalid_src_data_type3) { | |||
| DT_VARIANT}; | |||
| TransResult result; | |||
| FormatTransferFractalNzND transfer; | |||
| EXPECT_EQ(transfer.TransFormat(args, result), PARAM_INVALID); | |||
| EXPECT_EQ(transfer.TransFormat(args, result), ACL_ERROR_GE_DATATYPE_INVALID); | |||
| } | |||
| TEST_F(UtestFormatTransferNdFractNz, invalid_dst_format2) { | |||
| @@ -9164,8 +9164,8 @@ TEST_F(UtestFormatTransferNdFractNz, invalid_dst_format2) { | |||
| DT_FLOAT16}; | |||
| TransResult result; | |||
| FormatTransferFractalNzND transfer; | |||
| EXPECT_EQ(transfer.TransFormat(args, result), PARAM_INVALID); | |||
| EXPECT_EQ(TransFormat(args, result), UNSUPPORTED); | |||
| EXPECT_EQ(transfer.TransFormat(args, result), ACL_ERROR_GE_SHAPE_INVALID); | |||
| EXPECT_EQ(TransFormat(args, result), ACL_ERROR_GE_FORMAT_INVALID); | |||
| } | |||
| TEST_F(UtestFormatTransferNdFractNz, invalid_src_shape2) { | |||
| @@ -9174,7 +9174,7 @@ TEST_F(UtestFormatTransferNdFractNz, invalid_src_shape2) { | |||
| reinterpret_cast<uint8_t *>(data), FORMAT_FRACTAL_NZ, FORMAT_NHWC, {1, 1, 16, 16}, {1, 1, 4, 4}, DT_FLOAT16}; | |||
| TransResult result; | |||
| FormatTransferFractalNzND transfer; | |||
| EXPECT_EQ(transfer.TransFormat(args, result), PARAM_INVALID); | |||
| EXPECT_EQ(transfer.TransFormat(args, result), ACL_ERROR_GE_SHAPE_INVALID); | |||
| } | |||
| TEST_F(UtestFormatTransferNdFractNz, invalid_src_dst_shape_relation) { | |||
| @@ -9187,7 +9187,7 @@ TEST_F(UtestFormatTransferNdFractNz, invalid_src_dst_shape_relation) { | |||
| DT_FLOAT16}; | |||
| TransResult result; | |||
| FormatTransferFractalNzND transfer; | |||
| EXPECT_EQ(transfer.TransFormat(args, result), PARAM_INVALID); | |||
| EXPECT_EQ(transfer.TransFormat(args, result), ACL_ERROR_GE_SHAPE_INVALID); | |||
| } | |||
| } // namespace formats | |||
| } // namespace ge | |||
| @@ -1894,7 +1894,7 @@ TEST_F(UtestFormatTransferNdFractZz, nd_shape4_fp16_1) { | |||
| } | |||
| EXPECT_EQ( | |||
| transfer2.TransShape(args2.src_format, args2.src_shape, args2.src_data_type, args2.dst_format, args2.dst_shape), | |||
| ACL_ERROR_GE_TRANSSHAPE_FORMAT_INVALID); | |||
| ACL_ERROR_GE_FORMAT_INVALID); | |||
| } | |||
| TEST_F(UtestFormatTransferNdFractZz, nd_shape4_fp16) { | |||
| @@ -2071,7 +2071,7 @@ TEST_F(UtestFormatTransferNdFractZz, nd_shape4_fp16) { | |||
| } | |||
| EXPECT_EQ( | |||
| transfer2.TransShape(args2.src_format, args2.src_shape, args2.src_data_type, args2.dst_format, args2.dst_shape), | |||
| ACL_ERROR_GE_TRANSSHAPE_FORMAT_INVALID); | |||
| ACL_ERROR_GE_FORMAT_INVALID); | |||
| } | |||
| TEST_F(UtestFormatTransferNdFractZz, nd_shape5_fp16) { | |||
| @@ -7877,9 +7877,9 @@ TEST_F(UtestFormatTransferNdFractZz, invalid_src_shape) { | |||
| reinterpret_cast<uint8_t *>(data), FORMAT_NHWC, FORMAT_FRACTAL_ZZ, {1, 4, 4}, {1, 1, 1, 16, 16}, DT_FLOAT16}; | |||
| TransResult result; | |||
| FormatTransferFractalZz transfer; | |||
| EXPECT_EQ(transfer.TransFormat(args, result), PARAM_INVALID); | |||
| EXPECT_EQ(transfer.TransFormat(args, result), ACL_ERROR_GE_SHAPE_INVALID); | |||
| EXPECT_EQ(transfer.TransShape(args.src_format, args.src_shape, args.src_data_type, args.dst_format, args.dst_shape), | |||
| ACL_ERROR_GE_TRANSSHAPE_SHAPE_INVALID); | |||
| ACL_ERROR_GE_SHAPE_INVALID); | |||
| } | |||
| TEST_F(UtestFormatTransferNdFractZz, invalid_src_data_type) { | |||
| @@ -7897,9 +7897,9 @@ TEST_F(UtestFormatTransferNdFractZz, invalid_src_data_type) { | |||
| DT_UNDEFINED}; | |||
| TransResult result; | |||
| FormatTransferFractalZz transfer; | |||
| EXPECT_EQ(transfer.TransFormat(args, result), PARAM_INVALID); | |||
| EXPECT_EQ(transfer.TransFormat(args, result), ACL_ERROR_GE_DATATYPE_INVALID); | |||
| EXPECT_EQ(transfer.TransShape(args.src_format, args.src_shape, args.src_data_type, args.dst_format, args.dst_shape), | |||
| ACL_ERROR_GE_TRANSSHAPE_DATATYPE_INVALID); | |||
| ACL_ERROR_GE_DATATYPE_INVALID); | |||
| } | |||
| TEST_F(UtestFormatTransferNdFractZz, invalid_src_format) { | |||
| @@ -7912,10 +7912,10 @@ TEST_F(UtestFormatTransferNdFractZz, invalid_src_format) { | |||
| DT_FLOAT16}; | |||
| TransResult result; | |||
| FormatTransferFractalZz transfer; | |||
| EXPECT_EQ(transfer.TransFormat(args, result), PARAM_INVALID); | |||
| EXPECT_EQ(transfer.TransFormat(args, result), ACL_ERROR_GE_SHAPE_INVALID); | |||
| EXPECT_EQ(transfer.TransShape(args.src_format, args.src_shape, args.src_data_type, args.dst_format, args.dst_shape), | |||
| ACL_ERROR_GE_TRANSSHAPE_SHAPE_INVALID); | |||
| EXPECT_EQ(TransFormat(args, result), UNSUPPORTED); | |||
| ACL_ERROR_GE_SHAPE_INVALID); | |||
| EXPECT_EQ(TransFormat(args, result), ACL_ERROR_GE_FORMAT_INVALID); | |||
| } | |||
| TEST_F(UtestFormatTransferNdFractZz, invalid_dst_shape) { | |||
| @@ -7924,7 +7924,7 @@ TEST_F(UtestFormatTransferNdFractZz, invalid_dst_shape) { | |||
| reinterpret_cast<uint8_t *>(data), FORMAT_NHWC, FORMAT_FRACTAL_ZZ, {1, 1, 4, 4}, {1, 1, 16, 16}, DT_FLOAT16}; | |||
| TransResult result; | |||
| FormatTransferFractalZz transfer; | |||
| EXPECT_EQ(transfer.TransFormat(args, result), PARAM_INVALID); | |||
| EXPECT_EQ(transfer.TransFormat(args, result), ACL_ERROR_GE_SHAPE_INVALID); | |||
| EXPECT_EQ(transfer.TransShape(args.src_format, args.src_shape, args.src_data_type, args.dst_format, args.dst_shape), | |||
| SUCCESS); | |||
| } | |||
| @@ -7935,7 +7935,7 @@ TEST_F(UtestFormatTransferNdFractZz, invalid_dst_shape2) { | |||
| reinterpret_cast<uint8_t *>(data), FORMAT_FRACTAL_ZZ, FORMAT_NHWC, {1, 1, 1, 1, 16, 16}, {1, 4, 4}, DT_FLOAT16}; | |||
| TransResult result; | |||
| FormatTransferFractalZzND transfer; | |||
| EXPECT_EQ(transfer.TransFormat(args, result), PARAM_INVALID); | |||
| EXPECT_EQ(transfer.TransFormat(args, result), ACL_ERROR_GE_SHAPE_INVALID); | |||
| } | |||
| TEST_F(UtestFormatTransferNdFractZz, invalid_src_data_type2) { | |||
| @@ -7953,7 +7953,7 @@ TEST_F(UtestFormatTransferNdFractZz, invalid_src_data_type2) { | |||
| DT_UNDEFINED}; | |||
| TransResult result; | |||
| FormatTransferFractalZzND transfer; | |||
| EXPECT_EQ(transfer.TransFormat(args, result), PARAM_INVALID); | |||
| EXPECT_EQ(transfer.TransFormat(args, result), ACL_ERROR_GE_DATATYPE_INVALID); | |||
| } | |||
| TEST_F(UtestFormatTransferNdFractZz, invalid_dst_format2) { | |||
| @@ -7966,8 +7966,8 @@ TEST_F(UtestFormatTransferNdFractZz, invalid_dst_format2) { | |||
| DT_FLOAT16}; | |||
| TransResult result; | |||
| FormatTransferFractalZzND transfer; | |||
| EXPECT_EQ(transfer.TransFormat(args, result), PARAM_INVALID); | |||
| EXPECT_EQ(TransFormat(args, result), UNSUPPORTED); | |||
| EXPECT_EQ(transfer.TransFormat(args, result), ACL_ERROR_GE_SHAPE_INVALID); | |||
| EXPECT_EQ(TransFormat(args, result), ACL_ERROR_GE_FORMAT_INVALID); | |||
| } | |||
| TEST_F(UtestFormatTransferNdFractZz, invalid_src_shape2) { | |||
| @@ -7976,7 +7976,7 @@ TEST_F(UtestFormatTransferNdFractZz, invalid_src_shape2) { | |||
| reinterpret_cast<uint8_t *>(data), FORMAT_FRACTAL_ZZ, FORMAT_NHWC, {1, 1, 16, 16}, {1, 1, 4, 4}, DT_FLOAT16}; | |||
| TransResult result; | |||
| FormatTransferFractalZzND transfer; | |||
| EXPECT_EQ(transfer.TransFormat(args, result), PARAM_INVALID); | |||
| EXPECT_EQ(transfer.TransFormat(args, result), ACL_ERROR_GE_SHAPE_INVALID); | |||
| } | |||
| TEST_F(UtestFormatTransferNdFractZz, invalid_src_dst_shape_relation) { | |||
| @@ -7989,7 +7989,7 @@ TEST_F(UtestFormatTransferNdFractZz, invalid_src_dst_shape_relation) { | |||
| DT_FLOAT16}; | |||
| TransResult result; | |||
| FormatTransferFractalZzND transfer; | |||
| EXPECT_EQ(transfer.TransFormat(args, result), PARAM_INVALID); | |||
| EXPECT_EQ(transfer.TransFormat(args, result), ACL_ERROR_GE_SHAPE_INVALID); | |||
| } | |||
| } // namespace formats | |||
| } // namespace ge | |||
| @@ -39,7 +39,7 @@ TEST_F(UtestFormatTransferFracZHwcn, fracz_to_hwcn_invalid_data_type_invalid_dat | |||
| TransResult result; | |||
| FormatTransferFracZHwcn transfer; | |||
| EXPECT_EQ(transfer.TransFormat(args, result), PARAM_INVALID); | |||
| EXPECT_EQ(transfer.TransFormat(args, result), ACL_ERROR_GE_DATATYPE_INVALID); | |||
| } | |||
| TEST_F(UtestFormatTransferFracZHwcn, fracz_to_hwcn_invalid_src_format_reserved) { | |||
| @@ -50,7 +50,7 @@ TEST_F(UtestFormatTransferFracZHwcn, fracz_to_hwcn_invalid_src_format_reserved) | |||
| reinterpret_cast<uint8_t *>(data), FORMAT_RESERVED, FORMAT_HWCN, {16, 1, 16, 16}, {4, 4, 1, 1}, DT_FLOAT}; | |||
| TransResult result; | |||
| EXPECT_EQ(transfer.TransFormat(args, result), PARAM_INVALID); | |||
| EXPECT_EQ(transfer.TransFormat(args, result), ACL_ERROR_GE_FORMAT_INVALID); | |||
| } | |||
| TEST_F(UtestFormatTransferFracZHwcn, fracz_to_hwcn_invalid_dst_format_reserved) { | |||
| @@ -61,7 +61,7 @@ TEST_F(UtestFormatTransferFracZHwcn, fracz_to_hwcn_invalid_dst_format_reserved) | |||
| reinterpret_cast<uint8_t *>(data), FORMAT_FRACTAL_Z, FORMAT_RESERVED, {16, 1, 16, 16}, {4, 4, 1, 1}, DT_FLOAT}; | |||
| TransResult result; | |||
| EXPECT_EQ(transfer.TransFormat(args, result), PARAM_INVALID); | |||
| EXPECT_EQ(transfer.TransFormat(args, result), ACL_ERROR_GE_FORMAT_INVALID); | |||
| } | |||
| TEST_F(UtestFormatTransferFracZHwcn, fracz_to_hwcn_invalid_src_shape) { | |||
| @@ -72,7 +72,7 @@ TEST_F(UtestFormatTransferFracZHwcn, fracz_to_hwcn_invalid_src_shape) { | |||
| reinterpret_cast<uint8_t *>(data), FORMAT_FRACTAL_Z, FORMAT_HWCN, {16, 1, 1, 16, 16}, {4, 4, 1, 1}, DT_FLOAT}; | |||
| TransResult result; | |||
| EXPECT_EQ(transfer.TransFormat(args, result), PARAM_INVALID); | |||
| EXPECT_EQ(transfer.TransFormat(args, result), ACL_ERROR_GE_SHAPE_INVALID); | |||
| } | |||
| TEST_F(UtestFormatTransferFracZHwcn, fracz_to_hwcn_invalid_src_shape2) { | |||
| @@ -83,7 +83,7 @@ TEST_F(UtestFormatTransferFracZHwcn, fracz_to_hwcn_invalid_src_shape2) { | |||
| reinterpret_cast<uint8_t *>(data), FORMAT_FRACTAL_Z, FORMAT_HWCN, {16, -1, 16, 16}, {4, 4, 1, 1}, DT_FLOAT}; | |||
| TransResult result; | |||
| EXPECT_EQ(transfer.TransFormat(args, result), PARAM_INVALID); | |||
| EXPECT_EQ(transfer.TransFormat(args, result), ACL_ERROR_GE_SHAPE_INVALID); | |||
| } | |||
| TEST_F(UtestFormatTransferFracZHwcn, fracz_to_hwcn_invalid_dst_shape) { | |||
| @@ -94,7 +94,7 @@ TEST_F(UtestFormatTransferFracZHwcn, fracz_to_hwcn_invalid_dst_shape) { | |||
| reinterpret_cast<uint8_t *>(data), FORMAT_FRACTAL_Z, FORMAT_HWCN, {16, 1, 16, 16}, {4, 4, 1}, DT_FLOAT}; | |||
| TransResult result; | |||
| EXPECT_EQ(transfer.TransFormat(args, result), PARAM_INVALID); | |||
| EXPECT_EQ(transfer.TransFormat(args, result), ACL_ERROR_GE_SHAPE_INVALID); | |||
| } | |||
| TEST_F(UtestFormatTransferFracZHwcn, fracz_to_hwcn_invalid_dst_shape2) { | |||
| @@ -105,7 +105,7 @@ TEST_F(UtestFormatTransferFracZHwcn, fracz_to_hwcn_invalid_dst_shape2) { | |||
| reinterpret_cast<uint8_t *>(data), FORMAT_FRACTAL_Z, FORMAT_HWCN, {16, 1, 16, 16}, {4, 4, -1, 1}, DT_FLOAT}; | |||
| TransResult result; | |||
| EXPECT_EQ(transfer.TransFormat(args, result), PARAM_INVALID); | |||
| EXPECT_EQ(transfer.TransFormat(args, result), ACL_ERROR_GE_SHAPE_INVALID); | |||
| } | |||
| TEST_F(UtestFormatTransferFracZHwcn, fracz_to_hwcn_invalid_src_dst_shape_relation1) { | |||
| @@ -116,7 +116,7 @@ TEST_F(UtestFormatTransferFracZHwcn, fracz_to_hwcn_invalid_src_dst_shape_relatio | |||
| reinterpret_cast<uint8_t *>(data), FORMAT_FRACTAL_Z, FORMAT_HWCN, {16, 1, 16, 16}, {4, 4, 17, 1}, DT_FLOAT}; | |||
| TransResult result; | |||
| EXPECT_EQ(transfer.TransFormat(args, result), PARAM_INVALID); | |||
| EXPECT_EQ(transfer.TransFormat(args, result), ACL_ERROR_GE_SHAPE_INVALID); | |||
| } | |||
| TEST_F(UtestFormatTransferFracZHwcn, fracz_to_hwcn_invalid_src_dst_shape_relation2) { | |||
| @@ -127,7 +127,7 @@ TEST_F(UtestFormatTransferFracZHwcn, fracz_to_hwcn_invalid_src_dst_shape_relatio | |||
| reinterpret_cast<uint8_t *>(data), FORMAT_FRACTAL_Z, FORMAT_HWCN, {16, 1, 16, 16}, {4, 4, 1, 17}, DT_FLOAT}; | |||
| TransResult result; | |||
| EXPECT_EQ(transfer.TransFormat(args, result), PARAM_INVALID); | |||
| EXPECT_EQ(transfer.TransFormat(args, result), ACL_ERROR_GE_SHAPE_INVALID); | |||
| } | |||
| TEST_F(UtestFormatTransferFracZHwcn, fracz_to_hwcn_fp16_success_lt_cube) { | |||
| @@ -302,7 +302,7 @@ TEST_F(UtestFormatTransferFracZHwcn, fracz_to_hwcn_fp16_success_eq_cube) { | |||
| } | |||
| Status status = | |||
| transfer.TransShape(args.src_format, args.src_shape, args.src_data_type, args.dst_format, args.dst_shape); | |||
| EXPECT_EQ(status, ACL_ERROR_GE_TRANSSHAPE_FORMAT_INVALID); | |||
| EXPECT_EQ(status, ACL_ERROR_GE_FORMAT_INVALID); | |||
| } | |||
| TEST_F(UtestFormatTransferFracZHwcn, fracz_to_hwcn_fp16_success_gt_cube) { | |||
| @@ -39,7 +39,7 @@ TEST_F(UtestFormatTransferFraczNchw, fracz_to_nchw_invalid_data_type) { | |||
| TransResult result; | |||
| FormatTransferFracZNchw transfer; | |||
| EXPECT_EQ(transfer.TransFormat(args, result), PARAM_INVALID); | |||
| EXPECT_EQ(transfer.TransFormat(args, result), ACL_ERROR_GE_DATATYPE_INVALID); | |||
| } | |||
| TEST_F(UtestFormatTransferFraczNchw, fracz_to_nchw_Invalid_src_format_reserved) { | |||
| @@ -50,7 +50,7 @@ TEST_F(UtestFormatTransferFraczNchw, fracz_to_nchw_Invalid_src_format_reserved) | |||
| reinterpret_cast<uint8_t *>(data), FORMAT_RESERVED, FORMAT_NCHW, {16, 1, 16, 16}, {1, 1, 4, 4}, DT_FLOAT}; | |||
| TransResult result; | |||
| EXPECT_EQ(transfer.TransFormat(args, result), PARAM_INVALID); | |||
| EXPECT_EQ(transfer.TransFormat(args, result), ACL_ERROR_GE_FORMAT_INVALID); | |||
| } | |||
| TEST_F(UtestFormatTransferFraczNchw, fracz_to_nchw_invalid_dst_format_reserved) { | |||
| @@ -61,7 +61,7 @@ TEST_F(UtestFormatTransferFraczNchw, fracz_to_nchw_invalid_dst_format_reserved) | |||
| reinterpret_cast<uint8_t *>(data), FORMAT_FRACTAL_Z, FORMAT_RESERVED, {16, 1, 16, 16}, {1, 1, 4, 4}, DT_FLOAT}; | |||
| TransResult result; | |||
| EXPECT_EQ(transfer.TransFormat(args, result), PARAM_INVALID); | |||
| EXPECT_EQ(transfer.TransFormat(args, result), ACL_ERROR_GE_FORMAT_INVALID); | |||
| } | |||
| TEST_F(UtestFormatTransferFraczNchw, fracz_to_nchw_invalid_src_shape) { | |||
| @@ -72,7 +72,7 @@ TEST_F(UtestFormatTransferFraczNchw, fracz_to_nchw_invalid_src_shape) { | |||
| reinterpret_cast<uint8_t *>(data), FORMAT_FRACTAL_Z, FORMAT_NCHW, {16, 1, 1, 16, 16}, {1, 1, 4, 4}, DT_FLOAT}; | |||
| TransResult result; | |||
| EXPECT_EQ(transfer.TransFormat(args, result), PARAM_INVALID); | |||
| EXPECT_EQ(transfer.TransFormat(args, result), ACL_ERROR_GE_SHAPE_INVALID); | |||
| } | |||
| TEST_F(UtestFormatTransferFraczNchw, fracz_to_nchw_invalid_src_shape2) { | |||
| @@ -83,7 +83,7 @@ TEST_F(UtestFormatTransferFraczNchw, fracz_to_nchw_invalid_src_shape2) { | |||
| reinterpret_cast<uint8_t *>(data), FORMAT_FRACTAL_Z, FORMAT_NCHW, {16, 1, -16, 16}, {1, 1, 4, 4}, DT_FLOAT}; | |||
| TransResult result; | |||
| EXPECT_EQ(transfer.TransFormat(args, result), PARAM_INVALID); | |||
| EXPECT_EQ(transfer.TransFormat(args, result), ACL_ERROR_GE_SHAPE_INVALID); | |||
| } | |||
| TEST_F(UtestFormatTransferFraczNchw, fracz_to_nchw_invalid_dst_shape) { | |||
| @@ -94,7 +94,7 @@ TEST_F(UtestFormatTransferFraczNchw, fracz_to_nchw_invalid_dst_shape) { | |||
| reinterpret_cast<uint8_t *>(data), FORMAT_FRACTAL_Z, FORMAT_NCHW, {16, 1, 16, 16}, {1, 4, 4}, DT_FLOAT}; | |||
| TransResult result; | |||
| EXPECT_EQ(transfer.TransFormat(args, result), PARAM_INVALID); | |||
| EXPECT_EQ(transfer.TransFormat(args, result), ACL_ERROR_GE_SHAPE_INVALID); | |||
| } | |||
| TEST_F(UtestFormatTransferFraczNchw, fracz_to_nchw_invalid_dst_shape2) { | |||
| @@ -105,7 +105,7 @@ TEST_F(UtestFormatTransferFraczNchw, fracz_to_nchw_invalid_dst_shape2) { | |||
| reinterpret_cast<uint8_t *>(data), FORMAT_FRACTAL_Z, FORMAT_NCHW, {16, 1, 16, 16}, {1, -1, 4, 4}, DT_FLOAT}; | |||
| TransResult result; | |||
| EXPECT_EQ(transfer.TransFormat(args, result), PARAM_INVALID); | |||
| EXPECT_EQ(transfer.TransFormat(args, result), ACL_ERROR_GE_SHAPE_INVALID); | |||
| } | |||
| TEST_F(UtestFormatTransferFraczNchw, fracz_to_nchw_invalid_src_dst_shape_relation1) { | |||
| @@ -116,7 +116,7 @@ TEST_F(UtestFormatTransferFraczNchw, fracz_to_nchw_invalid_src_dst_shape_relatio | |||
| reinterpret_cast<uint8_t *>(data), FORMAT_FRACTAL_Z, FORMAT_NCHW, {16, 1, 16, 16}, {1, 17, 4, 4}, DT_FLOAT}; | |||
| TransResult result; | |||
| EXPECT_EQ(transfer.TransFormat(args, result), PARAM_INVALID); | |||
| EXPECT_EQ(transfer.TransFormat(args, result), ACL_ERROR_GE_SHAPE_INVALID); | |||
| } | |||
| TEST_F(UtestFormatTransferFraczNchw, fracz_to_nchw_invalid_src_dst_shape_relation2) { | |||
| @@ -127,7 +127,7 @@ TEST_F(UtestFormatTransferFraczNchw, fracz_to_nchw_invalid_src_dst_shape_relatio | |||
| reinterpret_cast<uint8_t *>(data), FORMAT_FRACTAL_Z, FORMAT_NCHW, {16, 1, 16, 16}, {17, 1, 4, 4}, DT_FLOAT}; | |||
| TransResult result; | |||
| EXPECT_EQ(transfer.TransFormat(args, result), PARAM_INVALID); | |||
| EXPECT_EQ(transfer.TransFormat(args, result), ACL_ERROR_GE_SHAPE_INVALID); | |||
| } | |||
| TEST_F(UtestFormatTransferFraczNchw, fracz_to_nchw_fp16_success_lt_cube) { | |||
| @@ -302,7 +302,7 @@ TEST_F(UtestFormatTransferFraczNchw, fracz_to_nchw_fp16_success_eq_cube) { | |||
| } | |||
| Status status = | |||
| transfer.TransShape(args.src_format, args.src_shape, args.src_data_type, args.dst_format, args.dst_shape); | |||
| EXPECT_EQ(status, ACL_ERROR_GE_TRANSSHAPE_FORMAT_INVALID); | |||
| EXPECT_EQ(status, ACL_ERROR_GE_FORMAT_INVALID); | |||
| } | |||
| TEST_F(UtestFormatTransferFraczNchw, fracz_to_nchw_fp16_success_gt_cube) { | |||
| @@ -42,7 +42,7 @@ TEST_F(UtestFormatTransferHwcnC1hwncoc0, hwcn_to_6d_invalid_data_type_uint8) { | |||
| TransResult result; | |||
| FormatTransferHwcnC1hwncoc0 transfer; | |||
| EXPECT_EQ(transfer.TransFormat(args, result), PARAM_INVALID); | |||
| EXPECT_EQ(transfer.TransFormat(args, result), ACL_ERROR_GE_DATATYPE_INVALID); | |||
| } | |||
| TEST_F(UtestFormatTransferHwcnC1hwncoc0, hwcn_to_6d_invalid_data_type_int32) { | |||
| @@ -57,7 +57,7 @@ TEST_F(UtestFormatTransferHwcnC1hwncoc0, hwcn_to_6d_invalid_data_type_int32) { | |||
| reinterpret_cast<uint8_t *>(data), FORMAT_HWCN, FORMAT_C1HWNCoC0, {4, 4, 3, 1}, {1, 4, 4, 1, 16, 16}, DT_INT32}; | |||
| TransResult result; | |||
| EXPECT_EQ(transfer.TransFormat(args, result), PARAM_INVALID); | |||
| EXPECT_EQ(transfer.TransFormat(args, result), ACL_ERROR_GE_DATATYPE_INVALID); | |||
| } | |||
| TEST_F(UtestFormatTransferHwcnC1hwncoc0, hwcn_to_6d_invalid_src_format_nchw) { | |||
| @@ -72,10 +72,10 @@ TEST_F(UtestFormatTransferHwcnC1hwncoc0, hwcn_to_6d_invalid_src_format_nchw) { | |||
| reinterpret_cast<uint8_t *>(data), FORMAT_NCHW, FORMAT_C1HWNCoC0, {4, 4, 3, 1}, {1, 4, 4, 1, 16, 16}, DT_FLOAT}; | |||
| TransResult result; | |||
| EXPECT_EQ(transfer.TransFormat(args, result), PARAM_INVALID); | |||
| EXPECT_EQ(transfer.TransFormat(args, result), ACL_ERROR_GE_FORMAT_INVALID); | |||
| Status status = | |||
| transfer.TransShape(args.src_format, args.src_shape, args.src_data_type, args.dst_format, args.dst_shape); | |||
| EXPECT_EQ(status, ACL_ERROR_GE_TRANSSHAPE_FORMAT_INVALID); | |||
| EXPECT_EQ(status, ACL_ERROR_GE_FORMAT_INVALID); | |||
| } | |||
| TEST_F(UtestFormatTransferHwcnC1hwncoc0, hwcn_to_6d_invalid_dst_format_nc1khkwhwc0) { | |||
| @@ -90,7 +90,7 @@ TEST_F(UtestFormatTransferHwcnC1hwncoc0, hwcn_to_6d_invalid_dst_format_nc1khkwhw | |||
| reinterpret_cast<uint8_t *>(data), FORMAT_HWCN, FORMAT_NC1KHKWHWC0, {4, 4, 3, 1}, {1, 4, 4, 1, 16, 16}, DT_FLOAT}; | |||
| TransResult result; | |||
| EXPECT_EQ(transfer.TransFormat(args, result), PARAM_INVALID); | |||
| EXPECT_EQ(transfer.TransFormat(args, result), ACL_ERROR_GE_FORMAT_INVALID); | |||
| } | |||
| TEST_F(UtestFormatTransferHwcnC1hwncoc0, hwcn_to_6d_invalid_src_shape) { | |||
| @@ -105,7 +105,7 @@ TEST_F(UtestFormatTransferHwcnC1hwncoc0, hwcn_to_6d_invalid_src_shape) { | |||
| reinterpret_cast<uint8_t *>(data), FORMAT_HWCN, FORMAT_NC1KHKWHWC0, {4, 4, 3}, {1, 4, 4, 1, 16, 16}, DT_FLOAT}; | |||
| TransResult result; | |||
| EXPECT_EQ(transfer.TransFormat(args, result), PARAM_INVALID); | |||
| EXPECT_EQ(transfer.TransFormat(args, result), ACL_ERROR_GE_FORMAT_INVALID); | |||
| } | |||
| TEST_F(UtestFormatTransferHwcnC1hwncoc0, hwcn_to_6d_invalid_src_shape2) { | |||
| @@ -120,7 +120,7 @@ TEST_F(UtestFormatTransferHwcnC1hwncoc0, hwcn_to_6d_invalid_src_shape2) { | |||
| reinterpret_cast<uint8_t *>(data), FORMAT_HWCN, FORMAT_C1HWNCoC0, {4, 4}, {1, 4, 4, 1, 16, 16}, DT_FLOAT}; | |||
| TransResult result; | |||
| EXPECT_EQ(transfer.TransFormat(args, result), PARAM_INVALID); | |||
| EXPECT_EQ(transfer.TransFormat(args, result), ACL_ERROR_GE_SHAPE_INVALID); | |||
| } | |||
| TEST_F(UtestFormatTransferHwcnC1hwncoc0, hwcn_to_6d_invalid_src_shape3) { | |||
| @@ -139,10 +139,10 @@ TEST_F(UtestFormatTransferHwcnC1hwncoc0, hwcn_to_6d_invalid_src_shape3) { | |||
| DT_FLOAT}; | |||
| TransResult result; | |||
| EXPECT_EQ(transfer.TransFormat(args, result), PARAM_INVALID); | |||
| EXPECT_EQ(transfer.TransFormat(args, result), ACL_ERROR_GE_FORMAT_INVALID); | |||
| Status status = | |||
| transfer.TransShape(args.src_format, args.src_shape, args.src_data_type, args.dst_format, args.dst_shape); | |||
| EXPECT_EQ(status, ACL_ERROR_GE_TRANSSHAPE_SHAPE_INVALID); | |||
| EXPECT_EQ(status, ACL_ERROR_GE_SHAPE_INVALID); | |||
| } | |||
| TEST_F(UtestFormatTransferHwcnC1hwncoc0, hwcn_to_6d_invalid_dst_format) { | |||
| @@ -157,7 +157,7 @@ TEST_F(UtestFormatTransferHwcnC1hwncoc0, hwcn_to_6d_invalid_dst_format) { | |||
| reinterpret_cast<uint8_t *>(data), FORMAT_HWCN, FORMAT_NC1KHKWHWC0, {4, 4, 3, 1}, {1, 1, 4, 4, 16, 16}, DT_FLOAT}; | |||
| TransResult result; | |||
| EXPECT_EQ(transfer.TransFormat(args, result), PARAM_INVALID); | |||
| EXPECT_EQ(transfer.TransFormat(args, result), ACL_ERROR_GE_FORMAT_INVALID); | |||
| } | |||
| TEST_F(UtestFormatTransferHwcnC1hwncoc0, hwcn_to_6d_invalid_dst_shape2) { | |||
| @@ -172,7 +172,7 @@ TEST_F(UtestFormatTransferHwcnC1hwncoc0, hwcn_to_6d_invalid_dst_shape2) { | |||
| reinterpret_cast<uint8_t *>(data), FORMAT_HWCN, FORMAT_C1HWNCoC0, {4, 4, 3, 1}, {2, 4, 4, 1, 16, 16}, DT_FLOAT}; | |||
| TransResult result; | |||
| EXPECT_EQ(transfer.TransFormat(args, result), PARAM_INVALID); | |||
| EXPECT_EQ(transfer.TransFormat(args, result), ACL_ERROR_GE_SHAPE_INVALID); | |||
| } | |||
| TEST_F(UtestFormatTransferHwcnC1hwncoc0, hwcn_to_6d_fp16_success_lt_cube) { | |||
| @@ -640,7 +640,7 @@ TEST_F(UtestFormatTransferNchw5d, invalid_data_format) { | |||
| reinterpret_cast<uint8_t *>(data), FORMAT_NHWC, FORMAT_FRACTAL_Z, {1, 4, 4}, {1, 1, 1, 16, 16}, DT_FLOAT16}; | |||
| FormatTransferNchwNc1hwc0 transfer; | |||
| EXPECT_EQ(transfer.TransShape(args.src_format, args.src_shape, args.src_data_type, args.dst_format, args.dst_shape), | |||
| ACL_ERROR_GE_TRANSSHAPE_FORMAT_INVALID); | |||
| ACL_ERROR_GE_FORMAT_INVALID); | |||
| } | |||
| } // namespace formats | |||
| } // namespace ge | |||
| @@ -691,7 +691,7 @@ TEST_F(UtestFormatTransferNhwc5d, invalid_src_shape1) { | |||
| TransResult result; | |||
| FormatTransferNhwcNc1hwc0 transfer; | |||
| EXPECT_EQ(transfer.TransFormat(args, result), PARAM_INVALID); | |||
| EXPECT_EQ(transfer.TransFormat(args, result), ACL_ERROR_GE_SHAPE_INVALID); | |||
| } | |||
| /* | |||
| @@ -716,10 +716,10 @@ TEST_F(UtestFormatTransferNhwc5d, invalid_src_format) { | |||
| TransResult result; | |||
| FormatTransferNhwcNc1hwc0 transfer; | |||
| EXPECT_EQ(transfer.TransFormat(args, result), PARAM_INVALID); | |||
| EXPECT_EQ(transfer.TransFormat(args, result), ACL_ERROR_GE_FORMAT_INVALID); | |||
| Status status = | |||
| transfer.TransShape(args.src_format, args.src_shape, args.src_data_type, args.dst_format, args.dst_shape); | |||
| EXPECT_EQ(status, ACL_ERROR_GE_TRANSSHAPE_FORMAT_INVALID); | |||
| EXPECT_EQ(status, ACL_ERROR_GE_FORMAT_INVALID); | |||
| } | |||
| TEST_F(UtestFormatTransferNhwc5d, invalid_dst_shape2) { | |||
| @@ -729,7 +729,7 @@ TEST_F(UtestFormatTransferNhwc5d, invalid_dst_shape2) { | |||
| TransResult result; | |||
| FormatTransferNhwcNc1hwc0 transfer; | |||
| EXPECT_EQ(transfer.TransFormat(args, result), PARAM_INVALID); | |||
| EXPECT_EQ(transfer.TransFormat(args, result), ACL_ERROR_GE_SHAPE_INVALID); | |||
| } | |||
| TEST_F(UtestFormatTransferNhwc5d, invalid_src_data_type) { | |||
| @@ -739,7 +739,7 @@ TEST_F(UtestFormatTransferNhwc5d, invalid_src_data_type) { | |||
| TransResult result; | |||
| FormatTransferNhwcNc1hwc0 transfer; | |||
| EXPECT_EQ(transfer.TransFormat(args, result), PARAM_INVALID); | |||
| EXPECT_EQ(transfer.TransFormat(args, result), ACL_ERROR_GE_DATATYPE_INVALID); | |||
| } | |||
| TEST_F(UtestFormatTransferNhwc5d, unsupport_dst_format) { | |||
| @@ -749,7 +749,7 @@ TEST_F(UtestFormatTransferNhwc5d, unsupport_dst_format) { | |||
| TransResult result; | |||
| FormatTransferNhwcNc1hwc0 transfer; | |||
| EXPECT_EQ(transfer.TransFormat(args, result), PARAM_INVALID); | |||
| EXPECT_EQ(transfer.TransFormat(args, result), ACL_ERROR_GE_FORMAT_INVALID); | |||
| } | |||
| TEST_F(UtestFormatTransferNhwc5d, invalid_data_shape) { | |||
| @@ -758,13 +758,13 @@ TEST_F(UtestFormatTransferNhwc5d, invalid_data_shape) { | |||
| reinterpret_cast<uint8_t *>(data), FORMAT_NHWC, FORMAT_FRACTAL_Z, {1, 4, 4}, {1, 1, 1, 16, 16}, DT_FLOAT16}; | |||
| FormatTransferNhwcNc1hwc0 transfer; | |||
| EXPECT_EQ(transfer.TransShape(args.src_format, args.src_shape, args.src_data_type, args.dst_format, args.dst_shape), | |||
| ACL_ERROR_GE_TRANSSHAPE_SHAPE_INVALID); | |||
| ACL_ERROR_GE_SHAPE_INVALID); | |||
| TransArgs args2{ | |||
| reinterpret_cast<uint8_t *>(data), FORMAT_NHWC, FORMAT_FRACTAL_Z, {1, 4, 4}, {1, 1, 1, 16, 16}, DT_STRING}; | |||
| FormatTransferNhwcNc1hwc0 transfer2; | |||
| EXPECT_EQ(transfer2.TransShape(args2.src_format, args2.src_shape, args2.src_data_type, args2.dst_format, args2.dst_shape), | |||
| ACL_ERROR_GE_TRANSSHAPE_DATATYPE_INVALID); | |||
| ACL_ERROR_GE_DATATYPE_INVALID); | |||
| } | |||
| } // namespace formats | |||
| } // namespace ge | |||
| @@ -5360,7 +5360,7 @@ TEST_F(UtestFormatTransferNhwcFz, invalid_data_type) { | |||
| reinterpret_cast<uint8_t *>(data), FORMAT_NHWC, FORMAT_FRACTAL_NZ, {1, 4, 4}, {1, 1, 1, 16, 16}, DT_VARIANT}; | |||
| FormatTransferFractalZ transfer; | |||
| EXPECT_EQ(transfer.TransShape(args.src_format, args.src_shape, args.src_data_type, args.dst_format, args.dst_shape), | |||
| ACL_ERROR_GE_TRANSSHAPE_DATATYPE_INVALID); | |||
| ACL_ERROR_GE_DATATYPE_INVALID); | |||
| } | |||
| TEST_F(UtestFormatTransferNhwcFz, invalid_data_format) { | |||
| @@ -5369,7 +5369,7 @@ TEST_F(UtestFormatTransferNhwcFz, invalid_data_format) { | |||
| reinterpret_cast<uint8_t *>(data), FORMAT_CHWN, FORMAT_FRACTAL_NZ, {1, 4, 4}, {1, 1, 1, 16, 16}, DT_FLOAT16}; | |||
| FormatTransferFractalZ transfer; | |||
| EXPECT_EQ(transfer.TransShape(args.src_format, args.src_shape, args.src_data_type, args.dst_format, args.dst_shape), | |||
| ACL_ERROR_GE_TRANSSHAPE_FORMAT_INVALID); | |||
| ACL_ERROR_GE_FORMAT_INVALID); | |||
| } | |||
| TEST_F(UtestFormatTransferNhwcFz, invalid_data_shape) { | |||
| @@ -5378,19 +5378,19 @@ TEST_F(UtestFormatTransferNhwcFz, invalid_data_shape) { | |||
| reinterpret_cast<uint8_t *>(data), FORMAT_NHWC, FORMAT_FRACTAL_Z, {1, 4, 4}, {1, 1, 1, 16, 16}, DT_FLOAT16}; | |||
| FormatTransferFractalZ transfer; | |||
| EXPECT_EQ(transfer.TransShape(args.src_format, args.src_shape, args.src_data_type, args.dst_format, args.dst_shape), | |||
| ACL_ERROR_GE_TRANSSHAPE_SHAPE_INVALID); | |||
| ACL_ERROR_GE_SHAPE_INVALID); | |||
| TransArgs args2{ | |||
| reinterpret_cast<uint8_t *>(data), FORMAT_HWCN, FORMAT_FRACTAL_Z, {1, 4, 4}, {1, 1, 1, 16, 16}, DT_FLOAT16}; | |||
| FormatTransferFractalZ transfer2; | |||
| EXPECT_EQ(transfer2.TransShape(args2.src_format, args2.src_shape, args2.src_data_type, args2.dst_format, args2.dst_shape), | |||
| ACL_ERROR_GE_TRANSSHAPE_SHAPE_INVALID); | |||
| ACL_ERROR_GE_SHAPE_INVALID); | |||
| TransArgs args3{ | |||
| reinterpret_cast<uint8_t *>(data), FORMAT_NCHW, FORMAT_FRACTAL_Z, {1, 4, 4}, {1, 1, 1, 16, 16}, DT_FLOAT16}; | |||
| FormatTransferFractalZ transfer3; | |||
| EXPECT_EQ(transfer3.TransShape(args3.src_format, args3.src_shape, args3.src_data_type, args3.dst_format, args3.dst_shape), | |||
| ACL_ERROR_GE_TRANSSHAPE_SHAPE_INVALID); | |||
| ACL_ERROR_GE_SHAPE_INVALID); | |||
| } | |||
| } // namespace formats | |||
| } // namespace ge | |||
| @@ -4659,14 +4659,14 @@ TEST_F(UtestFormatTranspose, invalid_data_shape) { | |||
| FormatTransferTranspose transfer; | |||
| std::vector<int64_t> dst_shape; | |||
| EXPECT_EQ(transfer.TransShape(FORMAT_NCHW, std::vector<int64_t>({}), DT_FLOAT16, FORMAT_HWCN, dst_shape), | |||
| ACL_ERROR_GE_TRANSSHAPE_SHAPE_INVALID); | |||
| ACL_ERROR_GE_SHAPE_INVALID); | |||
| } | |||
| TEST_F(UtestFormatTranspose, invalid_src_format) { | |||
| FormatTransferTranspose transfer; | |||
| std::vector<int64_t> dst_shape; | |||
| EXPECT_EQ(transfer.TransShape(FORMAT_NC1HWC0, std::vector<int64_t>({1, 3, 8, 8}), DT_FLOAT16, FORMAT_HWCN, dst_shape), | |||
| ACL_ERROR_GE_TRANSSHAPE_FORMAT_INVALID); | |||
| ACL_ERROR_GE_FORMAT_INVALID); | |||
| } | |||
| TEST_F(UtestFormatTranspose, invalid_dst_format) { | |||
| @@ -4674,7 +4674,7 @@ TEST_F(UtestFormatTranspose, invalid_dst_format) { | |||
| std::vector<int64_t> dst_shape; | |||
| std::vector<int64_t> src_shape; | |||
| EXPECT_EQ(transfer.TransShape(FORMAT_NCHW, src_shape, DT_FLOAT16, FORMAT_C1HWNC0, dst_shape), | |||
| ACL_ERROR_GE_TRANSSHAPE_FORMAT_INVALID); | |||
| ACL_ERROR_GE_FORMAT_INVALID); | |||
| } | |||
| } // namespace formats | |||
| } // namespace ge | |||
| @@ -31,6 +31,7 @@ class UtestGeGenerator : public testing::Test { | |||
| void TearDown() {} | |||
| }; | |||
| /* | |||
| TEST_F(UtestGeGenerator, test_build_single_op_offline) { | |||
| GeTensorDesc tensor_desc(GeShape(), FORMAT_NCHW, DT_FLOAT); | |||
| TensorUtils::SetSize(tensor_desc, 512); | |||
| @@ -52,27 +53,22 @@ TEST_F(UtestGeGenerator, test_build_single_op_offline) { | |||
| generator.Initialize({}); | |||
| EXPECT_EQ(generator.BuildSingleOpModel(op_desc, inputs, outputs, "offline_"), GE_GENERATOR_GRAPH_MANAGER_BUILD_GRAPH_FAILED); | |||
| } | |||
| */ | |||
| /* | |||
| TEST_F(UtestGeGenerator, test_build_single_op_online) { | |||
| GeTensorDesc tensor_desc(GeShape(), FORMAT_NCHW, DT_FLOAT); | |||
| TensorUtils::SetSize(tensor_desc, 512); | |||
| GeTensorDesc tensor_desc; | |||
| shared_ptr<OpDesc> op_desc = make_shared<OpDesc>("Add", "add"); | |||
| EXPECT_EQ(op_desc->AddInputDesc(tensor_desc), GRAPH_SUCCESS); | |||
| EXPECT_EQ(op_desc->AddInputDesc(tensor_desc), GRAPH_SUCCESS); | |||
| EXPECT_EQ(op_desc->AddOutputDesc(tensor_desc), GRAPH_SUCCESS); | |||
| op_desc->AddInputDesc(tensor_desc); | |||
| op_desc->AddInputDesc(tensor_desc); | |||
| op_desc->AddOutputDesc(tensor_desc); | |||
| GeTensor tensor(tensor_desc); | |||
| const vector<GeTensor> inputs = { tensor, tensor }; | |||
| const vector<GeTensor> outputs = { tensor }; | |||
| // not Initialize, impl is null. | |||
| GeGenerator generator; | |||
| generator.Initialize({}); | |||
| ModelBufferData model_buffer; | |||
| EXPECT_EQ(generator.BuildSingleOpModel(op_desc, inputs, outputs, ENGINE_SYS, model_buffer), GE_GENERATOR_GRAPH_MANAGER_BUILD_GRAPH_FAILED); | |||
| EXPECT_EQ(generator.BuildSingleOpModel(op_desc, inputs, outputs, ENGINE_AIVECTOR, model_buffer), FAILED); | |||
| } | |||
| */ | |||
| } // namespace ge | |||
| @@ -0,0 +1,67 @@ | |||
| /** | |||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include <vector> | |||
| #include <gtest/gtest.h> | |||
| #define protected public | |||
| #define private public | |||
| #include "graph/passes/transpose_transdata_pass.h" | |||
| #include "graph_builder_utils.h" | |||
| #undef private | |||
| #undef protected | |||
| #include "graph/graph.h" | |||
| #include "common/ge_inner_error_codes.h" | |||
| #include "common/types.h" | |||
| #include "graph/debug/ge_attr_define.h" | |||
| namespace ge { | |||
| class UtestGraphPassesTransposeTransdataPass : public testing::Test { | |||
| protected: | |||
| void SetUp() {} | |||
| void TearDown() {} | |||
| }; | |||
| static ComputeGraphPtr BuildGraphTransposeD() { | |||
| auto builder = ut::GraphBuilder("g1"); | |||
| auto transdata1 = builder.AddNode("transdata1", "TransData", 1, 1, FORMAT_NC1HWC0, DT_FLOAT, std::vector<int64_t>({1, 1, 224, 224, 16})); | |||
| transdata1->GetOpDesc()->MutableOutputDesc(0)->SetFormat(FORMAT_NHWC); | |||
| transdata1->GetOpDesc()->MutableOutputDesc(0)->SetShape(GeShape(std::vector<int64_t>({1, 224, 224, 3}))); | |||
| auto transpose1 = builder.AddNode("transpose1", "TransposeD", 1, 1, FORMAT_NCHW, DT_FLOAT, std::vector<int64_t>({1, 3, 224, 224})); | |||
| transpose1->GetOpDesc()->MutableInputDesc(0)->SetFormat(FORMAT_NHWC); | |||
| transpose1->GetOpDesc()->MutableInputDesc(0)->SetShape(GeShape(std::vector<int64_t>({1, 224, 224, 3}))); | |||
| auto transdata2 = builder.AddNode("transdata2", "TransData", 1, 1, FORMAT_NCHW, DT_FLOAT, std::vector<int64_t>({1, 3, 224, 224})); | |||
| transdata2->GetOpDesc()->MutableOutputDesc(0)->SetFormat(FORMAT_NC1HWC0); | |||
| transdata2->GetOpDesc()->MutableOutputDesc(0)->SetShape(GeShape(std::vector<int64_t>({1, 1, 224, 224, 16}))); | |||
| builder.AddDataEdge(transdata1, 0, transpose1, 0); | |||
| builder.AddDataEdge(transpose1, 0, transdata2, 0); | |||
| return builder.GetGraph(); | |||
| } | |||
| TEST_F(UtestGraphPassesTransposeTransdataPass, test_run) { | |||
| auto compute_graph = BuildGraphTransposeD(); | |||
| compute_graph->SetSessionID(0); | |||
| auto transpose = compute_graph->FindNode("transpose1"); | |||
| TransposeTransDataPass pass; | |||
| EXPECT_EQ(pass.Run(transpose), SUCCESS); | |||
| } | |||
| } // namespace ge | |||