| @@ -85,7 +85,7 @@ Status TransShapeToFz(int64_t n, int64_t c, int64_t h, int64_t w, DataType data_ | |||||
| } | } | ||||
| Status TransShapeToFzWithGroups(int64_t n, int64_t c, int64_t h, int64_t w, DataType data_type, std::vector<int64_t> &dst_shape | Status TransShapeToFzWithGroups(int64_t n, int64_t c, int64_t h, int64_t w, DataType data_type, std::vector<int64_t> &dst_shape | ||||
| , args.groups) { | |||||
| , int64_t groups) { | |||||
| auto c0 = GetCubeSizeByDataType(data_type); | auto c0 = GetCubeSizeByDataType(data_type); | ||||
| if (c0 < 0) { | if (c0 < 0) { | ||||
| return ACL_ERROR_GE_TRANSSHAPE_DATATYPE_INVALID; | return ACL_ERROR_GE_TRANSSHAPE_DATATYPE_INVALID; | ||||
| @@ -100,7 +100,7 @@ Status TransShapeToFzWithGroups(int64_t n, int64_t c, int64_t h, int64_t w, Data | |||||
| int64_t cout_opt = Ceil(e_mult * cout_ori, kCubeN) * kCubeN; | int64_t cout_opt = Ceil(e_mult * cout_ori, kCubeN) * kCubeN; | ||||
| int64_t c1_dim = cin_opt / cube_k; | int64_t c1_dim = cin_opt / cube_k; | ||||
| int64_t g_dim = Ceil(groups, e_mult); | int64_t g_dim = Ceil(groups, e_mult); | ||||
| auto n1 = Ceil(n, 16); | |||||
| auto n1 = Ceil(n, kCubeN); | |||||
| dst_shape.clear(); | dst_shape.clear(); | ||||
| dst_shape.push_back(g_dim * c1_dim * h * w); | dst_shape.push_back(g_dim * c1_dim * h * w); | ||||
| @@ -142,7 +142,7 @@ Status TransShapeHwcnToFz(const std::vector<int64_t> &src_shape, DataType data_t | |||||
| } | } | ||||
| TransShapeHwcnToFzWithGroups(const std::vector<int64_t> &src_shape, DataType data_type, std::vector<int64_t> &dst_shape | TransShapeHwcnToFzWithGroups(const std::vector<int64_t> &src_shape, DataType data_type, std::vector<int64_t> &dst_shape | ||||
| , groups){ | |||||
| , int64_t groups){ | |||||
| if (!CheckShapeValid(src_shape, kHwcnDimsNum)) { | if (!CheckShapeValid(src_shape, kHwcnDimsNum)) { | ||||
| return ACL_ERROR_GE_TRANSSHAPE_SHAPE_INVALID; | return ACL_ERROR_GE_TRANSSHAPE_SHAPE_INVALID; | ||||
| } | } | ||||
| @@ -257,7 +257,7 @@ Status TransFormatFromNchwToFz(const TransArgs &args, TransResult &result) { | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| Status TransFormatHwcnToFzWithGroups(const TransArgs &args, TransResult &result, groups){ | |||||
| Status TransFormatHwcnToFzWithGroups(const TransArgs &args, TransResult &result,int64_t groups){ | |||||
| int64_t h_dim = args.src_shape[kHwcnH]; | int64_t h_dim = args.src_shape[kHwcnH]; | ||||
| int64_t w_dim = args.src_shape[kHwcnW]; | int64_t w_dim = args.src_shape[kHwcnW]; | ||||
| int64_t c_dim = args.src_shape[kHwcnC]; | int64_t c_dim = args.src_shape[kHwcnC]; | ||||
| @@ -525,7 +525,7 @@ Status FormatTransferFractalZ::TransFormat(const TransArgs &args, TransResult &r | |||||
| } | } | ||||
| Status FormatTransferFractalZ::TransShape(Format src_format, const std::vector<int64_t> &src_shape, DataType data_type, | Status FormatTransferFractalZ::TransShape(Format src_format, const std::vector<int64_t> &src_shape, DataType data_type, | ||||
| Format dst_format, std::vector<int64_t> &dst_shape, groups){ | |||||
| Format dst_format, std::vector<int64_t> &dst_shape,int64_t groups){ | |||||
| if (CheckDataTypeSupport(data_type) != SUCCESS) { | if (CheckDataTypeSupport(data_type) != SUCCESS) { | ||||
| return ACL_ERROR_GE_TRANSSHAPE_DATATYPE_INVALID; | return ACL_ERROR_GE_TRANSSHAPE_DATATYPE_INVALID; | ||||
| } | } | ||||