|
|
|
@@ -104,10 +104,10 @@ bool AscendDeviceAddress::SyncDeviceToHost(const std::vector<int> &shape, size_t |
|
|
|
} else if (type_id_ == kNumberTypeFloat32 && type == kNumberTypeFloat64) { |
|
|
|
sync_ok = SyncDeviceToHostAndFloatToFloat64(host_ptr, size, ptr_, size_); |
|
|
|
} else { |
|
|
|
auto shape_size = trans::ShapeSize(host_shape); |
|
|
|
auto host_size = trans::ShapeSize(host_shape); |
|
|
|
auto host = std::vector<uint8_t>(size_); |
|
|
|
SyncMemory(host.data(), ptr_, size_, RT_MEMCPY_DEVICE_TO_HOST); |
|
|
|
const trans::TypeIdArgs type_args{host.data(), shape_size, type_id_, type}; |
|
|
|
const trans::TypeIdArgs type_args{host.data(), size_, size, type_id_, type, host_size, host_size}; |
|
|
|
sync_ok = trans::TransDataType(type_args, host_ptr); |
|
|
|
if (!sync_ok) { |
|
|
|
MS_LOG(ERROR) << "trans data type failed."; |
|
|
|
@@ -153,14 +153,15 @@ bool AscendDeviceAddress::SyncDeviceToHostAndConvertFormat(const std::vector<int |
|
|
|
auto host = std::vector<uint8_t>(size_); |
|
|
|
sync_ok = trans::TransFormatFromDeviceToHost(format_args, host.data()); |
|
|
|
if (!sync_ok) { |
|
|
|
MS_LOG(ERROR) << "trans format failed."; |
|
|
|
MS_LOG(ERROR) << "Trans format failed."; |
|
|
|
return false; |
|
|
|
} |
|
|
|
auto shape_size = trans::ShapeSize(host_shape); |
|
|
|
const trans::TypeIdArgs type_args{host.data(), shape_size, type_id_, type}; |
|
|
|
auto host_size = trans::ShapeSize(host_shape); |
|
|
|
auto device_size = trans::ShapeSize(device_shape); |
|
|
|
const trans::TypeIdArgs type_args{host.data(), size_, size, type_id_, type, device_size, host_size}; |
|
|
|
sync_ok = trans::TransDataType(type_args, host_ptr); |
|
|
|
if (!sync_ok) { |
|
|
|
MS_LOG(ERROR) << "trans format failed."; |
|
|
|
MS_LOG(ERROR) << "Trans format failed."; |
|
|
|
return false; |
|
|
|
} |
|
|
|
} else { |
|
|
|
@@ -168,7 +169,7 @@ bool AscendDeviceAddress::SyncDeviceToHostAndConvertFormat(const std::vector<int |
|
|
|
host_shape, device_shape, type_id_}; |
|
|
|
sync_ok = trans::TransFormatFromDeviceToHost(format_args, host_ptr); |
|
|
|
if (!sync_ok) { |
|
|
|
MS_LOG(ERROR) << "trans format failed."; |
|
|
|
MS_LOG(ERROR) << "Trans format failed."; |
|
|
|
return false; |
|
|
|
} |
|
|
|
} |
|
|
|
@@ -192,12 +193,12 @@ bool AscendDeviceAddress::SyncHostToDevice(const std::vector<int> &shape, size_t |
|
|
|
} else if (type_id_ == kNumberTypeFloat32 && type == kNumberTypeFloat64) { |
|
|
|
sync_ok = Float64ToFloatAndSyncHostToDevice(ptr_, size_, host_ptr, size); |
|
|
|
} else { |
|
|
|
auto shape_size = trans::ShapeSize(host_shape); |
|
|
|
const trans::TypeIdArgs type_args{host_ptr, shape_size, type, type_id_}; |
|
|
|
auto host_size = trans::ShapeSize(host_shape); |
|
|
|
const trans::TypeIdArgs type_args{host_ptr, size, size_, type, type_id_, host_size, host_size}; |
|
|
|
auto host_tmp = std::vector<uint8_t>(size_); |
|
|
|
sync_ok = trans::TransDataType(type_args, host_tmp.data()); |
|
|
|
if (!sync_ok) { |
|
|
|
MS_LOG(ERROR) << "trans data type failed."; |
|
|
|
MS_LOG(ERROR) << "Trans data type failed."; |
|
|
|
return false; |
|
|
|
} |
|
|
|
SyncMemory(ptr_, host_tmp.data(), size_, RT_MEMCPY_HOST_TO_DEVICE); |
|
|
|
@@ -234,12 +235,13 @@ bool AscendDeviceAddress::ConvertFormatAndSyncHostToDevice(const std::vector<int |
|
|
|
device_shape = trans::TransShapeToDevice(host_shape, format_); |
|
|
|
} |
|
|
|
if (type_id_ != type) { |
|
|
|
auto shape_size = trans::ShapeSize(host_shape); |
|
|
|
const trans::TypeIdArgs type_args{host_ptr, shape_size, type, type_id_}; |
|
|
|
auto host_size = trans::ShapeSize(host_shape); |
|
|
|
auto device_size = trans::ShapeSize(device_shape); |
|
|
|
const trans::TypeIdArgs type_args{host_ptr, size, size_, type, type_id_, host_size, device_size}; |
|
|
|
auto host_tmp = std::vector<uint8_t>(size_); |
|
|
|
sync_ok = trans::TransDataType(type_args, host_tmp.data()); |
|
|
|
if (!sync_ok) { |
|
|
|
MS_LOG(ERROR) << "trans datatype failed."; |
|
|
|
MS_LOG(ERROR) << "Trans datatype failed."; |
|
|
|
return false; |
|
|
|
} |
|
|
|
const trans::FormatArgs format_args{host_tmp.data(), size_, kOpFormat_NCHW, format_, |
|
|
|
@@ -247,7 +249,7 @@ bool AscendDeviceAddress::ConvertFormatAndSyncHostToDevice(const std::vector<int |
|
|
|
auto dst_tmp = std::vector<uint8_t>(size_); |
|
|
|
sync_ok = trans::TransFormat(format_args, dst_tmp.data()); |
|
|
|
if (!sync_ok) { |
|
|
|
MS_LOG(ERROR) << "trans format failed."; |
|
|
|
MS_LOG(ERROR) << "Trans format failed."; |
|
|
|
return false; |
|
|
|
} |
|
|
|
SyncMemory(ptr_, dst_tmp.data(), size_, RT_MEMCPY_HOST_TO_DEVICE); |
|
|
|
@@ -256,7 +258,7 @@ bool AscendDeviceAddress::ConvertFormatAndSyncHostToDevice(const std::vector<int |
|
|
|
auto host_tmp = std::vector<uint8_t>(size_); |
|
|
|
sync_ok = trans::TransFormat(format_args, host_tmp.data()); |
|
|
|
if (!sync_ok) { |
|
|
|
MS_LOG(ERROR) << "trans format failed."; |
|
|
|
MS_LOG(ERROR) << "Trans format failed."; |
|
|
|
return false; |
|
|
|
} |
|
|
|
SyncMemory(ptr_, host_tmp.data(), size_, RT_MEMCPY_HOST_TO_DEVICE); |
|
|
|
|