From 792c51fed584c55f585843ed3ca190849eb766ca Mon Sep 17 00:00:00 2001 From: zk <694972388@qq.com> Date: Sat, 6 Mar 2021 16:08:00 +0800 Subject: [PATCH 01/11] Transdata --- .../format_transfer_fractal_z.cc | 149 +++++++++++++++++- 1 file changed, 147 insertions(+), 2 deletions(-) diff --git a/ge/common/formats/format_transfers/format_transfer_fractal_z.cc b/ge/common/formats/format_transfers/format_transfer_fractal_z.cc index 45c6d157..348a84c4 100644 --- a/ge/common/formats/format_transfers/format_transfer_fractal_z.cc +++ b/ge/common/formats/format_transfers/format_transfer_fractal_z.cc @@ -29,6 +29,29 @@ namespace ge { namespace formats { namespace { +constexpr int64_t kCubeN = 16; +constexpr int64_t kGroupNum = 1; +constexpr int64_t kDim = 1; + +static int64_t Measure(int64_t x, int64_t y) { + int64_t z = y; + while (x % y != 0) { + z = x % y; + x = y; + y = z; + } + return z; +} + +// least common multiple +static int64_t Lcm(int64_t a, int64_t b) { + if (b == 0) { + return -1; + } + int64_t temp = (a * b) / (Measure(a, b)); + return temp; +} + Status CheckDataTypeSupport(DataType data_type) { return GetSizeByDataType(data_type) > 0 ? SUCCESS : UNSUPPORTED; } /** @@ -61,6 +84,39 @@ Status TransShapeToFz(int64_t n, int64_t c, int64_t h, int64_t w, DataType data_ return SUCCESS; } +Status TransShapeToFzWithGroups(int64_t n, int64_t c, int64_t h, int64_t w, DataType data_type, std::vector &dst_shape + , args.groups) { + auto c0 = GetCubeSizeByDataType(data_type); + if (c0 < 0) { + return ACL_ERROR_GE_TRANSSHAPE_DATATYPE_INVALID; + } + int64_t groups = args.groups; + int64_t cin_ori = c; + int64_t cout_ori = n / groups; + int64_t cube_k = args.src_data_type == DT_INT8 ? 32 : 16; + int64_t e_mult = std::min( + Lcm(Lcm(cin_ori, cube_k) / (cin_ori), Lcm(cout_ori, kCubeN) / (cout_ori)), + groups); + int64_t cin_opt = Ceil(e_mult * cin_ori, cube_k) * cube_k; + int64_t cout_opt = Ceil(e_mult * cout_ori, kCubeN) * kCubeN; + int64_t c1_dim = cin_opt / cube_k; + int64_t g_dim = Ceil(groups, e_mult); + auto n1 = Ceil(n, 16); + + dst_shape.clear(); + dst_shape.push_back(g_dim * c1_dim * h * w); + dst_shape.push_back(n1); + dst_shape.push_back(16); + dst_shape.push_back(cube_k); + if (!IsShapeValid(dst_shape)) { + GELOGE(ACL_ERROR_GE_TRANSSHAPE_SHAPE_INVALID, "Failed to check dst shape %s", + ShapeToString(dst_shape).c_str()); + return ACL_ERROR_GE_TRANSSHAPE_SHAPE_INVALID; + } + return SUCCESS; +} + + Status TransShapeNchwToFz(const std::vector &src_shape, DataType data_type, std::vector &dst_shape) { if (!CheckShapeValid(src_shape, kNchwDimsNum)) { return ACL_ERROR_GE_TRANSSHAPE_SHAPE_INVALID; @@ -86,6 +142,21 @@ Status TransShapeHwcnToFz(const std::vector &src_shape, DataType data_t return TransShapeToFz(n, c, h, w, data_type, dst_shape); } +TransShapeHwcnToFzWithGroups(const std::vector &src_shape, DataType data_type, std::vector &dst_shape +, args.groups){ + if (!CheckShapeValid(src_shape, kHwcnDimsNum)) { + return ACL_ERROR_GE_TRANSSHAPE_SHAPE_INVALID; + } + + auto h = src_shape.at(kHwcnH); + auto w = src_shape.at(kHwcnW); + auto c = src_shape.at(kHwcnC); + auto n = src_shape.at(kHwcnN); + + return TransShapeToFzWithGroups(n, c, h, w, data_type, dst_shape, args.groups); +} + + Status TransShapeNhwcToFz(const std::vector &src_shape, DataType data_type, std::vector &dst_shape) { if (!CheckShapeValid(src_shape, kNhwcDimsNum)) { return ACL_ERROR_GE_TRANSSHAPE_SHAPE_INVALID; @@ -187,6 +258,78 @@ Status TransFormatFromNchwToFz(const TransArgs &args, TransResult &result) { return SUCCESS; } +Status TransFormatHwcnToFzWithGroups(const TransArgs &args, TransResult &result){ + int64_t groups = args.groups; + int64_t h_dim = args.src_shape[kHwcnH]; + int64_t w_dim = args.src_shape[kHwcnW]; + int64_t c_dim = args.src_shape[kHwcnC]; + int64_t n_dim = args.src_shape[kHwcnN]; + int64_t cin_ori = c_dim; + int64_t cout_ori = n_dim / groups; + if (cin_ori == 0 || cout_ori == 0) { + GELOGE(GRAPH_FAILED, + "Cin_ori, cout_ori must not be equal 0, " + "and current cin_ori, cout_ori, groups are %d %d %d", + cin_ori, cout_ori, groups); + return GRAPH_FAILED; + } + const int64_t cube_k = args.src_data_type == DT_INT8 ? 32 : 16; + int64_t e_mult = std::min( + Lcm(Lcm(cin_ori, cube_k) / (cin_ori), Lcm(cout_ori, kCubeN) / (cout_ori)), + groups); + int64_t cin_opt = Ceil(e_mult * cin_ori, cube_k) * cube_k; + int64_t cout_opt = Ceil(e_mult * cout_ori, kCubeN) * kCubeN; + int64_t c1_dim = cin_opt / cube_k; + int64_t g_dim = Ceil(groups, e_mult); + int64_t dim_cin = cin_opt / cube_k; + int64_t data_size = GetCubeSizeByDataType(args.src_data_type); + int64_t size_output_data = + g_dim * kDim * dim_cin * h_dim * w_dim * cout_opt * cube_k * data_size; + GE_CHK_BOOL_EXEC_NOLOG(size_output_data != 0, result.length = static_cast(size_output_data); + return SUCCESS;); + std::shared_ptr dst(new (std::nothrow) uint8_t[size_output_data], std::default_delete()); + GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( + dst == nullptr, + GELOGE(OUT_OF_MEMORY, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld", + TypeUtils::FormatToSerialString(args.src_format).c_str(), + TypeUtils::FormatToSerialString(args.dst_format).c_str(), size_output_data); + return OUT_OF_MEMORY;); + for (int64_t g = 0; g < groups; g++) { + for (int64_t d = 0; d < kDim; d++) { + for (int64_t c = 0; c < c_dim; c++) { + for (int64_t h = 0; h < h_dim; h++) { + for (int64_t w = 0; w < w_dim; w++) { + for (int64_t n = 0; n < cout_ori; n++) { + int64_t e_val = g % e_mult; + int64_t dst_ci = e_val * cin_ori + c; + int64_t dst_co = e_val * cout_ori + n; + int64_t src_co = g * cout_ori + n; + int64_t tempory = dst_ci % cube_k; + int64_t srx_inx = 0; + int64_t dst_inx = + (g / e_mult) * kDim * c1_dim * h_dim * w_dim * cout_opt * + cube_k + + d * c1_dim * h_dim * w_dim * cout_opt * cube_k + + (dst_ci / cube_k) * h_dim * w_dim * cout_opt * cube_k + + h * w_dim * cout_opt * cube_k + w * cout_opt * cube_k + + dst_co * cube_k + tempory; + srx_inx = d * h_dim * w_dim * c_dim * n_dim + + h * w_dim * c_dim * n_dim + w * c_dim * n_dim + + c * n_dim + src_co; + char *dst_data = reinterpret_cast(dst.get() + dst_inx * data_size); + const char *src_data = reinterpret_cast(args.data + src_idx * data_size); + for (int64_t index = 0; index < data_size; index++) { + *dst_data++ = *src_data++; + } + } + } + } + } + } + } + result.data = dst; + result.length = static_cast(size_output_data); +} Status TransFormatHwcnToFz(const TransArgs &args, TransResult &result) { int64_t h = args.src_shape[kHwcnH]; int64_t w = args.src_shape[kHwcnW]; @@ -357,7 +500,8 @@ Status FormatTransferFractalZ::TransFormat(const TransArgs &args, TransResult &r } if (args.src_format == FORMAT_HWCN && args.dst_format == FORMAT_FRACTAL_Z) { - return TransFormatHwcnToFz(args, result); + return args.groups == 0 ? TransFormatHwcnToFz(args, result) : + TransFormatHwcnToFzWithGroups(args, result); } if (args.src_format == FORMAT_NCHW && args.dst_format == FORMAT_FRACTAL_Z) { @@ -377,7 +521,8 @@ Status FormatTransferFractalZ::TransShape(Format src_format, const std::vector Date: Mon, 8 Mar 2021 09:51:51 +0800 Subject: [PATCH 02/11] Transdata --- .../common/format_transfer_fracz_hwcn_unittest.cc | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/tests/ut/ge/common/format_transfer_fracz_hwcn_unittest.cc b/tests/ut/ge/common/format_transfer_fracz_hwcn_unittest.cc index 25caa741..6217c813 100644 --- a/tests/ut/ge/common/format_transfer_fracz_hwcn_unittest.cc +++ b/tests/ut/ge/common/format_transfer_fracz_hwcn_unittest.cc @@ -6897,5 +6897,20 @@ TEST_F(UtestFormatTransferFracZHwcn, fp16_1c_1n_pad_cn) { EXPECT_EQ((reinterpret_cast(result.data.get()))[i], ret[i]); } } +TEST_F(UtestFormatTransferFracZHwcn, fracz_to_hwcn_fp16_success_with_groups) { + uint16_t data_4d[2 * 2 * 2 * 2] = {1}; + uint16_t data[1 * 1 * 1 * 2 *2 * 16 *16] = {1024}; + int64_t groups = 1; + TransArgs args{ + reinterpret_cast(data_4d), FORMAT_FRACTAL_Z, FORMAT_HWCN, {1, 1, 1, 2, 2, 16, 6}, {2, 2, 2, 2}, DT_FLOAT16 ,groups}; + TransResult result; + + FormatTransferFracZHwcn transfer; + EXPECT_EQ(transfer.TransFormat(args, result), SUCCESS); + EXPECT_EQ(result.length, sizeof(data)); + for (int i = 0; i < sizeof(data) / sizeof(data[0]); ++i) { + EXPECT_EQ((reinterpret_cast(result.data.get()))[i], data[i]); + } +} } // namespace formats } // namespace ge From 3925716297387188dd5c71e3ba11f68de3864c28 Mon Sep 17 00:00:00 2001 From: zk <694972388@qq.com> Date: Mon, 8 Mar 2021 11:27:18 +0800 Subject: [PATCH 03/11] Transdata --- .../format_transfer_fractal_z.cc | 18 +++++++++++++++++- .../format_transfer_fractal_z.h | 1 + ge/host_kernels/transdata_kernel.cc | 14 ++++++++++++-- 3 files changed, 30 insertions(+), 3 deletions(-) diff --git a/ge/common/formats/format_transfers/format_transfer_fractal_z.cc b/ge/common/formats/format_transfers/format_transfer_fractal_z.cc index 348a84c4..dcb2cf70 100644 --- a/ge/common/formats/format_transfers/format_transfer_fractal_z.cc +++ b/ge/common/formats/format_transfers/format_transfer_fractal_z.cc @@ -317,7 +317,7 @@ Status TransFormatHwcnToFzWithGroups(const TransArgs &args, TransResult &result) h * w_dim * c_dim * n_dim + w * c_dim * n_dim + c * n_dim + src_co; char *dst_data = reinterpret_cast(dst.get() + dst_inx * data_size); - const char *src_data = reinterpret_cast(args.data + src_idx * data_size); + const char *src_data = reinterpret_cast(args.data + srx_inx * data_size); for (int64_t index = 0; index < data_size; index++) { *dst_data++ = *src_data++; } @@ -481,6 +481,22 @@ Status TransFormatNhwcToFz(const TransArgs &args, TransResult &result) { } } // namespace +Status FormatTransferFractalZ::TransFormat(const TransArgs &args, TransResult &result, + int64_t groups) { + GELOGD("Begin to trans format from %s to %s, src shape %s, data type %s, dst shape %s", + TypeUtils::FormatToSerialString(args.src_format).c_str(), + TypeUtils::FormatToSerialString(args.dst_format).c_str(), ShapeToString(args.src_shape).c_str(), + TypeUtils::DataTypeToSerialString(args.src_data_type).c_str(), ShapeToString(args.dst_shape).c_str()); + std::vector expect_shape; + auto ret = TransShape(args.src_format, args.src_shape, args.src_data_type, args.dst_format, expect_shape); + if (ret != SUCCESS) { + return ret; + } + if (!IsTransShapeDstCorrect(args, expect_shape)) { + return PARAM_INVALID; + } + return TransFormatHwcnToFzWithGroups(args, result); +} Status FormatTransferFractalZ::TransFormat(const TransArgs &args, TransResult &result) { GELOGD("Begin to trans format from %s to %s, src shape %s, data type %s, dst shape %s", TypeUtils::FormatToSerialString(args.src_format).c_str(), diff --git a/ge/common/formats/format_transfers/format_transfer_fractal_z.h b/ge/common/formats/format_transfers/format_transfer_fractal_z.h index d640eb60..3d286135 100755 --- a/ge/common/formats/format_transfers/format_transfer_fractal_z.h +++ b/ge/common/formats/format_transfers/format_transfer_fractal_z.h @@ -26,6 +26,7 @@ namespace formats { class FormatTransferFractalZ : public FormatTransfer { public: Status TransFormat(const TransArgs &args, TransResult &result) override; + Status TransFormat(const TransArgs &args, TransResult &result, int64_t groups) override; Status TransShape(Format src_format, const std::vector &src_shape, DataType data_type, Format dst_format, std::vector &dst_shape) override; }; diff --git a/ge/host_kernels/transdata_kernel.cc b/ge/host_kernels/transdata_kernel.cc index 2b16b075..8d9785cf 100644 --- a/ge/host_kernels/transdata_kernel.cc +++ b/ge/host_kernels/transdata_kernel.cc @@ -82,16 +82,17 @@ Status TransdataKernel::Compute(const OpDescPtr op_desc_ptr, const std::vectorGetShape().GetDims(); const auto &data_format = op_desc->GetFormat(); const auto &data_type = op_desc->GetDataType(); + const in64_t groups = op_desc_ptr->GetAttr("groups", groups) GELOGD( "current node %s, format %s, input shape %s, data type %s, weight format %s, shape %s, data type %s. " - "output format %s, shape %s, data type %s", + "output format %s, shape %s, data type %s, groups %d", op_desc_ptr->GetName().c_str(), TypeUtils::FormatToSerialString(src_format).c_str(), formats::ShapeToString(src_shape).c_str(), TypeUtils::DataTypeToSerialString(src_data_type).c_str(), TypeUtils::FormatToSerialString(const_weight_ptr->GetTensorDesc().GetFormat()).c_str(), formats::ShapeToString(const_weight_ptr->GetTensorDesc().GetShape()).c_str(), TypeUtils::DataTypeToSerialString(const_weight_ptr->GetTensorDesc().GetDataType()).c_str(), TypeUtils::FormatToSerialString(data_format).c_str(), formats::ShapeToString(data_shape).c_str(), - TypeUtils::DataTypeToSerialString(data_type).c_str()); + TypeUtils::DataTypeToSerialString(data_type).c_str(), groups); const uint8_t *src_data = const_weight_ptr->GetData().data(); const formats::TransArgs trans_args{src_data, src_format, data_format, src_shape, data_shape, src_data_type}; @@ -113,6 +114,15 @@ Status TransdataKernel::Compute(const OpDescPtr op_desc_ptr, const std::vector Date: Tue, 9 Mar 2021 10:09:30 +0800 Subject: [PATCH 04/11] Transdata --- .../format_transfer_fractal_z.cc | 28 ++++++++++++------- .../format_transfer_fractal_z.h | 2 ++ ge/host_kernels/transdata_kernel.cc | 2 +- 3 files changed, 21 insertions(+), 11 deletions(-) diff --git a/ge/common/formats/format_transfers/format_transfer_fractal_z.cc b/ge/common/formats/format_transfers/format_transfer_fractal_z.cc index dcb2cf70..b9703d09 100644 --- a/ge/common/formats/format_transfers/format_transfer_fractal_z.cc +++ b/ge/common/formats/format_transfers/format_transfer_fractal_z.cc @@ -143,7 +143,7 @@ Status TransShapeHwcnToFz(const std::vector &src_shape, DataType data_t } TransShapeHwcnToFzWithGroups(const std::vector &src_shape, DataType data_type, std::vector &dst_shape -, args.groups){ +, groups){ if (!CheckShapeValid(src_shape, kHwcnDimsNum)) { return ACL_ERROR_GE_TRANSSHAPE_SHAPE_INVALID; } @@ -153,7 +153,7 @@ TransShapeHwcnToFzWithGroups(const std::vector &src_shape, DataType dat auto c = src_shape.at(kHwcnC); auto n = src_shape.at(kHwcnN); - return TransShapeToFzWithGroups(n, c, h, w, data_type, dst_shape, args.groups); + return TransShapeToFzWithGroups(n, c, h, w, data_type, dst_shape, groups); } @@ -258,8 +258,7 @@ Status TransFormatFromNchwToFz(const TransArgs &args, TransResult &result) { return SUCCESS; } -Status TransFormatHwcnToFzWithGroups(const TransArgs &args, TransResult &result){ - int64_t groups = args.groups; +Status TransFormatHwcnToFzWithGroups(const TransArgs &args, TransResult &result, groups){ int64_t h_dim = args.src_shape[kHwcnH]; int64_t w_dim = args.src_shape[kHwcnW]; int64_t c_dim = args.src_shape[kHwcnC]; @@ -488,14 +487,14 @@ Status FormatTransferFractalZ::TransFormat(const TransArgs &args, TransResult &r TypeUtils::FormatToSerialString(args.dst_format).c_str(), ShapeToString(args.src_shape).c_str(), TypeUtils::DataTypeToSerialString(args.src_data_type).c_str(), ShapeToString(args.dst_shape).c_str()); std::vector expect_shape; - auto ret = TransShape(args.src_format, args.src_shape, args.src_data_type, args.dst_format, expect_shape); + auto ret = TransShape(args.src_format, args.src_shape, args.src_data_type, args.dst_format, expect_shape, groups); if (ret != SUCCESS) { return ret; } if (!IsTransShapeDstCorrect(args, expect_shape)) { return PARAM_INVALID; } - return TransFormatHwcnToFzWithGroups(args, result); + return TransFormatHwcnToFzWithGroups(args, result, groups); } Status FormatTransferFractalZ::TransFormat(const TransArgs &args, TransResult &result) { GELOGD("Begin to trans format from %s to %s, src shape %s, data type %s, dst shape %s", @@ -516,8 +515,7 @@ Status FormatTransferFractalZ::TransFormat(const TransArgs &args, TransResult &r } if (args.src_format == FORMAT_HWCN && args.dst_format == FORMAT_FRACTAL_Z) { - return args.groups == 0 ? TransFormatHwcnToFz(args, result) : - TransFormatHwcnToFzWithGroups(args, result); + return TransFormatHwcnToFz(args, result); } if (args.src_format == FORMAT_NCHW && args.dst_format == FORMAT_FRACTAL_Z) { @@ -527,6 +525,17 @@ Status FormatTransferFractalZ::TransFormat(const TransArgs &args, TransResult &r return UNSUPPORTED; } +Status FormatTransferFractalZ::TransShape(Format src_format, const std::vector &src_shape, DataType data_type, + Format dst_format, std::vector &dst_shape, groups){ +if (CheckDataTypeSupport(data_type) != SUCCESS) { + return ACL_ERROR_GE_TRANSSHAPE_DATATYPE_INVALID; + } + + if (src_format == FORMAT_HWCN && dst_format == FORMAT_FRACTAL_Z) { + return TransShapeHwcnToFzWithGroups(src_shape, data_type, dst_shape, groups); + } +} + Status FormatTransferFractalZ::TransShape(Format src_format, const std::vector &src_shape, DataType data_type, Format dst_format, std::vector &dst_shape) { if (CheckDataTypeSupport(data_type) != SUCCESS) { @@ -537,8 +546,7 @@ Status FormatTransferFractalZ::TransShape(Format src_format, const std::vector &src_shape, DataType data_type, Format dst_format, std::vector &dst_shape) override; + Status TransShape(Format src_format, const std::vector &src_shape, DataType data_type, Format dst_format, + std::vector &dst_shape, int64_t groups) override; }; } // namespace formats } // namespace ge diff --git a/ge/host_kernels/transdata_kernel.cc b/ge/host_kernels/transdata_kernel.cc index 8d9785cf..fa3c9944 100644 --- a/ge/host_kernels/transdata_kernel.cc +++ b/ge/host_kernels/transdata_kernel.cc @@ -114,7 +114,7 @@ Status TransdataKernel::Compute(const OpDescPtr op_desc_ptr, const std::vector Date: Tue, 9 Mar 2021 11:01:15 +0800 Subject: [PATCH 05/11] Tranadata --- ge/common/formats/format_transfers/format_transfer_fractal_z.cc | 1 - 1 file changed, 1 deletion(-) diff --git a/ge/common/formats/format_transfers/format_transfer_fractal_z.cc b/ge/common/formats/format_transfers/format_transfer_fractal_z.cc index b9703d09..8b733d53 100644 --- a/ge/common/formats/format_transfers/format_transfer_fractal_z.cc +++ b/ge/common/formats/format_transfers/format_transfer_fractal_z.cc @@ -90,7 +90,6 @@ Status TransShapeToFzWithGroups(int64_t n, int64_t c, int64_t h, int64_t w, Data if (c0 < 0) { return ACL_ERROR_GE_TRANSSHAPE_DATATYPE_INVALID; } - int64_t groups = args.groups; int64_t cin_ori = c; int64_t cout_ori = n / groups; int64_t cube_k = args.src_data_type == DT_INT8 ? 32 : 16; From 004daded46a39cc934b117e347aa9e09bf31effc Mon Sep 17 00:00:00 2001 From: zk <694972388@qq.com> Date: Tue, 9 Mar 2021 11:20:39 +0800 Subject: [PATCH 06/11] Transdata --- .../format_transfers/format_transfer_fractal_z.cc | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/ge/common/formats/format_transfers/format_transfer_fractal_z.cc b/ge/common/formats/format_transfers/format_transfer_fractal_z.cc index 8b733d53..15312fad 100644 --- a/ge/common/formats/format_transfers/format_transfer_fractal_z.cc +++ b/ge/common/formats/format_transfers/format_transfer_fractal_z.cc @@ -85,7 +85,7 @@ Status TransShapeToFz(int64_t n, int64_t c, int64_t h, int64_t w, DataType data_ } Status TransShapeToFzWithGroups(int64_t n, int64_t c, int64_t h, int64_t w, DataType data_type, std::vector &dst_shape - , args.groups) { + , int64_t groups) { auto c0 = GetCubeSizeByDataType(data_type); if (c0 < 0) { return ACL_ERROR_GE_TRANSSHAPE_DATATYPE_INVALID; @@ -100,7 +100,7 @@ Status TransShapeToFzWithGroups(int64_t n, int64_t c, int64_t h, int64_t w, Data int64_t cout_opt = Ceil(e_mult * cout_ori, kCubeN) * kCubeN; int64_t c1_dim = cin_opt / cube_k; int64_t g_dim = Ceil(groups, e_mult); - auto n1 = Ceil(n, 16); + auto n1 = Ceil(n, kCubeN); dst_shape.clear(); dst_shape.push_back(g_dim * c1_dim * h * w); @@ -142,7 +142,7 @@ Status TransShapeHwcnToFz(const std::vector &src_shape, DataType data_t } TransShapeHwcnToFzWithGroups(const std::vector &src_shape, DataType data_type, std::vector &dst_shape -, groups){ +, int64_t groups){ if (!CheckShapeValid(src_shape, kHwcnDimsNum)) { return ACL_ERROR_GE_TRANSSHAPE_SHAPE_INVALID; } @@ -257,7 +257,7 @@ Status TransFormatFromNchwToFz(const TransArgs &args, TransResult &result) { return SUCCESS; } -Status TransFormatHwcnToFzWithGroups(const TransArgs &args, TransResult &result, groups){ +Status TransFormatHwcnToFzWithGroups(const TransArgs &args, TransResult &result,int64_t groups){ int64_t h_dim = args.src_shape[kHwcnH]; int64_t w_dim = args.src_shape[kHwcnW]; int64_t c_dim = args.src_shape[kHwcnC]; @@ -525,7 +525,7 @@ Status FormatTransferFractalZ::TransFormat(const TransArgs &args, TransResult &r } Status FormatTransferFractalZ::TransShape(Format src_format, const std::vector &src_shape, DataType data_type, - Format dst_format, std::vector &dst_shape, groups){ + Format dst_format, std::vector &dst_shape,int64_t groups){ if (CheckDataTypeSupport(data_type) != SUCCESS) { return ACL_ERROR_GE_TRANSSHAPE_DATATYPE_INVALID; } From 70b4500345b68e40c22df6c409821738bfac606d Mon Sep 17 00:00:00 2001 From: zk <694972388@qq.com> Date: Tue, 9 Mar 2021 14:57:11 +0800 Subject: [PATCH 07/11] Transdata --- .../formats/format_transfers/format_transfer_fractal_z.cc | 5 ++--- .../formats/format_transfers/format_transfer_fractal_z.h | 4 ++-- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/ge/common/formats/format_transfers/format_transfer_fractal_z.cc b/ge/common/formats/format_transfers/format_transfer_fractal_z.cc index 15312fad..6d055641 100644 --- a/ge/common/formats/format_transfers/format_transfer_fractal_z.cc +++ b/ge/common/formats/format_transfers/format_transfer_fractal_z.cc @@ -92,7 +92,7 @@ Status TransShapeToFzWithGroups(int64_t n, int64_t c, int64_t h, int64_t w, Data } int64_t cin_ori = c; int64_t cout_ori = n / groups; - int64_t cube_k = args.src_data_type == DT_INT8 ? 32 : 16; + int64_t cube_k = data_type == DT_INT8 ? 32 : 16; int64_t e_mult = std::min( Lcm(Lcm(cin_ori, cube_k) / (cin_ori), Lcm(cout_ori, kCubeN) / (cout_ori)), groups); @@ -100,8 +100,7 @@ Status TransShapeToFzWithGroups(int64_t n, int64_t c, int64_t h, int64_t w, Data int64_t cout_opt = Ceil(e_mult * cout_ori, kCubeN) * kCubeN; int64_t c1_dim = cin_opt / cube_k; int64_t g_dim = Ceil(groups, e_mult); - auto n1 = Ceil(n, kCubeN); - + auto n1 = Ceil(n , kCubeN); dst_shape.clear(); dst_shape.push_back(g_dim * c1_dim * h * w); dst_shape.push_back(n1); diff --git a/ge/common/formats/format_transfers/format_transfer_fractal_z.h b/ge/common/formats/format_transfers/format_transfer_fractal_z.h index f59378c9..b962eb69 100755 --- a/ge/common/formats/format_transfers/format_transfer_fractal_z.h +++ b/ge/common/formats/format_transfers/format_transfer_fractal_z.h @@ -26,11 +26,11 @@ namespace formats { class FormatTransferFractalZ : public FormatTransfer { public: Status TransFormat(const TransArgs &args, TransResult &result) override; - Status TransFormat(const TransArgs &args, TransResult &result, int64_t groups) override; + Status TransFormat(const TransArgs &args, TransResult &result, int64_t groups); Status TransShape(Format src_format, const std::vector &src_shape, DataType data_type, Format dst_format, std::vector &dst_shape) override; Status TransShape(Format src_format, const std::vector &src_shape, DataType data_type, Format dst_format, - std::vector &dst_shape, int64_t groups) override; + std::vector &dst_shape, int64_t groups); }; } // namespace formats } // namespace ge From 8fce848f87d489a36ea054862fe1d1dd357b00e8 Mon Sep 17 00:00:00 2001 From: zk <694972388@qq.com> Date: Tue, 9 Mar 2021 15:11:53 +0800 Subject: [PATCH 08/11] Transdata --- .../format_transfers/format_transfer_fractal_z.cc | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/ge/common/formats/format_transfers/format_transfer_fractal_z.cc b/ge/common/formats/format_transfers/format_transfer_fractal_z.cc index 6d055641..71571194 100644 --- a/ge/common/formats/format_transfers/format_transfer_fractal_z.cc +++ b/ge/common/formats/format_transfers/format_transfer_fractal_z.cc @@ -114,7 +114,6 @@ Status TransShapeToFzWithGroups(int64_t n, int64_t c, int64_t h, int64_t w, Data return SUCCESS; } - Status TransShapeNchwToFz(const std::vector &src_shape, DataType data_type, std::vector &dst_shape) { if (!CheckShapeValid(src_shape, kNchwDimsNum)) { return ACL_ERROR_GE_TRANSSHAPE_SHAPE_INVALID; @@ -140,7 +139,7 @@ Status TransShapeHwcnToFz(const std::vector &src_shape, DataType data_t return TransShapeToFz(n, c, h, w, data_type, dst_shape); } -TransShapeHwcnToFzWithGroups(const std::vector &src_shape, DataType data_type, std::vector &dst_shape +Status TransShapeHwcnToFzWithGroups(const std::vector &src_shape, DataType data_type, std::vector &dst_shape , int64_t groups){ if (!CheckShapeValid(src_shape, kHwcnDimsNum)) { return ACL_ERROR_GE_TRANSSHAPE_SHAPE_INVALID; @@ -326,6 +325,7 @@ Status TransFormatHwcnToFzWithGroups(const TransArgs &args, TransResult &result, } result.data = dst; result.length = static_cast(size_output_data); + return SUCCESS; } Status TransFormatHwcnToFz(const TransArgs &args, TransResult &result) { int64_t h = args.src_shape[kHwcnH]; @@ -528,12 +528,8 @@ Status FormatTransferFractalZ::TransShape(Format src_format, const std::vector &src_shape, DataType data_type, Format dst_format, std::vector &dst_shape) { if (CheckDataTypeSupport(data_type) != SUCCESS) { From 0c993dd696a98c5b580ad7932fe917eea46e20b8 Mon Sep 17 00:00:00 2001 From: zk <694972388@qq.com> Date: Tue, 9 Mar 2021 15:22:16 +0800 Subject: [PATCH 09/11] Transdata --- .../format_transfer_fractal_z.cc | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/ge/common/formats/format_transfers/format_transfer_fractal_z.cc b/ge/common/formats/format_transfers/format_transfer_fractal_z.cc index 71571194..3c11a6f5 100644 --- a/ge/common/formats/format_transfers/format_transfer_fractal_z.cc +++ b/ge/common/formats/format_transfers/format_transfer_fractal_z.cc @@ -494,6 +494,15 @@ Status FormatTransferFractalZ::TransFormat(const TransArgs &args, TransResult &r } return TransFormatHwcnToFzWithGroups(args, result, groups); } + +Status FormatTransferFractalZ::TransShape(Format src_format, const std::vector &src_shape, DataType data_type, + Format dst_format, std::vector &dst_shape, int64_t groups){ +if (CheckDataTypeSupport(data_type) != SUCCESS) { + return ACL_ERROR_GE_TRANSSHAPE_DATATYPE_INVALID; + } + return TransShapeHwcnToFzWithGroups(src_shape, data_type, dst_shape, groups); +} + Status FormatTransferFractalZ::TransFormat(const TransArgs &args, TransResult &result) { GELOGD("Begin to trans format from %s to %s, src shape %s, data type %s, dst shape %s", TypeUtils::FormatToSerialString(args.src_format).c_str(), @@ -523,13 +532,6 @@ Status FormatTransferFractalZ::TransFormat(const TransArgs &args, TransResult &r return UNSUPPORTED; } -Status FormatTransferFractalZ::TransShape(Format src_format, const std::vector &src_shape, DataType data_type, - Format dst_format, std::vector &dst_shape,int64_t groups){ -if (CheckDataTypeSupport(data_type) != SUCCESS) { - return ACL_ERROR_GE_TRANSSHAPE_DATATYPE_INVALID; - } - return TransShapeHwcnToFzWithGroups(src_shape, data_type, dst_shape, groups); -} Status FormatTransferFractalZ::TransShape(Format src_format, const std::vector &src_shape, DataType data_type, Format dst_format, std::vector &dst_shape) { if (CheckDataTypeSupport(data_type) != SUCCESS) { From 4caf87b0055da8254b28d224a998bb4df4a65f8b Mon Sep 17 00:00:00 2001 From: zk <694972388@qq.com> Date: Tue, 9 Mar 2021 15:24:14 +0800 Subject: [PATCH 10/11] transdata --- ge/common/formats/format_transfers/format_transfer_fractal_z.cc | 2 -- 1 file changed, 2 deletions(-) diff --git a/ge/common/formats/format_transfers/format_transfer_fractal_z.cc b/ge/common/formats/format_transfers/format_transfer_fractal_z.cc index 3c11a6f5..bc71d27d 100644 --- a/ge/common/formats/format_transfers/format_transfer_fractal_z.cc +++ b/ge/common/formats/format_transfers/format_transfer_fractal_z.cc @@ -97,7 +97,6 @@ Status TransShapeToFzWithGroups(int64_t n, int64_t c, int64_t h, int64_t w, Data Lcm(Lcm(cin_ori, cube_k) / (cin_ori), Lcm(cout_ori, kCubeN) / (cout_ori)), groups); int64_t cin_opt = Ceil(e_mult * cin_ori, cube_k) * cube_k; - int64_t cout_opt = Ceil(e_mult * cout_ori, kCubeN) * kCubeN; int64_t c1_dim = cin_opt / cube_k; int64_t g_dim = Ceil(groups, e_mult); auto n1 = Ceil(n , kCubeN); @@ -528,7 +527,6 @@ Status FormatTransferFractalZ::TransFormat(const TransArgs &args, TransResult &r if (args.src_format == FORMAT_NCHW && args.dst_format == FORMAT_FRACTAL_Z) { return TransFormatFromNchwToFz(args, result); } - return UNSUPPORTED; } From 27b678d81efb993b1dbcb783343ec275bcbeeac3 Mon Sep 17 00:00:00 2001 From: zk <694972388@qq.com> Date: Tue, 9 Mar 2021 15:42:14 +0800 Subject: [PATCH 11/11] trandata --- .../format_transfer_fractal_z.cc | 61 +++++++++---------- 1 file changed, 30 insertions(+), 31 deletions(-) diff --git a/ge/common/formats/format_transfers/format_transfer_fractal_z.cc b/ge/common/formats/format_transfers/format_transfer_fractal_z.cc index bc71d27d..3beb6e46 100644 --- a/ge/common/formats/format_transfers/format_transfer_fractal_z.cc +++ b/ge/common/formats/format_transfers/format_transfer_fractal_z.cc @@ -65,7 +65,7 @@ Status CheckDataTypeSupport(DataType data_type) { return GetSizeByDataType(data_ Status TransShapeToFz(int64_t n, int64_t c, int64_t h, int64_t w, DataType data_type, std::vector &dst_shape) { auto c0 = GetCubeSizeByDataType(data_type); if (c0 < 0) { - return ACL_ERROR_GE_TRANSSHAPE_DATATYPE_INVALID; + return ACL_ERROR_GE_DATATYPE_INVALID; } auto c1 = Ceil(c, c0); @@ -77,9 +77,9 @@ Status TransShapeToFz(int64_t n, int64_t c, int64_t h, int64_t w, DataType data_ dst_shape.push_back(kNiSize); dst_shape.push_back(c0); if (!IsShapeValid(dst_shape)) { - GELOGE(ACL_ERROR_GE_TRANSSHAPE_SHAPE_INVALID, "Failed to check dst shape %s", + GELOGE(ACL_ERROR_GE_SHAPE_INVALID, "Failed to check dst shape %s", ShapeToString(dst_shape).c_str()); - return ACL_ERROR_GE_TRANSSHAPE_SHAPE_INVALID; + return ACL_ERROR_GE_SHAPE_INVALID; } return SUCCESS; } @@ -88,7 +88,7 @@ Status TransShapeToFzWithGroups(int64_t n, int64_t c, int64_t h, int64_t w, Data , int64_t groups) { auto c0 = GetCubeSizeByDataType(data_type); if (c0 < 0) { - return ACL_ERROR_GE_TRANSSHAPE_DATATYPE_INVALID; + return ACL_ERROR_GE_DATATYPE_INVALID; } int64_t cin_ori = c; int64_t cout_ori = n / groups; @@ -106,9 +106,9 @@ Status TransShapeToFzWithGroups(int64_t n, int64_t c, int64_t h, int64_t w, Data dst_shape.push_back(16); dst_shape.push_back(cube_k); if (!IsShapeValid(dst_shape)) { - GELOGE(ACL_ERROR_GE_TRANSSHAPE_SHAPE_INVALID, "Failed to check dst shape %s", + GELOGE(ACL_ERROR_GE_SHAPE_INVALID, "Failed to check dst shape %s", ShapeToString(dst_shape).c_str()); - return ACL_ERROR_GE_TRANSSHAPE_SHAPE_INVALID; + return ACL_ERROR_GE_SHAPE_INVALID; } return SUCCESS; } @@ -127,21 +127,20 @@ Status TransShapeNchwToFz(const std::vector &src_shape, DataType data_t Status TransShapeHwcnToFz(const std::vector &src_shape, DataType data_type, std::vector &dst_shape) { if (!CheckShapeValid(src_shape, kHwcnDimsNum)) { - return ACL_ERROR_GE_TRANSSHAPE_SHAPE_INVALID; + return ACL_ERROR_GE_SHAPE_INVALID; } auto h = src_shape.at(kHwcnH); auto w = src_shape.at(kHwcnW); auto c = src_shape.at(kHwcnC); auto n = src_shape.at(kHwcnN); - return TransShapeToFz(n, c, h, w, data_type, dst_shape); } Status TransShapeHwcnToFzWithGroups(const std::vector &src_shape, DataType data_type, std::vector &dst_shape , int64_t groups){ if (!CheckShapeValid(src_shape, kHwcnDimsNum)) { - return ACL_ERROR_GE_TRANSSHAPE_SHAPE_INVALID; + return ACL_ERROR_GE_SHAPE_INVALID; } auto h = src_shape.at(kHwcnH); @@ -155,7 +154,7 @@ Status TransShapeHwcnToFzWithGroups(const std::vector &src_shape, DataT Status TransShapeNhwcToFz(const std::vector &src_shape, DataType data_type, std::vector &dst_shape) { if (!CheckShapeValid(src_shape, kNhwcDimsNum)) { - return ACL_ERROR_GE_TRANSSHAPE_SHAPE_INVALID; + return ACL_ERROR_GE_SHAPE_INVALID; } auto n = src_shape.at(kNhwcN); @@ -194,10 +193,10 @@ Status TransFormatFromNchwToFz(const TransArgs &args, TransResult &result) { std::shared_ptr dst(new (std::nothrow) uint8_t[dst_size], std::default_delete()); GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( dst == nullptr, - GELOGE(OUT_OF_MEMORY, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld", + GELOGE(ACL_ERROR_GE_MEMORY_ALLOCATION, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld", TypeUtils::FormatToSerialString(args.src_format).c_str(), TypeUtils::FormatToSerialString(args.dst_format).c_str(), dst_size); - return OUT_OF_MEMORY;); + return ACL_ERROR_GE_MEMORY_ALLOCATION;); for (int64_t vfi = 0; vfi < vf_cnt; vfi++) { // vertical fractal matrix base index @@ -230,8 +229,8 @@ Status TransFormatFromNchwToFz(const TransArgs &args, TransResult &result) { if (protected_size < size) { std::string error = "Failed to operate the dst memory, protected_size is " + FmtToStr(protected_size) + " and size is " + FmtToStr(size); - GE_ERRORLOG_AND_ERRORMSG(INTERNAL_ERROR, error.c_str()); - return INTERNAL_ERROR; + GE_ERRORLOG_AND_ERRORMSG(ACL_ERROR_GE_PARAM_INVALID, error.c_str()); + return ACL_ERROR_GE_PARAM_INVALID; } char *dst_data = reinterpret_cast(dst.get() + offset); const char *src_data = reinterpret_cast(args.data + src_offset * size); @@ -240,9 +239,9 @@ Status TransFormatFromNchwToFz(const TransArgs &args, TransResult &result) { } } if (ret != EOK) { - GELOGE(INTERNAL_ERROR, "Failed to operate the dst memory at offset %ld, error-code %d pad mode %d", offset, + GELOGE(ACL_ERROR_GE_MEMORY_OPERATE_FAILED, "Failed to operate the dst memory at offset %ld, error-code %d pad mode %d", offset, ret, need_pad_zero); - return INTERNAL_ERROR; + return ACL_ERROR_GE_MEMORY_OPERATE_FAILED; } } } @@ -285,10 +284,10 @@ Status TransFormatHwcnToFzWithGroups(const TransArgs &args, TransResult &result, std::shared_ptr dst(new (std::nothrow) uint8_t[size_output_data], std::default_delete()); GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( dst == nullptr, - GELOGE(OUT_OF_MEMORY, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld", + GELOGE(ACL_ERROR_GE_MEMORY_ALLOCATION, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld", TypeUtils::FormatToSerialString(args.src_format).c_str(), TypeUtils::FormatToSerialString(args.dst_format).c_str(), size_output_data); - return OUT_OF_MEMORY;); + return ACL_ERROR_GE_MEMORY_ALLOCATION;); for (int64_t g = 0; g < groups; g++) { for (int64_t d = 0; d < kDim; d++) { for (int64_t c = 0; c < c_dim; c++) { @@ -352,10 +351,10 @@ Status TransFormatHwcnToFz(const TransArgs &args, TransResult &result) { std::shared_ptr dst(new (std::nothrow) uint8_t[dst_size], std::default_delete()); GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( dst == nullptr, - GELOGE(OUT_OF_MEMORY, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld", + GELOGE(ACL_ERROR_GE_MEMORY_ALLOCATION, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld", TypeUtils::FormatToSerialString(args.src_format).c_str(), TypeUtils::FormatToSerialString(args.dst_format).c_str(), dst_size); - return OUT_OF_MEMORY;); + return ACL_ERROR_GE_MEMORY_ALLOCATION;); for (int64_t c1i = 0; c1i < c1; c1i++) { for (int64_t hi = 0; hi < h; hi++) { @@ -386,9 +385,9 @@ Status TransFormatHwcnToFz(const TransArgs &args, TransResult &result) { } } if (ret != EOK) { - GELOGE(INTERNAL_ERROR, "Failed to operate the dst memory at offset %ld, error-code %d, pad mode %d", + GELOGE(ACL_ERROR_GE_MEMORY_OPERATE_FAILED, "Failed to operate the dst memory at offset %ld, error-code %d, pad mode %d", dst_offset, ret, pad_zero); - return INTERNAL_ERROR; + return ACL_ERROR_GE_MEMORY_OPERATE_FAILED; } } } @@ -427,10 +426,10 @@ Status TransFormatNhwcToFz(const TransArgs &args, TransResult &result) { std::shared_ptr dst(new (std::nothrow) uint8_t[dst_size], std::default_delete()); GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( dst == nullptr, - GELOGE(OUT_OF_MEMORY, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld", + GELOGE(ACL_ERROR_GE_MEMORY_ALLOCATION, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld", TypeUtils::FormatToSerialString(args.src_format).c_str(), TypeUtils::FormatToSerialString(args.dst_format).c_str(), dst_size); - return OUT_OF_MEMORY;); + return ACL_ERROR_GE_MEMORY_ALLOCATION;); for (int64_t c1i = 0; c1i < c1; c1i++) { for (int64_t hi = 0; hi < h; hi++) { @@ -449,9 +448,9 @@ Status TransFormatNhwcToFz(const TransArgs &args, TransResult &result) { static_cast(data_size)); } else { if (protected_size < data_size) { - GELOGE(INTERNAL_ERROR, "Failed to operate the dst memory, protected_size is %ld and size is %ld", + GELOGE(ACL_ERROR_GE_PARAM_INVALID, "Failed to operate the dst memory, protected_size is %ld and size is %ld", protected_size, data_size); - return INTERNAL_ERROR; + return ACL_ERROR_GE_PARAM_INVALID; } int64_t src_idx = n1n0i * hwc + hi * wc + wi * c + (c1i * c0 + c0i); char *dst_data = reinterpret_cast(dst.get() + dst_offset); @@ -461,9 +460,9 @@ Status TransFormatNhwcToFz(const TransArgs &args, TransResult &result) { } } if (ret != EOK) { - GELOGE(INTERNAL_ERROR, "Failed to operate the dst memory at offset %ld, error-code %d, pad mode %d", + GELOGE(ACL_ERROR_GE_MEMORY_OPERATE_FAILED, "Failed to operate the dst memory at offset %ld, error-code %d, pad mode %d", dst_offset, ret, pad_zero); - return INTERNAL_ERROR; + return ACL_ERROR_GE_MEMORY_OPERATE_FAILED; } } } @@ -489,7 +488,7 @@ Status FormatTransferFractalZ::TransFormat(const TransArgs &args, TransResult &r return ret; } if (!IsTransShapeDstCorrect(args, expect_shape)) { - return PARAM_INVALID; + return ACL_ERROR_GE_SHAPE_INVALID; } return TransFormatHwcnToFzWithGroups(args, result, groups); } @@ -513,7 +512,7 @@ Status FormatTransferFractalZ::TransFormat(const TransArgs &args, TransResult &r return ret; } if (!IsTransShapeDstCorrect(args, expect_shape)) { - return PARAM_INVALID; + return ACL_ERROR_GE_SHAPE_INVALID; } if (args.src_format == FORMAT_NHWC && args.dst_format == FORMAT_FRACTAL_Z) { @@ -527,7 +526,7 @@ Status FormatTransferFractalZ::TransFormat(const TransArgs &args, TransResult &r if (args.src_format == FORMAT_NCHW && args.dst_format == FORMAT_FRACTAL_Z) { return TransFormatFromNchwToFz(args, result); } - return UNSUPPORTED; + return ACL_ERROR_GE_FORMAT_INVALID; } Status FormatTransferFractalZ::TransShape(Format src_format, const std::vector &src_shape, DataType data_type,