Browse Source

Transdata

pull/1211/head
zk 4 years ago
parent
commit
dbbca6d8dd
3 changed files with 21 additions and 11 deletions
  1. +18
    -10
      ge/common/formats/format_transfers/format_transfer_fractal_z.cc
  2. +2
    -0
      ge/common/formats/format_transfers/format_transfer_fractal_z.h
  3. +1
    -1
      ge/host_kernels/transdata_kernel.cc

+ 18
- 10
ge/common/formats/format_transfers/format_transfer_fractal_z.cc View File

@@ -143,7 +143,7 @@ Status TransShapeHwcnToFz(const std::vector<int64_t> &src_shape, DataType data_t
}

TransShapeHwcnToFzWithGroups(const std::vector<int64_t> &src_shape, DataType data_type, std::vector<int64_t> &dst_shape
, args.groups){
, groups){
if (!CheckShapeValid(src_shape, kHwcnDimsNum)) {
return ACL_ERROR_GE_TRANSSHAPE_SHAPE_INVALID;
}
@@ -153,7 +153,7 @@ TransShapeHwcnToFzWithGroups(const std::vector<int64_t> &src_shape, DataType dat
auto c = src_shape.at(kHwcnC);
auto n = src_shape.at(kHwcnN);

return TransShapeToFzWithGroups(n, c, h, w, data_type, dst_shape, args.groups);
return TransShapeToFzWithGroups(n, c, h, w, data_type, dst_shape, groups);
}


@@ -258,8 +258,7 @@ Status TransFormatFromNchwToFz(const TransArgs &args, TransResult &result) {
return SUCCESS;
}

Status TransFormatHwcnToFzWithGroups(const TransArgs &args, TransResult &result){
int64_t groups = args.groups;
Status TransFormatHwcnToFzWithGroups(const TransArgs &args, TransResult &result, groups){
int64_t h_dim = args.src_shape[kHwcnH];
int64_t w_dim = args.src_shape[kHwcnW];
int64_t c_dim = args.src_shape[kHwcnC];
@@ -488,14 +487,14 @@ Status FormatTransferFractalZ::TransFormat(const TransArgs &args, TransResult &r
TypeUtils::FormatToSerialString(args.dst_format).c_str(), ShapeToString(args.src_shape).c_str(),
TypeUtils::DataTypeToSerialString(args.src_data_type).c_str(), ShapeToString(args.dst_shape).c_str());
std::vector<int64_t> expect_shape;
auto ret = TransShape(args.src_format, args.src_shape, args.src_data_type, args.dst_format, expect_shape);
auto ret = TransShape(args.src_format, args.src_shape, args.src_data_type, args.dst_format, expect_shape, groups);
if (ret != SUCCESS) {
return ret;
}
if (!IsTransShapeDstCorrect(args, expect_shape)) {
return PARAM_INVALID;
}
return TransFormatHwcnToFzWithGroups(args, result);
return TransFormatHwcnToFzWithGroups(args, result, groups);
}
Status FormatTransferFractalZ::TransFormat(const TransArgs &args, TransResult &result) {
GELOGD("Begin to trans format from %s to %s, src shape %s, data type %s, dst shape %s",
@@ -516,8 +515,7 @@ Status FormatTransferFractalZ::TransFormat(const TransArgs &args, TransResult &r
}

if (args.src_format == FORMAT_HWCN && args.dst_format == FORMAT_FRACTAL_Z) {
return args.groups == 0 ? TransFormatHwcnToFz(args, result) :
TransFormatHwcnToFzWithGroups(args, result);
return TransFormatHwcnToFz(args, result);
}

if (args.src_format == FORMAT_NCHW && args.dst_format == FORMAT_FRACTAL_Z) {
@@ -527,6 +525,17 @@ Status FormatTransferFractalZ::TransFormat(const TransArgs &args, TransResult &r
return UNSUPPORTED;
}

Status FormatTransferFractalZ::TransShape(Format src_format, const std::vector<int64_t> &src_shape, DataType data_type,
Format dst_format, std::vector<int64_t> &dst_shape, groups){
if (CheckDataTypeSupport(data_type) != SUCCESS) {
return ACL_ERROR_GE_TRANSSHAPE_DATATYPE_INVALID;
}

if (src_format == FORMAT_HWCN && dst_format == FORMAT_FRACTAL_Z) {
return TransShapeHwcnToFzWithGroups(src_shape, data_type, dst_shape, groups);
}
}

Status FormatTransferFractalZ::TransShape(Format src_format, const std::vector<int64_t> &src_shape, DataType data_type,
Format dst_format, std::vector<int64_t> &dst_shape) {
if (CheckDataTypeSupport(data_type) != SUCCESS) {
@@ -537,8 +546,7 @@ Status FormatTransferFractalZ::TransShape(Format src_format, const std::vector<i
return TransShapeNhwcToFz(src_shape, data_type, dst_shape);
}
if (src_format == FORMAT_HWCN && dst_format == FORMAT_FRACTAL_Z) {
return args.groups == 0 ? TransShapeHwcnToFz(src_shape, data_type, dst_shape)
: TransShapeHwcnToFzWithGroups(src_shape, data_type, dst_shape, args.groups);
return TransShapeHwcnToFz(src_shape, data_type, dst_shape);
}
if (src_format == FORMAT_NCHW && dst_format == FORMAT_FRACTAL_Z) {
return TransShapeNchwToFz(src_shape, data_type, dst_shape);


+ 2
- 0
ge/common/formats/format_transfers/format_transfer_fractal_z.h View File

@@ -29,6 +29,8 @@ class FormatTransferFractalZ : public FormatTransfer {
Status TransFormat(const TransArgs &args, TransResult &result, int64_t groups) override;
Status TransShape(Format src_format, const std::vector<int64_t> &src_shape, DataType data_type, Format dst_format,
std::vector<int64_t> &dst_shape) override;
Status TransShape(Format src_format, const std::vector<int64_t> &src_shape, DataType data_type, Format dst_format,
std::vector<int64_t> &dst_shape, int64_t groups) override;
};
} // namespace formats
} // namespace ge


+ 1
- 1
ge/host_kernels/transdata_kernel.cc View File

@@ -114,7 +114,7 @@ Status TransdataKernel::Compute(const OpDescPtr op_desc_ptr, const std::vector<C
GELOGI("CheckSize failed, input size is not equal to weight size");
return NOT_CHANGED;
}
if((src_format == FOMAT_HWCN) && (data_format == FORMAT_FRACTAL_Z_3D)) {
if((src_format == FOMAT_HWCN) && (data_format == FORMAT_FRACTAL_Z)) {
if (formats::TransFormat(trans_args, trans_result , groups) != SUCCESS) {
GELOGW("Failed to trans formats from %s to %s, shape %s to %s, data type %s",
TypeUtils::FormatToSerialString(src_format).c_str(), TypeUtils::FormatToSerialString(data_format).c_str(),


Loading…
Cancel
Save