|
|
|
@@ -153,6 +153,16 @@ bool SyncDeviceToHostAndFloatToFloat64(void *dst, size_t dst_size, const void *s |
|
|
|
return true; |
|
|
|
} |
|
|
|
|
|
|
|
DeviceAddressPtr AssignLaunchMemory(size_t size, const std::string &format, TypeId type) { |
|
|
|
auto ms_context = MsContext::GetInstance(); |
|
|
|
MS_EXCEPTION_IF_NULL(ms_context); |
|
|
|
auto device_id = ms_context->device_id(); |
|
|
|
auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id); |
|
|
|
MS_EXCEPTION_IF_NULL(runtime_instance); |
|
|
|
auto address_ptr = runtime_instance->AssignSingleOpLaunchMemory(size, format, type); |
|
|
|
return address_ptr; |
|
|
|
} |
|
|
|
|
|
|
|
size_t GetCommonAlignSize(size_t input_size) { |
|
|
|
return (input_size + kMemAlignSize + 31) / kMemAlignSize * kMemAlignSize; |
|
|
|
} |
|
|
|
@@ -325,18 +335,15 @@ void AscendDeviceAddress::LaunchTransData(kernel::KernelModPtr kernel_mod_ptr, v |
|
|
|
AddressPtrList kernel_inputs = {input_address}; |
|
|
|
AddressPtrList kernel_outputs = {output_address}; |
|
|
|
AddressPtrList kernel_workspaces; |
|
|
|
std::vector<void *> workspaces_address_ptr(workspace_size_list.size(), nullptr); |
|
|
|
if (!workspace_size_list.empty()) { |
|
|
|
for (size_t i = 0; i < workspace_size_list.size(); ++i) { |
|
|
|
auto workspace_size = GetCommonAlignSize(workspace_size_list[i]); |
|
|
|
auto ret_malloc = rtMalloc(&workspaces_address_ptr[i], workspace_size, RT_MEMORY_HBM); |
|
|
|
if (ret_malloc != RT_ERROR_NONE) { |
|
|
|
MS_LOG(ERROR) << "Failed to rtMalloc memory"; |
|
|
|
} |
|
|
|
auto workspace_address_ptr = AssignLaunchMemory(workspace_size, "", kTypeUnknown); |
|
|
|
MS_EXCEPTION_IF_NULL(workspace_address_ptr); |
|
|
|
auto workspace_address = std::make_shared<kernel::Address>(); |
|
|
|
MS_EXCEPTION_IF_NULL(workspace_address); |
|
|
|
workspace_address->addr = workspaces_address_ptr[i]; |
|
|
|
workspace_address->size = workspace_size; |
|
|
|
workspace_address->addr = workspace_address_ptr->GetMutablePtr(); |
|
|
|
workspace_address->size = workspace_address_ptr->GetSize(); |
|
|
|
kernel_workspaces.push_back(workspace_address); |
|
|
|
} |
|
|
|
} |
|
|
|
@@ -350,15 +357,6 @@ void AscendDeviceAddress::LaunchTransData(kernel::KernelModPtr kernel_mod_ptr, v |
|
|
|
if (!ret) { |
|
|
|
MS_LOG(ERROR) << "Launch kernel failed."; |
|
|
|
} |
|
|
|
SyncStream(); |
|
|
|
if (!workspace_size_list.empty()) { |
|
|
|
for (size_t i = 0; i < workspace_size_list.size(); ++i) { |
|
|
|
auto ret_free = rtFree(workspaces_address_ptr[i]); |
|
|
|
if (ret_free != RT_ERROR_NONE) { |
|
|
|
MS_LOG(ERROR) << "Failed to rtFree memory"; |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
kernel::KernelModPtr AscendDeviceAddress::CompileTransDataAndObtainKernelMod(const nlohmann::json &kernel_json) const { |
|
|
|
@@ -418,19 +416,17 @@ bool AscendDeviceAddress::SyncDeviceToHostAndConvertFormatBasedOnTransData(const |
|
|
|
size = device_dtype_size * shape_size; |
|
|
|
} |
|
|
|
size = GetCommonAlignSize(size); |
|
|
|
void *output_address_ptr = nullptr; |
|
|
|
auto ret_malloc = rtMalloc(&output_address_ptr, size, RT_MEMORY_HBM); |
|
|
|
if (ret_malloc != RT_ERROR_NONE) { |
|
|
|
MS_LOG(ERROR) << "Failed to rtMalloc memory"; |
|
|
|
} |
|
|
|
auto output_address = AssignLaunchMemory(size, kOpFormat_NCHW, type_id_); |
|
|
|
MS_EXCEPTION_IF_NULL(output_address); |
|
|
|
auto workspace_size_list = GetWorkspaceSizeList(kernel_json); |
|
|
|
// launch |
|
|
|
LaunchTransData(kernel_mod_ptr, output_address_ptr, size, workspace_size_list); |
|
|
|
LaunchTransData(kernel_mod_ptr, output_address->GetMutablePtr(), output_address->GetSize(), workspace_size_list); |
|
|
|
SyncStream(); |
|
|
|
if (type_id_ == type) { |
|
|
|
SyncMemory(host_ptr, output_address_ptr, host_size, RT_MEMCPY_DEVICE_TO_HOST); |
|
|
|
SyncMemory(host_ptr, output_address->GetPtr(), host_size, RT_MEMCPY_DEVICE_TO_HOST); |
|
|
|
} else { |
|
|
|
auto host = std::vector<uint8_t>(size); |
|
|
|
SyncMemory(host.data(), output_address_ptr, size, RT_MEMCPY_DEVICE_TO_HOST); |
|
|
|
SyncMemory(host.data(), output_address->GetPtr(), size, RT_MEMCPY_DEVICE_TO_HOST); |
|
|
|
auto shape_size = trans::ShapeSize(host_shape); |
|
|
|
const trans::TypeIdArgs type_args{host.data(), shape_size, type_id_, type, host_size}; |
|
|
|
sync_ok = trans::TransDataType(type_args, host_ptr); |
|
|
|
@@ -439,10 +435,6 @@ bool AscendDeviceAddress::SyncDeviceToHostAndConvertFormatBasedOnTransData(const |
|
|
|
return false; |
|
|
|
} |
|
|
|
} |
|
|
|
auto ret_free = rtFree(output_address_ptr); |
|
|
|
if (ret_free != RT_ERROR_NONE) { |
|
|
|
MS_LOG(ERROR) << "Failed to rtFree memory"; |
|
|
|
} |
|
|
|
return sync_ok; |
|
|
|
} |
|
|
|
|
|
|
|
|