|
|
|
@@ -415,9 +415,9 @@ bool AscendDeviceAddress::SyncDeviceToHostAndConvertFormatBasedOnTransData(const |
|
|
|
MS_LOG(ERROR) << "Illegal dtype."; |
|
|
|
} |
|
|
|
auto shape_size = trans::ShapeSize(host_shape); |
|
|
|
auto size_tmp = device_dtype_size * shape_size; |
|
|
|
size = GetCommonAlignSize(size_tmp); |
|
|
|
size = device_dtype_size * shape_size; |
|
|
|
} |
|
|
|
size = GetCommonAlignSize(size); |
|
|
|
void *output_address_ptr = nullptr; |
|
|
|
auto ret_malloc = rtMalloc(&output_address_ptr, size, RT_MEMORY_HBM); |
|
|
|
if (ret_malloc != RT_ERROR_NONE) { |
|
|
|
@@ -427,7 +427,7 @@ bool AscendDeviceAddress::SyncDeviceToHostAndConvertFormatBasedOnTransData(const |
|
|
|
// launch |
|
|
|
LaunchTransData(kernel_mod_ptr, output_address_ptr, size, workspace_size_list); |
|
|
|
if (type_id_ == type) { |
|
|
|
SyncMemory(host_ptr, output_address_ptr, size, RT_MEMCPY_DEVICE_TO_HOST); |
|
|
|
SyncMemory(host_ptr, output_address_ptr, host_size, RT_MEMCPY_DEVICE_TO_HOST); |
|
|
|
} else { |
|
|
|
auto host = std::vector<uint8_t>(size); |
|
|
|
SyncMemory(host.data(), output_address_ptr, size, RT_MEMCPY_DEVICE_TO_HOST); |
|
|
|
|