| @@ -87,6 +87,7 @@ class DynamicMemPoolBestFit { | |||||
| void ReleaseDeviceRes(); | void ReleaseDeviceRes(); | ||||
| // Display the information of memory block and memory buf. | // Display the information of memory block and memory buf. | ||||
| void DumpDynamicMemPoolInfo(); | void DumpDynamicMemPoolInfo(); | ||||
| SizeMapMemBuf GetIdleMemBufMap() { return global_idle_mem_buf_map_; } | |||||
| // Get the related memory statistics information. | // Get the related memory statistics information. | ||||
| size_t total_mem_statistics() const { return total_mem_statistics_; } | size_t total_mem_statistics() const { return total_mem_statistics_; } | ||||
| @@ -29,6 +29,7 @@ | |||||
| #include "backend/kernel_compiler/kernel.h" | #include "backend/kernel_compiler/kernel.h" | ||||
| #include "backend/session/session_factory.h" | #include "backend/session/session_factory.h" | ||||
| #include "backend/session/ascend_control_parser.h" | #include "backend/session/ascend_control_parser.h" | ||||
| #include "runtime/device/ascend/ascend_memory_pool.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace session { | namespace session { | ||||
| @@ -37,7 +38,7 @@ enum GraphType : int { COMMON_GRAPH = 0, CONDITION_GRAPH = 1, BRANCH_START = 2, | |||||
| class AscendSession : public SessionBasic { | class AscendSession : public SessionBasic { | ||||
| public: | public: | ||||
| AscendSession() { final_graph_id_ = kInvalidGraphId; } | AscendSession() { final_graph_id_ = kInvalidGraphId; } | ||||
| ~AscendSession() override = default; | |||||
| ~AscendSession() override { mindspore::device::ascend::AscendMemoryPool::GetInstance().ResetIdleMemBuf(); } | |||||
| void Init(uint32_t device_id) override { | void Init(uint32_t device_id) override { | ||||
| SessionBasic::Init(device_id); | SessionBasic::Init(device_id); | ||||
| context_ = std::make_shared<Context>(kAscendDevice, device_id); | context_ = std::make_shared<Context>(kAscendDevice, device_id); | ||||
| @@ -43,6 +43,13 @@ bool AscendMemoryPool::FreeDeviceMem(const DeviceMemPtr &addr) { | |||||
| return true; | return true; | ||||
| } | } | ||||
| void AscendMemoryPool::ResetIdleMemBuf() { | |||||
| auto idle_mem_buf_map = DynamicMemPoolBestFit::GetIdleMemBufMap(); | |||||
| for (auto &it : idle_mem_buf_map) { | |||||
| rtMemset(it.second->device_addr_, it.first, 0, it.first); | |||||
| } | |||||
| } | |||||
| size_t AscendMemoryPool::AlignMemorySize(size_t size) const { | size_t AscendMemoryPool::AlignMemorySize(size_t size) const { | ||||
| if (size == 0) { | if (size == 0) { | ||||
| MS_LOG(EXCEPTION) << "The align memory size is a zero !"; | MS_LOG(EXCEPTION) << "The align memory size is a zero !"; | ||||
| @@ -31,6 +31,7 @@ class AscendMemoryPool : public DynamicMemPoolBestFit { | |||||
| size_t AllocDeviceMem(size_t size, DeviceMemPtr *addr) override; | size_t AllocDeviceMem(size_t size, DeviceMemPtr *addr) override; | ||||
| bool FreeDeviceMem(const DeviceMemPtr &addr) override; | bool FreeDeviceMem(const DeviceMemPtr &addr) override; | ||||
| void ResetIdleMemBuf(); | |||||
| void set_device_mem_size(uint64_t device_mem_size); | void set_device_mem_size(uint64_t device_mem_size); | ||||
| void set_device_mem_pool_base(uint8_t *device_mem_pool_base); | void set_device_mem_pool_base(uint8_t *device_mem_pool_base); | ||||
| void set_device_mem_pool_offset(uint64_t device_mem_pool_offset); | void set_device_mem_pool_offset(uint64_t device_mem_pool_offset); | ||||