| @@ -1284,15 +1284,15 @@ bool NcdhwToFracZ3D(const FormatArgs &args, void *result) { | |||||
| auto n1n0c0 = n1n0 * c0; | auto n1n0c0 = n1n0 * c0; | ||||
| auto wn1n0c0 = w * n1n0c0; | auto wn1n0c0 = w * n1n0c0; | ||||
| auto hwn1n0c0 = h * wn1n0c0; | auto hwn1n0c0 = h * wn1n0c0; | ||||
| auto dhwn1n0c0 = d * hwn1n0c0; | |||||
| auto c1hwn1n0c0 = c1 * hwn1n0c0; | |||||
| for (size_t c1_i = 0; c1_i < c1; c1_i++) { | |||||
| for (size_t d_i = 0; d_i < d; d_i++) { | |||||
| for (size_t d_i = 0; d_i < d; d_i++) { | |||||
| for (size_t c1_i = 0; c1_i < c1; c1_i++) { | |||||
| for (size_t h_i = 0; h_i < h; h_i++) { | for (size_t h_i = 0; h_i < h; h_i++) { | ||||
| for (size_t w_i = 0; w_i < w; w_i++) { | for (size_t w_i = 0; w_i < w; w_i++) { | ||||
| for (size_t n1n0_i = 0; n1n0_i < n1n0; n1n0_i++) { | for (size_t n1n0_i = 0; n1n0_i < n1n0; n1n0_i++) { | ||||
| for (size_t c0_i = 0; c0_i < c0; c0_i++) { | for (size_t c0_i = 0; c0_i < c0; c0_i++) { | ||||
| size_t dst_i = c1_i * dhwn1n0c0 + d_i * hwn1n0c0 + h_i * wn1n0c0 + w_i * n1n0c0 + n1n0_i * c0 + c0_i; | |||||
| auto dst_i = d_i * c1hwn1n0c0 + c1_i * hwn1n0c0 + h_i * wn1n0c0 + w_i * n1n0c0 + n1n0_i * c0 + c0_i; | |||||
| // ncdhw | // ncdhw | ||||
| size_t src_i = n1n0_i * cdhw + (c1_i * c0 + c0_i) * dhw + d_i * hw + h_i * w + w_i; | size_t src_i = n1n0_i * cdhw + (c1_i * c0 + c0_i) * dhw + d_i * hw + h_i * w + w_i; | ||||
| auto pad_zero = ((c1_i * c0 + c0_i) >= c) || (n1n0_i >= n); | auto pad_zero = ((c1_i * c0 + c0_i) >= c) || (n1n0_i >= n); | ||||
| @@ -1329,17 +1329,16 @@ bool FracZ3DToNcdhw(const FormatArgs &args, void *result) { | |||||
| auto d = args.host_shape[2]; | auto d = args.host_shape[2]; | ||||
| auto h = args.host_shape[3]; | auto h = args.host_shape[3]; | ||||
| auto w = args.host_shape[4]; | auto w = args.host_shape[4]; | ||||
| auto n0 = args.device_shape[1]; | |||||
| auto ni = args.device_shape[2]; | |||||
| auto c0 = args.device_shape[3]; | auto c0 = args.device_shape[3]; | ||||
| auto c1 = DivCeil(c, kCubeSize); | |||||
| auto n1n0 = DivCeil(n, kCubeSize) * kCubeSize; | |||||
| auto n1n0c0 = n1n0 * c0; | |||||
| auto wn1n0c0 = w * n1n0c0; | |||||
| auto hwn1n0c0 = h * wn1n0c0; | |||||
| auto c1hwn1n0c0 = c1 * hwn1n0c0; | |||||
| auto hw = h * w; | auto hw = h * w; | ||||
| auto dhw = d * hw; | auto dhw = d * hw; | ||||
| auto cdhw = c * dhw; | auto cdhw = c * dhw; | ||||
| auto nc = ni * n0; | |||||
| auto ncc0 = nc * c0; | |||||
| auto wncc0 = w * ncc0; | |||||
| auto hwncc0 = h * wncc0; | |||||
| auto dhwncc0 = d * hwncc0; | |||||
| for (size_t n_i = 0; n_i < n; n_i++) { | for (size_t n_i = 0; n_i < n; n_i++) { | ||||
| size_t n_head = n_i * cdhw; | size_t n_head = n_i * cdhw; | ||||
| @@ -1354,7 +1353,7 @@ bool FracZ3DToNcdhw(const FormatArgs &args, void *result) { | |||||
| size_t c1_i = c_i / c0; | size_t c1_i = c_i / c0; | ||||
| size_t c0_i = c_i % c0; | size_t c0_i = c_i % c0; | ||||
| size_t nc_i = n_i; | size_t nc_i = n_i; | ||||
| size_t src_i = c1_i * dhwncc0 + d_i * hwncc0 + h_i * wncc0 + w_i * ncc0 + nc_i * c0 + c0_i; | |||||
| size_t src_i = d_i * c1hwn1n0c0 + c1_i * hwn1n0c0 + h_i * wn1n0c0 + w_i * n1n0c0 + nc_i * c0 + c0_i; | |||||
| SetData(size, false, src_i, dst_i, args, result); | SetData(size, false, src_i, dst_i, args, result); | ||||
| } | } | ||||
| } | } | ||||