|
|
@@ -231,7 +231,98 @@ std::vector<size_t> PaddingShapeTo4d(const std::vector<size_t> &shape, const std |
|
|
return shape_4d; |
|
|
return shape_4d; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
namespace { |
|
|
|
|
|
bool CheckDims(const std::vector<size_t> &shape) { |
|
|
|
|
|
if (shape.size() != 4) { |
|
|
|
|
|
MS_LOG(ERROR) << "Host shape dims shoud be 4"; |
|
|
|
|
|
return false; |
|
|
|
|
|
} |
|
|
|
|
|
return true; |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
std::vector<size_t> NchwDeviceShape(const std::vector<size_t> &shape) { |
|
|
|
|
|
if (!CheckDims(shape)) { |
|
|
|
|
|
MS_LOG(EXCEPTION) << "Check dims failed."; |
|
|
|
|
|
} |
|
|
|
|
|
return shape; |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
std::vector<size_t> NhwcDeviceShape(const std::vector<size_t> &shape) { |
|
|
|
|
|
if (!CheckDims(shape)) { |
|
|
|
|
|
MS_LOG(EXCEPTION) << "Ccheck dims failed."; |
|
|
|
|
|
} |
|
|
|
|
|
std::vector<size_t> device_shape; |
|
|
|
|
|
device_shape.push_back(shape[0]); |
|
|
|
|
|
device_shape.push_back(shape[2]); |
|
|
|
|
|
device_shape.push_back(shape[3]); |
|
|
|
|
|
device_shape.push_back(shape[1]); |
|
|
|
|
|
return device_shape; |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
std::vector<size_t> HwchDeviceShape(const std::vector<size_t> &shape) { |
|
|
|
|
|
if (!CheckDims(shape)) { |
|
|
|
|
|
MS_LOG(EXCEPTION) << "Check dims failed."; |
|
|
|
|
|
} |
|
|
|
|
|
std::vector<size_t> device_shape; |
|
|
|
|
|
device_shape.push_back(shape[2]); |
|
|
|
|
|
device_shape.push_back(shape[3]); |
|
|
|
|
|
device_shape.push_back(shape[1]); |
|
|
|
|
|
device_shape.push_back(shape[0]); |
|
|
|
|
|
return device_shape; |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
std::vector<size_t> FracZDeviceShape(const std::vector<size_t> &shape) { |
|
|
|
|
|
if (!CheckDims(shape)) { |
|
|
|
|
|
MS_LOG(EXCEPTION) << "Check dims failed."; |
|
|
|
|
|
} |
|
|
|
|
|
std::vector<size_t> device_shape; |
|
|
|
|
|
size_t cout16 = ((shape[0] + kCubeSize - 1) / kCubeSize) * kCubeSize; |
|
|
|
|
|
size_t cin16 = ((shape[1] + kCubeSize - 1) / kCubeSize) * kCubeSize; |
|
|
|
|
|
device_shape.push_back(shape[2] * shape[3] * cin16 / kCubeSize); |
|
|
|
|
|
device_shape.push_back(cout16 / kCubeSize); |
|
|
|
|
|
device_shape.push_back(kCubeSize); |
|
|
|
|
|
device_shape.push_back(kCubeSize); |
|
|
|
|
|
return device_shape; |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
std::vector<size_t> Nc1hwc0DeviceShape(const std::vector<size_t> &shape) { |
|
|
|
|
|
if (!CheckDims(shape)) { |
|
|
|
|
|
MS_LOG(EXCEPTION) << "Check dims failed."; |
|
|
|
|
|
} |
|
|
|
|
|
std::vector<size_t> device_shape; |
|
|
|
|
|
size_t C1 = (shape[1] + kCubeSize - 1) / kCubeSize; |
|
|
|
|
|
size_t C0 = kCubeSize; |
|
|
|
|
|
device_shape.push_back(shape[0]); |
|
|
|
|
|
device_shape.push_back(C1); |
|
|
|
|
|
device_shape.push_back(shape[2]); |
|
|
|
|
|
device_shape.push_back(shape[3]); |
|
|
|
|
|
device_shape.push_back(C0); |
|
|
|
|
|
return device_shape; |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
std::vector<size_t> C1hwncoc0DeviceShape(const std::vector<size_t> &shape) { |
|
|
|
|
|
if (!CheckDims(shape)) { |
|
|
|
|
|
MS_LOG(EXCEPTION) << "Check dims failed."; |
|
|
|
|
|
} |
|
|
|
|
|
std::vector<size_t> device_shape; |
|
|
|
|
|
device_shape.push_back((shape[1] - 1) / kCubeSize + 1); |
|
|
|
|
|
device_shape.push_back(shape[2]); |
|
|
|
|
|
device_shape.push_back(shape[3]); |
|
|
|
|
|
device_shape.push_back(shape[0]); |
|
|
|
|
|
device_shape.push_back(kCubeSize); |
|
|
|
|
|
device_shape.push_back(kCubeSize); |
|
|
|
|
|
return device_shape; |
|
|
|
|
|
} |
|
|
|
|
|
} // namespace |
|
|
|
|
|
|
|
|
std::vector<size_t> TransShapeToDevice(const std::vector<size_t> &shape, const std::string &format) { |
|
|
std::vector<size_t> TransShapeToDevice(const std::vector<size_t> &shape, const std::string &format) { |
|
|
|
|
|
using DeviceShapeTransfer = std::function<std::vector<size_t>(const std::vector<size_t> &)>; |
|
|
|
|
|
const std::map<std::string, DeviceShapeTransfer> device_shape_map{ |
|
|
|
|
|
{kOpFormat_NCHW, NchwDeviceShape}, {kOpFormat_NHWC, NhwcDeviceShape}, |
|
|
|
|
|
{kOpFormat_HWCN, HwchDeviceShape}, {kOpFormat_FRAC_Z, FracZDeviceShape}, |
|
|
|
|
|
{kOpFormat_NC1HWC0, Nc1hwc0DeviceShape}, {kOpFormat_C1HWNCoC0, C1hwncoc0DeviceShape}, |
|
|
|
|
|
}; |
|
|
|
|
|
|
|
|
if (format == kOpFormat_ND || format == kOpFormat_DEFAULT) { |
|
|
if (format == kOpFormat_ND || format == kOpFormat_DEFAULT) { |
|
|
return shape; |
|
|
return shape; |
|
|
} |
|
|
} |
|
|
@@ -255,37 +346,31 @@ std::vector<size_t> TransShapeToDevice(const std::vector<size_t> &shape, const s |
|
|
MS_LOG(WARNING) << "Get Device Shape using a shape size is less than 4 ,should be Padding shape by Default firstly"; |
|
|
MS_LOG(WARNING) << "Get Device Shape using a shape size is less than 4 ,should be Padding shape by Default firstly"; |
|
|
temp_shape = PaddingShapeTo4dByDefault(shape); |
|
|
temp_shape = PaddingShapeTo4dByDefault(shape); |
|
|
} |
|
|
} |
|
|
if (format == kOpFormat_NC1HWC0) { |
|
|
|
|
|
size_t C1 = (temp_shape[1] + kCubeSize - 1) / kCubeSize; |
|
|
|
|
|
size_t C0 = kCubeSize; |
|
|
|
|
|
device_shape.push_back(temp_shape[0]); |
|
|
|
|
|
device_shape.push_back(C1); |
|
|
|
|
|
device_shape.push_back(temp_shape[2]); |
|
|
|
|
|
device_shape.push_back(temp_shape[3]); |
|
|
|
|
|
device_shape.push_back(C0); |
|
|
|
|
|
return device_shape; |
|
|
|
|
|
} else if (format == kOpFormat_FRAC_Z) { |
|
|
|
|
|
size_t cout16 = ((temp_shape[0] + kCubeSize - 1) / kCubeSize) * kCubeSize; |
|
|
|
|
|
size_t cin16 = ((temp_shape[1] + kCubeSize - 1) / kCubeSize) * kCubeSize; |
|
|
|
|
|
device_shape.push_back(temp_shape[2] * temp_shape[3] * cin16 / kCubeSize); |
|
|
|
|
|
device_shape.push_back(cout16 / kCubeSize); |
|
|
|
|
|
device_shape.push_back(kCubeSize); |
|
|
|
|
|
device_shape.push_back(kCubeSize); |
|
|
|
|
|
return device_shape; |
|
|
|
|
|
} else if (format == kOpFormat_NHWC) { |
|
|
|
|
|
device_shape.push_back(temp_shape[0]); |
|
|
|
|
|
device_shape.push_back(temp_shape[2]); |
|
|
|
|
|
device_shape.push_back(temp_shape[3]); |
|
|
|
|
|
device_shape.push_back(temp_shape[1]); |
|
|
|
|
|
return device_shape; |
|
|
|
|
|
} else if (format == kOpFormat_HWCN) { |
|
|
|
|
|
return {temp_shape[2], temp_shape[3], temp_shape[1], temp_shape[0]}; |
|
|
|
|
|
} else if (format == kOpFormat_NCHW) { |
|
|
|
|
|
return temp_shape; |
|
|
|
|
|
|
|
|
auto iter = device_shape_map.find(format); |
|
|
|
|
|
if (iter != device_shape_map.end()) { |
|
|
|
|
|
return iter->second(temp_shape); |
|
|
} |
|
|
} |
|
|
MS_LOG(EXCEPTION) << "Unexpected format[" << format << "]"; |
|
|
MS_LOG(EXCEPTION) << "Unexpected format[" << format << "]"; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
bool CheckArgs(const FormatArgs &args, size_t *size, size_t *total_size) { |
|
|
|
|
|
if (args.host_shape.size() != kNchwDims) { |
|
|
|
|
|
MS_LOG(ERROR) << "Invalid host shape, host shape dims:" << args.host_shape.size() << ", expect dims:" << kNchwDims; |
|
|
|
|
|
return false; |
|
|
|
|
|
} |
|
|
|
|
|
*size = TypeIdSize(args.src_data_type); |
|
|
|
|
|
if (*size < 1) { |
|
|
|
|
|
MS_LOG(ERROR) << "Illegal dtype."; |
|
|
|
|
|
return false; |
|
|
|
|
|
} |
|
|
|
|
|
*total_size = ShapeSize(args.device_shape) * (*size); |
|
|
|
|
|
if (*total_size != args.device_size) { |
|
|
|
|
|
MS_LOG(ERROR) << "Illegal total data size, total_size:" << *total_size << ", device_size:" << args.device_size; |
|
|
|
|
|
return false; |
|
|
|
|
|
} |
|
|
|
|
|
return true; |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
bool TransDataType(const TypeIdArgs &args, void *result) { |
|
|
bool TransDataType(const TypeIdArgs &args, void *result) { |
|
|
MS_LOG(DEBUG) << "Begin trans datatype from " << TypeIdLabel(args.host_data_type) << " to " |
|
|
MS_LOG(DEBUG) << "Begin trans datatype from " << TypeIdLabel(args.host_data_type) << " to " |
|
|
<< TypeIdLabel(args.device_data_type); |
|
|
<< TypeIdLabel(args.device_data_type); |
|
|
@@ -320,13 +405,14 @@ bool TransFormat(const FormatArgs &args, void *result) { |
|
|
MS_LOG(ERROR) << "Invalid datatype.."; |
|
|
MS_LOG(ERROR) << "Invalid datatype.."; |
|
|
return false; |
|
|
return false; |
|
|
} |
|
|
} |
|
|
if ((args.host_format == kOpFormat_NCHW || args.host_format == kOpFormat_ND) && |
|
|
|
|
|
args.device_format == kOpFormat_FRAC_Z) { |
|
|
|
|
|
|
|
|
if (args.device_format == kOpFormat_FRAC_Z) { |
|
|
return NchwToFracZ(args, result); |
|
|
return NchwToFracZ(args, result); |
|
|
} else if (args.device_format == kOpFormat_FRAC_NZ) { |
|
|
} else if (args.device_format == kOpFormat_FRAC_NZ) { |
|
|
return NchwToFracNz(args, result); |
|
|
return NchwToFracNz(args, result); |
|
|
} else if (args.device_format == kOpFormat_NC1HWC0) { |
|
|
} else if (args.device_format == kOpFormat_NC1HWC0) { |
|
|
return NchwToNc1hwc0(args, result); |
|
|
return NchwToNc1hwc0(args, result); |
|
|
|
|
|
} else if (args.device_format == kOpFormat_C1HWNCoC0) { |
|
|
|
|
|
return NchwToC1hwncoc0(args, result); |
|
|
} |
|
|
} |
|
|
return true; |
|
|
return true; |
|
|
} |
|
|
} |
|
|
@@ -337,13 +423,14 @@ bool TransFormatFromDeviceToHost(const FormatArgs &args, void *result) { |
|
|
MS_LOG(ERROR) << "Invalid datatype.."; |
|
|
MS_LOG(ERROR) << "Invalid datatype.."; |
|
|
return false; |
|
|
return false; |
|
|
} |
|
|
} |
|
|
if ((args.host_format == kOpFormat_NCHW || args.host_format == kOpFormat_ND) && |
|
|
|
|
|
args.device_format == kOpFormat_FRAC_Z) { |
|
|
|
|
|
|
|
|
if (args.device_format == kOpFormat_FRAC_Z) { |
|
|
return FracZToNchw(args, result); |
|
|
return FracZToNchw(args, result); |
|
|
} else if (args.device_format == kOpFormat_FRAC_NZ) { |
|
|
} else if (args.device_format == kOpFormat_FRAC_NZ) { |
|
|
return FracNzToNchw(args, result); |
|
|
return FracNzToNchw(args, result); |
|
|
} else if (args.device_format == kOpFormat_NC1HWC0) { |
|
|
} else if (args.device_format == kOpFormat_NC1HWC0) { |
|
|
return Nc1hwc0ToNchw(args, result); |
|
|
return Nc1hwc0ToNchw(args, result); |
|
|
|
|
|
} else if (args.device_format == kOpFormat_C1HWNCoC0) { |
|
|
|
|
|
return C1hwncoc0ToNchw(args, result); |
|
|
} |
|
|
} |
|
|
return true; |
|
|
return true; |
|
|
} |
|
|
} |
|
|
@@ -801,5 +888,99 @@ bool Nc1hwc0ToNchw(const FormatArgs &args, void *result) { |
|
|
} |
|
|
} |
|
|
return true; |
|
|
return true; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
bool NchwToC1hwncoc0(const FormatArgs &args, void *result) { |
|
|
|
|
|
// trans nchw to c1hwncoc0 |
|
|
|
|
|
MS_LOG(DEBUG) << "Trans format from nchw to c1hwncoc0."; |
|
|
|
|
|
MS_EXCEPTION_IF_NULL(result); |
|
|
|
|
|
size_t size = 0; |
|
|
|
|
|
size_t total_size = 0; |
|
|
|
|
|
if (!CheckArgs(args, &size, &total_size)) { |
|
|
|
|
|
MS_LOG(ERROR) << "Check args failed."; |
|
|
|
|
|
return false; |
|
|
|
|
|
} |
|
|
|
|
|
auto n = args.host_shape[0]; |
|
|
|
|
|
auto c = args.host_shape[1]; |
|
|
|
|
|
auto h = args.host_shape[2]; |
|
|
|
|
|
auto w = args.host_shape[3]; |
|
|
|
|
|
auto c1 = args.device_shape[0]; |
|
|
|
|
|
auto co = args.device_shape[4]; |
|
|
|
|
|
auto c0 = args.device_shape[5]; |
|
|
|
|
|
for (size_t c1_i = 0; c1_i < c1; c1_i++) { |
|
|
|
|
|
for (size_t h_i = 0; h_i < h; h_i++) { |
|
|
|
|
|
for (size_t w_i = 0; w_i < w; w_i++) { |
|
|
|
|
|
for (size_t n_i = 0; n_i < n; n_i++) { |
|
|
|
|
|
for (size_t co_i = 0; co_i < co; co_i++) { |
|
|
|
|
|
for (size_t c0_i = 0; c0_i < c0; c0_i++) { |
|
|
|
|
|
size_t dst_offset = (c1_i * h * w * n * co * c0 + h_i * w * n * co * c0 + w_i * n * co * c0 + |
|
|
|
|
|
n_i * co * c0 + co_i * c0 + c0_i) * |
|
|
|
|
|
size; |
|
|
|
|
|
size_t protected_size = total_size - dst_offset < static_cast<size_t>(SECUREC_MEM_MAX_LEN) |
|
|
|
|
|
? total_size - dst_offset |
|
|
|
|
|
: static_cast<size_t>(SECUREC_MEM_MAX_LEN); |
|
|
|
|
|
size_t c_i = c0_i + c1_i * c0; |
|
|
|
|
|
size_t src_offset = (n_i * c * h * w + c_i * h * w + h_i * w + w_i) * size; |
|
|
|
|
|
error_t ret; |
|
|
|
|
|
if (c_i < c && c0_i == co_i) { |
|
|
|
|
|
ret = memcpy_s(static_cast<uint8_t *>(result) + dst_offset, protected_size, |
|
|
|
|
|
static_cast<uint8_t const *>(args.data) + src_offset, size); |
|
|
|
|
|
} else { |
|
|
|
|
|
ret = memset_s(static_cast<uint8_t *>(result) + dst_offset, protected_size, 0, size); |
|
|
|
|
|
} |
|
|
|
|
|
if (ret != EOK) { |
|
|
|
|
|
MS_LOG(ERROR) << "Failed to operate the dst memory, error-code:" << ret; |
|
|
|
|
|
return false; |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
return true; |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
bool C1hwncoc0ToNchw(const FormatArgs &args, void *result) { |
|
|
|
|
|
// trans c1hwncoc0 to nchw |
|
|
|
|
|
MS_LOG(DEBUG) << "Trans format from c1hwncoc0 to nchw"; |
|
|
|
|
|
MS_EXCEPTION_IF_NULL(result); |
|
|
|
|
|
size_t size = 0; |
|
|
|
|
|
size_t total_size = 0; |
|
|
|
|
|
if (!CheckArgs(args, &size, &total_size)) { |
|
|
|
|
|
MS_LOG(ERROR) << "Check args failed."; |
|
|
|
|
|
return false; |
|
|
|
|
|
} |
|
|
|
|
|
auto n = args.host_shape[0]; |
|
|
|
|
|
auto c = args.host_shape[1]; |
|
|
|
|
|
auto h = args.host_shape[2]; |
|
|
|
|
|
auto w = args.host_shape[3]; |
|
|
|
|
|
auto co = args.device_shape[4]; |
|
|
|
|
|
auto c0 = args.device_shape[5]; |
|
|
|
|
|
for (size_t n_i = 0; n_i < n; n_i++) { |
|
|
|
|
|
for (size_t c_i = 0; c_i < c; c_i++) { |
|
|
|
|
|
for (size_t h_i = 0; h_i < h; h_i++) { |
|
|
|
|
|
for (size_t w_i = 0; w_i < w; w_i++) { |
|
|
|
|
|
size_t dst_offset = (n_i * c * h * w + c_i * h * w + h_i * w + w_i) * size; |
|
|
|
|
|
size_t c1_i = c_i / kCubeSize; |
|
|
|
|
|
size_t c0_i = c_i % kCubeSize; |
|
|
|
|
|
size_t co_i = c0_i; |
|
|
|
|
|
size_t src_offset = (c1_i * h * w * n * co * c0 + h_i * w * n * co * c0 + w_i * n * co * c0 + n_i * co * c0 + |
|
|
|
|
|
co_i * c0 + c0_i) * |
|
|
|
|
|
size; |
|
|
|
|
|
size_t protected_size = total_size - dst_offset < static_cast<size_t>(SECUREC_MEM_MAX_LEN) |
|
|
|
|
|
? total_size - dst_offset |
|
|
|
|
|
: static_cast<size_t>(SECUREC_MEM_MAX_LEN); |
|
|
|
|
|
auto ret = memcpy_s(static_cast<uint8_t *>(result) + dst_offset, protected_size, |
|
|
|
|
|
static_cast<uint8_t const *>(args.data) + src_offset, size); |
|
|
|
|
|
if (ret != EOK) { |
|
|
|
|
|
MS_LOG(ERROR) << "Failed to operate the dst memory, error-code:" << ret; |
|
|
|
|
|
return false; |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
return true; |
|
|
|
|
|
} |
|
|
} // namespace trans |
|
|
} // namespace trans |
|
|
} // namespace mindspore |
|
|
} // namespace mindspore |