Browse Source

update frac_z trans func compute, change the order c and d

tags/v1.2.0-rc1
liubuyu 4 years ago
parent
commit
6f4b1880df
1 changed files with 11 additions and 12 deletions
  1. +11
    -12
      mindspore/ccsrc/common/trans.cc

+ 11
- 12
mindspore/ccsrc/common/trans.cc View File

@@ -1284,15 +1284,15 @@ bool NcdhwToFracZ3D(const FormatArgs &args, void *result) {
auto n1n0c0 = n1n0 * c0;
auto wn1n0c0 = w * n1n0c0;
auto hwn1n0c0 = h * wn1n0c0;
auto dhwn1n0c0 = d * hwn1n0c0;
auto c1hwn1n0c0 = c1 * hwn1n0c0;

for (size_t c1_i = 0; c1_i < c1; c1_i++) {
for (size_t d_i = 0; d_i < d; d_i++) {
for (size_t d_i = 0; d_i < d; d_i++) {
for (size_t c1_i = 0; c1_i < c1; c1_i++) {
for (size_t h_i = 0; h_i < h; h_i++) {
for (size_t w_i = 0; w_i < w; w_i++) {
for (size_t n1n0_i = 0; n1n0_i < n1n0; n1n0_i++) {
for (size_t c0_i = 0; c0_i < c0; c0_i++) {
size_t dst_i = c1_i * dhwn1n0c0 + d_i * hwn1n0c0 + h_i * wn1n0c0 + w_i * n1n0c0 + n1n0_i * c0 + c0_i;
auto dst_i = d_i * c1hwn1n0c0 + c1_i * hwn1n0c0 + h_i * wn1n0c0 + w_i * n1n0c0 + n1n0_i * c0 + c0_i;
// ncdhw
size_t src_i = n1n0_i * cdhw + (c1_i * c0 + c0_i) * dhw + d_i * hw + h_i * w + w_i;
auto pad_zero = ((c1_i * c0 + c0_i) >= c) || (n1n0_i >= n);
@@ -1329,17 +1329,16 @@ bool FracZ3DToNcdhw(const FormatArgs &args, void *result) {
auto d = args.host_shape[2];
auto h = args.host_shape[3];
auto w = args.host_shape[4];
auto n0 = args.device_shape[1];
auto ni = args.device_shape[2];
auto c0 = args.device_shape[3];
auto c1 = DivCeil(c, kCubeSize);
auto n1n0 = DivCeil(n, kCubeSize) * kCubeSize;
auto n1n0c0 = n1n0 * c0;
auto wn1n0c0 = w * n1n0c0;
auto hwn1n0c0 = h * wn1n0c0;
auto c1hwn1n0c0 = c1 * hwn1n0c0;
auto hw = h * w;
auto dhw = d * hw;
auto cdhw = c * dhw;
auto nc = ni * n0;
auto ncc0 = nc * c0;
auto wncc0 = w * ncc0;
auto hwncc0 = h * wncc0;
auto dhwncc0 = d * hwncc0;

for (size_t n_i = 0; n_i < n; n_i++) {
size_t n_head = n_i * cdhw;
@@ -1354,7 +1353,7 @@ bool FracZ3DToNcdhw(const FormatArgs &args, void *result) {
size_t c1_i = c_i / c0;
size_t c0_i = c_i % c0;
size_t nc_i = n_i;
size_t src_i = c1_i * dhwncc0 + d_i * hwncc0 + h_i * wncc0 + w_i * ncc0 + nc_i * c0 + c0_i;
size_t src_i = d_i * c1hwn1n0c0 + c1_i * hwn1n0c0 + h_i * wn1n0c0 + w_i * n1n0c0 + nc_i * c0 + c0_i;
SetData(size, false, src_i, dst_i, args, result);
}
}


Loading…
Cancel
Save