| @@ -38,7 +38,6 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace pynative { | namespace pynative { | ||||
| namespace py = pybind11; | namespace py = pybind11; | ||||
| using ResourcePtr = std::shared_ptr<pipeline::Resource>; | using ResourcePtr = std::shared_ptr<pipeline::Resource>; | ||||
| using GradOperationPtr = std::shared_ptr<prim::GradOperation>; | using GradOperationPtr = std::shared_ptr<prim::GradOperation>; | ||||
| @@ -168,7 +167,6 @@ class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> { | |||||
| }; | }; | ||||
| using PynativeExecutorPtr = std::shared_ptr<PynativeExecutor>; | using PynativeExecutorPtr = std::shared_ptr<PynativeExecutor>; | ||||
| } // namespace pynative | } // namespace pynative | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -34,10 +34,7 @@ void AscendMemoryManager::MallocDeviceMemory() { | |||||
| MS_EXCEPTION(DeviceProcessError) << "rtMalloc mem size[" << device_mem_size_ << "] fail, ret[" << ret << "]"; | MS_EXCEPTION(DeviceProcessError) << "rtMalloc mem size[" << device_mem_size_ << "] fail, ret[" << ret << "]"; | ||||
| } | } | ||||
| AscendMemoryPool::GetInstance().set_device_mem_size(device_mem_size_); | |||||
| AscendMemoryPool::GetInstance().set_device_mem_pool_base(device_mem_base_); | |||||
| AscendMemoryPool::GetInstance().set_device_mem_pool_offset(device_mem_size_); | |||||
| AscendMemoryPool::GetInstance().set_graph_dynamic_mem_offset(dynamic_mem_offset_); | |||||
| AscendMemoryPool::GetInstance().Init(device_mem_base_, device_mem_size_, dynamic_mem_offset_); | |||||
| } | } | ||||
| uint64_t AscendMemoryManager::GetDeviceMemSizeFromContext() { | uint64_t AscendMemoryManager::GetDeviceMemSizeFromContext() { | ||||
| @@ -21,6 +21,25 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace device { | namespace device { | ||||
| namespace ascend { | namespace ascend { | ||||
| void AscendMemoryPool::Init(uint8_t *device_mem_base, uint64_t device_mem_size, uint64_t dynamic_mem_offset) { | |||||
| static bool initialized = false; | |||||
| if (initialized) { | |||||
| return; | |||||
| } | |||||
| MS_EXCEPTION_IF_NULL(device_mem_base); | |||||
| set_device_mem_pool_base(device_mem_base); | |||||
| if (dynamic_mem_offset > device_mem_size) { | |||||
| MS_LOG(EXCEPTION) << "Dynamic memory offset: " << dynamic_mem_offset | |||||
| << " exceed the device memory size: " << device_mem_size; | |||||
| } | |||||
| set_device_mem_size(device_mem_size); | |||||
| set_device_mem_pool_offset(device_mem_size); | |||||
| set_graph_dynamic_mem_offset(dynamic_mem_offset); | |||||
| initialized = true; | |||||
| } | |||||
| size_t AscendMemoryPool::AllocDeviceMem(size_t size, DeviceMemPtr *addr) { | size_t AscendMemoryPool::AllocDeviceMem(size_t size, DeviceMemPtr *addr) { | ||||
| if (size == 0) { | if (size == 0) { | ||||
| MS_LOG(EXCEPTION) << "Failed to alloc memory pool resource, the size is zero!"; | MS_LOG(EXCEPTION) << "Failed to alloc memory pool resource, the size is zero!"; | ||||
| @@ -29,6 +29,7 @@ class AscendMemoryPool : public DynamicMemPoolBestFit { | |||||
| AscendMemoryPool(const AscendMemoryPool &) = delete; | AscendMemoryPool(const AscendMemoryPool &) = delete; | ||||
| AscendMemoryPool &operator=(const AscendMemoryPool &) = delete; | AscendMemoryPool &operator=(const AscendMemoryPool &) = delete; | ||||
| void Init(uint8_t *device_mem_base, uint64_t device_mem_size, uint64_t dynamic_mem_offset); | |||||
| size_t AllocDeviceMem(size_t size, DeviceMemPtr *addr) override; | size_t AllocDeviceMem(size_t size, DeviceMemPtr *addr) override; | ||||
| bool FreeDeviceMem(const DeviceMemPtr &addr) override; | bool FreeDeviceMem(const DeviceMemPtr &addr) override; | ||||
| void ResetIdleMemBuf(); | void ResetIdleMemBuf(); | ||||