Merge pull request !2173 from limingqi107/gpu_memreuse_support_summary_nodetags/v0.5.0-beta
| @@ -190,6 +190,8 @@ void GPUKernelRuntime::InitKernelRefCount(const session::KernelGraph *graph) { | |||||
| mem_reuse_util_ptr->SetReuseRefCount(); | mem_reuse_util_ptr->SetReuseRefCount(); | ||||
| // Can't free the device address of graph output, so set the reference count of graph output specially. | // Can't free the device address of graph output, so set the reference count of graph output specially. | ||||
| mem_reuse_util_ptr->SetGraphOutputRefCount(); | mem_reuse_util_ptr->SetGraphOutputRefCount(); | ||||
| // Can't free the device address of summary nodes, so set the reference count of summary nodes specially. | |||||
| mem_reuse_util_ptr->SetSummaryNodesRefCount(); | |||||
| auto graph_id = graph->graph_id(); | auto graph_id = graph->graph_id(); | ||||
| mem_reuse_util_map_[graph_id] = mem_reuse_util_ptr; | mem_reuse_util_map_[graph_id] = mem_reuse_util_ptr; | ||||
| } | } | ||||
| @@ -323,6 +323,10 @@ void MemReuseUtil::SetSummaryNodesRefCount() { | |||||
| MS_LOG(WARNING) << "can't find summary node's kernel_def " << node->fullname_with_scope(); | MS_LOG(WARNING) << "can't find summary node's kernel_def " << node->fullname_with_scope(); | ||||
| } | } | ||||
| } | } | ||||
| #ifdef MEM_REUSE_DEBUG | |||||
| auto graph = *graph_; | |||||
| MemReuseChecker::GetInstance().CheckMemReuseIR(total_refs_list_, kernel_def_ptr_list_, &graph); | |||||
| #endif | |||||
| } | } | ||||
| void MemReuseUtil::SetGraphOutputRefCount() { | void MemReuseUtil::SetGraphOutputRefCount() { | ||||
| @@ -162,6 +162,8 @@ GraphId GPUSession::CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrList | |||||
| auto execution_order = graph->execution_order(); | auto execution_order = graph->execution_order(); | ||||
| Reorder(&execution_order); | Reorder(&execution_order); | ||||
| graph->set_execution_order(execution_order); | graph->set_execution_order(execution_order); | ||||
| // Get summary nodes. | |||||
| GetSummaryNodes(graph.get()); | |||||
| // Remove NoOp from execution graph | // Remove NoOp from execution graph | ||||
| opt::RemoveNopNode(graph.get()); | opt::RemoveNopNode(graph.get()); | ||||
| // Alloc memory, including static memory and dynamic memory | // Alloc memory, including static memory and dynamic memory | ||||