From 6f4b1880dfecfe4183ca24b9d5c799e92e6f3ad6 Mon Sep 17 00:00:00 2001 From: liubuyu Date: Sun, 28 Feb 2021 15:12:12 +0800 Subject: [PATCH] update frac_z trans func compute, change the order c and d --- mindspore/ccsrc/common/trans.cc | 23 +++++++++++------------ 1 file changed, 11 insertions(+), 12 deletions(-) diff --git a/mindspore/ccsrc/common/trans.cc b/mindspore/ccsrc/common/trans.cc index 2853aa9e09..9532b6fb52 100644 --- a/mindspore/ccsrc/common/trans.cc +++ b/mindspore/ccsrc/common/trans.cc @@ -1284,15 +1284,15 @@ bool NcdhwToFracZ3D(const FormatArgs &args, void *result) { auto n1n0c0 = n1n0 * c0; auto wn1n0c0 = w * n1n0c0; auto hwn1n0c0 = h * wn1n0c0; - auto dhwn1n0c0 = d * hwn1n0c0; + auto c1hwn1n0c0 = c1 * hwn1n0c0; - for (size_t c1_i = 0; c1_i < c1; c1_i++) { - for (size_t d_i = 0; d_i < d; d_i++) { + for (size_t d_i = 0; d_i < d; d_i++) { + for (size_t c1_i = 0; c1_i < c1; c1_i++) { for (size_t h_i = 0; h_i < h; h_i++) { for (size_t w_i = 0; w_i < w; w_i++) { for (size_t n1n0_i = 0; n1n0_i < n1n0; n1n0_i++) { for (size_t c0_i = 0; c0_i < c0; c0_i++) { - size_t dst_i = c1_i * dhwn1n0c0 + d_i * hwn1n0c0 + h_i * wn1n0c0 + w_i * n1n0c0 + n1n0_i * c0 + c0_i; + auto dst_i = d_i * c1hwn1n0c0 + c1_i * hwn1n0c0 + h_i * wn1n0c0 + w_i * n1n0c0 + n1n0_i * c0 + c0_i; // ncdhw size_t src_i = n1n0_i * cdhw + (c1_i * c0 + c0_i) * dhw + d_i * hw + h_i * w + w_i; auto pad_zero = ((c1_i * c0 + c0_i) >= c) || (n1n0_i >= n); @@ -1329,17 +1329,16 @@ bool FracZ3DToNcdhw(const FormatArgs &args, void *result) { auto d = args.host_shape[2]; auto h = args.host_shape[3]; auto w = args.host_shape[4]; - auto n0 = args.device_shape[1]; - auto ni = args.device_shape[2]; auto c0 = args.device_shape[3]; + auto c1 = DivCeil(c, kCubeSize); + auto n1n0 = DivCeil(n, kCubeSize) * kCubeSize; + auto n1n0c0 = n1n0 * c0; + auto wn1n0c0 = w * n1n0c0; + auto hwn1n0c0 = h * wn1n0c0; + auto c1hwn1n0c0 = c1 * hwn1n0c0; auto hw = h * w; auto dhw = d * hw; auto cdhw = c * dhw; - auto nc = ni * n0; - auto ncc0 = nc * c0; - auto wncc0 = w * ncc0; - auto hwncc0 = h * wncc0; - auto dhwncc0 = d * hwncc0; for (size_t n_i = 0; n_i < n; n_i++) { size_t n_head = n_i * cdhw; @@ -1354,7 +1353,7 @@ bool FracZ3DToNcdhw(const FormatArgs &args, void *result) { size_t c1_i = c_i / c0; size_t c0_i = c_i % c0; size_t nc_i = n_i; - size_t src_i = c1_i * dhwncc0 + d_i * hwncc0 + h_i * wncc0 + w_i * ncc0 + nc_i * c0 + c0_i; + size_t src_i = d_i * c1hwn1n0c0 + c1_i * hwn1n0c0 + h_i * wn1n0c0 + w_i * n1n0c0 + nc_i * c0 + c0_i; SetData(size, false, src_i, dst_i, args, result); } }