diff --git a/mindspore/ccsrc/runtime/device/kernel_runtime.cc b/mindspore/ccsrc/runtime/device/kernel_runtime.cc index d1c0392e1c..c06205a2df 100644 --- a/mindspore/ccsrc/runtime/device/kernel_runtime.cc +++ b/mindspore/ccsrc/runtime/device/kernel_runtime.cc @@ -411,7 +411,7 @@ void KernelRuntime::AssignStaticMemoryInput(const session::KernelGraph *graph) { } auto tensor_size = CountNodeDeviceMemorySize(item, index); auto address = CreateDeviceAddress(nullptr, tensor_size, AnfAlgo::GetOutputFormat(item, index), output_type_id); - if (mem_manager_->MallocMem(address, kStaticMem, tensor_size) == nullptr) { + if (mem_manager_->MallocMem(kStaticMem, tensor_size, address) == nullptr) { MS_LOG(EXCEPTION) << "Cannot alloc address when flag is: " << kStaticMem << ", tensor size is: " << tensor_size; } AnfAlgo::SetOutputAddr(address, index, item.get()); @@ -517,7 +517,7 @@ void KernelRuntime::AssignCommunicationNodeOutputMem(MemType type, const AnfNode auto address = CreateDeviceAddress(nullptr, output_sizes[j], output_format, output_type); MS_EXCEPTION_IF_NULL(address); if (output_ptr == nullptr) { - output_ptr = mem_manager_->MallocMem(address, type, total_size, std::pair(node, 0)); + output_ptr = mem_manager_->MallocOutputMem(node, 0, type, total_size, address); MS_EXCEPTION_IF_NULL(output_ptr); } else { address->set_ptr(output_ptr); @@ -565,8 +565,7 @@ void KernelRuntime::AssignCommunicationNodeInputMem(MemType type, const AnfNodeP if (addr_size.empty()) { return; } - uint8_t *input_ptr = - mem_manager_->MallocMem(addr_size[0].first, type, total_size, std::pair(node, 0)); + uint8_t *input_ptr = mem_manager_->MallocOutputMem(node, 0, type, total_size, addr_size[0].first); for (const auto &iter : addr_size) { MS_EXCEPTION_IF_NULL(iter.first); iter.first->set_ptr(input_ptr); @@ -600,8 +599,7 @@ void KernelRuntime::AssignNodeOutputMem(MemType type, const AnfNodePtr &node, in auto output_type = AnfAlgo::GetOutputDeviceDataType(node, i); auto device_address = CreateDeviceAddress(nullptr, output_sizes[i], output_format, output_type); MS_EXCEPTION_IF_NULL(device_address); - uint8_t *ptr = - mem_manager_->MallocMem(device_address, type, output_sizes[i], std::pair(node, i)); + uint8_t *ptr = mem_manager_->MallocOutputMem(node, i, type, output_sizes[i], device_address); MS_EXCEPTION_IF_NULL(ptr); device_address->set_host_shape(trans::GetRuntimePaddingShape(node, i)); AnfAlgo::SetOutputAddr(device_address, i, node.get()); @@ -639,7 +637,7 @@ void KernelRuntime::AssignValueNodeTensor(const ValueNodePtr &value_node, const MS_EXCEPTION_IF_NULL(address); if (ms_context->enable_pynative_infer() && !mem_manager_->MallocMemFromMemPool(address, node_size)) { MS_LOG(EXCEPTION) << "Cannot alloc address from memory pool when tensor size is: " << node_size; - } else if (mem_manager_->MallocMem(address, kStaticMem, node_size) == nullptr) { + } else if (mem_manager_->MallocMem(kStaticMem, node_size, address) == nullptr) { MS_LOG(EXCEPTION) << "Cannot alloc address when flag is: " << kStaticMem << ", tensor size is: " << node_size; } AnfAlgo::SetOutputAddr(address, output_idx, value_node.get()); @@ -675,7 +673,7 @@ void KernelRuntime::AssignStaticMemoryValueNode(session::KernelGraph *graph) { MS_EXCEPTION_IF_NULL(address); if (ms_context->enable_pynative_infer() && !mem_manager_->MallocMemFromMemPool(address, tensor_size)) { MS_LOG(EXCEPTION) << "Cannot alloc address from memory pool when tensor size is: " << tensor_size; - } else if (mem_manager_->MallocMem(address, kStaticMem, tensor_size) == nullptr) { + } else if (mem_manager_->MallocMem(kStaticMem, tensor_size, address) == nullptr) { MS_LOG(EXCEPTION) << "Cannot alloc address when flag is: " << kStaticMem << ", tensor size is: " << tensor_size; } AnfAlgo::SetOutputAddr(address, 0, value_node.get()); @@ -859,8 +857,8 @@ DeviceAddressPtr KernelRuntime::AssignSingleOpLaunchMemory(size_t size, const st auto device_address = CreateDeviceAddress(nullptr, size, format, type); MS_EXCEPTION_IF_NULL(device_address); MS_EXCEPTION_IF_NULL(mem_manager_); - auto base_ptr = mem_manager_->MallocMem(kDynamicMem, size); - device_address->set_ptr(base_ptr); + auto base_ptr = mem_manager_->MallocMem(kDynamicMem, size, device_address); + MS_EXCEPTION_IF_NULL(base_ptr); return device_address; } diff --git a/mindspore/ccsrc/runtime/device/memory_manager.cc b/mindspore/ccsrc/runtime/device/memory_manager.cc index b00c1b719b..cd3cd620b5 100644 --- a/mindspore/ccsrc/runtime/device/memory_manager.cc +++ b/mindspore/ccsrc/runtime/device/memory_manager.cc @@ -45,8 +45,10 @@ void MemoryManager::MallocReusedDynamicMem(const session::KernelGraph *graph) { mem_reuse_util_ptr_->set_mem_base(base_ptr); } -uint8_t *MemoryManager::MallocOutputMem(const AnfNodePtr &node, size_t index, MemType type, size_t size) { +uint8_t *MemoryManager::MallocOutputMem(const AnfNodePtr &node, size_t index, MemType type, size_t size, + const DeviceAddressPtr &address) { MS_EXCEPTION_IF_NULL(node); + MS_EXCEPTION_IF_NULL(address); auto context_ptr = MsContext::GetInstance(); MS_EXCEPTION_IF_NULL(context_ptr); uint8_t *ptr = nullptr; @@ -57,23 +59,30 @@ uint8_t *MemoryManager::MallocOutputMem(const AnfNodePtr &node, size_t index, Me } if (type == kStaticMem) { ptr = MallocStaticMem(size, communication_mem); + address->from_mem_pool_ = true; + if (communication_mem) { + address->communication_ptr_ = ptr - kMemAlignSize; + } } else if (type == kReuseDynamicCommMem) { MS_EXCEPTION_IF_NULL(mem_reuse_util_ptr_); ptr = mem_reuse_util_ptr_->GetNodeOutputPtr(node, index); } else { ptr = MallocDynamicMem(size, communication_mem); } + address->ptr_ = ptr; return ptr; } if (type == kStaticMem) { ptr = MallocStaticMem(size, false); + address->from_mem_pool_ = true; } else if (type == kDynamicMem) { ptr = MallocDynamicMem(size, false); } else if (type == kReuseDynamicMem) { MS_EXCEPTION_IF_NULL(mem_reuse_util_ptr_); ptr = mem_reuse_util_ptr_->GetNodeOutputPtr(node, index); } + address->ptr_ = ptr; return ptr; } @@ -85,38 +94,16 @@ uint8_t *MemoryManager::MallocWorkSpaceMem(const AnfNodePtr &node, size_t index, return MallocDynamicMem(size, false); } -uint8_t *MemoryManager::MallocMem(MemType type, size_t size) { +uint8_t *MemoryManager::MallocMem(MemType type, size_t size, const DeviceAddressPtr &address) { + MS_EXCEPTION_IF_NULL(address); uint8_t *ptr = nullptr; if (type == kStaticMem) { ptr = MallocStaticMem(size, false); + address->from_mem_pool_ = true; } else if (type == kDynamicMem) { ptr = MallocDynamicMem(size, false); } - return ptr; -} - -uint8_t *MemoryManager::MallocMem(const DeviceAddressPtr &address, MemType flag, size_t size, - const session::KernelWithIndex &node_with_index) { - MS_EXCEPTION_IF_NULL(address); - auto context_ptr = MsContext::GetInstance(); - MS_EXCEPTION_IF_NULL(context_ptr); - - uint8_t *ptr = nullptr; - if (node_with_index.first != nullptr) { - ptr = MallocOutputMem(node_with_index.first, node_with_index.second, flag, size); - MS_EXCEPTION_IF_NULL(ptr); - if (AnfAlgo::IsCommunicationOp(node_with_index.first) && context_ptr->enable_hccl()) { - address->communication_ptr_ = ptr - kMemAlignSize; - } - } else { - ptr = MallocMem(flag, size); - MS_EXCEPTION_IF_NULL(ptr); - } address->ptr_ = ptr; - - if (flag == kStaticMem) { - address->from_mem_pool_ = true; - } return ptr; } diff --git a/mindspore/ccsrc/runtime/device/memory_manager.h b/mindspore/ccsrc/runtime/device/memory_manager.h index 8f31f759a4..cb045f8d27 100644 --- a/mindspore/ccsrc/runtime/device/memory_manager.h +++ b/mindspore/ccsrc/runtime/device/memory_manager.h @@ -41,11 +41,10 @@ class MemoryManager { } void MallocReusedDynamicMem(const session::KernelGraph *graph); - uint8_t *MallocOutputMem(const AnfNodePtr &node, size_t index, MemType type, size_t size); + uint8_t *MallocOutputMem(const AnfNodePtr &node, size_t index, MemType type, size_t size, + const DeviceAddressPtr &address); uint8_t *MallocWorkSpaceMem(const AnfNodePtr &node, size_t index, MemType type, size_t size); - uint8_t *MallocMem(const DeviceAddressPtr &address, MemType flag, size_t size, - const session::KernelWithIndex &node_with_index = std::pair(nullptr, 0)); - virtual uint8_t *MallocMem(MemType type, size_t size); + virtual uint8_t *MallocMem(MemType type, size_t size, const DeviceAddressPtr &address); virtual bool MallocMemFromMemPool(const DeviceAddressPtr address, size_t size); virtual void *MallocMemFromMemPool(size_t size);