From 8d0691aaf97c3ebfc7e180c32842a809165d58a2 Mon Sep 17 00:00:00 2001 From: laiyongqiang Date: Tue, 16 Jun 2020 16:21:46 +0800 Subject: [PATCH] fix summary nodes memory reuse refcount --- .../ccsrc/pre_activate/mem_reuse/mem_reuse.cc | 26 +++++++++++++++++++ .../ccsrc/pre_activate/mem_reuse/mem_reuse.h | 1 + 2 files changed, 27 insertions(+) diff --git a/mindspore/ccsrc/pre_activate/mem_reuse/mem_reuse.cc b/mindspore/ccsrc/pre_activate/mem_reuse/mem_reuse.cc index 3a631daa15..0b349d02d8 100644 --- a/mindspore/ccsrc/pre_activate/mem_reuse/mem_reuse.cc +++ b/mindspore/ccsrc/pre_activate/mem_reuse/mem_reuse.cc @@ -267,6 +267,31 @@ void MemReuseUtil::SetReuseRefCount() { } } +void MemReuseUtil::SetSummaryNodesRefCount() { + bool summary_exist = graph_->summary_node_exist(); + if (!summary_exist) { + return; + } + + auto summary_nodes = graph_->summary_nodes(); + if (summary_nodes.empty()) { + return; + } + + for (auto &node_item : summary_nodes) { + auto node = node_item.second.first; + size_t index = IntToSize(node_item.second.second); + MS_LOG(INFO) << "set summary node's ref count, node: " << node->fullname_with_scope() << " index: " << index; + if (kernel_output_refs_.find(node.get()) != kernel_output_refs_.end()) { + KernelRefCountPtr kernel_ref = kernel_output_refs_[node.get()][index]; + kernel_ref->ref_count_ = kMaxRefCount; + kernel_ref->ref_count_dynamic_use_ = kMaxRefCount; + } else { + MS_LOG(WARNING) << "can't find summary node's kernel_def " << node->fullname_with_scope(); + } + } +} + void MemReuseUtil::SetGraphOutputRefCount() { auto nodes = AnfAlgo::GetAllOutput(graph_->output(), {prim::kPrimTupleGetItem}); for (const auto &node : nodes) { @@ -305,6 +330,7 @@ void MemReuseUtil::SetAllInfo(KernelGraph *graph) { } SetKernelDefMap(); SetReuseRefCount(); + SetSummaryNodesRefCount(); SetWorkSpaceList(); #ifdef MEM_REUSE_DEBUG MemReuseChecker::GetInstance().CheckMemReuseIR(total_refs_list_, kernel_def_ptr_list_, graph); diff --git a/mindspore/ccsrc/pre_activate/mem_reuse/mem_reuse.h b/mindspore/ccsrc/pre_activate/mem_reuse/mem_reuse.h index 20a362e76f..999990b094 100644 --- a/mindspore/ccsrc/pre_activate/mem_reuse/mem_reuse.h +++ b/mindspore/ccsrc/pre_activate/mem_reuse/mem_reuse.h @@ -62,6 +62,7 @@ class MemReuseUtil { void SetOutputMap(const CNodePtr &kernel, KernelDef *kernel_def_ptr); void SetWkMap(const CNodePtr &kernel, KernelDef *kernel_def_ptr); void SetReuseRefCount(); + void SetSummaryNodesRefCount(); // Set the reference count of graph output specially. void SetGraphOutputRefCount(); // Reset the dynamic used reference count by ref_count_.