From acba03b191f48f0474455e8c7691833c1c106b0e Mon Sep 17 00:00:00 2001 From: laiyongqiang Date: Thu, 16 Jul 2020 15:27:23 +0800 Subject: [PATCH] not reuse ref node input's memory --- .../optimizer/mem_reuse/kernel_refcount.h | 8 +-- .../backend/optimizer/mem_reuse/mem_reuse.cc | 52 ++++++++++++++----- .../backend/optimizer/mem_reuse/mem_reuse.h | 3 +- .../mem_reuse/mem_reuse_allocator.cc | 43 ++++++++------- .../optimizer/mem_reuse/mem_reuse_allocator.h | 7 +-- .../optimizer/mem_reuse/mem_reuse_checker.cc | 6 +-- .../optimizer/mem_reuse/mem_reuse_checker.h | 6 +-- .../ccsrc/runtime/device/kernel_runtime.cc | 40 +++++++------- .../ccsrc/runtime/device/kernel_runtime.h | 10 ++-- .../ccsrc/runtime/device/memory_manager.cc | 24 ++++----- .../ccsrc/runtime/device/memory_manager.h | 13 ++--- 11 files changed, 121 insertions(+), 91 deletions(-) diff --git a/mindspore/ccsrc/backend/optimizer/mem_reuse/kernel_refcount.h b/mindspore/ccsrc/backend/optimizer/mem_reuse/kernel_refcount.h index 58f7ef3672..36bf8f7180 100644 --- a/mindspore/ccsrc/backend/optimizer/mem_reuse/kernel_refcount.h +++ b/mindspore/ccsrc/backend/optimizer/mem_reuse/kernel_refcount.h @@ -25,8 +25,8 @@ namespace mindspore { namespace memreuse { enum RefCountType { kDynamicRefCount, kStaticRefCount }; -enum NodeType { COMMON_NODE, COMMUNICATION_NODE }; -enum KernelRefType { COMMON, REFNODE_OUTPUT, COMM_NOTREUSE, COMM_REUSE, SUMMARY }; +enum NodeType { kCommonNode, kCommunicationNode }; +enum KernelRefType { kCommon, kRefNodeInput, kRefNodeOutput, kCommNotReuse, kCommReuse, kSummary }; static constexpr int kInitIndex = -1; class KernelRefCount { public: @@ -46,7 +46,7 @@ class KernelRefCount { offset_(0), size_(0), index_(kInitIndex), - type_(COMMON), + type_(kCommon), reftype_(kStaticRefCount) {} ~KernelRefCount() = default; void SetKernelRefCountInfo(int index, size_t size, RefCountType reftype); @@ -68,7 +68,7 @@ class KernelDef { KernelMap inputs_; KernelMap outputs_; KernelMap wk_space_; - NodeType type_ = COMMON_NODE; + NodeType type_ = kCommonNode; KernelDef() = default; ~KernelDef() = default; void set_input_refs(const KernelRefCountPtrList &kernelRefPtrList) { input_refs_ = kernelRefPtrList; } diff --git a/mindspore/ccsrc/backend/optimizer/mem_reuse/mem_reuse.cc b/mindspore/ccsrc/backend/optimizer/mem_reuse/mem_reuse.cc index 8166a7bcc1..02a277f224 100644 --- a/mindspore/ccsrc/backend/optimizer/mem_reuse/mem_reuse.cc +++ b/mindspore/ccsrc/backend/optimizer/mem_reuse/mem_reuse.cc @@ -57,13 +57,22 @@ bool MemReuseUtil::InitDynamicOutputKernelRef() { kernel_ref->stream_id_ = curr_stream_id; kernel_ref->SetKernelRefCountInfo(index, size, kDynamicRefCount); if (is_comm_op) { - kernel_ref->type_ = COMM_REUSE; + kernel_ref->type_ = kCommReuse; } else { session::AnfWithOutIndex out_pair(kernel_cnode, output_index); if (graph_->IsInRefOutputMap(out_pair)) { - kernel_ref->type_ = REFNODE_OUTPUT; + kernel_ref->type_ = kRefNodeOutput; + auto origin_pair = graph_->GetRefCorrespondOutput(out_pair); + MS_EXCEPTION_IF_NULL(origin_pair.first); + if (origin_pair.first->isa()) { + auto cnode = origin_pair.first->cast(); + auto ref_ptr = GetKernelInputRef(cnode, origin_pair.second); + if (ref_ptr != nullptr) { + kernel_ref->type_ = kRefNodeInput; + } + } } else { - kernel_ref->type_ = COMMON; + kernel_ref->type_ = kCommon; } } kernel_refs.push_back(kernel_ref); @@ -175,9 +184,9 @@ void MemReuseUtil::SetInputMap(const CNodePtr &kernel, KernelDef *kernel_def_ptr if (ref_ptr != nullptr) { if (is_comm_op) { if (input_tensor_num == 1) { - ref_ptr->type_ = COMM_REUSE; + ref_ptr->type_ = kCommReuse; } else { - ref_ptr->type_ = COMM_NOTREUSE; + ref_ptr->type_ = kCommNotReuse; } } @@ -282,9 +291,9 @@ void MemReuseUtil::SetKernelDefMap() { kernel_def_ptr->set_input_refs(kernel_def_ptr->inputs_[key]); kernel_def_ptr->set_output_refs(kernel_def_ptr->outputs_[key]); if (AnfAlgo::IsCommunicationOp(kernel)) { - kernel_def_ptr->type_ = COMMUNICATION_NODE; + kernel_def_ptr->type_ = kCommunicationNode; } else { - kernel_def_ptr->type_ = COMMON_NODE; + kernel_def_ptr->type_ = kCommonNode; } kernel_def_ptr_list_.push_back(kernel_def_ptr); kernel_map_[key] = kernel_def_ptr; @@ -365,7 +374,7 @@ void MemReuseUtil::SetSummaryNodesRefCount() { KernelRefCountPtr kernel_ref = kernel_output_refs_[node.get()][index]; kernel_ref->ref_count_ = kMaxRefCount; kernel_ref->ref_count_dynamic_use_ = kMaxRefCount; - kernel_ref->type_ = SUMMARY; + kernel_ref->type_ = kSummary; total_summary_size += kernel_ref->size_; MS_LOG(INFO) << "Set summary node's ref count, node: " << node->fullname_with_scope() << " index: " << index; } else { @@ -373,12 +382,29 @@ void MemReuseUtil::SetSummaryNodesRefCount() { } } #ifdef MEM_REUSE_DEBUG - auto graph = *graph_; - MemReuseChecker::GetInstance().CheckMemReuseIR(total_refs_list_, kernel_def_ptr_list_, &graph); + MemReuseChecker::GetInstance().CheckMemReuseIR(total_refs_list_, kernel_def_ptr_list_, graph_); #endif MS_LOG(INFO) << "Special Tensor total size: SummaryNodes: " << total_summary_size; } +void MemReuseUtil::SetRefNodesInputRefCount() { + size_t total_size = 0; + for (auto iter : kernel_output_refs_) { + for (auto &ref_count : iter.second) { + MS_EXCEPTION_IF_NULL(ref_count); + if (ref_count->type_ == kRefNodeInput) { + ref_count->ref_count_ = kMaxRefCount; + total_size += ref_count->size_; + } + } + } + + MS_LOG(INFO) << "Special Tensor total size: RefNodeInput: " << total_size; +#ifdef MEM_REUSE_DEBUG + MemReuseChecker::GetInstance().CheckMemReuseIR(total_refs_list_, kernel_def_ptr_list_, graph_); +#endif +} + void MemReuseUtil::SetGraphOutputRefCount() { auto nodes = AnfAlgo::GetAllOutput(graph_->output(), {prim::kPrimTupleGetItem}); for (const auto &node : nodes) { @@ -405,8 +431,7 @@ void MemReuseUtil::SetGraphOutputRefCount() { } } #ifdef MEM_REUSE_DEBUG - auto graph = *graph_; - MemReuseChecker::GetInstance().CheckMemReuseIR(total_refs_list_, kernel_def_ptr_list_, &graph); + MemReuseChecker::GetInstance().CheckMemReuseIR(total_refs_list_, kernel_def_ptr_list_, graph_); #endif } @@ -419,13 +444,14 @@ void MemReuseUtil::ResetDynamicUsedRefCount() { } } -void MemReuseUtil::SetAllInfo(KernelGraph *graph) { +void MemReuseUtil::SetAllInfo(const KernelGraph *graph) { if (!InitDynamicKernelRef(graph)) { MS_LOG(EXCEPTION) << "Init ReuseAssignDynamicMemory Fault"; } SetKernelDefMap(); SetReuseRefCount(); SetSummaryNodesRefCount(); + SetRefNodesInputRefCount(); SetWorkSpaceList(); #ifdef MEM_REUSE_DEBUG MemReuseChecker::GetInstance().CheckMemReuseIR(total_refs_list_, kernel_def_ptr_list_, graph); diff --git a/mindspore/ccsrc/backend/optimizer/mem_reuse/mem_reuse.h b/mindspore/ccsrc/backend/optimizer/mem_reuse/mem_reuse.h index 011b20c4ab..b755e049b8 100644 --- a/mindspore/ccsrc/backend/optimizer/mem_reuse/mem_reuse.h +++ b/mindspore/ccsrc/backend/optimizer/mem_reuse/mem_reuse.h @@ -52,7 +52,7 @@ class MemReuseUtil { MS_LOG(INFO) << "Total Reused WorkSpafce Memory Size: " << total_reuseworkspace_size_; } - void SetAllInfo(KernelGraph *graph); + void SetAllInfo(const KernelGraph *graph); bool InitDynamicOutputKernelRef(); bool InitDynamicWorkspaceKernelRef(); bool InitDynamicKernelRef(const KernelGraph *graph); @@ -64,6 +64,7 @@ class MemReuseUtil { void SetKernelDefInputs(); void SetReuseRefCount(); void SetSummaryNodesRefCount(); + void SetRefNodesInputRefCount(); // Set the reference count of graph output specially. void SetGraphOutputRefCount(); // Reset the dynamic used reference count by ref_count_. diff --git a/mindspore/ccsrc/backend/optimizer/mem_reuse/mem_reuse_allocator.cc b/mindspore/ccsrc/backend/optimizer/mem_reuse/mem_reuse_allocator.cc index f57a78863a..e791d318fa 100644 --- a/mindspore/ccsrc/backend/optimizer/mem_reuse/mem_reuse_allocator.cc +++ b/mindspore/ccsrc/backend/optimizer/mem_reuse/mem_reuse_allocator.cc @@ -90,7 +90,7 @@ bool BestFitMemReuse::IsUsable(const KernelDefPtr &kernel_curr, const MembufPtr auto curr_stream_id = kernel_curr->stream_id(); auto prev_stream_id = kernel_prev->stream_id(); if (curr_stream_id == prev_stream_id) { - mem_buf->type_ = IN_STREAM_REUSE; + mem_buf->type_ = kInStreamReuse; return true; } @@ -117,7 +117,7 @@ bool BestFitMemReuse::IsUsable(const KernelDefPtr &kernel_curr, const MembufPtr } if (reuse_between_streams) { - mem_buf->type_ = BETWEEN_STREAMS_REUSE; + mem_buf->type_ = kBetweenStreamReuse; return true; } @@ -128,7 +128,7 @@ bool BestFitMemReuse::IsUsable(const KernelDefPtr &kernel_curr, const MembufPtr auto kernel_curr_front = iter->second; auto depend_count = kernel_curr_front.count(kernel_prev); if (depend_count) { - mem_buf->type_ = KERNEL_DEPENDENCE_REUSE; + mem_buf->type_ = kKernelDependenceReuse; return true; } @@ -137,16 +137,19 @@ bool BestFitMemReuse::IsUsable(const KernelDefPtr &kernel_curr, const MembufPtr void BestFitMemReuse::AssignCommonNodeOutputOffset() { MS_EXCEPTION_IF_NULL(current_kernel_); - for (auto &tensor_idx : current_kernel_->GetOutputRefIndexs()) { + for (const auto &tensor_idx : current_kernel_->GetOutputRefIndexs()) { size_t index = GetTensorIndex(tensor_idx); auto tensor_desc = tensor_ptr_list_[index]; MS_EXCEPTION_IF_NULL(tensor_desc); - if (tensor_desc->type_ == REFNODE_OUTPUT) { + if (tensor_desc->type_ == kRefNodeInput) { + total_refinput_size += tensor_desc->size_; + } else if (tensor_desc->type_ == kRefNodeOutput) { total_refoutput_size += tensor_desc->size_; + // no need to alloc refnode output's memory continue; - } else if (tensor_desc->type_ == COMM_NOTREUSE) { + } else if (tensor_desc->type_ == kCommNotReuse) { total_comm_not_reuse_size += tensor_desc->size_; - } else if (tensor_desc->type_ == COMM_REUSE) { + } else if (tensor_desc->type_ == kCommReuse) { // get align size for communication op's single input tensor_desc->size_ = AlignCommunicationMemorySize(tensor_desc->size_); total_comm_reuse_size += tensor_desc->size_; @@ -165,7 +168,7 @@ void BestFitMemReuse::AssignCommonNodeOutputOffset() { #endif } // skip left align border for communication op single input to used - if (tensor_desc->type_ == COMM_REUSE) { + if (tensor_desc->type_ == kCommReuse) { tensor_desc->offset_ += kDefaultMemAlignSize; } } @@ -176,17 +179,18 @@ void BestFitMemReuse::AssignCommunicationNodeOutputOffset() { size_t output_num = 0; // get all output size MS_EXCEPTION_IF_NULL(current_kernel_); - for (auto &tensor_idx : current_kernel_->GetOutputRefIndexs()) { + for (const auto &tensor_idx : current_kernel_->GetOutputRefIndexs()) { size_t index = GetTensorIndex(tensor_idx); auto tensor_desc = tensor_ptr_list_[index]; MS_EXCEPTION_IF_NULL(tensor_desc); - if (tensor_desc->type_ == COMM_REUSE) { + if (tensor_desc->type_ == kCommReuse) { total_comm_reuse_size += tensor_desc->size_; total_comm_output_reuse_size += tensor_desc->size_; total_kernel_output_size += tensor_desc->size_; } else { MS_LOG(ERROR) << "All communication op's outputs should be memory reuse, Kernel:" - << current_kernel_->scope_full_name(); + << current_kernel_->scope_full_name() << " output index:" << tensor_idx + << " tensor_type:" << tensor_desc->type_; continue; } } @@ -195,7 +199,7 @@ void BestFitMemReuse::AssignCommunicationNodeOutputOffset() { // add left align border for the first output and right align border for the last output to alloc align border memory size_t output_index = 0; auto output_ref_indexes = current_kernel_->GetOutputRefIndexs(); - for (auto &tensor_idx : output_ref_indexes) { + for (const auto &tensor_idx : output_ref_indexes) { size_t index = GetTensorIndex(tensor_idx); auto tensor_desc = tensor_ptr_list_[index]; MS_EXCEPTION_IF_NULL(tensor_desc); @@ -215,7 +219,7 @@ void BestFitMemReuse::AssignCommunicationNodeOutputOffset() { if (!reusable_membuf_map.empty()) { auto membuf_index = reusable_membuf_map.begin()->second; output_index = 0; - for (auto &tensor_idx : current_kernel_->GetOutputRefIndexs()) { + for (const auto &tensor_idx : current_kernel_->GetOutputRefIndexs()) { size_t index = GetTensorIndex(tensor_idx); auto tensor_desc = tensor_ptr_list_[index]; MS_EXCEPTION_IF_NULL(tensor_desc); @@ -229,7 +233,7 @@ void BestFitMemReuse::AssignCommunicationNodeOutputOffset() { } else { // no membuf can reuse, add new membuf after the membuf_ptr_list output_index = 0; - for (auto &tensor_idx : current_kernel_->GetOutputRefIndexs()) { + for (const auto &tensor_idx : current_kernel_->GetOutputRefIndexs()) { size_t index = GetTensorIndex(tensor_idx); auto tensor_desc = tensor_ptr_list_[index]; MS_EXCEPTION_IF_NULL(tensor_desc); @@ -247,7 +251,7 @@ void BestFitMemReuse::AssignCommunicationNodeOutputOffset() { } void BestFitMemReuse::AssignNodeOutputOffset() { - if (current_kernel_->type_ == COMMUNICATION_NODE) { + if (current_kernel_->type_ == kCommunicationNode) { AssignCommunicationNodeOutputOffset(); } else { AssignCommonNodeOutputOffset(); @@ -330,7 +334,7 @@ void BestFitMemReuse::AddNewMembufPtr(KernelRefCount *tensor_desc, int flag) { } auto membuf_size = tensor_desc->size_; auto real_index = GetRealIndex(IntToSize(tensor_desc->index_), flag); - auto membuf = std::make_shared(kReused, membuf_size, membuf_offset, real_index, NEW, current_kernel_); + auto membuf = std::make_shared(kReused, membuf_size, membuf_offset, real_index, kNew, current_kernel_); membuf_ptr_list_.push_back(membuf); tensor_desc->offset_ = membuf_offset; } @@ -352,7 +356,7 @@ void BestFitMemReuse::UpdateNodeInputAndMembuf() { } void BestFitMemReuse::ReleaseNodeUnusedOutput() { - for (auto &tensor_idx : current_kernel_->GetOutputRefIndexs()) { + for (const auto &tensor_idx : current_kernel_->GetOutputRefIndexs()) { size_t tensor_index = GetTensorIndex(tensor_idx); auto tensor_desc = tensor_ptr_list_[tensor_index]; MS_EXCEPTION_IF_NULL(tensor_desc); @@ -517,8 +521,9 @@ void BestFitMemReuse::Reuse(const MemReuseUtil *mem_reuse_util_ptr) { ++op_num; #endif } - MS_LOG(INFO) << "Special Tensor total size: RefOutput: " << total_refoutput_size - << " CommReuse: " << total_comm_reuse_size << " CommOutputReuse: " << total_comm_output_reuse_size + MS_LOG(INFO) << "Special Tensor total size: RefInput: " << total_refinput_size + << " RefOutput: " << total_refoutput_size << " CommReuse: " << total_comm_reuse_size + << " CommOutputReuse: " << total_comm_output_reuse_size << " CommNotReuse: " << total_comm_not_reuse_size; #ifdef MEM_REUSE_DEBUG MemReuseChecker::GetInstance().ExportMembufInfoIR(); diff --git a/mindspore/ccsrc/backend/optimizer/mem_reuse/mem_reuse_allocator.h b/mindspore/ccsrc/backend/optimizer/mem_reuse/mem_reuse_allocator.h index 322c7b940c..b5f1f5b9c5 100644 --- a/mindspore/ccsrc/backend/optimizer/mem_reuse/mem_reuse_allocator.h +++ b/mindspore/ccsrc/backend/optimizer/mem_reuse/mem_reuse_allocator.h @@ -40,11 +40,11 @@ static constexpr int kDynamicMem = -1; static constexpr int kWorkspaceMem = 1; static constexpr size_t kTotalSize = 0; enum Status { kUnused, kReused }; -enum MEMTYPE { NEW, IN_STREAM_REUSE, BETWEEN_STREAMS_REUSE, KERNEL_DEPENDENCE_REUSE }; +enum MemType { kNew, kInStreamReuse, kBetweenStreamReuse, kKernelDependenceReuse }; class Membuf { public: Membuf() = default; - Membuf(Status status, size_t size, size_t offset, int index, MEMTYPE type, const KernelDefPtr &used_kernel) + Membuf(Status status, size_t size, size_t offset, int index, MemType type, const KernelDefPtr &used_kernel) : status_(status), size_(size), offset_(offset), index_(index), type_(type), used_kernel_(used_kernel) {} ~Membuf() = default; // Memory block status flags @@ -53,7 +53,7 @@ class Membuf { size_t offset_{0}; // Store the tensor index stored in this memory block at a certain moment int index_{0}; - MEMTYPE type_{NEW}; + MemType type_{kNew}; KernelDefPtr used_kernel_; }; using MembufPtr = std::shared_ptr; @@ -163,6 +163,7 @@ class BestFitMemReuse { // kernel_front_map_, key: the kernel_def, value: kernels before this kernel_def std::map> kernel_front_map_; std::vector> stream_groups_; + size_t total_refinput_size{0}; size_t total_refoutput_size{0}; size_t total_comm_reuse_size{0}; size_t total_comm_output_reuse_size{0}; diff --git a/mindspore/ccsrc/backend/optimizer/mem_reuse/mem_reuse_checker.cc b/mindspore/ccsrc/backend/optimizer/mem_reuse/mem_reuse_checker.cc index eca595cead..81dc3f8ba0 100644 --- a/mindspore/ccsrc/backend/optimizer/mem_reuse/mem_reuse_checker.cc +++ b/mindspore/ccsrc/backend/optimizer/mem_reuse/mem_reuse_checker.cc @@ -83,7 +83,7 @@ int64_t MemReuseChecker::CalculOriInput(const KernelGraph *graph) const { return static_input_size; } -int64_t MemReuseChecker::CalculOriValue(KernelGraph *graph) const { +int64_t MemReuseChecker::CalculOriValue(const KernelGraph *graph) const { MS_EXCEPTION_IF_NULL(graph); int64_t static_value_size = 0; for (auto &value_node : graph->graph_value_nodes()) { @@ -101,7 +101,7 @@ int64_t MemReuseChecker::CalculOriValue(KernelGraph *graph) const { return static_value_size; } -int64_t MemReuseChecker::CalculOriStatic(KernelGraph *graph) const { +int64_t MemReuseChecker::CalculOriStatic(const KernelGraph *graph) const { // cal static inputs auto static_input_size = CalculOriInput(graph); // do not calcul outpput size @@ -154,7 +154,7 @@ std::string MemReuseChecker::GetSplitName(const std::string &scope_name) const { } void MemReuseChecker::CheckMemReuseIR(const KernelRefCountPtrList &total_refs_list, - const KernelDefPtrMaps &kernel_def_ptr_list, KernelGraph *graph) { + const KernelDefPtrMaps &kernel_def_ptr_list, const KernelGraph *graph) { total_ori_static_size_ = CalculOriStatic(graph); total_ori_input_size_ = CalculOriInput(graph); total_ori_value_size_ = CalculOriValue(graph); diff --git a/mindspore/ccsrc/backend/optimizer/mem_reuse/mem_reuse_checker.h b/mindspore/ccsrc/backend/optimizer/mem_reuse/mem_reuse_checker.h index 3c4a00a3ca..9b4d1215ee 100644 --- a/mindspore/ccsrc/backend/optimizer/mem_reuse/mem_reuse_checker.h +++ b/mindspore/ccsrc/backend/optimizer/mem_reuse/mem_reuse_checker.h @@ -43,10 +43,10 @@ class MemReuseChecker { void CheckOutRef(const KernelRefs &kernel_refs, const CNodePtr &c_node, size_t output_idx); bool CheckGraphOutputAssigned(const session::KernelGraph *graph); void CheckMemReuseIR(const KernelRefCountPtrList &total_refs_list, const KernelDefPtrMaps &kernel_def_ptr_list, - KernelGraph *graph); - int64_t CalculOriStatic(KernelGraph *graph) const; + const KernelGraph *graph); + int64_t CalculOriStatic(const KernelGraph *graph) const; int64_t CalculOriInput(const KernelGraph *graph) const; - int64_t CalculOriValue(KernelGraph *graph) const; + int64_t CalculOriValue(const KernelGraph *graph) const; int64_t CalculOriDy(const KernelGraph *graph) const; int64_t CalculOriWk(const KernelGraph *graph) const; std::string GetSplitName(const std::string &scope_name) const; diff --git a/mindspore/ccsrc/runtime/device/kernel_runtime.cc b/mindspore/ccsrc/runtime/device/kernel_runtime.cc index d6cce971c2..418a75fd6a 100644 --- a/mindspore/ccsrc/runtime/device/kernel_runtime.cc +++ b/mindspore/ccsrc/runtime/device/kernel_runtime.cc @@ -398,12 +398,12 @@ void KernelRuntime::UpdateRefNodeOutputMem(const session::KernelGraph *graph) { } } -void KernelRuntime::AssignCommunicationNodeMem(int flag, const AnfNodePtr &node) { - AssignCommunicationNodeInputMem(flag, node); - AssignCommunicationNodeOutputMem(flag, node); +void KernelRuntime::AssignCommunicationNodeMem(MemType type, const AnfNodePtr &node) { + AssignCommunicationNodeInputMem(type, node); + AssignCommunicationNodeOutputMem(type, node); } -void KernelRuntime::AssignCommunicationNodeOutputMem(int flag, const AnfNodePtr &node) { +void KernelRuntime::AssignCommunicationNodeOutputMem(MemType type, const AnfNodePtr &node) { MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(mem_manager_); auto kernel_mod = AnfAlgo::GetKernelMod(node); @@ -430,11 +430,11 @@ void KernelRuntime::AssignCommunicationNodeOutputMem(int flag, const AnfNodePtr align_size_list.emplace_back(mem_size); } - if (flag == kReuseDynamicMem) { + if (type == kReuseDynamicMem) { // reuse communication op's all outputs' memory - flag = kReuseDynamicCommMem; + type = kReuseDynamicCommMem; } - uint8_t *output_ptr = mem_manager_->MallocOutputMem(node, 0, flag, total_size); + uint8_t *output_ptr = mem_manager_->MallocOutputMem(node, 0, type, total_size); for (size_t j = 0; j < align_size_list.size(); ++j) { std::string output_format = AnfAlgo::GetOutputFormat(node, j); auto output_type = AnfAlgo::GetOutputDeviceDataType(node, j); @@ -458,7 +458,7 @@ DeviceAddressPtr KernelRuntime::PreAssignCNodeMemory(const AnfNodePtr &anf_node, return address; } -void KernelRuntime::AssignCommunicationNodeInputMem(int flag, const AnfNodePtr &node) { +void KernelRuntime::AssignCommunicationNodeInputMem(MemType type, const AnfNodePtr &node) { auto context_ptr = MsContext::GetInstance(); MS_EXCEPTION_IF_NULL(context_ptr); MS_EXCEPTION_IF_NULL(node); @@ -479,7 +479,7 @@ void KernelRuntime::AssignCommunicationNodeInputMem(int flag, const AnfNodePtr & total_size += mem_size; addr_size.emplace_back(address.get(), mem_size); } - uint8_t *input_ptr = mem_manager_->MallocOutputMem(node, 0, flag, total_size); + uint8_t *input_ptr = mem_manager_->MallocOutputMem(node, 0, type, total_size); for (const auto &iter : addr_size) { MS_EXCEPTION_IF_NULL(iter.first); iter.first->set_ptr(input_ptr); @@ -487,12 +487,12 @@ void KernelRuntime::AssignCommunicationNodeInputMem(int flag, const AnfNodePtr & } } -void KernelRuntime::AssignNodeOutputMem(int flag, const AnfNodePtr &node, int index) { +void KernelRuntime::AssignNodeOutputMem(MemType type, const AnfNodePtr &node, int index) { MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(mem_manager_); - if (AnfAlgo::IsGetNext(NOT_NULL(node)) && flag == kReuseDynamicMem) { + if (AnfAlgo::IsGetNext(NOT_NULL(node)) && type == kReuseDynamicMem) { MS_LOG(INFO) << "GetNext disable mem_reuse"; - flag = kDynamicMem; + type = kDynamicMem; } auto kernel_mod = AnfAlgo::GetKernelMod(node); MS_EXCEPTION_IF_NULL(kernel_mod); @@ -509,7 +509,7 @@ void KernelRuntime::AssignNodeOutputMem(int flag, const AnfNodePtr &node, int in MS_LOG(INFO) << "Already malloc index:" << i; continue; } - auto ptr = mem_manager_->MallocOutputMem(node, i, flag, output_sizes[i]); + auto ptr = mem_manager_->MallocOutputMem(node, i, type, output_sizes[i]); if (ptr == nullptr) { // reused ptr, no need alloc, continue; continue; @@ -608,10 +608,10 @@ void KernelRuntime::AssignDynamicMemory(session::KernelGraph *graph) { auto context_ptr = MsContext::GetInstance(); MS_EXCEPTION_IF_NULL(context_ptr); bool is_enable_mem_reuse = context_ptr->enable_mem_reuse(); - auto mem_flag = kDynamicMem; + auto mem_type = kDynamicMem; if (is_enable_mem_reuse) { mem_manager_->MallocReusedDynamicMem(graph); - mem_flag = kReuseDynamicMem; + mem_type = kReuseDynamicMem; } auto &execution_nodes = graph->execution_order(); std::vector compute_nodes; @@ -619,7 +619,7 @@ void KernelRuntime::AssignDynamicMemory(session::KernelGraph *graph) { for (auto &node : execution_nodes) { if (AnfAlgo::IsCommunicationOp(node)) { // skip if the memory is already alocated - AssignCommunicationNodeMem(mem_flag, node); + AssignCommunicationNodeMem(mem_type, node); } else { compute_nodes.emplace_back(node); } @@ -627,19 +627,19 @@ void KernelRuntime::AssignDynamicMemory(session::KernelGraph *graph) { // then compute nodes for (auto &node : compute_nodes) { - AssignNodeOutputMem(mem_flag, node, kGetAllOuts); - AssignWorkSpaceMem(mem_flag, node); + AssignNodeOutputMem(mem_type, node, kGetAllOuts); + AssignWorkSpaceMem(mem_type, node); } } -void KernelRuntime::AssignWorkSpaceMem(int flag, const AnfNodePtr &node) { +void KernelRuntime::AssignWorkSpaceMem(MemType type, const AnfNodePtr &node) { MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(mem_manager_); auto kernel_mod = AnfAlgo::GetKernelMod(node); MS_EXCEPTION_IF_NULL(kernel_mod); size_t index = 0; for (auto &size : kernel_mod->GetWorkspaceSizeList()) { - auto ptr = mem_manager_->MallocWorkSpaceMem(node, index, flag, size); + auto ptr = mem_manager_->MallocWorkSpaceMem(node, index, type, size); AnfAlgo::SetWorkspaceAddr(CreateDeviceAddress(ptr, size, "", kTypeUnknown), index, node.get()); index++; } diff --git a/mindspore/ccsrc/runtime/device/kernel_runtime.h b/mindspore/ccsrc/runtime/device/kernel_runtime.h index 41cbd6f4e4..3f441cb897 100644 --- a/mindspore/ccsrc/runtime/device/kernel_runtime.h +++ b/mindspore/ccsrc/runtime/device/kernel_runtime.h @@ -81,15 +81,15 @@ class KernelRuntime { void AssignStaticMemory(session::KernelGraph *graph); void AssignDynamicMemory(session::KernelGraph *graph); void ReuseAssignDynamicMemory(session::KernelGraph *graph); - void AssignNodeOutputMem(int flag, const AnfNodePtr &node, int index); - void AssignWorkSpaceMem(int flag, const AnfNodePtr &node); + void AssignNodeOutputMem(MemType type, const AnfNodePtr &node, int index); + void AssignWorkSpaceMem(MemType type, const AnfNodePtr &node); void AssignReuseWorkSpaceMem(const AnfNodePtr &node); void UpdateRefNodeOutputMem(const session::KernelGraph *graph); - void AssignCommunicationNodeOutputMem(int flag, const AnfNodePtr &node); - void AssignCommunicationNodeInputMem(int flag, const AnfNodePtr &node); - void AssignCommunicationNodeMem(int flag, const AnfNodePtr &node); + void AssignCommunicationNodeOutputMem(MemType type, const AnfNodePtr &node); + void AssignCommunicationNodeInputMem(MemType type, const AnfNodePtr &node); + void AssignCommunicationNodeMem(MemType type, const AnfNodePtr &node); #ifdef ENABLE_DUMP_E2E bool SetDumpConf(); #endif diff --git a/mindspore/ccsrc/runtime/device/memory_manager.cc b/mindspore/ccsrc/runtime/device/memory_manager.cc index 0199f8ee18..46a624922d 100644 --- a/mindspore/ccsrc/runtime/device/memory_manager.cc +++ b/mindspore/ccsrc/runtime/device/memory_manager.cc @@ -29,7 +29,7 @@ size_t MemoryManager::GetCommunicationAlignSize(size_t input_size) const { return (input_size + kMemAlignSize - 1) / kMemAlignSize * kMemAlignSize + 2 * kMemAlignSize; } -void MemoryManager::MallocReusedDynamicMem(session::KernelGraph *graph) { +void MemoryManager::MallocReusedDynamicMem(const session::KernelGraph *graph) { MS_EXCEPTION_IF_NULL(graph); MemReuseUtilPtr mem_reuse_util_ptr = std::make_shared(); MS_EXCEPTION_IF_NULL(mem_reuse_util_ptr); @@ -45,7 +45,7 @@ void MemoryManager::MallocReusedDynamicMem(session::KernelGraph *graph) { mem_reuse_util_ptr_->set_mem_base(base_ptr); } -uint8_t *MemoryManager::MallocOutputMem(const AnfNodePtr &node, size_t index, int flag, size_t size) { +uint8_t *MemoryManager::MallocOutputMem(const AnfNodePtr &node, size_t index, MemType type, size_t size) { MS_EXCEPTION_IF_NULL(node); auto context_ptr = MsContext::GetInstance(); MS_EXCEPTION_IF_NULL(context_ptr); @@ -55,9 +55,9 @@ uint8_t *MemoryManager::MallocOutputMem(const AnfNodePtr &node, size_t index, in if (context_ptr->enable_hccl()) { communication_mem = true; } - if (flag == kStaticMem) { + if (type == kStaticMem) { ptr = MallocStaticMem(size, communication_mem); - } else if (flag == kReuseDynamicCommMem) { + } else if (type == kReuseDynamicCommMem) { MS_EXCEPTION_IF_NULL(mem_reuse_util_ptr_); ptr = mem_reuse_util_ptr_->GetNodeOutputPtr(node, index); } else { @@ -66,30 +66,30 @@ uint8_t *MemoryManager::MallocOutputMem(const AnfNodePtr &node, size_t index, in return ptr; } - if (flag == kStaticMem) { + if (type == kStaticMem) { ptr = MallocStaticMem(size, false); - } else if (flag == kDynamicMem) { + } else if (type == kDynamicMem) { ptr = MallocDynamicMem(size, false); - } else if (flag == kReuseDynamicMem) { + } else if (type == kReuseDynamicMem) { MS_EXCEPTION_IF_NULL(mem_reuse_util_ptr_); ptr = mem_reuse_util_ptr_->GetNodeOutputPtr(node, index); } return ptr; } -uint8_t *MemoryManager::MallocWorkSpaceMem(const AnfNodePtr &node, size_t index, int flag, size_t size) { - if (flag == kReuseDynamicMem) { +uint8_t *MemoryManager::MallocWorkSpaceMem(const AnfNodePtr &node, size_t index, MemType type, size_t size) { + if (type == kReuseDynamicMem) { MS_EXCEPTION_IF_NULL(mem_reuse_util_ptr_); return mem_reuse_util_ptr_->GetNodeWorkSpacePtr(node, index); } return MallocDynamicMem(size, false); } -uint8_t *MemoryManager::MallocMem(int flag, size_t size) { +uint8_t *MemoryManager::MallocMem(MemType type, size_t size) { uint8_t *ptr = nullptr; - if (flag == kStaticMem) { + if (type == kStaticMem) { ptr = MallocStaticMem(size, false); - } else if (flag == kDynamicMem) { + } else if (type == kDynamicMem) { ptr = MallocDynamicMem(size, false); } return ptr; diff --git a/mindspore/ccsrc/runtime/device/memory_manager.h b/mindspore/ccsrc/runtime/device/memory_manager.h index 83a7e90d20..4ce870df08 100644 --- a/mindspore/ccsrc/runtime/device/memory_manager.h +++ b/mindspore/ccsrc/runtime/device/memory_manager.h @@ -22,10 +22,7 @@ #include "backend/optimizer/mem_reuse/mem_reuse_allocator.h" namespace mindspore { namespace device { -const int kStaticMem = 0; -const int kDynamicMem = 1; -const int kReuseDynamicMem = 2; -const int kReuseDynamicCommMem = 3; +enum MemType { kStaticMem, kDynamicMem, kReuseDynamicMem, kReuseDynamicCommMem }; const int kGetAllOuts = -1; const uint64_t kMemAlignSize = 512; using MemReuseUtilPtr = mindspore::memreuse::MemReuseUtilPtr; @@ -42,10 +39,10 @@ class MemoryManager { dynamic_mem_offset_ = 0; } - void MallocReusedDynamicMem(session::KernelGraph *graph); - uint8_t *MallocOutputMem(const AnfNodePtr &node, size_t index, int flag, size_t size); - uint8_t *MallocWorkSpaceMem(const AnfNodePtr &node, size_t index, int flag, size_t size); - virtual uint8_t *MallocMem(int flag, size_t size); + void MallocReusedDynamicMem(const session::KernelGraph *graph); + uint8_t *MallocOutputMem(const AnfNodePtr &node, size_t index, MemType type, size_t size); + uint8_t *MallocWorkSpaceMem(const AnfNodePtr &node, size_t index, MemType type, size_t size); + virtual uint8_t *MallocMem(MemType type, size_t size); virtual bool MallocMemFromMemPool(const DeviceAddressPtr address, size_t size); virtual void *MallocMemFromMemPool(size_t size);