|
|
|
@@ -85,7 +85,7 @@ Status TransShapeToFz(int64_t n, int64_t c, int64_t h, int64_t w, DataType data_ |
|
|
|
} |
|
|
|
|
|
|
|
Status TransShapeToFzWithGroups(int64_t n, int64_t c, int64_t h, int64_t w, DataType data_type, std::vector<int64_t> &dst_shape |
|
|
|
, args.groups) { |
|
|
|
, int64_t groups) { |
|
|
|
auto c0 = GetCubeSizeByDataType(data_type); |
|
|
|
if (c0 < 0) { |
|
|
|
return ACL_ERROR_GE_TRANSSHAPE_DATATYPE_INVALID; |
|
|
|
@@ -100,7 +100,7 @@ Status TransShapeToFzWithGroups(int64_t n, int64_t c, int64_t h, int64_t w, Data |
|
|
|
int64_t cout_opt = Ceil(e_mult * cout_ori, kCubeN) * kCubeN; |
|
|
|
int64_t c1_dim = cin_opt / cube_k; |
|
|
|
int64_t g_dim = Ceil(groups, e_mult); |
|
|
|
auto n1 = Ceil(n, 16); |
|
|
|
auto n1 = Ceil(n, kCubeN); |
|
|
|
|
|
|
|
dst_shape.clear(); |
|
|
|
dst_shape.push_back(g_dim * c1_dim * h * w); |
|
|
|
@@ -142,7 +142,7 @@ Status TransShapeHwcnToFz(const std::vector<int64_t> &src_shape, DataType data_t |
|
|
|
} |
|
|
|
|
|
|
|
TransShapeHwcnToFzWithGroups(const std::vector<int64_t> &src_shape, DataType data_type, std::vector<int64_t> &dst_shape |
|
|
|
, groups){ |
|
|
|
, int64_t groups){ |
|
|
|
if (!CheckShapeValid(src_shape, kHwcnDimsNum)) { |
|
|
|
return ACL_ERROR_GE_TRANSSHAPE_SHAPE_INVALID; |
|
|
|
} |
|
|
|
@@ -257,7 +257,7 @@ Status TransFormatFromNchwToFz(const TransArgs &args, TransResult &result) { |
|
|
|
return SUCCESS; |
|
|
|
} |
|
|
|
|
|
|
|
Status TransFormatHwcnToFzWithGroups(const TransArgs &args, TransResult &result, groups){ |
|
|
|
Status TransFormatHwcnToFzWithGroups(const TransArgs &args, TransResult &result,int64_t groups){ |
|
|
|
int64_t h_dim = args.src_shape[kHwcnH]; |
|
|
|
int64_t w_dim = args.src_shape[kHwcnW]; |
|
|
|
int64_t c_dim = args.src_shape[kHwcnC]; |
|
|
|
@@ -525,7 +525,7 @@ Status FormatTransferFractalZ::TransFormat(const TransArgs &args, TransResult &r |
|
|
|
} |
|
|
|
|
|
|
|
Status FormatTransferFractalZ::TransShape(Format src_format, const std::vector<int64_t> &src_shape, DataType data_type, |
|
|
|
Format dst_format, std::vector<int64_t> &dst_shape, groups){ |
|
|
|
Format dst_format, std::vector<int64_t> &dst_shape,int64_t groups){ |
|
|
|
if (CheckDataTypeSupport(data_type) != SUCCESS) { |
|
|
|
return ACL_ERROR_GE_TRANSSHAPE_DATATYPE_INVALID; |
|
|
|
} |
|
|
|
|