Browse Source

!1148 add format transfer for fracz_c04 and nc1hwc04

Merge pull request !1148 from liubuyu/master
tags/v0.3.0-alpha
mindspore-ci-bot Gitee 5 years ago
parent
commit
ebdba96596
2 changed files with 70 additions and 2 deletions
  1. +67
    -2
      mindspore/ccsrc/common/trans.cc
  2. +3
    -0
      mindspore/ccsrc/common/trans.h

+ 67
- 2
mindspore/ccsrc/common/trans.cc View File

@@ -378,8 +378,8 @@ std::vector<size_t> FracZc04DeviceShape(const std::vector<size_t> &shape) {
}
std::vector<size_t> device_shape;
size_t c0 = 4;
size_t first_dim = DivCeil(c0 * shape[2] * shape[3], kCubeSize);
size_t no = DivCeil(DivCeil(shape[0], kCubeSize) * kCubeSize, kCubeSize);
auto first_dim = DivCeil(c0 * shape.at(2) * shape.at(3), kCubeSize);
auto no = DivCeil(shape.at(0), kCubeSize);
device_shape.push_back(first_dim);
device_shape.push_back(no);
device_shape.push_back(kCubeSize);
@@ -495,6 +495,10 @@ bool TransFormat(const FormatArgs &args, void *result) {
return NchwToNc1hwc0(args, result);
} else if (args.device_format == kOpFormat_C1HWNCoC0) {
return NchwToC1hwncoc0(args, result);
} else if (args.device_format == kOpFormat_FRACTAL_Z_C04) {
return NchwToFracZc04(args, result);
} else if (args.device_format == kOpFormat_NC1HWC0_C04) {
return NchwToNc1hwc04(args, result);
}
return true;
}
@@ -513,6 +517,8 @@ bool TransFormatFromDeviceToHost(const FormatArgs &args, void *result) {
return Nc1hwc0ToNchw(args, result);
} else if (args.device_format == kOpFormat_C1HWNCoC0) {
return C1hwncoc0ToNchw(args, result);
} else if (args.device_format == kOpFormat_NC1HWC0_C04) {
return Nc1hwc04ToNchw(args, result);
}
return true;
}
@@ -632,6 +638,65 @@ bool FracZToNchw(const FormatArgs &args, void *result) {
return true;
}

bool NchwToFracZc04(const FormatArgs &args, void *result) {
// trans nchw to FracZc04
MS_LOG(DEBUG) << "Trans format from nchw to FracZc04.";
MS_EXCEPTION_IF_NULL(result);
size_t size = 0;
size_t total_size = 0;
if (!CheckArgs(args, &size, &total_size)) {
MS_LOG(ERROR) << "Check args failed.";
return false;
}
size_t cube = kCubeSize;
size_t n = args.host_shape[0];
size_t c = args.host_shape[1];
size_t h = args.host_shape[2];
size_t w = args.host_shape[3];

size_t c0 = 4;
size_t c1 = DivCeil(c, c0);
size_t hwc0 = h * w * c0;
size_t hwc = h * w * c;
size_t nhwc = n * h * w * c;

size_t n_cnt = DivCeil(n, cube);
size_t v_cnt = DivCeil(h * w * c0 * c1, cube);
size_t dst_idx = 0;

for (size_t vi = 0; vi < v_cnt; vi++) {
for (size_t ni = 0; ni < n_cnt; ni++) {
for (size_t col = 0; col < cube; col++) {
for (size_t row = 0; row < cube; row++) {
size_t cur_cube_n = cube * ni + col;
size_t cur_cube_c1hwc0 = cube * vi + row;
auto desc_g = cur_cube_n / n;
auto desc_n = cur_cube_n % n;
auto desc_c1 = cur_cube_c1hwc0 / hwc0;
auto desc_c0 = cur_cube_c1hwc0 % c0;
auto desc_h = (cur_cube_c1hwc0 - hwc0 * desc_c1) / (w * c0);
auto desc_w = (cur_cube_c1hwc0 - hwc0 * desc_c1 - w * c0 * desc_h) / c0;
auto c_idx = desc_c1 * c0 + desc_c0;
auto src_idx = desc_g * nhwc + desc_n * hwc + c_idx * h * w + desc_h * w + desc_w;
auto pad_zero = desc_g >= 1 || desc_n >= n || c_idx >= c;
SetDataBysize(size, pad_zero);
dst_idx++;
}
}
}
}
return true;
}

bool NchwToNc1hwc04(const FormatArgs &args, void *result) {
MS_LOG(DEBUG) << "Trans format from nchw to Nc1hwc04.";
return NchwToNc1hwc0(args, result);
}
bool Nc1hwc04ToNchw(const FormatArgs &args, void *result) {
MS_LOG(DEBUG) << "Trans format from Nc1hwc04 to nchw.";
return Nc1hwc0ToNchw(args, result);
}

bool TransShapeToNz(const std::vector<size_t> &host_shape, std::vector<size_t> *hw_shape) {
MS_EXCEPTION_IF_NULL(hw_shape);
if (host_shape.empty()) {


+ 3
- 0
mindspore/ccsrc/common/trans.h View File

@@ -64,11 +64,14 @@ bool TransFormatFromDeviceToHost(const FormatArgs &args, void *result);
bool NchwToFracZ(const FormatArgs &args, void *result);
bool NchwToFracNz(const FormatArgs &args, void *result);
bool NchwToNc1hwc0(const FormatArgs &args, void *result);
bool NchwToFracZc04(const FormatArgs &args, void *result);
bool NchwToNc1hwc04(const FormatArgs &args, void *result);
bool NchwToC1hwncoc0(const FormatArgs &args, void *result);
// device to host
bool FracZToNchw(const FormatArgs &args, void *result);
bool FracNzToNchw(const FormatArgs &args, void *result);
bool Nc1hwc0ToNchw(const FormatArgs &args, void *result);
bool Nc1hwc04ToNchw(const FormatArgs &args, void *result);
bool C1hwncoc0ToNchw(const FormatArgs &args, void *result);
} // namespace trans
} // namespace mindspore


Loading…
Cancel
Save