From dbbca6d8dde9f011b86cc83e3f13ef8bc238c5ac Mon Sep 17 00:00:00 2001 From: zk <694972388@qq.com> Date: Tue, 9 Mar 2021 10:09:30 +0800 Subject: [PATCH] Transdata --- .../format_transfer_fractal_z.cc | 28 ++++++++++++------- .../format_transfer_fractal_z.h | 2 ++ ge/host_kernels/transdata_kernel.cc | 2 +- 3 files changed, 21 insertions(+), 11 deletions(-) diff --git a/ge/common/formats/format_transfers/format_transfer_fractal_z.cc b/ge/common/formats/format_transfers/format_transfer_fractal_z.cc index dcb2cf70..b9703d09 100644 --- a/ge/common/formats/format_transfers/format_transfer_fractal_z.cc +++ b/ge/common/formats/format_transfers/format_transfer_fractal_z.cc @@ -143,7 +143,7 @@ Status TransShapeHwcnToFz(const std::vector &src_shape, DataType data_t } TransShapeHwcnToFzWithGroups(const std::vector &src_shape, DataType data_type, std::vector &dst_shape -, args.groups){ +, groups){ if (!CheckShapeValid(src_shape, kHwcnDimsNum)) { return ACL_ERROR_GE_TRANSSHAPE_SHAPE_INVALID; } @@ -153,7 +153,7 @@ TransShapeHwcnToFzWithGroups(const std::vector &src_shape, DataType dat auto c = src_shape.at(kHwcnC); auto n = src_shape.at(kHwcnN); - return TransShapeToFzWithGroups(n, c, h, w, data_type, dst_shape, args.groups); + return TransShapeToFzWithGroups(n, c, h, w, data_type, dst_shape, groups); } @@ -258,8 +258,7 @@ Status TransFormatFromNchwToFz(const TransArgs &args, TransResult &result) { return SUCCESS; } -Status TransFormatHwcnToFzWithGroups(const TransArgs &args, TransResult &result){ - int64_t groups = args.groups; +Status TransFormatHwcnToFzWithGroups(const TransArgs &args, TransResult &result, groups){ int64_t h_dim = args.src_shape[kHwcnH]; int64_t w_dim = args.src_shape[kHwcnW]; int64_t c_dim = args.src_shape[kHwcnC]; @@ -488,14 +487,14 @@ Status FormatTransferFractalZ::TransFormat(const TransArgs &args, TransResult &r TypeUtils::FormatToSerialString(args.dst_format).c_str(), ShapeToString(args.src_shape).c_str(), TypeUtils::DataTypeToSerialString(args.src_data_type).c_str(), ShapeToString(args.dst_shape).c_str()); std::vector expect_shape; - auto ret = TransShape(args.src_format, args.src_shape, args.src_data_type, args.dst_format, expect_shape); + auto ret = TransShape(args.src_format, args.src_shape, args.src_data_type, args.dst_format, expect_shape, groups); if (ret != SUCCESS) { return ret; } if (!IsTransShapeDstCorrect(args, expect_shape)) { return PARAM_INVALID; } - return TransFormatHwcnToFzWithGroups(args, result); + return TransFormatHwcnToFzWithGroups(args, result, groups); } Status FormatTransferFractalZ::TransFormat(const TransArgs &args, TransResult &result) { GELOGD("Begin to trans format from %s to %s, src shape %s, data type %s, dst shape %s", @@ -516,8 +515,7 @@ Status FormatTransferFractalZ::TransFormat(const TransArgs &args, TransResult &r } if (args.src_format == FORMAT_HWCN && args.dst_format == FORMAT_FRACTAL_Z) { - return args.groups == 0 ? TransFormatHwcnToFz(args, result) : - TransFormatHwcnToFzWithGroups(args, result); + return TransFormatHwcnToFz(args, result); } if (args.src_format == FORMAT_NCHW && args.dst_format == FORMAT_FRACTAL_Z) { @@ -527,6 +525,17 @@ Status FormatTransferFractalZ::TransFormat(const TransArgs &args, TransResult &r return UNSUPPORTED; } +Status FormatTransferFractalZ::TransShape(Format src_format, const std::vector &src_shape, DataType data_type, + Format dst_format, std::vector &dst_shape, groups){ +if (CheckDataTypeSupport(data_type) != SUCCESS) { + return ACL_ERROR_GE_TRANSSHAPE_DATATYPE_INVALID; + } + + if (src_format == FORMAT_HWCN && dst_format == FORMAT_FRACTAL_Z) { + return TransShapeHwcnToFzWithGroups(src_shape, data_type, dst_shape, groups); + } +} + Status FormatTransferFractalZ::TransShape(Format src_format, const std::vector &src_shape, DataType data_type, Format dst_format, std::vector &dst_shape) { if (CheckDataTypeSupport(data_type) != SUCCESS) { @@ -537,8 +546,7 @@ Status FormatTransferFractalZ::TransShape(Format src_format, const std::vector &src_shape, DataType data_type, Format dst_format, std::vector &dst_shape) override; + Status TransShape(Format src_format, const std::vector &src_shape, DataType data_type, Format dst_format, + std::vector &dst_shape, int64_t groups) override; }; } // namespace formats } // namespace ge diff --git a/ge/host_kernels/transdata_kernel.cc b/ge/host_kernels/transdata_kernel.cc index 8d9785cf..fa3c9944 100644 --- a/ge/host_kernels/transdata_kernel.cc +++ b/ge/host_kernels/transdata_kernel.cc @@ -114,7 +114,7 @@ Status TransdataKernel::Compute(const OpDescPtr op_desc_ptr, const std::vector