|
|
|
@@ -378,8 +378,8 @@ std::vector<size_t> FracZc04DeviceShape(const std::vector<size_t> &shape) { |
|
|
|
} |
|
|
|
std::vector<size_t> device_shape; |
|
|
|
size_t c0 = 4; |
|
|
|
size_t first_dim = DivCeil(c0 * shape[2] * shape[3], kCubeSize); |
|
|
|
size_t no = DivCeil(DivCeil(shape[0], kCubeSize) * kCubeSize, kCubeSize); |
|
|
|
auto first_dim = DivCeil(c0 * shape.at(2) * shape.at(3), kCubeSize); |
|
|
|
auto no = DivCeil(shape.at(0), kCubeSize); |
|
|
|
device_shape.push_back(first_dim); |
|
|
|
device_shape.push_back(no); |
|
|
|
device_shape.push_back(kCubeSize); |
|
|
|
@@ -495,6 +495,10 @@ bool TransFormat(const FormatArgs &args, void *result) { |
|
|
|
return NchwToNc1hwc0(args, result); |
|
|
|
} else if (args.device_format == kOpFormat_C1HWNCoC0) { |
|
|
|
return NchwToC1hwncoc0(args, result); |
|
|
|
} else if (args.device_format == kOpFormat_FRACTAL_Z_C04) { |
|
|
|
return NchwToFracZc04(args, result); |
|
|
|
} else if (args.device_format == kOpFormat_NC1HWC0_C04) { |
|
|
|
return NchwToNc1hwc04(args, result); |
|
|
|
} |
|
|
|
return true; |
|
|
|
} |
|
|
|
@@ -513,6 +517,8 @@ bool TransFormatFromDeviceToHost(const FormatArgs &args, void *result) { |
|
|
|
return Nc1hwc0ToNchw(args, result); |
|
|
|
} else if (args.device_format == kOpFormat_C1HWNCoC0) { |
|
|
|
return C1hwncoc0ToNchw(args, result); |
|
|
|
} else if (args.device_format == kOpFormat_NC1HWC0_C04) { |
|
|
|
return Nc1hwc04ToNchw(args, result); |
|
|
|
} |
|
|
|
return true; |
|
|
|
} |
|
|
|
@@ -632,6 +638,65 @@ bool FracZToNchw(const FormatArgs &args, void *result) { |
|
|
|
return true; |
|
|
|
} |
|
|
|
|
|
|
|
bool NchwToFracZc04(const FormatArgs &args, void *result) { |
|
|
|
// trans nchw to FracZc04 |
|
|
|
MS_LOG(DEBUG) << "Trans format from nchw to FracZc04."; |
|
|
|
MS_EXCEPTION_IF_NULL(result); |
|
|
|
size_t size = 0; |
|
|
|
size_t total_size = 0; |
|
|
|
if (!CheckArgs(args, &size, &total_size)) { |
|
|
|
MS_LOG(ERROR) << "Check args failed."; |
|
|
|
return false; |
|
|
|
} |
|
|
|
size_t cube = kCubeSize; |
|
|
|
size_t n = args.host_shape[0]; |
|
|
|
size_t c = args.host_shape[1]; |
|
|
|
size_t h = args.host_shape[2]; |
|
|
|
size_t w = args.host_shape[3]; |
|
|
|
|
|
|
|
size_t c0 = 4; |
|
|
|
size_t c1 = DivCeil(c, c0); |
|
|
|
size_t hwc0 = h * w * c0; |
|
|
|
size_t hwc = h * w * c; |
|
|
|
size_t nhwc = n * h * w * c; |
|
|
|
|
|
|
|
size_t n_cnt = DivCeil(n, cube); |
|
|
|
size_t v_cnt = DivCeil(h * w * c0 * c1, cube); |
|
|
|
size_t dst_idx = 0; |
|
|
|
|
|
|
|
for (size_t vi = 0; vi < v_cnt; vi++) { |
|
|
|
for (size_t ni = 0; ni < n_cnt; ni++) { |
|
|
|
for (size_t col = 0; col < cube; col++) { |
|
|
|
for (size_t row = 0; row < cube; row++) { |
|
|
|
size_t cur_cube_n = cube * ni + col; |
|
|
|
size_t cur_cube_c1hwc0 = cube * vi + row; |
|
|
|
auto desc_g = cur_cube_n / n; |
|
|
|
auto desc_n = cur_cube_n % n; |
|
|
|
auto desc_c1 = cur_cube_c1hwc0 / hwc0; |
|
|
|
auto desc_c0 = cur_cube_c1hwc0 % c0; |
|
|
|
auto desc_h = (cur_cube_c1hwc0 - hwc0 * desc_c1) / (w * c0); |
|
|
|
auto desc_w = (cur_cube_c1hwc0 - hwc0 * desc_c1 - w * c0 * desc_h) / c0; |
|
|
|
auto c_idx = desc_c1 * c0 + desc_c0; |
|
|
|
auto src_idx = desc_g * nhwc + desc_n * hwc + c_idx * h * w + desc_h * w + desc_w; |
|
|
|
auto pad_zero = desc_g >= 1 || desc_n >= n || c_idx >= c; |
|
|
|
SetDataBysize(size, pad_zero); |
|
|
|
dst_idx++; |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
return true; |
|
|
|
} |
|
|
|
|
|
|
|
bool NchwToNc1hwc04(const FormatArgs &args, void *result) { |
|
|
|
MS_LOG(DEBUG) << "Trans format from nchw to Nc1hwc04."; |
|
|
|
return NchwToNc1hwc0(args, result); |
|
|
|
} |
|
|
|
bool Nc1hwc04ToNchw(const FormatArgs &args, void *result) { |
|
|
|
MS_LOG(DEBUG) << "Trans format from Nc1hwc04 to nchw."; |
|
|
|
return Nc1hwc0ToNchw(args, result); |
|
|
|
} |
|
|
|
|
|
|
|
bool TransShapeToNz(const std::vector<size_t> &host_shape, std::vector<size_t> *hw_shape) { |
|
|
|
MS_EXCEPTION_IF_NULL(hw_shape); |
|
|
|
if (host_shape.empty()) { |
|
|
|
|