|
|
|
@@ -200,7 +200,7 @@ size_t CubeSizeByType(const TypeId data_type) { |
|
|
|
namespace { |
|
|
|
bool CheckDims(const std::vector<size_t> &shape) { |
|
|
|
if (shape.size() != kNchwDims) { |
|
|
|
MS_LOG(ERROR) << "Host shape dims shoud be 4"; |
|
|
|
MS_LOG(ERROR) << "Host shape dims should be 4"; |
|
|
|
return false; |
|
|
|
} |
|
|
|
return true; |
|
|
|
@@ -370,7 +370,7 @@ std::vector<size_t> PaddingShapeTo4dByDefault(const std::vector<size_t> &shape) |
|
|
|
std::copy(shape.begin(), shape.end(), shape_4d.begin()); |
|
|
|
break; |
|
|
|
default: |
|
|
|
MS_LOG(EXCEPTION) << "Unexpect shape size = " << shape.size(); |
|
|
|
MS_LOG(EXCEPTION) << "Unexpected shape size = " << shape.size(); |
|
|
|
} |
|
|
|
return shape_4d; |
|
|
|
} |
|
|
|
@@ -545,7 +545,8 @@ bool TransFormatFromDeviceToHost(const FormatArgs &args, void *result) { |
|
|
|
const std::map<std::string, FormatTransfer> format_trans_map{ |
|
|
|
{kOpFormat_FRAC_Z, FracZToNchw}, {kOpFormat_FRAC_NZ, FracNzToNchw}, |
|
|
|
{kOpFormat_NC1HWC0, Nc1hwc0ToNchw}, {kOpFormat_C1HWNCoC0, C1hwncoc0ToNchw}, |
|
|
|
{kOpFormat_NC1HWC0_C04, Nc1hwc04ToNchw}, {kOpFormat_NDC1HWC0, Ndc1hwc0ToNcdhw}}; |
|
|
|
{kOpFormat_NC1HWC0_C04, Nc1hwc04ToNchw}, {kOpFormat_NDC1HWC0, Ndc1hwc0ToNcdhw}, |
|
|
|
{kOpFormat_FRACTAL_Z_3D, FracZ3DToNcdhw}}; |
|
|
|
|
|
|
|
MS_LOG(DEBUG) << "Start trans format."; |
|
|
|
if (abstract::TypeIdSize(args.src_data_type) < 1) { |
|
|
|
@@ -1248,5 +1249,119 @@ bool NcdhwToNdc1hwc0(const FormatArgs &args, void *result) { |
|
|
|
} |
|
|
|
return true; |
|
|
|
} |
|
|
|
|
|
|
|
bool NcdhwToFracZ3D(const FormatArgs &args, void *result) { |
|
|
|
MS_LOG(DEBUG) << "Trans from ncdhw to frac_z_3d"; |
|
|
|
MS_EXCEPTION_IF_NULL(result); |
|
|
|
|
|
|
|
if (args.host_shape.size() != 5) { |
|
|
|
MS_LOG(ERROR) << "Illegal host shape dim, expect dim: 5, but got " << args.host_shape.size(); |
|
|
|
return false; |
|
|
|
} |
|
|
|
auto size = abstract::TypeIdSize(args.src_data_type); |
|
|
|
if (size < 1) { |
|
|
|
MS_LOG(ERROR) << "Illegal dtype."; |
|
|
|
return false; |
|
|
|
} |
|
|
|
auto total_size = abstract::ShapeSize(args.device_shape) * size; |
|
|
|
if (total_size != args.device_size) { |
|
|
|
MS_LOG(ERROR) << "Illegal total data size, total_size:" << total_size << ", device_size:" << args.device_size; |
|
|
|
return false; |
|
|
|
} |
|
|
|
|
|
|
|
auto n = args.host_shape[0]; |
|
|
|
auto c = args.host_shape[1]; |
|
|
|
auto d = args.host_shape[2]; |
|
|
|
auto h = args.host_shape[3]; |
|
|
|
auto w = args.host_shape[4]; |
|
|
|
|
|
|
|
auto n1n0 = DivCeil(n, kCubeSize) * kCubeSize; |
|
|
|
auto c0 = CubeSizeByType(args.src_data_type); |
|
|
|
auto c1 = DivCeil(c, c0); |
|
|
|
auto hw = h * w; |
|
|
|
auto dhw = d * hw; |
|
|
|
auto cdhw = c * dhw; |
|
|
|
auto n1n0c0 = n1n0 * c0; |
|
|
|
auto wn1n0c0 = w * n1n0c0; |
|
|
|
auto hwn1n0c0 = h * wn1n0c0; |
|
|
|
auto dhwn1n0c0 = d * hwn1n0c0; |
|
|
|
|
|
|
|
for (size_t c1_i = 0; c1_i < c1; c1_i++) { |
|
|
|
for (size_t d_i = 0; d_i < d; d_i++) { |
|
|
|
for (size_t h_i = 0; h_i < h; h_i++) { |
|
|
|
for (size_t w_i = 0; w_i < w; w_i++) { |
|
|
|
for (size_t n1n0_i = 0; n1n0_i < n1n0; n1n0_i++) { |
|
|
|
for (size_t c0_i = 0; c0_i < c0; c0_i++) { |
|
|
|
size_t dst_i = c1_i * dhwn1n0c0 + d_i * hwn1n0c0 + h_i * wn1n0c0 + w_i * n1n0c0 + n1n0_i * c0 + c0_i; |
|
|
|
// ncdhw |
|
|
|
size_t src_i = n1n0_i * cdhw + (c1_i * c0 + c0_i) * dhw + d_i * hw + h_i * w + w_i; |
|
|
|
auto pad_zero = ((c1_i * c0 + c0_i) >= c) || (n1n0_i >= n); |
|
|
|
SetData(size, pad_zero, src_i, dst_i, args, result); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
return true; |
|
|
|
} |
|
|
|
|
|
|
|
bool FracZ3DToNcdhw(const FormatArgs &args, void *result) { |
|
|
|
MS_LOG(DEBUG) << "Trans from frac_z_3d to ncdhw"; |
|
|
|
MS_EXCEPTION_IF_NULL(result); |
|
|
|
|
|
|
|
if (args.host_shape.size() != 5) { |
|
|
|
MS_LOG(ERROR) << "Illegal host shape dim, expect dim: 5, but got " << args.host_shape.size(); |
|
|
|
return false; |
|
|
|
} |
|
|
|
auto size = abstract::TypeIdSize(args.src_data_type); |
|
|
|
if (size < 1) { |
|
|
|
MS_LOG(ERROR) << "Illegal dtype."; |
|
|
|
return false; |
|
|
|
} |
|
|
|
auto total_size = abstract::ShapeSize(args.device_shape) * size; |
|
|
|
if (total_size != args.device_size) { |
|
|
|
MS_LOG(ERROR) << "Illegal total data size, total_size:" << total_size << ", device_size:" << args.device_size; |
|
|
|
return false; |
|
|
|
} |
|
|
|
auto n = args.host_shape[0]; |
|
|
|
auto c = args.host_shape[1]; |
|
|
|
auto d = args.host_shape[2]; |
|
|
|
auto h = args.host_shape[3]; |
|
|
|
auto w = args.host_shape[4]; |
|
|
|
auto n0 = args.device_shape[1]; |
|
|
|
auto ni = args.device_shape[1]; |
|
|
|
auto c0 = args.device_shape[3]; |
|
|
|
auto hw = h * w; |
|
|
|
auto dhw = d * hw; |
|
|
|
auto cdhw = c * dhw; |
|
|
|
auto nc = ni * n0; |
|
|
|
auto ncc0 = nc * c0; |
|
|
|
auto wncc0 = w * ncc0; |
|
|
|
auto hwncc0 = h * wncc0; |
|
|
|
auto dhwncc0 = d * hwncc0; |
|
|
|
|
|
|
|
for (size_t n_i = 0; n_i < n; n_i++) { |
|
|
|
size_t n_head = n_i * cdhw; |
|
|
|
for (size_t c_i = 0; c_i < c; c_i++) { |
|
|
|
size_t c_head = n_head + c_i * dhw; |
|
|
|
for (size_t d_i = 0; d_i < d; d_i++) { |
|
|
|
size_t d_head = c_head + d_i * hw; |
|
|
|
for (size_t h_i = 0; h_i < h; h_i++) { |
|
|
|
size_t h_head = d_head + h_i * w; |
|
|
|
for (size_t w_i = 0; w_i < w; w_i++) { |
|
|
|
size_t dst_i = h_head + w_i; |
|
|
|
size_t c1_i = c_i / c0; |
|
|
|
size_t c0_i = c_i % c0; |
|
|
|
size_t nc_i = n_i; |
|
|
|
size_t src_i = c1_i * dhwncc0 + d_i * hwncc0 + h_i * wncc0 + w_i * ncc0 + nc_i * c0 + c0_i; |
|
|
|
SetData(size, false, src_i, dst_i, args, result); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
return true; |
|
|
|
} |
|
|
|
} // namespace trans |
|
|
|
} // namespace mindspore |