| @@ -29,6 +29,42 @@ | |||||
| namespace ge { | namespace ge { | ||||
| namespace formats { | namespace formats { | ||||
| namespace { | namespace { | ||||
| constexpr int64_t kCubeN = 16; | |||||
| constexpr int64_t kGroupNum = 1; | |||||
| constexpr int64_t kDim = 1; | |||||
| static int64_t Measure(int64_t x, int64_t y) { | |||||
| int64_t z = y; | |||||
| while (x % y != 0) { | |||||
| z = x % y; | |||||
| x = y; | |||||
| y = z; | |||||
| } | |||||
| return z; | |||||
| } | |||||
| // least common multiple | |||||
| static int64_t Lcm(int64_t a, int64_t b) { | |||||
| if (b == 0) { | |||||
| return -1; | |||||
| } | |||||
| int64_t temp = (a * b) / (Measure(a, b)); | |||||
| return temp; | |||||
| } | |||||
| // get the result of two number divisor and let result round up | |||||
| static int64_t DivCeil(int64_t a, int64_t b) { | |||||
| if (b == 0) { | |||||
| return -1; | |||||
| } else { | |||||
| int64_t ret = a / b; | |||||
| if ((a % b) != 0) { | |||||
| ret++; | |||||
| } | |||||
| return ret; | |||||
| } | |||||
| } | |||||
| Status CheckDataTypeSupport(DataType data_type) { return GetSizeByDataType(data_type) > 0 ? SUCCESS : UNSUPPORTED; } | Status CheckDataTypeSupport(DataType data_type) { return GetSizeByDataType(data_type) > 0 ? SUCCESS : UNSUPPORTED; } | ||||
| /** | /** | ||||
| @@ -61,6 +97,35 @@ Status TransShapeToFz(int64_t n, int64_t c, int64_t h, int64_t w, DataType data_ | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| Status TransShapeToFzWithGroups(int64_t n, int64_t c, int64_t h, int64_t w, DataType data_type, std::vector<int64_t> &dst_shape | |||||
| , int64_t groups) { | |||||
| auto c0 = GetCubeSizeByDataType(data_type); | |||||
| if (c0 < 0) { | |||||
| return ACL_ERROR_GE_DATATYPE_INVALID; | |||||
| } | |||||
| int64_t cin_ori = c; | |||||
| int64_t cout_ori = n / groups; | |||||
| int64_t cube_k = data_type == DT_INT8 ? 32 : 16; | |||||
| int64_t e_mult = std::min( | |||||
| Lcm(Lcm(cin_ori, cube_k) / (cin_ori), Lcm(cout_ori, kCubeN) / (cout_ori)), | |||||
| groups); | |||||
| int64_t cin_opt = DivCeil(e_mult * cin_ori, cube_k) * cube_k; | |||||
| int64_t c1_dim = cin_opt / cube_k; | |||||
| int64_t g_dim = DivCeil(groups, e_mult); | |||||
| auto n1 = DivCeil(cout_ori * e_mult, kCubeN); | |||||
| dst_shape.clear(); | |||||
| dst_shape.push_back(g_dim * c1_dim * h * w); | |||||
| dst_shape.push_back(n1); | |||||
| dst_shape.push_back(16); | |||||
| dst_shape.push_back(cube_k); | |||||
| if (!IsShapeValid(dst_shape)) { | |||||
| GELOGE(ACL_ERROR_GE_SHAPE_INVALID, "Failed to check dst shape %s", | |||||
| ShapeToString(dst_shape).c_str()); | |||||
| return ACL_ERROR_GE_SHAPE_INVALID; | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| Status TransShapeNchwToFz(const std::vector<int64_t> &src_shape, DataType data_type, std::vector<int64_t> &dst_shape) { | Status TransShapeNchwToFz(const std::vector<int64_t> &src_shape, DataType data_type, std::vector<int64_t> &dst_shape) { | ||||
| if (!CheckShapeValid(src_shape, kNchwDimsNum)) { | if (!CheckShapeValid(src_shape, kNchwDimsNum)) { | ||||
| return ACL_ERROR_GE_SHAPE_INVALID; | return ACL_ERROR_GE_SHAPE_INVALID; | ||||
| @@ -82,10 +147,24 @@ Status TransShapeHwcnToFz(const std::vector<int64_t> &src_shape, DataType data_t | |||||
| auto w = src_shape.at(kHwcnW); | auto w = src_shape.at(kHwcnW); | ||||
| auto c = src_shape.at(kHwcnC); | auto c = src_shape.at(kHwcnC); | ||||
| auto n = src_shape.at(kHwcnN); | auto n = src_shape.at(kHwcnN); | ||||
| return TransShapeToFz(n, c, h, w, data_type, dst_shape); | return TransShapeToFz(n, c, h, w, data_type, dst_shape); | ||||
| } | } | ||||
| Status TransShapeHwcnToFzWithGroups(const std::vector<int64_t> &src_shape, DataType data_type, std::vector<int64_t> &dst_shape | |||||
| , int64_t groups){ | |||||
| if (!CheckShapeValid(src_shape, kHwcnDimsNum)) { | |||||
| return ACL_ERROR_GE_SHAPE_INVALID; | |||||
| } | |||||
| auto h = src_shape.at(kHwcnH); | |||||
| auto w = src_shape.at(kHwcnW); | |||||
| auto c = src_shape.at(kHwcnC); | |||||
| auto n = src_shape.at(kHwcnN); | |||||
| return TransShapeToFzWithGroups(n, c, h, w, data_type, dst_shape, groups); | |||||
| } | |||||
| Status TransShapeNhwcToFz(const std::vector<int64_t> &src_shape, DataType data_type, std::vector<int64_t> &dst_shape) { | Status TransShapeNhwcToFz(const std::vector<int64_t> &src_shape, DataType data_type, std::vector<int64_t> &dst_shape) { | ||||
| if (!CheckShapeValid(src_shape, kNhwcDimsNum)) { | if (!CheckShapeValid(src_shape, kNhwcDimsNum)) { | ||||
| return ACL_ERROR_GE_SHAPE_INVALID; | return ACL_ERROR_GE_SHAPE_INVALID; | ||||
| @@ -187,6 +266,78 @@ Status TransFormatFromNchwToFz(const TransArgs &args, TransResult &result) { | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| Status TransFormatHwcnToFzWithGroups(const TransArgs &args, TransResult &result,int64_t groups){ | |||||
| int64_t h_dim = args.src_shape[kHwcnH]; | |||||
| int64_t w_dim = args.src_shape[kHwcnW]; | |||||
| int64_t c_dim = args.src_shape[kHwcnC]; | |||||
| int64_t n_dim = args.src_shape[kHwcnN]; | |||||
| int64_t cin_ori = c_dim; | |||||
| int64_t cout_ori = n_dim / groups; | |||||
| if (cin_ori == 0 || cout_ori == 0) { | |||||
| GELOGE(GRAPH_FAILED, | |||||
| "Cin_ori, cout_ori must not be equal 0, " | |||||
| "and current cin_ori, cout_ori, groups are %d %d %d", | |||||
| cin_ori, cout_ori, groups); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| const int64_t cube_k = args.src_data_type == DT_INT8 ? 32 : 16; | |||||
| int64_t e_mult = std::min( | |||||
| Lcm(Lcm(cin_ori, cube_k) / (cin_ori), Lcm(cout_ori, kCubeN) / (cout_ori)), | |||||
| groups); | |||||
| int64_t cin_opt = DivCeil(e_mult * cin_ori, cube_k) * cube_k; | |||||
| int64_t cout_opt = DivCeil(e_mult * cout_ori, kCubeN) * kCubeN; | |||||
| int64_t c1_dim = cin_opt / cube_k; | |||||
| int64_t g_dim = DivCeil(groups, e_mult); | |||||
| int64_t dim_cin = cin_opt / cube_k; | |||||
| int64_t data_size = GetCubeSizeByDataType(args.src_data_type); | |||||
| int64_t size_output_data = | |||||
| g_dim * kDim * dim_cin * h_dim * w_dim * cout_opt * cube_k * data_size; | |||||
| GE_CHK_BOOL_EXEC_NOLOG(size_output_data != 0, result.length = static_cast<size_t>(size_output_data); | |||||
| return SUCCESS;); | |||||
| std::shared_ptr<uint8_t> dst(new (std::nothrow) uint8_t[size_output_data], std::default_delete<uint8_t[]>()); | |||||
| GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( | |||||
| dst == nullptr, | |||||
| GELOGE(ACL_ERROR_GE_MEMORY_ALLOCATION, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld", | |||||
| TypeUtils::FormatToSerialString(args.src_format).c_str(), | |||||
| TypeUtils::FormatToSerialString(args.dst_format).c_str(), size_output_data); | |||||
| return ACL_ERROR_GE_MEMORY_ALLOCATION;); | |||||
| for (int64_t g = 0; g < groups; g++) { | |||||
| for (int64_t d = 0; d < kDim; d++) { | |||||
| for (int64_t c = 0; c < c_dim; c++) { | |||||
| for (int64_t h = 0; h < h_dim; h++) { | |||||
| for (int64_t w = 0; w < w_dim; w++) { | |||||
| for (int64_t n = 0; n < cout_ori; n++) { | |||||
| int64_t e_val = g % e_mult; | |||||
| int64_t dst_ci = e_val * cin_ori + c; | |||||
| int64_t dst_co = e_val * cout_ori + n; | |||||
| int64_t src_co = g * cout_ori + n; | |||||
| int64_t tempory = dst_ci % cube_k; | |||||
| int64_t srx_inx = 0; | |||||
| int64_t dst_inx = | |||||
| (g / e_mult) * kDim * c1_dim * h_dim * w_dim * cout_opt * | |||||
| cube_k + | |||||
| d * c1_dim * h_dim * w_dim * cout_opt * cube_k + | |||||
| (dst_ci / cube_k) * h_dim * w_dim * cout_opt * cube_k + | |||||
| h * w_dim * cout_opt * cube_k + w * cout_opt * cube_k + | |||||
| dst_co * cube_k + tempory; | |||||
| srx_inx = d * h_dim * w_dim * c_dim * n_dim + | |||||
| h * w_dim * c_dim * n_dim + w * c_dim * n_dim + | |||||
| c * n_dim + src_co; | |||||
| char *dst_data = reinterpret_cast<char *>(dst.get() + dst_inx * data_size); | |||||
| const char *src_data = reinterpret_cast<const char *>(args.data + srx_inx * data_size); | |||||
| for (int64_t index = 0; index < data_size; index++) { | |||||
| *dst_data++ = *src_data++; | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| result.data = dst; | |||||
| result.length = static_cast<size_t>(size_output_data); | |||||
| return SUCCESS; | |||||
| } | |||||
| Status TransFormatHwcnToFz(const TransArgs &args, TransResult &result) { | Status TransFormatHwcnToFz(const TransArgs &args, TransResult &result) { | ||||
| int64_t h = args.src_shape[kHwcnH]; | int64_t h = args.src_shape[kHwcnH]; | ||||
| int64_t w = args.src_shape[kHwcnW]; | int64_t w = args.src_shape[kHwcnW]; | ||||
| @@ -355,15 +506,16 @@ Status FormatTransferFractalZ::TransFormat(const TransArgs &args, TransResult &r | |||||
| if (args.src_format == FORMAT_NHWC && args.dst_format == FORMAT_FRACTAL_Z) { | if (args.src_format == FORMAT_NHWC && args.dst_format == FORMAT_FRACTAL_Z) { | ||||
| return TransFormatNhwcToFz(args, result); | return TransFormatNhwcToFz(args, result); | ||||
| } | } | ||||
| if (args.src_format == FORMAT_HWCN && args.dst_format == FORMAT_FRACTAL_Z) { | |||||
| if ((args.src_format == FORMAT_HWCN) && (GetPrimaryFormat(args.dst_format) == FORMAT_FRACTAL_Z)) { | |||||
| if (GetSubFormat(args.dst_format) >= 1) { | |||||
| return TransFormatHwcnToFzWithGroups(args, result, GetSubFormat(args.dst_format)); | |||||
| } | |||||
| return TransFormatHwcnToFz(args, result); | return TransFormatHwcnToFz(args, result); | ||||
| } | } | ||||
| if (args.src_format == FORMAT_NCHW && args.dst_format == FORMAT_FRACTAL_Z) { | if (args.src_format == FORMAT_NCHW && args.dst_format == FORMAT_FRACTAL_Z) { | ||||
| return TransFormatFromNchwToFz(args, result); | return TransFormatFromNchwToFz(args, result); | ||||
| } | } | ||||
| return ACL_ERROR_GE_FORMAT_INVALID; | return ACL_ERROR_GE_FORMAT_INVALID; | ||||
| } | } | ||||
| @@ -376,7 +528,10 @@ Status FormatTransferFractalZ::TransShape(Format src_format, const std::vector<i | |||||
| if (src_format == FORMAT_NHWC && dst_format == FORMAT_FRACTAL_Z) { | if (src_format == FORMAT_NHWC && dst_format == FORMAT_FRACTAL_Z) { | ||||
| return TransShapeNhwcToFz(src_shape, data_type, dst_shape); | return TransShapeNhwcToFz(src_shape, data_type, dst_shape); | ||||
| } | } | ||||
| if (src_format == FORMAT_HWCN && dst_format == FORMAT_FRACTAL_Z) { | |||||
| if ((src_format == FORMAT_HWCN) && (GetPrimaryFormat(dst_format) == FORMAT_FRACTAL_Z)) { | |||||
| if (GetSubFormat(dst_format) >= 1) { | |||||
| return TransShapeHwcnToFzWithGroups(src_shape, data_type, dst_shape, GetSubFormat(dst_format)); | |||||
| } | |||||
| return TransShapeHwcnToFz(src_shape, data_type, dst_shape); | return TransShapeHwcnToFz(src_shape, data_type, dst_shape); | ||||
| } | } | ||||
| if (src_format == FORMAT_NCHW && dst_format == FORMAT_FRACTAL_Z) { | if (src_format == FORMAT_NCHW && dst_format == FORMAT_FRACTAL_Z) { | ||||