diff --git a/mindspore/ccsrc/common/trans.cc b/mindspore/ccsrc/common/trans.cc index 254155ab4d..bc3d5096e1 100644 --- a/mindspore/ccsrc/common/trans.cc +++ b/mindspore/ccsrc/common/trans.cc @@ -200,7 +200,7 @@ size_t CubeSizeByType(const TypeId data_type) { namespace { bool CheckDims(const std::vector &shape) { if (shape.size() != kNchwDims) { - MS_LOG(ERROR) << "Host shape dims shoud be 4"; + MS_LOG(ERROR) << "Host shape dims should be 4"; return false; } return true; @@ -370,7 +370,7 @@ std::vector PaddingShapeTo4dByDefault(const std::vector &shape) std::copy(shape.begin(), shape.end(), shape_4d.begin()); break; default: - MS_LOG(EXCEPTION) << "Unexpect shape size = " << shape.size(); + MS_LOG(EXCEPTION) << "Unexpected shape size = " << shape.size(); } return shape_4d; } @@ -545,7 +545,8 @@ bool TransFormatFromDeviceToHost(const FormatArgs &args, void *result) { const std::map format_trans_map{ {kOpFormat_FRAC_Z, FracZToNchw}, {kOpFormat_FRAC_NZ, FracNzToNchw}, {kOpFormat_NC1HWC0, Nc1hwc0ToNchw}, {kOpFormat_C1HWNCoC0, C1hwncoc0ToNchw}, - {kOpFormat_NC1HWC0_C04, Nc1hwc04ToNchw}, {kOpFormat_NDC1HWC0, Ndc1hwc0ToNcdhw}}; + {kOpFormat_NC1HWC0_C04, Nc1hwc04ToNchw}, {kOpFormat_NDC1HWC0, Ndc1hwc0ToNcdhw}, + {kOpFormat_FRACTAL_Z_3D, FracZ3DToNcdhw}}; MS_LOG(DEBUG) << "Start trans format."; if (abstract::TypeIdSize(args.src_data_type) < 1) { @@ -1248,5 +1249,119 @@ bool NcdhwToNdc1hwc0(const FormatArgs &args, void *result) { } return true; } + +bool NcdhwToFracZ3D(const FormatArgs &args, void *result) { + MS_LOG(DEBUG) << "Trans from ncdhw to frac_z_3d"; + MS_EXCEPTION_IF_NULL(result); + + if (args.host_shape.size() != 5) { + MS_LOG(ERROR) << "Illegal host shape dim, expect dim: 5, but got " << args.host_shape.size(); + return false; + } + auto size = abstract::TypeIdSize(args.src_data_type); + if (size < 1) { + MS_LOG(ERROR) << "Illegal dtype."; + return false; + } + auto total_size = abstract::ShapeSize(args.device_shape) * size; + if (total_size != args.device_size) { + MS_LOG(ERROR) << "Illegal total data size, total_size:" << total_size << ", device_size:" << args.device_size; + return false; + } + + auto n = args.host_shape[0]; + auto c = args.host_shape[1]; + auto d = args.host_shape[2]; + auto h = args.host_shape[3]; + auto w = args.host_shape[4]; + + auto n1n0 = DivCeil(n, kCubeSize) * kCubeSize; + auto c0 = CubeSizeByType(args.src_data_type); + auto c1 = DivCeil(c, c0); + auto hw = h * w; + auto dhw = d * hw; + auto cdhw = c * dhw; + auto n1n0c0 = n1n0 * c0; + auto wn1n0c0 = w * n1n0c0; + auto hwn1n0c0 = h * wn1n0c0; + auto dhwn1n0c0 = d * hwn1n0c0; + + for (size_t c1_i = 0; c1_i < c1; c1_i++) { + for (size_t d_i = 0; d_i < d; d_i++) { + for (size_t h_i = 0; h_i < h; h_i++) { + for (size_t w_i = 0; w_i < w; w_i++) { + for (size_t n1n0_i = 0; n1n0_i < n1n0; n1n0_i++) { + for (size_t c0_i = 0; c0_i < c0; c0_i++) { + size_t dst_i = c1_i * dhwn1n0c0 + d_i * hwn1n0c0 + h_i * wn1n0c0 + w_i * n1n0c0 + n1n0_i * c0 + c0_i; + // ncdhw + size_t src_i = n1n0_i * cdhw + (c1_i * c0 + c0_i) * dhw + d_i * hw + h_i * w + w_i; + auto pad_zero = ((c1_i * c0 + c0_i) >= c) || (n1n0_i >= n); + SetData(size, pad_zero, src_i, dst_i, args, result); + } + } + } + } + } + } + return true; +} + +bool FracZ3DToNcdhw(const FormatArgs &args, void *result) { + MS_LOG(DEBUG) << "Trans from frac_z_3d to ncdhw"; + MS_EXCEPTION_IF_NULL(result); + + if (args.host_shape.size() != 5) { + MS_LOG(ERROR) << "Illegal host shape dim, expect dim: 5, but got " << args.host_shape.size(); + return false; + } + auto size = abstract::TypeIdSize(args.src_data_type); + if (size < 1) { + MS_LOG(ERROR) << "Illegal dtype."; + return false; + } + auto total_size = abstract::ShapeSize(args.device_shape) * size; + if (total_size != args.device_size) { + MS_LOG(ERROR) << "Illegal total data size, total_size:" << total_size << ", device_size:" << args.device_size; + return false; + } + auto n = args.host_shape[0]; + auto c = args.host_shape[1]; + auto d = args.host_shape[2]; + auto h = args.host_shape[3]; + auto w = args.host_shape[4]; + auto n0 = args.device_shape[1]; + auto ni = args.device_shape[1]; + auto c0 = args.device_shape[3]; + auto hw = h * w; + auto dhw = d * hw; + auto cdhw = c * dhw; + auto nc = ni * n0; + auto ncc0 = nc * c0; + auto wncc0 = w * ncc0; + auto hwncc0 = h * wncc0; + auto dhwncc0 = d * hwncc0; + + for (size_t n_i = 0; n_i < n; n_i++) { + size_t n_head = n_i * cdhw; + for (size_t c_i = 0; c_i < c; c_i++) { + size_t c_head = n_head + c_i * dhw; + for (size_t d_i = 0; d_i < d; d_i++) { + size_t d_head = c_head + d_i * hw; + for (size_t h_i = 0; h_i < h; h_i++) { + size_t h_head = d_head + h_i * w; + for (size_t w_i = 0; w_i < w; w_i++) { + size_t dst_i = h_head + w_i; + size_t c1_i = c_i / c0; + size_t c0_i = c_i % c0; + size_t nc_i = n_i; + size_t src_i = c1_i * dhwncc0 + d_i * hwncc0 + h_i * wncc0 + w_i * ncc0 + nc_i * c0 + c0_i; + SetData(size, false, src_i, dst_i, args, result); + } + } + } + } + } + return true; +} } // namespace trans } // namespace mindspore diff --git a/mindspore/ccsrc/common/trans.h b/mindspore/ccsrc/common/trans.h index 153b773f26..9014f3c051 100644 --- a/mindspore/ccsrc/common/trans.h +++ b/mindspore/ccsrc/common/trans.h @@ -63,6 +63,7 @@ bool NchwTo4D(const FormatArgs &args, void *result); bool NchwToFracZ(const FormatArgs &args, void *result); bool NchwToFracNz(const FormatArgs &args, void *result); bool NchwToNc1hwc0(const FormatArgs &args, void *result); +bool NcdhwToFracZ3D(const FormatArgs &args, void *result); bool NchwToFracZc04(const FormatArgs &args, void *result); bool NchwToNc1hwc04(const FormatArgs &args, void *result); bool NchwToC1hwncoc0(const FormatArgs &args, void *result); @@ -74,6 +75,7 @@ bool FracZToNchw(const FormatArgs &args, void *result); bool FracNzToNchw(const FormatArgs &args, void *result); bool Nc1hwc0ToNchw(const FormatArgs &args, void *result); bool Nc1hwc04ToNchw(const FormatArgs &args, void *result); +bool FracZ3DToNcdhw(const FormatArgs &args, void *result); bool C1hwncoc0ToNchw(const FormatArgs &args, void *result); bool Ndc1hwc0ToNcdhw(const FormatArgs &args, void *result); using FormatTransfer = std::function; @@ -81,7 +83,7 @@ const std::map kTransFormatMapOfHostToDevice{ {kOpFormat_FRAC_Z, NchwToFracZ}, {kOpFormat_FRAC_NZ, NchwToFracNz}, {kOpFormat_NC1HWC0, NchwToNc1hwc0}, {kOpFormat_C1HWNCoC0, NchwToC1hwncoc0}, {kOpFormat_FRACTAL_Z_C04, NchwToFracZc04}, {kOpFormat_NC1HWC0_C04, NchwToNc1hwc04}, - {kOpFormat_NDC1HWC0, NcdhwToNdc1hwc0}}; + {kOpFormat_NDC1HWC0, NcdhwToNdc1hwc0}, {kOpFormat_FRACTAL_Z_3D, NcdhwToFracZ3D}}; } // namespace trans } // namespace mindspore diff --git a/mindspore/ccsrc/runtime/device/ascend/ascend_device_address.cc b/mindspore/ccsrc/runtime/device/ascend/ascend_device_address.cc index c6147e28b0..9c669fd510 100644 --- a/mindspore/ccsrc/runtime/device/ascend/ascend_device_address.cc +++ b/mindspore/ccsrc/runtime/device/ascend/ascend_device_address.cc @@ -93,9 +93,9 @@ namespace device { namespace ascend { const int FLOAT_LEN = sizeof(float); const int FLOAT16_LEN = 2; // sizeof(float16); -const std::set kOpNeedTransFormat = {kOpFormat_NHWC, kOpFormat_HWCN, kOpFormat_NC1HWC0, - kOpFormat_FRAC_Z, kOpFormat_C1HWNCoC0, kOpFormat_FRAC_NZ, - kOpFormat_NC1HWC0_C04, kOpFormat_FRACTAL_Z_C04, kOpFormat_NDC1HWC0}; +const std::set kOpNeedTransFormat = { + kOpFormat_NHWC, kOpFormat_HWCN, kOpFormat_NC1HWC0, kOpFormat_FRAC_Z, kOpFormat_C1HWNCoC0, + kOpFormat_FRAC_NZ, kOpFormat_NC1HWC0_C04, kOpFormat_FRACTAL_Z_C04, kOpFormat_NDC1HWC0, kOpFormat_FRACTAL_Z_3D}; void SyncMemory(void *dst, const void *src, uint64_t size, rtMemcpyKind_t kind) { auto ms_context = MsContext::GetInstance(); @@ -575,7 +575,8 @@ bool AscendDeviceAddress::ConvertFormatAndSyncHostToDevice(const ShapeVector &sh host_shape.emplace_back(1); } std::vector device_shape; - if (format_ == kOpFormat_FRAC_NZ || format_ == kOpFormat_NCDHW || format_ == kOpFormat_NDC1HWC0) { + if (format_ == kOpFormat_FRAC_NZ || format_ == kOpFormat_NCDHW || format_ == kOpFormat_NDC1HWC0 || + format_ == kOpFormat_FRACTAL_Z_3D) { device_shape = trans::TransShapeToDevice(host_shape, format_); } else { host_shape = trans::PaddingShapeTo4d(host_shape);