|
|
|
@@ -19,6 +19,8 @@ |
|
|
|
#include "utils/ms_context.h" |
|
|
|
#include "utils/convert_utils.h" |
|
|
|
#include "ps/ps_cache/ps_cache_manager.h" |
|
|
|
#include "runtime/device/gpu/gpu_device_manager.h" |
|
|
|
#include "runtime/device/gpu/gpu_common.h" |
|
|
|
namespace mindspore { |
|
|
|
namespace device { |
|
|
|
namespace gpu { |
|
|
|
@@ -34,6 +36,39 @@ std::vector<void *> GPUMemoryManager::MallocContinuousMemFromMemPool(size_t tota |
|
|
|
return GPUMemoryAllocator::GetInstance().AllocContinuousTensorMem(total_size, size_list); |
|
|
|
} |
|
|
|
|
|
|
|
bool GPUMemoryManager::MallocContinuousMemFromMemPool(const DeviceAddressPtrList addr_list, size_t total_size, |
|
|
|
std::vector<size_t> size_list) { |
|
|
|
auto device_ptr_list = MallocContinuousMemFromMemPool(total_size, size_list); |
|
|
|
if (device_ptr_list.size() == 0) { |
|
|
|
return false; |
|
|
|
} |
|
|
|
if (addr_list.size() != device_ptr_list.size()) { |
|
|
|
MS_LOG(EXCEPTION) << "The size of device list is not equal to the size of address list."; |
|
|
|
} |
|
|
|
auto &stream = GPUDeviceManager::GetInstance().default_stream(); |
|
|
|
MS_EXCEPTION_IF_NULL(stream); |
|
|
|
bool need_sync_stream = false; |
|
|
|
for (size_t i = 0; i < addr_list.size(); i++) { |
|
|
|
MS_EXCEPTION_IF_NULL(addr_list[i]); |
|
|
|
auto old_addr = addr_list[i]->ptr_; |
|
|
|
auto new_addr = device_ptr_list[i]; |
|
|
|
MS_EXCEPTION_IF_NULL(new_addr); |
|
|
|
if (old_addr != nullptr) { |
|
|
|
need_sync_stream = true; |
|
|
|
CHECK_OP_RET_WITH_EXCEPT( |
|
|
|
GPUDeviceManager::GetInstance().CopyDeviceMemToDeviceAsync(new_addr, old_addr, size_list[i], stream), |
|
|
|
"Failed to copyHostMemToDeviceAsync."); |
|
|
|
FreeMemFromMemPool(old_addr); |
|
|
|
} |
|
|
|
addr_list[i]->ptr_ = new_addr; |
|
|
|
addr_list[i]->from_mem_pool_ = true; |
|
|
|
} |
|
|
|
if (need_sync_stream) { |
|
|
|
return GPUDeviceManager::GetInstance().SyncStream(stream); |
|
|
|
} |
|
|
|
return true; |
|
|
|
} |
|
|
|
|
|
|
|
void GPUMemoryManager::MallocDeviceMemory() { |
|
|
|
auto context_ptr = MsContext::GetInstance(); |
|
|
|
MS_EXCEPTION_IF_NULL(context_ptr); |
|
|
|
|