Browse Source

Transdata

pull/1211/head
zk 4 years ago
parent
commit
004daded46
1 changed files with 5 additions and 5 deletions
  1. +5
    -5
      ge/common/formats/format_transfers/format_transfer_fractal_z.cc

+ 5
- 5
ge/common/formats/format_transfers/format_transfer_fractal_z.cc View File

@@ -85,7 +85,7 @@ Status TransShapeToFz(int64_t n, int64_t c, int64_t h, int64_t w, DataType data_
}

Status TransShapeToFzWithGroups(int64_t n, int64_t c, int64_t h, int64_t w, DataType data_type, std::vector<int64_t> &dst_shape
, args.groups) {
, int64_t groups) {
auto c0 = GetCubeSizeByDataType(data_type);
if (c0 < 0) {
return ACL_ERROR_GE_TRANSSHAPE_DATATYPE_INVALID;
@@ -100,7 +100,7 @@ Status TransShapeToFzWithGroups(int64_t n, int64_t c, int64_t h, int64_t w, Data
int64_t cout_opt = Ceil(e_mult * cout_ori, kCubeN) * kCubeN;
int64_t c1_dim = cin_opt / cube_k;
int64_t g_dim = Ceil(groups, e_mult);
auto n1 = Ceil(n, 16);
auto n1 = Ceil(n, kCubeN);

dst_shape.clear();
dst_shape.push_back(g_dim * c1_dim * h * w);
@@ -142,7 +142,7 @@ Status TransShapeHwcnToFz(const std::vector<int64_t> &src_shape, DataType data_t
}

TransShapeHwcnToFzWithGroups(const std::vector<int64_t> &src_shape, DataType data_type, std::vector<int64_t> &dst_shape
, groups){
, int64_t groups){
if (!CheckShapeValid(src_shape, kHwcnDimsNum)) {
return ACL_ERROR_GE_TRANSSHAPE_SHAPE_INVALID;
}
@@ -257,7 +257,7 @@ Status TransFormatFromNchwToFz(const TransArgs &args, TransResult &result) {
return SUCCESS;
}

Status TransFormatHwcnToFzWithGroups(const TransArgs &args, TransResult &result, groups){
Status TransFormatHwcnToFzWithGroups(const TransArgs &args, TransResult &result,int64_t groups){
int64_t h_dim = args.src_shape[kHwcnH];
int64_t w_dim = args.src_shape[kHwcnW];
int64_t c_dim = args.src_shape[kHwcnC];
@@ -525,7 +525,7 @@ Status FormatTransferFractalZ::TransFormat(const TransArgs &args, TransResult &r
}

Status FormatTransferFractalZ::TransShape(Format src_format, const std::vector<int64_t> &src_shape, DataType data_type,
Format dst_format, std::vector<int64_t> &dst_shape, groups){
Format dst_format, std::vector<int64_t> &dst_shape,int64_t groups){
if (CheckDataTypeSupport(data_type) != SUCCESS) {
return ACL_ERROR_GE_TRANSSHAPE_DATATYPE_INVALID;
}


Loading…
Cancel
Save