|
|
|
@@ -63,26 +63,24 @@ const std::map<TypeId, size_t> type_map = {{kNumberTypeBool, 1}, {kNumberType |
|
|
|
{kNumberTypeUInt32, 4}, {kNumberTypeUInt64, 8}, {kNumberTypeFloat, 4}, |
|
|
|
{kNumberTypeFloat16, 2}, {kNumberTypeFloat32, 4}, {kNumberTypeFloat64, 8}}; |
|
|
|
|
|
|
|
#define SetDataBysize(size, pad_zero) \ |
|
|
|
do { \ |
|
|
|
switch (size) { \ |
|
|
|
case 1: \ |
|
|
|
static_cast<uint8_t *>(result)[dst_idx] = pad_zero ? 0 : static_cast<const uint8_t *>(args.data)[src_idx]; \ |
|
|
|
break; \ |
|
|
|
case 2: \ |
|
|
|
static_cast<uint16_t *>(result)[dst_idx] = pad_zero ? 0 : static_cast<const uint16_t *>(args.data)[src_idx]; \ |
|
|
|
break; \ |
|
|
|
case 4: \ |
|
|
|
static_cast<uint32_t *>(result)[dst_idx] = pad_zero ? 0 : static_cast<const uint32_t *>(args.data)[src_idx]; \ |
|
|
|
break; \ |
|
|
|
case 8: \ |
|
|
|
static_cast<uint64_t *>(result)[dst_idx] = pad_zero ? 0 : static_cast<const uint64_t *>(args.data)[src_idx]; \ |
|
|
|
break; \ |
|
|
|
default: \ |
|
|
|
MS_LOG(ERROR) << "Trans data not support size " << size; \ |
|
|
|
return false; \ |
|
|
|
} \ |
|
|
|
} while (0) |
|
|
|
inline void SetData(size_t size, bool pad_zero, size_t src_idx, size_t dst_idx, const FormatArgs &args, void *result) { |
|
|
|
switch (size) { |
|
|
|
case 1: |
|
|
|
static_cast<uint8_t *>(result)[dst_idx] = pad_zero ? 0 : static_cast<const uint8_t *>(args.data)[src_idx]; |
|
|
|
break; |
|
|
|
case 2: |
|
|
|
static_cast<uint16_t *>(result)[dst_idx] = pad_zero ? 0 : static_cast<const uint16_t *>(args.data)[src_idx]; |
|
|
|
break; |
|
|
|
case 4: |
|
|
|
static_cast<uint32_t *>(result)[dst_idx] = pad_zero ? 0 : static_cast<const uint32_t *>(args.data)[src_idx]; |
|
|
|
break; |
|
|
|
case 8: |
|
|
|
static_cast<uint64_t *>(result)[dst_idx] = pad_zero ? 0 : static_cast<const uint64_t *>(args.data)[src_idx]; |
|
|
|
break; |
|
|
|
default: |
|
|
|
MS_LOG(EXCEPTION) << "Trans data not support size " << size; |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
template <typename T> |
|
|
|
T DivCeil(T n1, T n2) { |
|
|
|
@@ -401,6 +399,13 @@ std::vector<size_t> Nc1hwc04DeviceShape(const std::vector<size_t> &shape) { |
|
|
|
device_shape.push_back(C0); |
|
|
|
return device_shape; |
|
|
|
} |
|
|
|
|
|
|
|
std::vector<size_t> NdhwcDeviceShape(const std::vector<size_t> &shape) { |
|
|
|
if (shape.size() < 5) { |
|
|
|
MS_LOG(EXCEPTION) << "Shape dims must be 5 when format is ndhwc."; |
|
|
|
} |
|
|
|
return shape; |
|
|
|
} |
|
|
|
} // namespace |
|
|
|
|
|
|
|
std::vector<size_t> TransShapeToDevice(const std::vector<size_t> &shape, const std::string &format) { |
|
|
|
@@ -412,7 +417,8 @@ std::vector<size_t> TransShapeToDevice(const std::vector<size_t> &shape, const s |
|
|
|
{kOpFormat_NC1HWC0, Nc1hwc0DeviceShape}, |
|
|
|
{kOpFormat_C1HWNCoC0, C1hwncoc0DeviceShape}, |
|
|
|
{kOpFormat_FRACTAL_Z_C04, FracZc04DeviceShape}, |
|
|
|
{kOpFormat_NC1HWC0_C04, Nc1hwc04DeviceShape}}; |
|
|
|
{kOpFormat_NC1HWC0_C04, Nc1hwc04DeviceShape}, |
|
|
|
{kOpFormat_NDHWC, NdhwcDeviceShape}}; |
|
|
|
|
|
|
|
if (format == kOpFormat_ND || format == kOpFormat_DEFAULT) { |
|
|
|
return shape; |
|
|
|
@@ -482,43 +488,109 @@ bool TransDataType(const TypeIdArgs &args, void *result) { |
|
|
|
} |
|
|
|
|
|
|
|
bool TransFormat(const FormatArgs &args, void *result) { |
|
|
|
using FormatTransfer = std::function<bool(const FormatArgs &, void *)>; |
|
|
|
const std::map<std::string, FormatTransfer> format_trans_map{ |
|
|
|
{kOpFormat_FRAC_Z, NchwToFracZ}, {kOpFormat_FRAC_NZ, NchwToFracNz}, |
|
|
|
{kOpFormat_NC1HWC0, NchwToNc1hwc0}, {kOpFormat_C1HWNCoC0, NchwToC1hwncoc0}, |
|
|
|
{kOpFormat_FRACTAL_Z_C04, NchwToFracZc04}, {kOpFormat_NC1HWC0_C04, NchwToNc1hwc04}}; |
|
|
|
MS_LOG(DEBUG) << "Start trans format."; |
|
|
|
if (TypeIdSize(args.src_data_type) < 1) { |
|
|
|
MS_LOG(ERROR) << "Invalid datatype.."; |
|
|
|
return false; |
|
|
|
} |
|
|
|
if (args.device_format == kOpFormat_FRAC_Z) { |
|
|
|
return NchwToFracZ(args, result); |
|
|
|
} else if (args.device_format == kOpFormat_FRAC_NZ) { |
|
|
|
return NchwToFracNz(args, result); |
|
|
|
} else if (args.device_format == kOpFormat_NC1HWC0) { |
|
|
|
return NchwToNc1hwc0(args, result); |
|
|
|
} else if (args.device_format == kOpFormat_C1HWNCoC0) { |
|
|
|
return NchwToC1hwncoc0(args, result); |
|
|
|
} else if (args.device_format == kOpFormat_FRACTAL_Z_C04) { |
|
|
|
return NchwToFracZc04(args, result); |
|
|
|
} else if (args.device_format == kOpFormat_NC1HWC0_C04) { |
|
|
|
return NchwToNc1hwc04(args, result); |
|
|
|
if (args.device_format == kOpFormat_HWCN || args.device_format == kOpFormat_NHWC) { |
|
|
|
return NchwTo4D(args, result); |
|
|
|
} |
|
|
|
return true; |
|
|
|
auto iter = format_trans_map.find(args.device_format); |
|
|
|
if (iter == format_trans_map.end()) { |
|
|
|
MS_LOG(EXCEPTION) << "Unexpected format[" << args.device_format << "]"; |
|
|
|
} |
|
|
|
return iter->second(args, result); |
|
|
|
} |
|
|
|
|
|
|
|
bool TransFormatFromDeviceToHost(const FormatArgs &args, void *result) { |
|
|
|
using FormatTransfer = std::function<bool(const FormatArgs &, void *)>; |
|
|
|
const std::map<std::string, FormatTransfer> format_trans_map{{kOpFormat_FRAC_Z, FracZToNchw}, |
|
|
|
{kOpFormat_FRAC_NZ, FracNzToNchw}, |
|
|
|
{kOpFormat_NC1HWC0, Nc1hwc0ToNchw}, |
|
|
|
{kOpFormat_C1HWNCoC0, C1hwncoc0ToNchw}, |
|
|
|
{kOpFormat_NC1HWC0_C04, Nc1hwc04ToNchw}}; |
|
|
|
MS_LOG(DEBUG) << "Start trans format."; |
|
|
|
if (TypeIdSize(args.src_data_type) < 1) { |
|
|
|
MS_LOG(ERROR) << "Invalid datatype.."; |
|
|
|
return false; |
|
|
|
} |
|
|
|
if (args.device_format == kOpFormat_FRAC_Z) { |
|
|
|
return FracZToNchw(args, result); |
|
|
|
} else if (args.device_format == kOpFormat_FRAC_NZ) { |
|
|
|
return FracNzToNchw(args, result); |
|
|
|
} else if (args.device_format == kOpFormat_NC1HWC0) { |
|
|
|
return Nc1hwc0ToNchw(args, result); |
|
|
|
} else if (args.device_format == kOpFormat_C1HWNCoC0) { |
|
|
|
return C1hwncoc0ToNchw(args, result); |
|
|
|
} else if (args.device_format == kOpFormat_NC1HWC0_C04) { |
|
|
|
return Nc1hwc04ToNchw(args, result); |
|
|
|
if (args.device_format == kOpFormat_HWCN || args.device_format == kOpFormat_NHWC) { |
|
|
|
return ToNchw(args, result); |
|
|
|
} |
|
|
|
auto iter = format_trans_map.find(args.device_format); |
|
|
|
if (iter == format_trans_map.end()) { |
|
|
|
MS_LOG(EXCEPTION) << "Unexpected format[" << args.device_format << "]"; |
|
|
|
} |
|
|
|
return iter->second(args, result); |
|
|
|
} |
|
|
|
|
|
|
|
bool NchwTo4D(const FormatArgs &args, void *result) { |
|
|
|
// trans nchw to 4d |
|
|
|
MS_LOG(DEBUG) << "Trans format from nchw to 4d."; |
|
|
|
MS_EXCEPTION_IF_NULL(result); |
|
|
|
size_t size = 0; |
|
|
|
size_t total_size = 0; |
|
|
|
if (!CheckArgs(args, &size, &total_size)) { |
|
|
|
MS_LOG(ERROR) << "Check args failed."; |
|
|
|
return false; |
|
|
|
} |
|
|
|
size_t n = args.host_shape[0]; |
|
|
|
size_t c = args.host_shape[1]; |
|
|
|
size_t h = args.host_shape[2]; |
|
|
|
size_t w = args.host_shape[3]; |
|
|
|
for (size_t ni = 0; ni < n; ni++) { |
|
|
|
for (size_t ci = 0; ci < c; ci++) { |
|
|
|
for (size_t hi = 0; hi < h; hi++) { |
|
|
|
for (size_t wi = 0; wi < w; wi++) { |
|
|
|
auto src_idx = ni * c * h * w + ci * h * w + hi * w + wi; |
|
|
|
auto dst_idx = 0; |
|
|
|
if (args.device_format == kOpFormat_NHWC) { |
|
|
|
dst_idx = ni * h * w * c + hi * w * c + wi * c + ci; |
|
|
|
} else if (args.device_format == kOpFormat_HWCN) { |
|
|
|
dst_idx = hi * w * c * n + wi * c * n + ci * n + ni; |
|
|
|
} |
|
|
|
SetData(size, false, src_idx, dst_idx, args, result); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
return true; |
|
|
|
} |
|
|
|
|
|
|
|
bool ToNchw(const FormatArgs &args, void *result) { |
|
|
|
MS_LOG(DEBUG) << "Trans format to nchw from 4d."; |
|
|
|
MS_EXCEPTION_IF_NULL(result); |
|
|
|
size_t size = 0; |
|
|
|
size_t total_size = 0; |
|
|
|
if (!CheckArgs(args, &size, &total_size)) { |
|
|
|
MS_LOG(ERROR) << "Check args failed."; |
|
|
|
return false; |
|
|
|
} |
|
|
|
size_t n = args.host_shape[0]; |
|
|
|
size_t c = args.host_shape[1]; |
|
|
|
size_t h = args.host_shape[2]; |
|
|
|
size_t w = args.host_shape[3]; |
|
|
|
for (size_t ni = 0; ni < n; ni++) { |
|
|
|
for (size_t ci = 0; ci < c; ci++) { |
|
|
|
for (size_t hi = 0; hi < h; hi++) { |
|
|
|
for (size_t wi = 0; wi < w; wi++) { |
|
|
|
auto dst_idx = ni * c * h * w + ci * h * w + hi * w + wi; |
|
|
|
auto src_idx = 0; |
|
|
|
if (args.device_format == kOpFormat_NHWC) { |
|
|
|
src_idx = ni * h * w * c + hi * w * c + wi * c + ci; |
|
|
|
} else if (args.device_format == kOpFormat_HWCN) { |
|
|
|
src_idx = hi * w * c * n + wi * c * n + ci * n + ni; |
|
|
|
} |
|
|
|
SetData(size, false, src_idx, dst_idx, args, result); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
return true; |
|
|
|
} |
|
|
|
@@ -575,8 +647,8 @@ bool NchwToFracZ(const FormatArgs &args, void *result) { |
|
|
|
auto src_ni = hfi * kCubeSize + col; |
|
|
|
auto src_idx = src_row_offset + chw * col; |
|
|
|
auto dst_idx = gfi * fractal_ele_cnt + col * c0 + row; |
|
|
|
auto pad_zero = (src_ni >= n || src_idx >= nchw || src_ci >= c) ? 1 : 0; |
|
|
|
SetDataBysize(size, pad_zero); |
|
|
|
auto pad_zero = (src_ni >= n || src_idx >= nchw || src_ci >= c) ? true : false; |
|
|
|
SetData(size, pad_zero, src_idx, dst_idx, args, result); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
@@ -630,7 +702,7 @@ bool FracZToNchw(const FormatArgs &args, void *result) { |
|
|
|
size_t c0_idx = c_idx % c0; |
|
|
|
size_t nc_idx = n_idx; |
|
|
|
size_t src_idx = c1_idx * hwncc0 + h_idx * wncc0 + w_idx * ncc0 + nc_idx * c0 + c0_idx; |
|
|
|
SetDataBysize(size, 0); |
|
|
|
SetData(size, false, src_idx, dst_idx, args, result); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
@@ -679,7 +751,7 @@ bool NchwToFracZc04(const FormatArgs &args, void *result) { |
|
|
|
auto c_idx = desc_c1 * c0 + desc_c0; |
|
|
|
auto src_idx = desc_g * nhwc + desc_n * hwc + c_idx * h * w + desc_h * w + desc_w; |
|
|
|
auto pad_zero = desc_g >= 1 || desc_n >= n || c_idx >= c; |
|
|
|
SetDataBysize(size, pad_zero); |
|
|
|
SetData(size, pad_zero, src_idx, dst_idx, args, result); |
|
|
|
dst_idx++; |
|
|
|
} |
|
|
|
} |
|
|
|
@@ -773,7 +845,7 @@ bool NchwToFracNz(const FormatArgs &args, void *result) { |
|
|
|
for (size_t i = 0; i < w0; ++i) { |
|
|
|
size_t src_idx = src_h_head + w1_idx * w0 + i; |
|
|
|
size_t dst_idx = h1h0_head + w1_idx * h1h0w0 + i; |
|
|
|
SetDataBysize(size, 0); |
|
|
|
SetData(size, false, src_idx, dst_idx, args, result); |
|
|
|
} |
|
|
|
} |
|
|
|
auto w1_head = num_w1 * w0; |
|
|
|
@@ -781,7 +853,7 @@ bool NchwToFracNz(const FormatArgs &args, void *result) { |
|
|
|
auto src_w_idx = w1_head + w0_idx; |
|
|
|
size_t dst_idx = h1h0_head + num_w1 * h1h0w0 + w0_idx; |
|
|
|
size_t src_idx = src_h_head + src_w_idx; |
|
|
|
SetDataBysize(size, 0); |
|
|
|
SetData(size, false, src_idx, dst_idx, args, result); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
@@ -835,7 +907,7 @@ bool FracNzToNchw(const FormatArgs &args, void *result) { |
|
|
|
for (size_t i = 0; i < w0; ++i) { |
|
|
|
size_t src_idx = h1h0_head + w1_idx * h1h0w0 + i; |
|
|
|
size_t dst_idx = src_h_head + w1_idx * w0 + i; |
|
|
|
SetDataBysize(size, 0); |
|
|
|
SetData(size, false, src_idx, dst_idx, args, result); |
|
|
|
} |
|
|
|
} |
|
|
|
auto w1_head = num_w1 * w0; |
|
|
|
@@ -843,7 +915,7 @@ bool FracNzToNchw(const FormatArgs &args, void *result) { |
|
|
|
auto src_w_idx = w1_head + w0_idx; |
|
|
|
size_t src_idx = h1h0_head + num_w1 * h1h0w0 + w0_idx; |
|
|
|
size_t dst_idx = src_h_head + src_w_idx; |
|
|
|
SetDataBysize(size, 0); |
|
|
|
SetData(size, false, src_idx, dst_idx, args, result); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
@@ -895,8 +967,8 @@ bool NchwToNc1hwc0(const FormatArgs &args, void *result) { |
|
|
|
size_t dst_idx = c0_idx + w_head_addr; |
|
|
|
size_t c_idx = c0_idx + c1_idx * c0; |
|
|
|
size_t src_idx = n_idx * chw + c_idx * hw + h_idx * w + w_idx; |
|
|
|
auto pad_zero = (c_idx < c) ? 0 : 1; |
|
|
|
SetDataBysize(size, pad_zero); |
|
|
|
auto pad_zero = (c_idx < c) ? false : true; |
|
|
|
SetData(size, pad_zero, src_idx, dst_idx, args, result); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
@@ -947,7 +1019,7 @@ bool Nc1hwc0ToNchw(const FormatArgs &args, void *result) { |
|
|
|
size_t c1_idx = c_idx / c0; |
|
|
|
size_t c0_idx = c_idx % c0; |
|
|
|
size_t src_idx = n_idx * c1hwc0 + c1_idx * hwc0 + h_idx * wc0 + w_idx * c0 + c0_idx; |
|
|
|
SetDataBysize(size, 0); |
|
|
|
SetData(size, false, src_idx, dst_idx, args, result); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
@@ -983,8 +1055,8 @@ bool NchwToC1hwncoc0(const FormatArgs &args, void *result) { |
|
|
|
co_i * c0 + c0_i; |
|
|
|
size_t c_i = c0_i + c1_i * c0; |
|
|
|
size_t src_idx = n_i * c * h * w + c_i * h * w + h_i * w + w_i; |
|
|
|
auto pad_zero = (c_i < c && c0_i == co_i) ? 0 : 1; |
|
|
|
SetDataBysize(size, pad_zero); |
|
|
|
auto pad_zero = (c_i < c && c0_i == co_i) ? false : true; |
|
|
|
SetData(size, pad_zero, src_idx, dst_idx, args, result); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
@@ -1020,7 +1092,7 @@ bool C1hwncoc0ToNchw(const FormatArgs &args, void *result) { |
|
|
|
size_t co_i = c0_i; |
|
|
|
size_t src_idx = |
|
|
|
c1_i * h * w * n * co * c0 + h_i * w * n * co * c0 + w_i * n * co * c0 + n_i * co * c0 + co_i * c0 + c0_i; |
|
|
|
SetDataBysize(size, 0); |
|
|
|
SetData(size, false, src_idx, dst_idx, args, result); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|