Merge pull request !306 from limingqi107/mastertags/v0.2.0-alpha
| @@ -127,9 +127,10 @@ bool GPUKernelRuntime::Run(session::KernelGraph *graph) { | |||
| auto context_ptr = MsContext::GetInstance(); | |||
| MS_EXCEPTION_IF_NULL(context_ptr); | |||
| bool is_enable_dynamic_mem = context_ptr->enable_dynamic_mem_pool(); | |||
| bool is_enable_pynative_infer = context_ptr->enable_pynative_infer(); | |||
| struct timeval start_time, end_time; | |||
| (void)gettimeofday(&start_time, nullptr); | |||
| if (is_enable_dynamic_mem) { | |||
| if (is_enable_dynamic_mem && !is_enable_pynative_infer) { | |||
| ret = LaunchKernelDynamic(graph); | |||
| } else { | |||
| ret = LaunchKernel(graph); | |||
| @@ -152,7 +153,7 @@ void GPUKernelRuntime::InitKernelRefCount(const session::KernelGraph *graph) { | |||
| } | |||
| mem_reuse_util_ptr->SetKernelDefMap(); | |||
| mem_reuse_util_ptr->SetReuseRefCount(); | |||
| // Can't free the device address of graph output, so set the reference count of graph output specially, | |||
| // Can't free the device address of graph output, so set the reference count of graph output specially. | |||
| mem_reuse_util_ptr->SetGraphOutputRefCount(); | |||
| mem_reuse_util_ptr_ = mem_reuse_util_ptr; | |||
| } | |||
| @@ -351,6 +352,10 @@ void GPUKernelRuntime::FreeKernelDynamicRes(const mindspore::AnfNodePtr &kernel, | |||
| if (kernel_ref_count_ptr == nullptr) { | |||
| continue; | |||
| } | |||
| // Can't free the output of graph. | |||
| if (kernel_ref_count_ptr->ref_count_dynamic_use_ == memreuse::kMaxRefCount) { | |||
| continue; | |||
| } | |||
| kernel_ref_count_ptr->ref_count_dynamic_use_--; | |||
| if (kernel_ref_count_ptr->ref_count_dynamic_use_ == 0) { | |||
| // Reset the reference count. | |||
| @@ -360,14 +365,10 @@ void GPUKernelRuntime::FreeKernelDynamicRes(const mindspore::AnfNodePtr &kernel, | |||
| FreeCommunicationOpDynamicRes(kernel, i, &is_communication_op); | |||
| if (!is_communication_op) { | |||
| auto device_address = AnfAlgo::GetPrevNodeMutableOutputAddr(kernel, i); | |||
| MS_EXCEPTION_IF_NULL(device_address); | |||
| MS_EXCEPTION_IF_NULL(device_address->ptr_); | |||
| mem_manager_->FreeMemFromMemPool(device_address->ptr_); | |||
| device_address->ptr_ = nullptr; | |||
| mem_manager_->FreeMemFromMemPool(device_address); | |||
| } | |||
| } | |||
| } | |||
| // Free the workspace of kernel. | |||
| for (size_t i = 0; i < kernel_workspaces.size(); ++i) { | |||
| auto workspace = kernel_workspaces[i]; | |||
| @@ -388,10 +389,7 @@ void GPUKernelRuntime::FreeCommunicationOpDynamicRes(const mindspore::AnfNodePtr | |||
| communication_op_input_ref_count_--; | |||
| if (communication_op_input_ref_count_ == 0) { | |||
| auto device_address = AnfAlgo::GetPrevNodeMutableOutputAddr(kernel, 0); | |||
| MS_EXCEPTION_IF_NULL(device_address); | |||
| MS_EXCEPTION_IF_NULL(device_address->ptr_); | |||
| mem_manager_->FreeMemFromMemPool(device_address->ptr_); | |||
| device_address->ptr_ = nullptr; | |||
| mem_manager_->FreeMemFromMemPool(device_address); | |||
| } | |||
| *is_communication_op = true; | |||
| return; | |||
| @@ -410,10 +408,7 @@ void GPUKernelRuntime::FreeCommunicationOpDynamicRes(const mindspore::AnfNodePtr | |||
| communication_op_output_ref_count_--; | |||
| if (communication_op_output_ref_count_ == 0) { | |||
| auto device_address = AnfAlgo::GetMutableOutputAddr(kernel_input.first, 0); | |||
| MS_EXCEPTION_IF_NULL(device_address); | |||
| MS_EXCEPTION_IF_NULL(device_address->ptr_); | |||
| mem_manager_->FreeMemFromMemPool(device_address->ptr_); | |||
| device_address->ptr_ = nullptr; | |||
| mem_manager_->FreeMemFromMemPool(device_address); | |||
| } | |||
| *is_communication_op = true; | |||
| } | |||
| @@ -155,6 +155,13 @@ void *MemoryManager::MallocMemFromMemPool(size_t size) { | |||
| return nullptr; | |||
| } | |||
| void MemoryManager::FreeMemFromMemPool(const DeviceAddressPtr address) { | |||
| MS_EXCEPTION_IF_NULL(address); | |||
| MS_EXCEPTION_IF_NULL(address->ptr_); | |||
| FreeMemFromMemPool(address->ptr_); | |||
| address->ptr_ = nullptr; | |||
| } | |||
| void MemoryManager::FreeMemFromMemPool(void *device_ptr) { | |||
| if (device_ptr == nullptr) { | |||
| MS_LOG(ERROR) << "FreeMemFromMemPool device_ptr is null."; | |||
| @@ -47,6 +47,7 @@ class MemoryManager { | |||
| virtual void MallocMemFromMemPool(const DeviceAddressPtr address, size_t size); | |||
| virtual void *MallocMemFromMemPool(size_t size); | |||
| virtual void FreeMemFromMemPool(const DeviceAddressPtr address); | |||
| virtual void FreeMemFromMemPool(void *device_ptr); | |||
| size_t GetCommonAlignSize(size_t input_size) const; | |||
| @@ -273,30 +273,21 @@ void MemReuseUtil::SetReuseRefCount() { | |||
| } | |||
| void MemReuseUtil::SetGraphOutputRefCount() { | |||
| for (const auto &output : graph_->outputs()) { | |||
| MS_EXCEPTION_IF_NULL(output); | |||
| for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(output); ++i) { | |||
| if (!(output->isa<CNode>())) { | |||
| continue; | |||
| } | |||
| auto cnode = output->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| auto input_node = cnode->input(i + 1); | |||
| MS_EXCEPTION_IF_NULL(input_node); | |||
| auto kernel_input = AnfAlgo::VisitKernel(input_node, 0); | |||
| MS_EXCEPTION_IF_NULL(kernel_input.first); | |||
| if (!(kernel_input.first->isa<CNode>())) { | |||
| continue; | |||
| } | |||
| auto ak_node = kernel_input.first->cast<CNodePtr>(); | |||
| auto key = ak_node.get(); | |||
| auto iter = kernel_output_refs_.find(key); | |||
| if ((iter != kernel_output_refs_.end()) && (kernel_input.second < iter->second.size())) { | |||
| auto kernel_ref_count_ptr = kernel_output_refs_[key][kernel_input.second]; | |||
| MS_EXCEPTION_IF_NULL(kernel_ref_count_ptr); | |||
| kernel_ref_count_ptr->ref_count_ = kMaxRefCount; | |||
| kernel_ref_count_ptr->ref_count_dynamic_use_ = kMaxRefCount; | |||
| } | |||
| auto nodes = AnfAlgo::GetAllOutput(graph_->output(), {prim::kPrimTupleGetItem}); | |||
| for (const auto &node : nodes) { | |||
| auto kernel_input = AnfAlgo::VisitKernelWithReturnType(node, 0); | |||
| MS_EXCEPTION_IF_NULL(kernel_input.first); | |||
| if (!kernel_input.first->isa<CNode>() || !AnfAlgo::IsRealKernel(kernel_input.first)) { | |||
| continue; | |||
| } | |||
| auto ak_node = kernel_input.first->cast<CNodePtr>(); | |||
| auto key = ak_node.get(); | |||
| auto iter = kernel_output_refs_.find(key); | |||
| if ((iter != kernel_output_refs_.end()) && (kernel_input.second < iter->second.size())) { | |||
| auto kernel_ref_count_ptr = kernel_output_refs_[key][kernel_input.second]; | |||
| MS_EXCEPTION_IF_NULL(kernel_ref_count_ptr); | |||
| kernel_ref_count_ptr->ref_count_ = kMaxRefCount; | |||
| kernel_ref_count_ptr->ref_count_dynamic_use_ = kMaxRefCount; | |||
| } | |||
| } | |||
| #ifdef MEM_REUSE_DEBUG | |||
| @@ -75,7 +75,7 @@ MsContext::MsContext(const std::string& policy, const std::string& target) { | |||
| precompile_only_ = false; | |||
| auto_mixed_precision_flag_ = true; | |||
| enable_pynative_infer_ = false; | |||
| enable_dynamic_mem_pool_ = false; | |||
| enable_dynamic_mem_pool_ = true; | |||
| graph_memory_max_size_ = "0"; | |||
| variable_memory_max_size_ = "0"; | |||
| MS_LOG(INFO) << "Create context with backend policy:" << policy << ", device target:" << target << "."; | |||