Merge pull request !3117 from laiyongqiang/refnode_inputtags/v0.6.0-beta
| @@ -25,8 +25,8 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace memreuse { | namespace memreuse { | ||||
| enum RefCountType { kDynamicRefCount, kStaticRefCount }; | enum RefCountType { kDynamicRefCount, kStaticRefCount }; | ||||
| enum NodeType { COMMON_NODE, COMMUNICATION_NODE }; | |||||
| enum KernelRefType { COMMON, REFNODE_OUTPUT, COMM_NOTREUSE, COMM_REUSE, SUMMARY }; | |||||
| enum NodeType { kCommonNode, kCommunicationNode }; | |||||
| enum KernelRefType { kCommon, kRefNodeInput, kRefNodeOutput, kCommNotReuse, kCommReuse, kSummary }; | |||||
| static constexpr int kInitIndex = -1; | static constexpr int kInitIndex = -1; | ||||
| class KernelRefCount { | class KernelRefCount { | ||||
| public: | public: | ||||
| @@ -46,7 +46,7 @@ class KernelRefCount { | |||||
| offset_(0), | offset_(0), | ||||
| size_(0), | size_(0), | ||||
| index_(kInitIndex), | index_(kInitIndex), | ||||
| type_(COMMON), | |||||
| type_(kCommon), | |||||
| reftype_(kStaticRefCount) {} | reftype_(kStaticRefCount) {} | ||||
| ~KernelRefCount() = default; | ~KernelRefCount() = default; | ||||
| void SetKernelRefCountInfo(int index, size_t size, RefCountType reftype); | void SetKernelRefCountInfo(int index, size_t size, RefCountType reftype); | ||||
| @@ -68,7 +68,7 @@ class KernelDef { | |||||
| KernelMap inputs_; | KernelMap inputs_; | ||||
| KernelMap outputs_; | KernelMap outputs_; | ||||
| KernelMap wk_space_; | KernelMap wk_space_; | ||||
| NodeType type_ = COMMON_NODE; | |||||
| NodeType type_ = kCommonNode; | |||||
| KernelDef() = default; | KernelDef() = default; | ||||
| ~KernelDef() = default; | ~KernelDef() = default; | ||||
| void set_input_refs(const KernelRefCountPtrList &kernelRefPtrList) { input_refs_ = kernelRefPtrList; } | void set_input_refs(const KernelRefCountPtrList &kernelRefPtrList) { input_refs_ = kernelRefPtrList; } | ||||
| @@ -57,13 +57,22 @@ bool MemReuseUtil::InitDynamicOutputKernelRef() { | |||||
| kernel_ref->stream_id_ = curr_stream_id; | kernel_ref->stream_id_ = curr_stream_id; | ||||
| kernel_ref->SetKernelRefCountInfo(index, size, kDynamicRefCount); | kernel_ref->SetKernelRefCountInfo(index, size, kDynamicRefCount); | ||||
| if (is_comm_op) { | if (is_comm_op) { | ||||
| kernel_ref->type_ = COMM_REUSE; | |||||
| kernel_ref->type_ = kCommReuse; | |||||
| } else { | } else { | ||||
| session::AnfWithOutIndex out_pair(kernel_cnode, output_index); | session::AnfWithOutIndex out_pair(kernel_cnode, output_index); | ||||
| if (graph_->IsInRefOutputMap(out_pair)) { | if (graph_->IsInRefOutputMap(out_pair)) { | ||||
| kernel_ref->type_ = REFNODE_OUTPUT; | |||||
| kernel_ref->type_ = kRefNodeOutput; | |||||
| auto origin_pair = graph_->GetRefCorrespondOutput(out_pair); | |||||
| MS_EXCEPTION_IF_NULL(origin_pair.first); | |||||
| if (origin_pair.first->isa<CNode>()) { | |||||
| auto cnode = origin_pair.first->cast<CNodePtr>(); | |||||
| auto ref_ptr = GetKernelInputRef(cnode, origin_pair.second); | |||||
| if (ref_ptr != nullptr) { | |||||
| kernel_ref->type_ = kRefNodeInput; | |||||
| } | |||||
| } | |||||
| } else { | } else { | ||||
| kernel_ref->type_ = COMMON; | |||||
| kernel_ref->type_ = kCommon; | |||||
| } | } | ||||
| } | } | ||||
| kernel_refs.push_back(kernel_ref); | kernel_refs.push_back(kernel_ref); | ||||
| @@ -175,9 +184,9 @@ void MemReuseUtil::SetInputMap(const CNodePtr &kernel, KernelDef *kernel_def_ptr | |||||
| if (ref_ptr != nullptr) { | if (ref_ptr != nullptr) { | ||||
| if (is_comm_op) { | if (is_comm_op) { | ||||
| if (input_tensor_num == 1) { | if (input_tensor_num == 1) { | ||||
| ref_ptr->type_ = COMM_REUSE; | |||||
| ref_ptr->type_ = kCommReuse; | |||||
| } else { | } else { | ||||
| ref_ptr->type_ = COMM_NOTREUSE; | |||||
| ref_ptr->type_ = kCommNotReuse; | |||||
| } | } | ||||
| } | } | ||||
| @@ -282,9 +291,9 @@ void MemReuseUtil::SetKernelDefMap() { | |||||
| kernel_def_ptr->set_input_refs(kernel_def_ptr->inputs_[key]); | kernel_def_ptr->set_input_refs(kernel_def_ptr->inputs_[key]); | ||||
| kernel_def_ptr->set_output_refs(kernel_def_ptr->outputs_[key]); | kernel_def_ptr->set_output_refs(kernel_def_ptr->outputs_[key]); | ||||
| if (AnfAlgo::IsCommunicationOp(kernel)) { | if (AnfAlgo::IsCommunicationOp(kernel)) { | ||||
| kernel_def_ptr->type_ = COMMUNICATION_NODE; | |||||
| kernel_def_ptr->type_ = kCommunicationNode; | |||||
| } else { | } else { | ||||
| kernel_def_ptr->type_ = COMMON_NODE; | |||||
| kernel_def_ptr->type_ = kCommonNode; | |||||
| } | } | ||||
| kernel_def_ptr_list_.push_back(kernel_def_ptr); | kernel_def_ptr_list_.push_back(kernel_def_ptr); | ||||
| kernel_map_[key] = kernel_def_ptr; | kernel_map_[key] = kernel_def_ptr; | ||||
| @@ -365,7 +374,7 @@ void MemReuseUtil::SetSummaryNodesRefCount() { | |||||
| KernelRefCountPtr kernel_ref = kernel_output_refs_[node.get()][index]; | KernelRefCountPtr kernel_ref = kernel_output_refs_[node.get()][index]; | ||||
| kernel_ref->ref_count_ = kMaxRefCount; | kernel_ref->ref_count_ = kMaxRefCount; | ||||
| kernel_ref->ref_count_dynamic_use_ = kMaxRefCount; | kernel_ref->ref_count_dynamic_use_ = kMaxRefCount; | ||||
| kernel_ref->type_ = SUMMARY; | |||||
| kernel_ref->type_ = kSummary; | |||||
| total_summary_size += kernel_ref->size_; | total_summary_size += kernel_ref->size_; | ||||
| MS_LOG(INFO) << "Set summary node's ref count, node: " << node->fullname_with_scope() << " index: " << index; | MS_LOG(INFO) << "Set summary node's ref count, node: " << node->fullname_with_scope() << " index: " << index; | ||||
| } else { | } else { | ||||
| @@ -373,12 +382,29 @@ void MemReuseUtil::SetSummaryNodesRefCount() { | |||||
| } | } | ||||
| } | } | ||||
| #ifdef MEM_REUSE_DEBUG | #ifdef MEM_REUSE_DEBUG | ||||
| auto graph = *graph_; | |||||
| MemReuseChecker::GetInstance().CheckMemReuseIR(total_refs_list_, kernel_def_ptr_list_, &graph); | |||||
| MemReuseChecker::GetInstance().CheckMemReuseIR(total_refs_list_, kernel_def_ptr_list_, graph_); | |||||
| #endif | #endif | ||||
| MS_LOG(INFO) << "Special Tensor total size: SummaryNodes: " << total_summary_size; | MS_LOG(INFO) << "Special Tensor total size: SummaryNodes: " << total_summary_size; | ||||
| } | } | ||||
| void MemReuseUtil::SetRefNodesInputRefCount() { | |||||
| size_t total_size = 0; | |||||
| for (auto iter : kernel_output_refs_) { | |||||
| for (auto &ref_count : iter.second) { | |||||
| MS_EXCEPTION_IF_NULL(ref_count); | |||||
| if (ref_count->type_ == kRefNodeInput) { | |||||
| ref_count->ref_count_ = kMaxRefCount; | |||||
| total_size += ref_count->size_; | |||||
| } | |||||
| } | |||||
| } | |||||
| MS_LOG(INFO) << "Special Tensor total size: RefNodeInput: " << total_size; | |||||
| #ifdef MEM_REUSE_DEBUG | |||||
| MemReuseChecker::GetInstance().CheckMemReuseIR(total_refs_list_, kernel_def_ptr_list_, graph_); | |||||
| #endif | |||||
| } | |||||
| void MemReuseUtil::SetGraphOutputRefCount() { | void MemReuseUtil::SetGraphOutputRefCount() { | ||||
| auto nodes = AnfAlgo::GetAllOutput(graph_->output(), {prim::kPrimTupleGetItem}); | auto nodes = AnfAlgo::GetAllOutput(graph_->output(), {prim::kPrimTupleGetItem}); | ||||
| for (const auto &node : nodes) { | for (const auto &node : nodes) { | ||||
| @@ -405,8 +431,7 @@ void MemReuseUtil::SetGraphOutputRefCount() { | |||||
| } | } | ||||
| } | } | ||||
| #ifdef MEM_REUSE_DEBUG | #ifdef MEM_REUSE_DEBUG | ||||
| auto graph = *graph_; | |||||
| MemReuseChecker::GetInstance().CheckMemReuseIR(total_refs_list_, kernel_def_ptr_list_, &graph); | |||||
| MemReuseChecker::GetInstance().CheckMemReuseIR(total_refs_list_, kernel_def_ptr_list_, graph_); | |||||
| #endif | #endif | ||||
| } | } | ||||
| @@ -419,13 +444,14 @@ void MemReuseUtil::ResetDynamicUsedRefCount() { | |||||
| } | } | ||||
| } | } | ||||
| void MemReuseUtil::SetAllInfo(KernelGraph *graph) { | |||||
| void MemReuseUtil::SetAllInfo(const KernelGraph *graph) { | |||||
| if (!InitDynamicKernelRef(graph)) { | if (!InitDynamicKernelRef(graph)) { | ||||
| MS_LOG(EXCEPTION) << "Init ReuseAssignDynamicMemory Fault"; | MS_LOG(EXCEPTION) << "Init ReuseAssignDynamicMemory Fault"; | ||||
| } | } | ||||
| SetKernelDefMap(); | SetKernelDefMap(); | ||||
| SetReuseRefCount(); | SetReuseRefCount(); | ||||
| SetSummaryNodesRefCount(); | SetSummaryNodesRefCount(); | ||||
| SetRefNodesInputRefCount(); | |||||
| SetWorkSpaceList(); | SetWorkSpaceList(); | ||||
| #ifdef MEM_REUSE_DEBUG | #ifdef MEM_REUSE_DEBUG | ||||
| MemReuseChecker::GetInstance().CheckMemReuseIR(total_refs_list_, kernel_def_ptr_list_, graph); | MemReuseChecker::GetInstance().CheckMemReuseIR(total_refs_list_, kernel_def_ptr_list_, graph); | ||||
| @@ -52,7 +52,7 @@ class MemReuseUtil { | |||||
| MS_LOG(INFO) << "Total Reused WorkSpafce Memory Size: " << total_reuseworkspace_size_; | MS_LOG(INFO) << "Total Reused WorkSpafce Memory Size: " << total_reuseworkspace_size_; | ||||
| } | } | ||||
| void SetAllInfo(KernelGraph *graph); | |||||
| void SetAllInfo(const KernelGraph *graph); | |||||
| bool InitDynamicOutputKernelRef(); | bool InitDynamicOutputKernelRef(); | ||||
| bool InitDynamicWorkspaceKernelRef(); | bool InitDynamicWorkspaceKernelRef(); | ||||
| bool InitDynamicKernelRef(const KernelGraph *graph); | bool InitDynamicKernelRef(const KernelGraph *graph); | ||||
| @@ -64,6 +64,7 @@ class MemReuseUtil { | |||||
| void SetKernelDefInputs(); | void SetKernelDefInputs(); | ||||
| void SetReuseRefCount(); | void SetReuseRefCount(); | ||||
| void SetSummaryNodesRefCount(); | void SetSummaryNodesRefCount(); | ||||
| void SetRefNodesInputRefCount(); | |||||
| // Set the reference count of graph output specially. | // Set the reference count of graph output specially. | ||||
| void SetGraphOutputRefCount(); | void SetGraphOutputRefCount(); | ||||
| // Reset the dynamic used reference count by ref_count_. | // Reset the dynamic used reference count by ref_count_. | ||||
| @@ -90,7 +90,7 @@ bool BestFitMemReuse::IsUsable(const KernelDefPtr &kernel_curr, const MembufPtr | |||||
| auto curr_stream_id = kernel_curr->stream_id(); | auto curr_stream_id = kernel_curr->stream_id(); | ||||
| auto prev_stream_id = kernel_prev->stream_id(); | auto prev_stream_id = kernel_prev->stream_id(); | ||||
| if (curr_stream_id == prev_stream_id) { | if (curr_stream_id == prev_stream_id) { | ||||
| mem_buf->type_ = IN_STREAM_REUSE; | |||||
| mem_buf->type_ = kInStreamReuse; | |||||
| return true; | return true; | ||||
| } | } | ||||
| @@ -117,7 +117,7 @@ bool BestFitMemReuse::IsUsable(const KernelDefPtr &kernel_curr, const MembufPtr | |||||
| } | } | ||||
| if (reuse_between_streams) { | if (reuse_between_streams) { | ||||
| mem_buf->type_ = BETWEEN_STREAMS_REUSE; | |||||
| mem_buf->type_ = kBetweenStreamReuse; | |||||
| return true; | return true; | ||||
| } | } | ||||
| @@ -128,7 +128,7 @@ bool BestFitMemReuse::IsUsable(const KernelDefPtr &kernel_curr, const MembufPtr | |||||
| auto kernel_curr_front = iter->second; | auto kernel_curr_front = iter->second; | ||||
| auto depend_count = kernel_curr_front.count(kernel_prev); | auto depend_count = kernel_curr_front.count(kernel_prev); | ||||
| if (depend_count) { | if (depend_count) { | ||||
| mem_buf->type_ = KERNEL_DEPENDENCE_REUSE; | |||||
| mem_buf->type_ = kKernelDependenceReuse; | |||||
| return true; | return true; | ||||
| } | } | ||||
| @@ -137,16 +137,19 @@ bool BestFitMemReuse::IsUsable(const KernelDefPtr &kernel_curr, const MembufPtr | |||||
| void BestFitMemReuse::AssignCommonNodeOutputOffset() { | void BestFitMemReuse::AssignCommonNodeOutputOffset() { | ||||
| MS_EXCEPTION_IF_NULL(current_kernel_); | MS_EXCEPTION_IF_NULL(current_kernel_); | ||||
| for (auto &tensor_idx : current_kernel_->GetOutputRefIndexs()) { | |||||
| for (const auto &tensor_idx : current_kernel_->GetOutputRefIndexs()) { | |||||
| size_t index = GetTensorIndex(tensor_idx); | size_t index = GetTensorIndex(tensor_idx); | ||||
| auto tensor_desc = tensor_ptr_list_[index]; | auto tensor_desc = tensor_ptr_list_[index]; | ||||
| MS_EXCEPTION_IF_NULL(tensor_desc); | MS_EXCEPTION_IF_NULL(tensor_desc); | ||||
| if (tensor_desc->type_ == REFNODE_OUTPUT) { | |||||
| if (tensor_desc->type_ == kRefNodeInput) { | |||||
| total_refinput_size += tensor_desc->size_; | |||||
| } else if (tensor_desc->type_ == kRefNodeOutput) { | |||||
| total_refoutput_size += tensor_desc->size_; | total_refoutput_size += tensor_desc->size_; | ||||
| // no need to alloc refnode output's memory | |||||
| continue; | continue; | ||||
| } else if (tensor_desc->type_ == COMM_NOTREUSE) { | |||||
| } else if (tensor_desc->type_ == kCommNotReuse) { | |||||
| total_comm_not_reuse_size += tensor_desc->size_; | total_comm_not_reuse_size += tensor_desc->size_; | ||||
| } else if (tensor_desc->type_ == COMM_REUSE) { | |||||
| } else if (tensor_desc->type_ == kCommReuse) { | |||||
| // get align size for communication op's single input | // get align size for communication op's single input | ||||
| tensor_desc->size_ = AlignCommunicationMemorySize(tensor_desc->size_); | tensor_desc->size_ = AlignCommunicationMemorySize(tensor_desc->size_); | ||||
| total_comm_reuse_size += tensor_desc->size_; | total_comm_reuse_size += tensor_desc->size_; | ||||
| @@ -165,7 +168,7 @@ void BestFitMemReuse::AssignCommonNodeOutputOffset() { | |||||
| #endif | #endif | ||||
| } | } | ||||
| // skip left align border for communication op single input to used | // skip left align border for communication op single input to used | ||||
| if (tensor_desc->type_ == COMM_REUSE) { | |||||
| if (tensor_desc->type_ == kCommReuse) { | |||||
| tensor_desc->offset_ += kDefaultMemAlignSize; | tensor_desc->offset_ += kDefaultMemAlignSize; | ||||
| } | } | ||||
| } | } | ||||
| @@ -176,17 +179,18 @@ void BestFitMemReuse::AssignCommunicationNodeOutputOffset() { | |||||
| size_t output_num = 0; | size_t output_num = 0; | ||||
| // get all output size | // get all output size | ||||
| MS_EXCEPTION_IF_NULL(current_kernel_); | MS_EXCEPTION_IF_NULL(current_kernel_); | ||||
| for (auto &tensor_idx : current_kernel_->GetOutputRefIndexs()) { | |||||
| for (const auto &tensor_idx : current_kernel_->GetOutputRefIndexs()) { | |||||
| size_t index = GetTensorIndex(tensor_idx); | size_t index = GetTensorIndex(tensor_idx); | ||||
| auto tensor_desc = tensor_ptr_list_[index]; | auto tensor_desc = tensor_ptr_list_[index]; | ||||
| MS_EXCEPTION_IF_NULL(tensor_desc); | MS_EXCEPTION_IF_NULL(tensor_desc); | ||||
| if (tensor_desc->type_ == COMM_REUSE) { | |||||
| if (tensor_desc->type_ == kCommReuse) { | |||||
| total_comm_reuse_size += tensor_desc->size_; | total_comm_reuse_size += tensor_desc->size_; | ||||
| total_comm_output_reuse_size += tensor_desc->size_; | total_comm_output_reuse_size += tensor_desc->size_; | ||||
| total_kernel_output_size += tensor_desc->size_; | total_kernel_output_size += tensor_desc->size_; | ||||
| } else { | } else { | ||||
| MS_LOG(ERROR) << "All communication op's outputs should be memory reuse, Kernel:" | MS_LOG(ERROR) << "All communication op's outputs should be memory reuse, Kernel:" | ||||
| << current_kernel_->scope_full_name(); | |||||
| << current_kernel_->scope_full_name() << " output index:" << tensor_idx | |||||
| << " tensor_type:" << tensor_desc->type_; | |||||
| continue; | continue; | ||||
| } | } | ||||
| } | } | ||||
| @@ -195,7 +199,7 @@ void BestFitMemReuse::AssignCommunicationNodeOutputOffset() { | |||||
| // add left align border for the first output and right align border for the last output to alloc align border memory | // add left align border for the first output and right align border for the last output to alloc align border memory | ||||
| size_t output_index = 0; | size_t output_index = 0; | ||||
| auto output_ref_indexes = current_kernel_->GetOutputRefIndexs(); | auto output_ref_indexes = current_kernel_->GetOutputRefIndexs(); | ||||
| for (auto &tensor_idx : output_ref_indexes) { | |||||
| for (const auto &tensor_idx : output_ref_indexes) { | |||||
| size_t index = GetTensorIndex(tensor_idx); | size_t index = GetTensorIndex(tensor_idx); | ||||
| auto tensor_desc = tensor_ptr_list_[index]; | auto tensor_desc = tensor_ptr_list_[index]; | ||||
| MS_EXCEPTION_IF_NULL(tensor_desc); | MS_EXCEPTION_IF_NULL(tensor_desc); | ||||
| @@ -215,7 +219,7 @@ void BestFitMemReuse::AssignCommunicationNodeOutputOffset() { | |||||
| if (!reusable_membuf_map.empty()) { | if (!reusable_membuf_map.empty()) { | ||||
| auto membuf_index = reusable_membuf_map.begin()->second; | auto membuf_index = reusable_membuf_map.begin()->second; | ||||
| output_index = 0; | output_index = 0; | ||||
| for (auto &tensor_idx : current_kernel_->GetOutputRefIndexs()) { | |||||
| for (const auto &tensor_idx : current_kernel_->GetOutputRefIndexs()) { | |||||
| size_t index = GetTensorIndex(tensor_idx); | size_t index = GetTensorIndex(tensor_idx); | ||||
| auto tensor_desc = tensor_ptr_list_[index]; | auto tensor_desc = tensor_ptr_list_[index]; | ||||
| MS_EXCEPTION_IF_NULL(tensor_desc); | MS_EXCEPTION_IF_NULL(tensor_desc); | ||||
| @@ -229,7 +233,7 @@ void BestFitMemReuse::AssignCommunicationNodeOutputOffset() { | |||||
| } else { | } else { | ||||
| // no membuf can reuse, add new membuf after the membuf_ptr_list | // no membuf can reuse, add new membuf after the membuf_ptr_list | ||||
| output_index = 0; | output_index = 0; | ||||
| for (auto &tensor_idx : current_kernel_->GetOutputRefIndexs()) { | |||||
| for (const auto &tensor_idx : current_kernel_->GetOutputRefIndexs()) { | |||||
| size_t index = GetTensorIndex(tensor_idx); | size_t index = GetTensorIndex(tensor_idx); | ||||
| auto tensor_desc = tensor_ptr_list_[index]; | auto tensor_desc = tensor_ptr_list_[index]; | ||||
| MS_EXCEPTION_IF_NULL(tensor_desc); | MS_EXCEPTION_IF_NULL(tensor_desc); | ||||
| @@ -247,7 +251,7 @@ void BestFitMemReuse::AssignCommunicationNodeOutputOffset() { | |||||
| } | } | ||||
| void BestFitMemReuse::AssignNodeOutputOffset() { | void BestFitMemReuse::AssignNodeOutputOffset() { | ||||
| if (current_kernel_->type_ == COMMUNICATION_NODE) { | |||||
| if (current_kernel_->type_ == kCommunicationNode) { | |||||
| AssignCommunicationNodeOutputOffset(); | AssignCommunicationNodeOutputOffset(); | ||||
| } else { | } else { | ||||
| AssignCommonNodeOutputOffset(); | AssignCommonNodeOutputOffset(); | ||||
| @@ -330,7 +334,7 @@ void BestFitMemReuse::AddNewMembufPtr(KernelRefCount *tensor_desc, int flag) { | |||||
| } | } | ||||
| auto membuf_size = tensor_desc->size_; | auto membuf_size = tensor_desc->size_; | ||||
| auto real_index = GetRealIndex(IntToSize(tensor_desc->index_), flag); | auto real_index = GetRealIndex(IntToSize(tensor_desc->index_), flag); | ||||
| auto membuf = std::make_shared<Membuf>(kReused, membuf_size, membuf_offset, real_index, NEW, current_kernel_); | |||||
| auto membuf = std::make_shared<Membuf>(kReused, membuf_size, membuf_offset, real_index, kNew, current_kernel_); | |||||
| membuf_ptr_list_.push_back(membuf); | membuf_ptr_list_.push_back(membuf); | ||||
| tensor_desc->offset_ = membuf_offset; | tensor_desc->offset_ = membuf_offset; | ||||
| } | } | ||||
| @@ -352,7 +356,7 @@ void BestFitMemReuse::UpdateNodeInputAndMembuf() { | |||||
| } | } | ||||
| void BestFitMemReuse::ReleaseNodeUnusedOutput() { | void BestFitMemReuse::ReleaseNodeUnusedOutput() { | ||||
| for (auto &tensor_idx : current_kernel_->GetOutputRefIndexs()) { | |||||
| for (const auto &tensor_idx : current_kernel_->GetOutputRefIndexs()) { | |||||
| size_t tensor_index = GetTensorIndex(tensor_idx); | size_t tensor_index = GetTensorIndex(tensor_idx); | ||||
| auto tensor_desc = tensor_ptr_list_[tensor_index]; | auto tensor_desc = tensor_ptr_list_[tensor_index]; | ||||
| MS_EXCEPTION_IF_NULL(tensor_desc); | MS_EXCEPTION_IF_NULL(tensor_desc); | ||||
| @@ -517,8 +521,9 @@ void BestFitMemReuse::Reuse(const MemReuseUtil *mem_reuse_util_ptr) { | |||||
| ++op_num; | ++op_num; | ||||
| #endif | #endif | ||||
| } | } | ||||
| MS_LOG(INFO) << "Special Tensor total size: RefOutput: " << total_refoutput_size | |||||
| << " CommReuse: " << total_comm_reuse_size << " CommOutputReuse: " << total_comm_output_reuse_size | |||||
| MS_LOG(INFO) << "Special Tensor total size: RefInput: " << total_refinput_size | |||||
| << " RefOutput: " << total_refoutput_size << " CommReuse: " << total_comm_reuse_size | |||||
| << " CommOutputReuse: " << total_comm_output_reuse_size | |||||
| << " CommNotReuse: " << total_comm_not_reuse_size; | << " CommNotReuse: " << total_comm_not_reuse_size; | ||||
| #ifdef MEM_REUSE_DEBUG | #ifdef MEM_REUSE_DEBUG | ||||
| MemReuseChecker::GetInstance().ExportMembufInfoIR(); | MemReuseChecker::GetInstance().ExportMembufInfoIR(); | ||||
| @@ -40,11 +40,11 @@ static constexpr int kDynamicMem = -1; | |||||
| static constexpr int kWorkspaceMem = 1; | static constexpr int kWorkspaceMem = 1; | ||||
| static constexpr size_t kTotalSize = 0; | static constexpr size_t kTotalSize = 0; | ||||
| enum Status { kUnused, kReused }; | enum Status { kUnused, kReused }; | ||||
| enum MEMTYPE { NEW, IN_STREAM_REUSE, BETWEEN_STREAMS_REUSE, KERNEL_DEPENDENCE_REUSE }; | |||||
| enum MemType { kNew, kInStreamReuse, kBetweenStreamReuse, kKernelDependenceReuse }; | |||||
| class Membuf { | class Membuf { | ||||
| public: | public: | ||||
| Membuf() = default; | Membuf() = default; | ||||
| Membuf(Status status, size_t size, size_t offset, int index, MEMTYPE type, const KernelDefPtr &used_kernel) | |||||
| Membuf(Status status, size_t size, size_t offset, int index, MemType type, const KernelDefPtr &used_kernel) | |||||
| : status_(status), size_(size), offset_(offset), index_(index), type_(type), used_kernel_(used_kernel) {} | : status_(status), size_(size), offset_(offset), index_(index), type_(type), used_kernel_(used_kernel) {} | ||||
| ~Membuf() = default; | ~Membuf() = default; | ||||
| // Memory block status flags | // Memory block status flags | ||||
| @@ -53,7 +53,7 @@ class Membuf { | |||||
| size_t offset_{0}; | size_t offset_{0}; | ||||
| // Store the tensor index stored in this memory block at a certain moment | // Store the tensor index stored in this memory block at a certain moment | ||||
| int index_{0}; | int index_{0}; | ||||
| MEMTYPE type_{NEW}; | |||||
| MemType type_{kNew}; | |||||
| KernelDefPtr used_kernel_; | KernelDefPtr used_kernel_; | ||||
| }; | }; | ||||
| using MembufPtr = std::shared_ptr<Membuf>; | using MembufPtr = std::shared_ptr<Membuf>; | ||||
| @@ -163,6 +163,7 @@ class BestFitMemReuse { | |||||
| // kernel_front_map_, key: the kernel_def, value: kernels before this kernel_def | // kernel_front_map_, key: the kernel_def, value: kernels before this kernel_def | ||||
| std::map<KernelDefPtr, std::set<KernelDefPtr>> kernel_front_map_; | std::map<KernelDefPtr, std::set<KernelDefPtr>> kernel_front_map_; | ||||
| std::vector<std::vector<uint32_t>> stream_groups_; | std::vector<std::vector<uint32_t>> stream_groups_; | ||||
| size_t total_refinput_size{0}; | |||||
| size_t total_refoutput_size{0}; | size_t total_refoutput_size{0}; | ||||
| size_t total_comm_reuse_size{0}; | size_t total_comm_reuse_size{0}; | ||||
| size_t total_comm_output_reuse_size{0}; | size_t total_comm_output_reuse_size{0}; | ||||
| @@ -83,7 +83,7 @@ int64_t MemReuseChecker::CalculOriInput(const KernelGraph *graph) const { | |||||
| return static_input_size; | return static_input_size; | ||||
| } | } | ||||
| int64_t MemReuseChecker::CalculOriValue(KernelGraph *graph) const { | |||||
| int64_t MemReuseChecker::CalculOriValue(const KernelGraph *graph) const { | |||||
| MS_EXCEPTION_IF_NULL(graph); | MS_EXCEPTION_IF_NULL(graph); | ||||
| int64_t static_value_size = 0; | int64_t static_value_size = 0; | ||||
| for (auto &value_node : graph->graph_value_nodes()) { | for (auto &value_node : graph->graph_value_nodes()) { | ||||
| @@ -101,7 +101,7 @@ int64_t MemReuseChecker::CalculOriValue(KernelGraph *graph) const { | |||||
| return static_value_size; | return static_value_size; | ||||
| } | } | ||||
| int64_t MemReuseChecker::CalculOriStatic(KernelGraph *graph) const { | |||||
| int64_t MemReuseChecker::CalculOriStatic(const KernelGraph *graph) const { | |||||
| // cal static inputs | // cal static inputs | ||||
| auto static_input_size = CalculOriInput(graph); | auto static_input_size = CalculOriInput(graph); | ||||
| // do not calcul outpput size | // do not calcul outpput size | ||||
| @@ -154,7 +154,7 @@ std::string MemReuseChecker::GetSplitName(const std::string &scope_name) const { | |||||
| } | } | ||||
| void MemReuseChecker::CheckMemReuseIR(const KernelRefCountPtrList &total_refs_list, | void MemReuseChecker::CheckMemReuseIR(const KernelRefCountPtrList &total_refs_list, | ||||
| const KernelDefPtrMaps &kernel_def_ptr_list, KernelGraph *graph) { | |||||
| const KernelDefPtrMaps &kernel_def_ptr_list, const KernelGraph *graph) { | |||||
| total_ori_static_size_ = CalculOriStatic(graph); | total_ori_static_size_ = CalculOriStatic(graph); | ||||
| total_ori_input_size_ = CalculOriInput(graph); | total_ori_input_size_ = CalculOriInput(graph); | ||||
| total_ori_value_size_ = CalculOriValue(graph); | total_ori_value_size_ = CalculOriValue(graph); | ||||
| @@ -43,10 +43,10 @@ class MemReuseChecker { | |||||
| void CheckOutRef(const KernelRefs &kernel_refs, const CNodePtr &c_node, size_t output_idx); | void CheckOutRef(const KernelRefs &kernel_refs, const CNodePtr &c_node, size_t output_idx); | ||||
| bool CheckGraphOutputAssigned(const session::KernelGraph *graph); | bool CheckGraphOutputAssigned(const session::KernelGraph *graph); | ||||
| void CheckMemReuseIR(const KernelRefCountPtrList &total_refs_list, const KernelDefPtrMaps &kernel_def_ptr_list, | void CheckMemReuseIR(const KernelRefCountPtrList &total_refs_list, const KernelDefPtrMaps &kernel_def_ptr_list, | ||||
| KernelGraph *graph); | |||||
| int64_t CalculOriStatic(KernelGraph *graph) const; | |||||
| const KernelGraph *graph); | |||||
| int64_t CalculOriStatic(const KernelGraph *graph) const; | |||||
| int64_t CalculOriInput(const KernelGraph *graph) const; | int64_t CalculOriInput(const KernelGraph *graph) const; | ||||
| int64_t CalculOriValue(KernelGraph *graph) const; | |||||
| int64_t CalculOriValue(const KernelGraph *graph) const; | |||||
| int64_t CalculOriDy(const KernelGraph *graph) const; | int64_t CalculOriDy(const KernelGraph *graph) const; | ||||
| int64_t CalculOriWk(const KernelGraph *graph) const; | int64_t CalculOriWk(const KernelGraph *graph) const; | ||||
| std::string GetSplitName(const std::string &scope_name) const; | std::string GetSplitName(const std::string &scope_name) const; | ||||
| @@ -398,12 +398,12 @@ void KernelRuntime::UpdateRefNodeOutputMem(const session::KernelGraph *graph) { | |||||
| } | } | ||||
| } | } | ||||
| void KernelRuntime::AssignCommunicationNodeMem(int flag, const AnfNodePtr &node) { | |||||
| AssignCommunicationNodeInputMem(flag, node); | |||||
| AssignCommunicationNodeOutputMem(flag, node); | |||||
| void KernelRuntime::AssignCommunicationNodeMem(MemType type, const AnfNodePtr &node) { | |||||
| AssignCommunicationNodeInputMem(type, node); | |||||
| AssignCommunicationNodeOutputMem(type, node); | |||||
| } | } | ||||
| void KernelRuntime::AssignCommunicationNodeOutputMem(int flag, const AnfNodePtr &node) { | |||||
| void KernelRuntime::AssignCommunicationNodeOutputMem(MemType type, const AnfNodePtr &node) { | |||||
| MS_EXCEPTION_IF_NULL(node); | MS_EXCEPTION_IF_NULL(node); | ||||
| MS_EXCEPTION_IF_NULL(mem_manager_); | MS_EXCEPTION_IF_NULL(mem_manager_); | ||||
| auto kernel_mod = AnfAlgo::GetKernelMod(node); | auto kernel_mod = AnfAlgo::GetKernelMod(node); | ||||
| @@ -430,11 +430,11 @@ void KernelRuntime::AssignCommunicationNodeOutputMem(int flag, const AnfNodePtr | |||||
| align_size_list.emplace_back(mem_size); | align_size_list.emplace_back(mem_size); | ||||
| } | } | ||||
| if (flag == kReuseDynamicMem) { | |||||
| if (type == kReuseDynamicMem) { | |||||
| // reuse communication op's all outputs' memory | // reuse communication op's all outputs' memory | ||||
| flag = kReuseDynamicCommMem; | |||||
| type = kReuseDynamicCommMem; | |||||
| } | } | ||||
| uint8_t *output_ptr = mem_manager_->MallocOutputMem(node, 0, flag, total_size); | |||||
| uint8_t *output_ptr = mem_manager_->MallocOutputMem(node, 0, type, total_size); | |||||
| for (size_t j = 0; j < align_size_list.size(); ++j) { | for (size_t j = 0; j < align_size_list.size(); ++j) { | ||||
| std::string output_format = AnfAlgo::GetOutputFormat(node, j); | std::string output_format = AnfAlgo::GetOutputFormat(node, j); | ||||
| auto output_type = AnfAlgo::GetOutputDeviceDataType(node, j); | auto output_type = AnfAlgo::GetOutputDeviceDataType(node, j); | ||||
| @@ -458,7 +458,7 @@ DeviceAddressPtr KernelRuntime::PreAssignCNodeMemory(const AnfNodePtr &anf_node, | |||||
| return address; | return address; | ||||
| } | } | ||||
| void KernelRuntime::AssignCommunicationNodeInputMem(int flag, const AnfNodePtr &node) { | |||||
| void KernelRuntime::AssignCommunicationNodeInputMem(MemType type, const AnfNodePtr &node) { | |||||
| auto context_ptr = MsContext::GetInstance(); | auto context_ptr = MsContext::GetInstance(); | ||||
| MS_EXCEPTION_IF_NULL(context_ptr); | MS_EXCEPTION_IF_NULL(context_ptr); | ||||
| MS_EXCEPTION_IF_NULL(node); | MS_EXCEPTION_IF_NULL(node); | ||||
| @@ -479,7 +479,7 @@ void KernelRuntime::AssignCommunicationNodeInputMem(int flag, const AnfNodePtr & | |||||
| total_size += mem_size; | total_size += mem_size; | ||||
| addr_size.emplace_back(address.get(), mem_size); | addr_size.emplace_back(address.get(), mem_size); | ||||
| } | } | ||||
| uint8_t *input_ptr = mem_manager_->MallocOutputMem(node, 0, flag, total_size); | |||||
| uint8_t *input_ptr = mem_manager_->MallocOutputMem(node, 0, type, total_size); | |||||
| for (const auto &iter : addr_size) { | for (const auto &iter : addr_size) { | ||||
| MS_EXCEPTION_IF_NULL(iter.first); | MS_EXCEPTION_IF_NULL(iter.first); | ||||
| iter.first->set_ptr(input_ptr); | iter.first->set_ptr(input_ptr); | ||||
| @@ -487,12 +487,12 @@ void KernelRuntime::AssignCommunicationNodeInputMem(int flag, const AnfNodePtr & | |||||
| } | } | ||||
| } | } | ||||
| void KernelRuntime::AssignNodeOutputMem(int flag, const AnfNodePtr &node, int index) { | |||||
| void KernelRuntime::AssignNodeOutputMem(MemType type, const AnfNodePtr &node, int index) { | |||||
| MS_EXCEPTION_IF_NULL(node); | MS_EXCEPTION_IF_NULL(node); | ||||
| MS_EXCEPTION_IF_NULL(mem_manager_); | MS_EXCEPTION_IF_NULL(mem_manager_); | ||||
| if (AnfAlgo::IsGetNext(NOT_NULL(node)) && flag == kReuseDynamicMem) { | |||||
| if (AnfAlgo::IsGetNext(NOT_NULL(node)) && type == kReuseDynamicMem) { | |||||
| MS_LOG(INFO) << "GetNext disable mem_reuse"; | MS_LOG(INFO) << "GetNext disable mem_reuse"; | ||||
| flag = kDynamicMem; | |||||
| type = kDynamicMem; | |||||
| } | } | ||||
| auto kernel_mod = AnfAlgo::GetKernelMod(node); | auto kernel_mod = AnfAlgo::GetKernelMod(node); | ||||
| MS_EXCEPTION_IF_NULL(kernel_mod); | MS_EXCEPTION_IF_NULL(kernel_mod); | ||||
| @@ -509,7 +509,7 @@ void KernelRuntime::AssignNodeOutputMem(int flag, const AnfNodePtr &node, int in | |||||
| MS_LOG(INFO) << "Already malloc index:" << i; | MS_LOG(INFO) << "Already malloc index:" << i; | ||||
| continue; | continue; | ||||
| } | } | ||||
| auto ptr = mem_manager_->MallocOutputMem(node, i, flag, output_sizes[i]); | |||||
| auto ptr = mem_manager_->MallocOutputMem(node, i, type, output_sizes[i]); | |||||
| if (ptr == nullptr) { | if (ptr == nullptr) { | ||||
| // reused ptr, no need alloc, continue; | // reused ptr, no need alloc, continue; | ||||
| continue; | continue; | ||||
| @@ -608,10 +608,10 @@ void KernelRuntime::AssignDynamicMemory(session::KernelGraph *graph) { | |||||
| auto context_ptr = MsContext::GetInstance(); | auto context_ptr = MsContext::GetInstance(); | ||||
| MS_EXCEPTION_IF_NULL(context_ptr); | MS_EXCEPTION_IF_NULL(context_ptr); | ||||
| bool is_enable_mem_reuse = context_ptr->enable_mem_reuse(); | bool is_enable_mem_reuse = context_ptr->enable_mem_reuse(); | ||||
| auto mem_flag = kDynamicMem; | |||||
| auto mem_type = kDynamicMem; | |||||
| if (is_enable_mem_reuse) { | if (is_enable_mem_reuse) { | ||||
| mem_manager_->MallocReusedDynamicMem(graph); | mem_manager_->MallocReusedDynamicMem(graph); | ||||
| mem_flag = kReuseDynamicMem; | |||||
| mem_type = kReuseDynamicMem; | |||||
| } | } | ||||
| auto &execution_nodes = graph->execution_order(); | auto &execution_nodes = graph->execution_order(); | ||||
| std::vector<CNodePtr> compute_nodes; | std::vector<CNodePtr> compute_nodes; | ||||
| @@ -619,7 +619,7 @@ void KernelRuntime::AssignDynamicMemory(session::KernelGraph *graph) { | |||||
| for (auto &node : execution_nodes) { | for (auto &node : execution_nodes) { | ||||
| if (AnfAlgo::IsCommunicationOp(node)) { | if (AnfAlgo::IsCommunicationOp(node)) { | ||||
| // skip if the memory is already alocated | // skip if the memory is already alocated | ||||
| AssignCommunicationNodeMem(mem_flag, node); | |||||
| AssignCommunicationNodeMem(mem_type, node); | |||||
| } else { | } else { | ||||
| compute_nodes.emplace_back(node); | compute_nodes.emplace_back(node); | ||||
| } | } | ||||
| @@ -627,19 +627,19 @@ void KernelRuntime::AssignDynamicMemory(session::KernelGraph *graph) { | |||||
| // then compute nodes | // then compute nodes | ||||
| for (auto &node : compute_nodes) { | for (auto &node : compute_nodes) { | ||||
| AssignNodeOutputMem(mem_flag, node, kGetAllOuts); | |||||
| AssignWorkSpaceMem(mem_flag, node); | |||||
| AssignNodeOutputMem(mem_type, node, kGetAllOuts); | |||||
| AssignWorkSpaceMem(mem_type, node); | |||||
| } | } | ||||
| } | } | ||||
| void KernelRuntime::AssignWorkSpaceMem(int flag, const AnfNodePtr &node) { | |||||
| void KernelRuntime::AssignWorkSpaceMem(MemType type, const AnfNodePtr &node) { | |||||
| MS_EXCEPTION_IF_NULL(node); | MS_EXCEPTION_IF_NULL(node); | ||||
| MS_EXCEPTION_IF_NULL(mem_manager_); | MS_EXCEPTION_IF_NULL(mem_manager_); | ||||
| auto kernel_mod = AnfAlgo::GetKernelMod(node); | auto kernel_mod = AnfAlgo::GetKernelMod(node); | ||||
| MS_EXCEPTION_IF_NULL(kernel_mod); | MS_EXCEPTION_IF_NULL(kernel_mod); | ||||
| size_t index = 0; | size_t index = 0; | ||||
| for (auto &size : kernel_mod->GetWorkspaceSizeList()) { | for (auto &size : kernel_mod->GetWorkspaceSizeList()) { | ||||
| auto ptr = mem_manager_->MallocWorkSpaceMem(node, index, flag, size); | |||||
| auto ptr = mem_manager_->MallocWorkSpaceMem(node, index, type, size); | |||||
| AnfAlgo::SetWorkspaceAddr(CreateDeviceAddress(ptr, size, "", kTypeUnknown), index, node.get()); | AnfAlgo::SetWorkspaceAddr(CreateDeviceAddress(ptr, size, "", kTypeUnknown), index, node.get()); | ||||
| index++; | index++; | ||||
| } | } | ||||
| @@ -83,15 +83,15 @@ class KernelRuntime { | |||||
| void AssignStaticMemory(session::KernelGraph *graph); | void AssignStaticMemory(session::KernelGraph *graph); | ||||
| void AssignDynamicMemory(session::KernelGraph *graph); | void AssignDynamicMemory(session::KernelGraph *graph); | ||||
| void ReuseAssignDynamicMemory(session::KernelGraph *graph); | void ReuseAssignDynamicMemory(session::KernelGraph *graph); | ||||
| void AssignNodeOutputMem(int flag, const AnfNodePtr &node, int index); | |||||
| void AssignWorkSpaceMem(int flag, const AnfNodePtr &node); | |||||
| void AssignNodeOutputMem(MemType type, const AnfNodePtr &node, int index); | |||||
| void AssignWorkSpaceMem(MemType type, const AnfNodePtr &node); | |||||
| void AssignReuseWorkSpaceMem(const AnfNodePtr &node); | void AssignReuseWorkSpaceMem(const AnfNodePtr &node); | ||||
| void UpdateRefNodeOutputMem(const session::KernelGraph *graph); | void UpdateRefNodeOutputMem(const session::KernelGraph *graph); | ||||
| void AssignCommunicationNodeOutputMem(int flag, const AnfNodePtr &node); | |||||
| void AssignCommunicationNodeInputMem(int flag, const AnfNodePtr &node); | |||||
| void AssignCommunicationNodeMem(int flag, const AnfNodePtr &node); | |||||
| void AssignCommunicationNodeOutputMem(MemType type, const AnfNodePtr &node); | |||||
| void AssignCommunicationNodeInputMem(MemType type, const AnfNodePtr &node); | |||||
| void AssignCommunicationNodeMem(MemType type, const AnfNodePtr &node); | |||||
| #ifdef ENABLE_DUMP_E2E | #ifdef ENABLE_DUMP_E2E | ||||
| bool SetDumpConf(); | bool SetDumpConf(); | ||||
| #endif | #endif | ||||
| @@ -29,7 +29,7 @@ size_t MemoryManager::GetCommunicationAlignSize(size_t input_size) const { | |||||
| return (input_size + kMemAlignSize - 1) / kMemAlignSize * kMemAlignSize + 2 * kMemAlignSize; | return (input_size + kMemAlignSize - 1) / kMemAlignSize * kMemAlignSize + 2 * kMemAlignSize; | ||||
| } | } | ||||
| void MemoryManager::MallocReusedDynamicMem(session::KernelGraph *graph) { | |||||
| void MemoryManager::MallocReusedDynamicMem(const session::KernelGraph *graph) { | |||||
| MS_EXCEPTION_IF_NULL(graph); | MS_EXCEPTION_IF_NULL(graph); | ||||
| MemReuseUtilPtr mem_reuse_util_ptr = std::make_shared<memreuse::MemReuseUtil>(); | MemReuseUtilPtr mem_reuse_util_ptr = std::make_shared<memreuse::MemReuseUtil>(); | ||||
| MS_EXCEPTION_IF_NULL(mem_reuse_util_ptr); | MS_EXCEPTION_IF_NULL(mem_reuse_util_ptr); | ||||
| @@ -45,7 +45,7 @@ void MemoryManager::MallocReusedDynamicMem(session::KernelGraph *graph) { | |||||
| mem_reuse_util_ptr_->set_mem_base(base_ptr); | mem_reuse_util_ptr_->set_mem_base(base_ptr); | ||||
| } | } | ||||
| uint8_t *MemoryManager::MallocOutputMem(const AnfNodePtr &node, size_t index, int flag, size_t size) { | |||||
| uint8_t *MemoryManager::MallocOutputMem(const AnfNodePtr &node, size_t index, MemType type, size_t size) { | |||||
| MS_EXCEPTION_IF_NULL(node); | MS_EXCEPTION_IF_NULL(node); | ||||
| auto context_ptr = MsContext::GetInstance(); | auto context_ptr = MsContext::GetInstance(); | ||||
| MS_EXCEPTION_IF_NULL(context_ptr); | MS_EXCEPTION_IF_NULL(context_ptr); | ||||
| @@ -55,9 +55,9 @@ uint8_t *MemoryManager::MallocOutputMem(const AnfNodePtr &node, size_t index, in | |||||
| if (context_ptr->enable_hccl()) { | if (context_ptr->enable_hccl()) { | ||||
| communication_mem = true; | communication_mem = true; | ||||
| } | } | ||||
| if (flag == kStaticMem) { | |||||
| if (type == kStaticMem) { | |||||
| ptr = MallocStaticMem(size, communication_mem); | ptr = MallocStaticMem(size, communication_mem); | ||||
| } else if (flag == kReuseDynamicCommMem) { | |||||
| } else if (type == kReuseDynamicCommMem) { | |||||
| MS_EXCEPTION_IF_NULL(mem_reuse_util_ptr_); | MS_EXCEPTION_IF_NULL(mem_reuse_util_ptr_); | ||||
| ptr = mem_reuse_util_ptr_->GetNodeOutputPtr(node, index); | ptr = mem_reuse_util_ptr_->GetNodeOutputPtr(node, index); | ||||
| } else { | } else { | ||||
| @@ -66,30 +66,30 @@ uint8_t *MemoryManager::MallocOutputMem(const AnfNodePtr &node, size_t index, in | |||||
| return ptr; | return ptr; | ||||
| } | } | ||||
| if (flag == kStaticMem) { | |||||
| if (type == kStaticMem) { | |||||
| ptr = MallocStaticMem(size, false); | ptr = MallocStaticMem(size, false); | ||||
| } else if (flag == kDynamicMem) { | |||||
| } else if (type == kDynamicMem) { | |||||
| ptr = MallocDynamicMem(size, false); | ptr = MallocDynamicMem(size, false); | ||||
| } else if (flag == kReuseDynamicMem) { | |||||
| } else if (type == kReuseDynamicMem) { | |||||
| MS_EXCEPTION_IF_NULL(mem_reuse_util_ptr_); | MS_EXCEPTION_IF_NULL(mem_reuse_util_ptr_); | ||||
| ptr = mem_reuse_util_ptr_->GetNodeOutputPtr(node, index); | ptr = mem_reuse_util_ptr_->GetNodeOutputPtr(node, index); | ||||
| } | } | ||||
| return ptr; | return ptr; | ||||
| } | } | ||||
| uint8_t *MemoryManager::MallocWorkSpaceMem(const AnfNodePtr &node, size_t index, int flag, size_t size) { | |||||
| if (flag == kReuseDynamicMem) { | |||||
| uint8_t *MemoryManager::MallocWorkSpaceMem(const AnfNodePtr &node, size_t index, MemType type, size_t size) { | |||||
| if (type == kReuseDynamicMem) { | |||||
| MS_EXCEPTION_IF_NULL(mem_reuse_util_ptr_); | MS_EXCEPTION_IF_NULL(mem_reuse_util_ptr_); | ||||
| return mem_reuse_util_ptr_->GetNodeWorkSpacePtr(node, index); | return mem_reuse_util_ptr_->GetNodeWorkSpacePtr(node, index); | ||||
| } | } | ||||
| return MallocDynamicMem(size, false); | return MallocDynamicMem(size, false); | ||||
| } | } | ||||
| uint8_t *MemoryManager::MallocMem(int flag, size_t size) { | |||||
| uint8_t *MemoryManager::MallocMem(MemType type, size_t size) { | |||||
| uint8_t *ptr = nullptr; | uint8_t *ptr = nullptr; | ||||
| if (flag == kStaticMem) { | |||||
| if (type == kStaticMem) { | |||||
| ptr = MallocStaticMem(size, false); | ptr = MallocStaticMem(size, false); | ||||
| } else if (flag == kDynamicMem) { | |||||
| } else if (type == kDynamicMem) { | |||||
| ptr = MallocDynamicMem(size, false); | ptr = MallocDynamicMem(size, false); | ||||
| } | } | ||||
| return ptr; | return ptr; | ||||
| @@ -22,10 +22,7 @@ | |||||
| #include "backend/optimizer/mem_reuse/mem_reuse_allocator.h" | #include "backend/optimizer/mem_reuse/mem_reuse_allocator.h" | ||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace device { | namespace device { | ||||
| const int kStaticMem = 0; | |||||
| const int kDynamicMem = 1; | |||||
| const int kReuseDynamicMem = 2; | |||||
| const int kReuseDynamicCommMem = 3; | |||||
| enum MemType { kStaticMem, kDynamicMem, kReuseDynamicMem, kReuseDynamicCommMem }; | |||||
| const int kGetAllOuts = -1; | const int kGetAllOuts = -1; | ||||
| const uint64_t kMemAlignSize = 512; | const uint64_t kMemAlignSize = 512; | ||||
| using MemReuseUtilPtr = mindspore::memreuse::MemReuseUtilPtr; | using MemReuseUtilPtr = mindspore::memreuse::MemReuseUtilPtr; | ||||
| @@ -42,10 +39,10 @@ class MemoryManager { | |||||
| dynamic_mem_offset_ = 0; | dynamic_mem_offset_ = 0; | ||||
| } | } | ||||
| void MallocReusedDynamicMem(session::KernelGraph *graph); | |||||
| uint8_t *MallocOutputMem(const AnfNodePtr &node, size_t index, int flag, size_t size); | |||||
| uint8_t *MallocWorkSpaceMem(const AnfNodePtr &node, size_t index, int flag, size_t size); | |||||
| virtual uint8_t *MallocMem(int flag, size_t size); | |||||
| void MallocReusedDynamicMem(const session::KernelGraph *graph); | |||||
| uint8_t *MallocOutputMem(const AnfNodePtr &node, size_t index, MemType type, size_t size); | |||||
| uint8_t *MallocWorkSpaceMem(const AnfNodePtr &node, size_t index, MemType type, size_t size); | |||||
| virtual uint8_t *MallocMem(MemType type, size_t size); | |||||
| virtual bool MallocMemFromMemPool(const DeviceAddressPtr address, size_t size); | virtual bool MallocMemFromMemPool(const DeviceAddressPtr address, size_t size); | ||||
| virtual void *MallocMemFromMemPool(size_t size); | virtual void *MallocMemFromMemPool(size_t size); | ||||