Browse Source

Transdata

pull/1211/head
zk 4 years ago
parent
commit
dbbca6d8dd
3 changed files with 21 additions and 11 deletions
  1. +18
    -10
      ge/common/formats/format_transfers/format_transfer_fractal_z.cc
  2. +2
    -0
      ge/common/formats/format_transfers/format_transfer_fractal_z.h
  3. +1
    -1
      ge/host_kernels/transdata_kernel.cc

+ 18
- 10
ge/common/formats/format_transfers/format_transfer_fractal_z.cc View File

@@ -143,7 +143,7 @@ Status TransShapeHwcnToFz(const std::vector<int64_t> &src_shape, DataType data_t
} }


TransShapeHwcnToFzWithGroups(const std::vector<int64_t> &src_shape, DataType data_type, std::vector<int64_t> &dst_shape TransShapeHwcnToFzWithGroups(const std::vector<int64_t> &src_shape, DataType data_type, std::vector<int64_t> &dst_shape
, args.groups){
, groups){
if (!CheckShapeValid(src_shape, kHwcnDimsNum)) { if (!CheckShapeValid(src_shape, kHwcnDimsNum)) {
return ACL_ERROR_GE_TRANSSHAPE_SHAPE_INVALID; return ACL_ERROR_GE_TRANSSHAPE_SHAPE_INVALID;
} }
@@ -153,7 +153,7 @@ TransShapeHwcnToFzWithGroups(const std::vector<int64_t> &src_shape, DataType dat
auto c = src_shape.at(kHwcnC); auto c = src_shape.at(kHwcnC);
auto n = src_shape.at(kHwcnN); auto n = src_shape.at(kHwcnN);


return TransShapeToFzWithGroups(n, c, h, w, data_type, dst_shape, args.groups);
return TransShapeToFzWithGroups(n, c, h, w, data_type, dst_shape, groups);
} }




@@ -258,8 +258,7 @@ Status TransFormatFromNchwToFz(const TransArgs &args, TransResult &result) {
return SUCCESS; return SUCCESS;
} }


Status TransFormatHwcnToFzWithGroups(const TransArgs &args, TransResult &result){
int64_t groups = args.groups;
Status TransFormatHwcnToFzWithGroups(const TransArgs &args, TransResult &result, groups){
int64_t h_dim = args.src_shape[kHwcnH]; int64_t h_dim = args.src_shape[kHwcnH];
int64_t w_dim = args.src_shape[kHwcnW]; int64_t w_dim = args.src_shape[kHwcnW];
int64_t c_dim = args.src_shape[kHwcnC]; int64_t c_dim = args.src_shape[kHwcnC];
@@ -488,14 +487,14 @@ Status FormatTransferFractalZ::TransFormat(const TransArgs &args, TransResult &r
TypeUtils::FormatToSerialString(args.dst_format).c_str(), ShapeToString(args.src_shape).c_str(), TypeUtils::FormatToSerialString(args.dst_format).c_str(), ShapeToString(args.src_shape).c_str(),
TypeUtils::DataTypeToSerialString(args.src_data_type).c_str(), ShapeToString(args.dst_shape).c_str()); TypeUtils::DataTypeToSerialString(args.src_data_type).c_str(), ShapeToString(args.dst_shape).c_str());
std::vector<int64_t> expect_shape; std::vector<int64_t> expect_shape;
auto ret = TransShape(args.src_format, args.src_shape, args.src_data_type, args.dst_format, expect_shape);
auto ret = TransShape(args.src_format, args.src_shape, args.src_data_type, args.dst_format, expect_shape, groups);
if (ret != SUCCESS) { if (ret != SUCCESS) {
return ret; return ret;
} }
if (!IsTransShapeDstCorrect(args, expect_shape)) { if (!IsTransShapeDstCorrect(args, expect_shape)) {
return PARAM_INVALID; return PARAM_INVALID;
} }
return TransFormatHwcnToFzWithGroups(args, result);
return TransFormatHwcnToFzWithGroups(args, result, groups);
} }
Status FormatTransferFractalZ::TransFormat(const TransArgs &args, TransResult &result) { Status FormatTransferFractalZ::TransFormat(const TransArgs &args, TransResult &result) {
GELOGD("Begin to trans format from %s to %s, src shape %s, data type %s, dst shape %s", GELOGD("Begin to trans format from %s to %s, src shape %s, data type %s, dst shape %s",
@@ -516,8 +515,7 @@ Status FormatTransferFractalZ::TransFormat(const TransArgs &args, TransResult &r
} }


if (args.src_format == FORMAT_HWCN && args.dst_format == FORMAT_FRACTAL_Z) { if (args.src_format == FORMAT_HWCN && args.dst_format == FORMAT_FRACTAL_Z) {
return args.groups == 0 ? TransFormatHwcnToFz(args, result) :
TransFormatHwcnToFzWithGroups(args, result);
return TransFormatHwcnToFz(args, result);
} }


if (args.src_format == FORMAT_NCHW && args.dst_format == FORMAT_FRACTAL_Z) { if (args.src_format == FORMAT_NCHW && args.dst_format == FORMAT_FRACTAL_Z) {
@@ -527,6 +525,17 @@ Status FormatTransferFractalZ::TransFormat(const TransArgs &args, TransResult &r
return UNSUPPORTED; return UNSUPPORTED;
} }


Status FormatTransferFractalZ::TransShape(Format src_format, const std::vector<int64_t> &src_shape, DataType data_type,
Format dst_format, std::vector<int64_t> &dst_shape, groups){
if (CheckDataTypeSupport(data_type) != SUCCESS) {
return ACL_ERROR_GE_TRANSSHAPE_DATATYPE_INVALID;
}

if (src_format == FORMAT_HWCN && dst_format == FORMAT_FRACTAL_Z) {
return TransShapeHwcnToFzWithGroups(src_shape, data_type, dst_shape, groups);
}
}

Status FormatTransferFractalZ::TransShape(Format src_format, const std::vector<int64_t> &src_shape, DataType data_type, Status FormatTransferFractalZ::TransShape(Format src_format, const std::vector<int64_t> &src_shape, DataType data_type,
Format dst_format, std::vector<int64_t> &dst_shape) { Format dst_format, std::vector<int64_t> &dst_shape) {
if (CheckDataTypeSupport(data_type) != SUCCESS) { if (CheckDataTypeSupport(data_type) != SUCCESS) {
@@ -537,8 +546,7 @@ Status FormatTransferFractalZ::TransShape(Format src_format, const std::vector<i
return TransShapeNhwcToFz(src_shape, data_type, dst_shape); return TransShapeNhwcToFz(src_shape, data_type, dst_shape);
} }
if (src_format == FORMAT_HWCN && dst_format == FORMAT_FRACTAL_Z) { if (src_format == FORMAT_HWCN && dst_format == FORMAT_FRACTAL_Z) {
return args.groups == 0 ? TransShapeHwcnToFz(src_shape, data_type, dst_shape)
: TransShapeHwcnToFzWithGroups(src_shape, data_type, dst_shape, args.groups);
return TransShapeHwcnToFz(src_shape, data_type, dst_shape);
} }
if (src_format == FORMAT_NCHW && dst_format == FORMAT_FRACTAL_Z) { if (src_format == FORMAT_NCHW && dst_format == FORMAT_FRACTAL_Z) {
return TransShapeNchwToFz(src_shape, data_type, dst_shape); return TransShapeNchwToFz(src_shape, data_type, dst_shape);


+ 2
- 0
ge/common/formats/format_transfers/format_transfer_fractal_z.h View File

@@ -29,6 +29,8 @@ class FormatTransferFractalZ : public FormatTransfer {
Status TransFormat(const TransArgs &args, TransResult &result, int64_t groups) override; Status TransFormat(const TransArgs &args, TransResult &result, int64_t groups) override;
Status TransShape(Format src_format, const std::vector<int64_t> &src_shape, DataType data_type, Format dst_format, Status TransShape(Format src_format, const std::vector<int64_t> &src_shape, DataType data_type, Format dst_format,
std::vector<int64_t> &dst_shape) override; std::vector<int64_t> &dst_shape) override;
Status TransShape(Format src_format, const std::vector<int64_t> &src_shape, DataType data_type, Format dst_format,
std::vector<int64_t> &dst_shape, int64_t groups) override;
}; };
} // namespace formats } // namespace formats
} // namespace ge } // namespace ge


+ 1
- 1
ge/host_kernels/transdata_kernel.cc View File

@@ -114,7 +114,7 @@ Status TransdataKernel::Compute(const OpDescPtr op_desc_ptr, const std::vector<C
GELOGI("CheckSize failed, input size is not equal to weight size"); GELOGI("CheckSize failed, input size is not equal to weight size");
return NOT_CHANGED; return NOT_CHANGED;
} }
if((src_format == FOMAT_HWCN) && (data_format == FORMAT_FRACTAL_Z_3D)) {
if((src_format == FOMAT_HWCN) && (data_format == FORMAT_FRACTAL_Z)) {
if (formats::TransFormat(trans_args, trans_result , groups) != SUCCESS) { if (formats::TransFormat(trans_args, trans_result , groups) != SUCCESS) {
GELOGW("Failed to trans formats from %s to %s, shape %s to %s, data type %s", GELOGW("Failed to trans formats from %s to %s, shape %s to %s, data type %s",
TypeUtils::FormatToSerialString(src_format).c_str(), TypeUtils::FormatToSerialString(data_format).c_str(), TypeUtils::FormatToSerialString(src_format).c_str(), TypeUtils::FormatToSerialString(data_format).c_str(),


Loading…
Cancel
Save