|
|
|
@@ -143,7 +143,7 @@ Status TransShapeHwcnToFz(const std::vector<int64_t> &src_shape, DataType data_t |
|
|
|
} |
|
|
|
|
|
|
|
TransShapeHwcnToFzWithGroups(const std::vector<int64_t> &src_shape, DataType data_type, std::vector<int64_t> &dst_shape |
|
|
|
, args.groups){ |
|
|
|
, groups){ |
|
|
|
if (!CheckShapeValid(src_shape, kHwcnDimsNum)) { |
|
|
|
return ACL_ERROR_GE_TRANSSHAPE_SHAPE_INVALID; |
|
|
|
} |
|
|
|
@@ -153,7 +153,7 @@ TransShapeHwcnToFzWithGroups(const std::vector<int64_t> &src_shape, DataType dat |
|
|
|
auto c = src_shape.at(kHwcnC); |
|
|
|
auto n = src_shape.at(kHwcnN); |
|
|
|
|
|
|
|
return TransShapeToFzWithGroups(n, c, h, w, data_type, dst_shape, args.groups); |
|
|
|
return TransShapeToFzWithGroups(n, c, h, w, data_type, dst_shape, groups); |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
@@ -258,8 +258,7 @@ Status TransFormatFromNchwToFz(const TransArgs &args, TransResult &result) { |
|
|
|
return SUCCESS; |
|
|
|
} |
|
|
|
|
|
|
|
Status TransFormatHwcnToFzWithGroups(const TransArgs &args, TransResult &result){ |
|
|
|
int64_t groups = args.groups; |
|
|
|
Status TransFormatHwcnToFzWithGroups(const TransArgs &args, TransResult &result, groups){ |
|
|
|
int64_t h_dim = args.src_shape[kHwcnH]; |
|
|
|
int64_t w_dim = args.src_shape[kHwcnW]; |
|
|
|
int64_t c_dim = args.src_shape[kHwcnC]; |
|
|
|
@@ -488,14 +487,14 @@ Status FormatTransferFractalZ::TransFormat(const TransArgs &args, TransResult &r |
|
|
|
TypeUtils::FormatToSerialString(args.dst_format).c_str(), ShapeToString(args.src_shape).c_str(), |
|
|
|
TypeUtils::DataTypeToSerialString(args.src_data_type).c_str(), ShapeToString(args.dst_shape).c_str()); |
|
|
|
std::vector<int64_t> expect_shape; |
|
|
|
auto ret = TransShape(args.src_format, args.src_shape, args.src_data_type, args.dst_format, expect_shape); |
|
|
|
auto ret = TransShape(args.src_format, args.src_shape, args.src_data_type, args.dst_format, expect_shape, groups); |
|
|
|
if (ret != SUCCESS) { |
|
|
|
return ret; |
|
|
|
} |
|
|
|
if (!IsTransShapeDstCorrect(args, expect_shape)) { |
|
|
|
return PARAM_INVALID; |
|
|
|
} |
|
|
|
return TransFormatHwcnToFzWithGroups(args, result); |
|
|
|
return TransFormatHwcnToFzWithGroups(args, result, groups); |
|
|
|
} |
|
|
|
Status FormatTransferFractalZ::TransFormat(const TransArgs &args, TransResult &result) { |
|
|
|
GELOGD("Begin to trans format from %s to %s, src shape %s, data type %s, dst shape %s", |
|
|
|
@@ -516,8 +515,7 @@ Status FormatTransferFractalZ::TransFormat(const TransArgs &args, TransResult &r |
|
|
|
} |
|
|
|
|
|
|
|
if (args.src_format == FORMAT_HWCN && args.dst_format == FORMAT_FRACTAL_Z) { |
|
|
|
return args.groups == 0 ? TransFormatHwcnToFz(args, result) : |
|
|
|
TransFormatHwcnToFzWithGroups(args, result); |
|
|
|
return TransFormatHwcnToFz(args, result); |
|
|
|
} |
|
|
|
|
|
|
|
if (args.src_format == FORMAT_NCHW && args.dst_format == FORMAT_FRACTAL_Z) { |
|
|
|
@@ -527,6 +525,17 @@ Status FormatTransferFractalZ::TransFormat(const TransArgs &args, TransResult &r |
|
|
|
return UNSUPPORTED; |
|
|
|
} |
|
|
|
|
|
|
|
Status FormatTransferFractalZ::TransShape(Format src_format, const std::vector<int64_t> &src_shape, DataType data_type, |
|
|
|
Format dst_format, std::vector<int64_t> &dst_shape, groups){ |
|
|
|
if (CheckDataTypeSupport(data_type) != SUCCESS) { |
|
|
|
return ACL_ERROR_GE_TRANSSHAPE_DATATYPE_INVALID; |
|
|
|
} |
|
|
|
|
|
|
|
if (src_format == FORMAT_HWCN && dst_format == FORMAT_FRACTAL_Z) { |
|
|
|
return TransShapeHwcnToFzWithGroups(src_shape, data_type, dst_shape, groups); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
Status FormatTransferFractalZ::TransShape(Format src_format, const std::vector<int64_t> &src_shape, DataType data_type, |
|
|
|
Format dst_format, std::vector<int64_t> &dst_shape) { |
|
|
|
if (CheckDataTypeSupport(data_type) != SUCCESS) { |
|
|
|
@@ -537,8 +546,7 @@ Status FormatTransferFractalZ::TransShape(Format src_format, const std::vector<i |
|
|
|
return TransShapeNhwcToFz(src_shape, data_type, dst_shape); |
|
|
|
} |
|
|
|
if (src_format == FORMAT_HWCN && dst_format == FORMAT_FRACTAL_Z) { |
|
|
|
return args.groups == 0 ? TransShapeHwcnToFz(src_shape, data_type, dst_shape) |
|
|
|
: TransShapeHwcnToFzWithGroups(src_shape, data_type, dst_shape, args.groups); |
|
|
|
return TransShapeHwcnToFz(src_shape, data_type, dst_shape); |
|
|
|
} |
|
|
|
if (src_format == FORMAT_NCHW && dst_format == FORMAT_FRACTAL_Z) { |
|
|
|
return TransShapeNchwToFz(src_shape, data_type, dst_shape); |
|
|
|
|