Merge pull request !2966 from laiyongqiang/hcom_memreusetags/v0.6.0-beta
| @@ -25,7 +25,8 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace memreuse { | namespace memreuse { | ||||
| enum RefCountType { kDynamicRefCount, kStaticRefCount }; | enum RefCountType { kDynamicRefCount, kStaticRefCount }; | ||||
| enum NodeType { NORMAL, SPECIAL }; | |||||
| enum NodeType { COMMON_NODE, COMMUNICATION_NODE }; | |||||
| enum KernelRefType { COMMON, REFNODE_OUTPUT, COMM_NOTREUSE, COMM_REUSE, SUMMARY }; | |||||
| static constexpr int kInitIndex = -1; | static constexpr int kInitIndex = -1; | ||||
| class KernelRefCount { | class KernelRefCount { | ||||
| public: | public: | ||||
| @@ -36,6 +37,7 @@ class KernelRefCount { | |||||
| size_t offset_; | size_t offset_; | ||||
| size_t size_; | size_t size_; | ||||
| int index_; | int index_; | ||||
| KernelRefType type_; | |||||
| // remember to reset offset | // remember to reset offset | ||||
| KernelRefCount() | KernelRefCount() | ||||
| : stream_id_(0), | : stream_id_(0), | ||||
| @@ -44,6 +46,7 @@ class KernelRefCount { | |||||
| offset_(0), | offset_(0), | ||||
| size_(0), | size_(0), | ||||
| index_(kInitIndex), | index_(kInitIndex), | ||||
| type_(COMMON), | |||||
| reftype_(kStaticRefCount) {} | reftype_(kStaticRefCount) {} | ||||
| ~KernelRefCount() = default; | ~KernelRefCount() = default; | ||||
| void SetKernelRefCountInfo(int index, size_t size, RefCountType reftype); | void SetKernelRefCountInfo(int index, size_t size, RefCountType reftype); | ||||
| @@ -65,7 +68,7 @@ class KernelDef { | |||||
| KernelMap inputs_; | KernelMap inputs_; | ||||
| KernelMap outputs_; | KernelMap outputs_; | ||||
| KernelMap wk_space_; | KernelMap wk_space_; | ||||
| NodeType dirty = NORMAL; | |||||
| NodeType type_ = COMMON_NODE; | |||||
| KernelDef() = default; | KernelDef() = default; | ||||
| ~KernelDef() = default; | ~KernelDef() = default; | ||||
| void set_input_refs(const KernelRefCountPtrList &kernelRefPtrList) { input_refs_ = kernelRefPtrList; } | void set_input_refs(const KernelRefCountPtrList &kernelRefPtrList) { input_refs_ = kernelRefPtrList; } | ||||
| @@ -46,6 +46,8 @@ bool MemReuseUtil::InitDynamicOutputKernelRef() { | |||||
| if (iter == kernel_output_refs_.end()) { | if (iter == kernel_output_refs_.end()) { | ||||
| auto output_sizes = kernel_mod->GetOutputSizeList(); | auto output_sizes = kernel_mod->GetOutputSizeList(); | ||||
| KernelRefCountPtrList kernel_refs; | KernelRefCountPtrList kernel_refs; | ||||
| bool is_comm_op = AnfAlgo::IsCommunicationOp(kernel_cnode); | |||||
| size_t output_index = 0; | |||||
| for (auto size : output_sizes) { | for (auto size : output_sizes) { | ||||
| total_dy_size_ += size; | total_dy_size_ += size; | ||||
| // do not MallocDynamicMem just record this | // do not MallocDynamicMem just record this | ||||
| @@ -54,9 +56,20 @@ bool MemReuseUtil::InitDynamicOutputKernelRef() { | |||||
| auto curr_stream_id = AnfAlgo::GetStreamId(kernel_cnode); | auto curr_stream_id = AnfAlgo::GetStreamId(kernel_cnode); | ||||
| kernel_ref->stream_id_ = curr_stream_id; | kernel_ref->stream_id_ = curr_stream_id; | ||||
| kernel_ref->SetKernelRefCountInfo(index, size, kDynamicRefCount); | kernel_ref->SetKernelRefCountInfo(index, size, kDynamicRefCount); | ||||
| if (is_comm_op) { | |||||
| kernel_ref->type_ = COMM_REUSE; | |||||
| } else { | |||||
| session::AnfWithOutIndex out_pair(kernel_cnode, output_index); | |||||
| if (graph_->IsInRefOutputMap(out_pair)) { | |||||
| kernel_ref->type_ = REFNODE_OUTPUT; | |||||
| } else { | |||||
| kernel_ref->type_ = COMMON; | |||||
| } | |||||
| } | |||||
| kernel_refs.push_back(kernel_ref); | kernel_refs.push_back(kernel_ref); | ||||
| kernel_out_ref_num++; | kernel_out_ref_num++; | ||||
| total_refs_list_.push_back(kernel_ref); | total_refs_list_.push_back(kernel_ref); | ||||
| output_index++; | |||||
| } | } | ||||
| if (!kernel_refs.empty()) { | if (!kernel_refs.empty()) { | ||||
| kernel_output_refs_[key] = kernel_refs; | kernel_output_refs_[key] = kernel_refs; | ||||
| @@ -155,9 +168,19 @@ void MemReuseUtil::SetInputMap(const CNodePtr &kernel, KernelDef *kernel_def_ptr | |||||
| MS_EXCEPTION_IF_NULL(kernel); | MS_EXCEPTION_IF_NULL(kernel); | ||||
| MS_EXCEPTION_IF_NULL(kernel_def_ptr); | MS_EXCEPTION_IF_NULL(kernel_def_ptr); | ||||
| auto key = kernel.get(); | auto key = kernel.get(); | ||||
| for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(kernel); ++i) { | |||||
| bool is_comm_op = AnfAlgo::IsCommunicationOp(kernel); | |||||
| size_t input_tensor_num = AnfAlgo::GetInputTensorNum(kernel); | |||||
| for (size_t i = 0; i < input_tensor_num; ++i) { | |||||
| auto ref_ptr = GetKernelInputRef(kernel, i); | auto ref_ptr = GetKernelInputRef(kernel, i); | ||||
| if (ref_ptr != nullptr) { | if (ref_ptr != nullptr) { | ||||
| if (is_comm_op) { | |||||
| if (input_tensor_num == 1) { | |||||
| ref_ptr->type_ = COMM_REUSE; | |||||
| } else { | |||||
| ref_ptr->type_ = COMM_NOTREUSE; | |||||
| } | |||||
| } | |||||
| if (ref_ptr->reftype() == kStaticRefCount) { | if (ref_ptr->reftype() == kStaticRefCount) { | ||||
| continue; | continue; | ||||
| } else if (ref_ptr->reftype() == kDynamicRefCount) { | } else if (ref_ptr->reftype() == kDynamicRefCount) { | ||||
| @@ -258,6 +281,11 @@ void MemReuseUtil::SetKernelDefMap() { | |||||
| auto key = kernel.get(); | auto key = kernel.get(); | ||||
| kernel_def_ptr->set_input_refs(kernel_def_ptr->inputs_[key]); | kernel_def_ptr->set_input_refs(kernel_def_ptr->inputs_[key]); | ||||
| kernel_def_ptr->set_output_refs(kernel_def_ptr->outputs_[key]); | kernel_def_ptr->set_output_refs(kernel_def_ptr->outputs_[key]); | ||||
| if (AnfAlgo::IsCommunicationOp(kernel)) { | |||||
| kernel_def_ptr->type_ = COMMUNICATION_NODE; | |||||
| } else { | |||||
| kernel_def_ptr->type_ = COMMON_NODE; | |||||
| } | |||||
| kernel_def_ptr_list_.push_back(kernel_def_ptr); | kernel_def_ptr_list_.push_back(kernel_def_ptr); | ||||
| kernel_map_[key] = kernel_def_ptr; | kernel_map_[key] = kernel_def_ptr; | ||||
| } | } | ||||
| @@ -337,6 +365,7 @@ void MemReuseUtil::SetSummaryNodesRefCount() { | |||||
| KernelRefCountPtr kernel_ref = kernel_output_refs_[node.get()][index]; | KernelRefCountPtr kernel_ref = kernel_output_refs_[node.get()][index]; | ||||
| kernel_ref->ref_count_ = kMaxRefCount; | kernel_ref->ref_count_ = kMaxRefCount; | ||||
| kernel_ref->ref_count_dynamic_use_ = kMaxRefCount; | kernel_ref->ref_count_dynamic_use_ = kMaxRefCount; | ||||
| kernel_ref->type_ = SUMMARY; | |||||
| total_summary_size += kernel_ref->size_; | total_summary_size += kernel_ref->size_; | ||||
| MS_LOG(INFO) << "Set summary node's ref count, node: " << node->fullname_with_scope() << " index: " << index; | MS_LOG(INFO) << "Set summary node's ref count, node: " << node->fullname_with_scope() << " index: " << index; | ||||
| } else { | } else { | ||||
| @@ -33,11 +33,11 @@ void BestFitMemReuse::InitMemReuseInfo(const MemReuseUtil *mem_reuse_util_ptr) { | |||||
| set_op_ptr_list(mem_reuse_util_ptr->kernel_def_ptr_list()); | set_op_ptr_list(mem_reuse_util_ptr->kernel_def_ptr_list()); | ||||
| // check info Correctness | // check info Correctness | ||||
| for (auto &tensor : tensor_ptr_list_) { | for (auto &tensor : tensor_ptr_list_) { | ||||
| tensor->size_ = AlignMemorySize(tensor->size_); | |||||
| tensor->size_ = AlignCommonMemorySize(tensor->size_); | |||||
| } | } | ||||
| // align wk size to 512 && refcount == 1 | // align wk size to 512 && refcount == 1 | ||||
| for (auto &wk : wk_tensor_list_) { | for (auto &wk : wk_tensor_list_) { | ||||
| wk->size_ = AlignMemorySize(wk->size_); | |||||
| wk->size_ = AlignCommonMemorySize(wk->size_); | |||||
| wk->ref_count_ = 1; | wk->ref_count_ = 1; | ||||
| } | } | ||||
| #ifdef ENABLE_D | #ifdef ENABLE_D | ||||
| @@ -135,11 +135,23 @@ bool BestFitMemReuse::IsUsable(const KernelDefPtr &kernel_curr, const MembufPtr | |||||
| return false; | return false; | ||||
| } | } | ||||
| void BestFitMemReuse::AssignNodeOutputOffset() { | |||||
| void BestFitMemReuse::AssignCommonNodeOutputOffset() { | |||||
| MS_EXCEPTION_IF_NULL(current_kernel_); | |||||
| for (auto &tensor_idx : current_kernel_->GetOutputRefIndexs()) { | for (auto &tensor_idx : current_kernel_->GetOutputRefIndexs()) { | ||||
| size_t index = GetTensorIndex(tensor_idx); | size_t index = GetTensorIndex(tensor_idx); | ||||
| auto tensor_desc = tensor_ptr_list_[index]; | auto tensor_desc = tensor_ptr_list_[index]; | ||||
| MS_EXCEPTION_IF_NULL(tensor_desc); | MS_EXCEPTION_IF_NULL(tensor_desc); | ||||
| if (tensor_desc->type_ == REFNODE_OUTPUT) { | |||||
| total_refoutput_size += tensor_desc->size_; | |||||
| continue; | |||||
| } else if (tensor_desc->type_ == COMM_NOTREUSE) { | |||||
| total_comm_not_reuse_size += tensor_desc->size_; | |||||
| } else if (tensor_desc->type_ == COMM_REUSE) { | |||||
| // get align size for communication op's single input | |||||
| tensor_desc->size_ = AlignCommunicationMemorySize(tensor_desc->size_); | |||||
| total_comm_reuse_size += tensor_desc->size_; | |||||
| } | |||||
| auto reusable_membuf_map = GetReusableMembufMap(tensor_desc->size_); | auto reusable_membuf_map = GetReusableMembufMap(tensor_desc->size_); | ||||
| if (!reusable_membuf_map.empty()) { | if (!reusable_membuf_map.empty()) { | ||||
| auto membuf_index = reusable_membuf_map.begin()->second; | auto membuf_index = reusable_membuf_map.begin()->second; | ||||
| @@ -150,11 +162,91 @@ void BestFitMemReuse::AssignNodeOutputOffset() { | |||||
| AddNewMembufPtr(tensor_desc.get(), kDynamicMem); | AddNewMembufPtr(tensor_desc.get(), kDynamicMem); | ||||
| #ifdef MEM_REUSE_DEBUG | #ifdef MEM_REUSE_DEBUG | ||||
| MemReuseChecker::GetInstance().IsAddNewMembuf_ = true; | MemReuseChecker::GetInstance().IsAddNewMembuf_ = true; | ||||
| #endif | |||||
| } | |||||
| // skip left align border for communication op single input to used | |||||
| if (tensor_desc->type_ == COMM_REUSE) { | |||||
| tensor_desc->offset_ += kDefaultMemAlignSize; | |||||
| } | |||||
| } | |||||
| } | |||||
| void BestFitMemReuse::AssignCommunicationNodeOutputOffset() { | |||||
| size_t total_kernel_output_size = 0; | |||||
| size_t output_num = 0; | |||||
| // get all output size | |||||
| MS_EXCEPTION_IF_NULL(current_kernel_); | |||||
| for (auto &tensor_idx : current_kernel_->GetOutputRefIndexs()) { | |||||
| size_t index = GetTensorIndex(tensor_idx); | |||||
| auto tensor_desc = tensor_ptr_list_[index]; | |||||
| MS_EXCEPTION_IF_NULL(tensor_desc); | |||||
| if (tensor_desc->type_ == COMM_REUSE) { | |||||
| total_comm_reuse_size += tensor_desc->size_; | |||||
| total_comm_output_reuse_size += tensor_desc->size_; | |||||
| total_kernel_output_size += tensor_desc->size_; | |||||
| } else { | |||||
| MS_LOG(ERROR) << "All communication op's outputs should be memory reuse, Kernel:" | |||||
| << current_kernel_->scope_full_name(); | |||||
| continue; | |||||
| } | |||||
| } | |||||
| total_kernel_output_size = AlignCommunicationMemorySize(total_kernel_output_size); | |||||
| // add left align border for the first output and right align border for the last output to alloc align border memory | |||||
| size_t output_index = 0; | |||||
| for (auto &tensor_idx : current_kernel_->GetOutputRefIndexs()) { | |||||
| size_t index = GetTensorIndex(tensor_idx); | |||||
| auto tensor_desc = tensor_ptr_list_[index]; | |||||
| MS_EXCEPTION_IF_NULL(tensor_desc); | |||||
| if (output_index == 0 || output_index == output_num - 1) { | |||||
| tensor_desc->size_ += kDefaultMemAlignSize; | |||||
| } | |||||
| output_index++; | |||||
| } | |||||
| auto reusable_membuf_map = GetReusableMembufMap(total_kernel_output_size); | |||||
| if (!reusable_membuf_map.empty()) { | |||||
| auto membuf_index = reusable_membuf_map.begin()->second; | |||||
| output_index = 0; | |||||
| for (auto &tensor_idx : current_kernel_->GetOutputRefIndexs()) { | |||||
| size_t index = GetTensorIndex(tensor_idx); | |||||
| auto tensor_desc = tensor_ptr_list_[index]; | |||||
| MS_EXCEPTION_IF_NULL(tensor_desc); | |||||
| ReuseExistMembuf(tensor_desc.get(), membuf_index + output_index, kDynamicMem); | |||||
| // skip skip left align border for communication op's first output to used | |||||
| if (output_index == 0) { | |||||
| tensor_desc->offset_ += kDefaultMemAlignSize; | |||||
| } | |||||
| output_index++; | |||||
| } | |||||
| } else { | |||||
| // no membuf can reuse, add new membuf after the membuf_ptr_list | |||||
| output_index = 0; | |||||
| for (auto &tensor_idx : current_kernel_->GetOutputRefIndexs()) { | |||||
| size_t index = GetTensorIndex(tensor_idx); | |||||
| auto tensor_desc = tensor_ptr_list_[index]; | |||||
| MS_EXCEPTION_IF_NULL(tensor_desc); | |||||
| AddNewMembufPtr(tensor_desc.get(), kDynamicMem); | |||||
| // skip align size offset for first output to used | |||||
| if (output_index == 0) { | |||||
| tensor_desc->offset_ += kDefaultMemAlignSize; | |||||
| } | |||||
| output_index++; | |||||
| #ifdef MEM_REUSE_DEBUG | |||||
| MemReuseChecker::GetInstance().IsAddNewMembuf_ = true; | |||||
| #endif | #endif | ||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| void BestFitMemReuse::AssignNodeOutputOffset() { | |||||
| if (current_kernel_->type_ == COMMUNICATION_NODE) { | |||||
| AssignCommunicationNodeOutputOffset(); | |||||
| } else { | |||||
| AssignCommonNodeOutputOffset(); | |||||
| } | |||||
| } | |||||
| void BestFitMemReuse::AssignNodeWorkspaceOffset() { | void BestFitMemReuse::AssignNodeWorkspaceOffset() { | ||||
| for (auto &wk_idx : current_kernel_->GetWorkspaceRefIndexs()) { | for (auto &wk_idx : current_kernel_->GetWorkspaceRefIndexs()) { | ||||
| size_t index = GetWorkspaceIndex(wk_idx); | size_t index = GetWorkspaceIndex(wk_idx); | ||||
| @@ -319,11 +411,17 @@ void BestFitMemReuse::ReleaseMembuf(size_t tensor_index, int flag) { | |||||
| } | } | ||||
| } | } | ||||
| size_t BestFitMemReuse::AlignMemorySize(size_t size) const { | |||||
| size_t BestFitMemReuse::AlignCommonMemorySize(size_t size) const { | |||||
| // memory size 512 align | // memory size 512 align | ||||
| return (size + kDefaultMemAlignSize + kAttAlignSize) / kDefaultMemAlignSize * kDefaultMemAlignSize; | return (size + kDefaultMemAlignSize + kAttAlignSize) / kDefaultMemAlignSize * kDefaultMemAlignSize; | ||||
| } | } | ||||
| size_t BestFitMemReuse::AlignCommunicationMemorySize(size_t size) const { | |||||
| // memory size 512 align and add communication memory: left align border memory - data - right align border memory | |||||
| return kDefaultMemAlignSize + (size + kDefaultMemAlignSize - 1) / kDefaultMemAlignSize * kDefaultMemAlignSize + | |||||
| kDefaultMemAlignSize; | |||||
| } | |||||
| size_t BestFitMemReuse::GetAllocatedSize() { | size_t BestFitMemReuse::GetAllocatedSize() { | ||||
| size_t AllocatedSize = kTotalSize; | size_t AllocatedSize = kTotalSize; | ||||
| if (membuf_ptr_list_.empty()) { | if (membuf_ptr_list_.empty()) { | ||||
| @@ -412,6 +510,9 @@ void BestFitMemReuse::Reuse(const MemReuseUtil *mem_reuse_util_ptr) { | |||||
| ++op_num; | ++op_num; | ||||
| #endif | #endif | ||||
| } | } | ||||
| MS_LOG(INFO) << "Special Tensor total size: RefOutput: " << total_refoutput_size | |||||
| << " CommReuse: " << total_comm_reuse_size << " CommOutputReuse: " << total_comm_output_reuse_size | |||||
| << " CommNotReuse: " << total_comm_not_reuse_size; | |||||
| #ifdef MEM_REUSE_DEBUG | #ifdef MEM_REUSE_DEBUG | ||||
| MemReuseChecker::GetInstance().ExportMembufInfoIR(); | MemReuseChecker::GetInstance().ExportMembufInfoIR(); | ||||
| MemReuseChecker::GetInstance().ExportAddNewMmebufIR(); | MemReuseChecker::GetInstance().ExportAddNewMmebufIR(); | ||||
| @@ -74,6 +74,14 @@ class BestFitMemReuse { | |||||
| * Assign output tensor memory offset of current kernel | * Assign output tensor memory offset of current kernel | ||||
| */ | */ | ||||
| void AssignNodeOutputOffset(); | void AssignNodeOutputOffset(); | ||||
| /** | |||||
| * Assign output tensor memory offset of common kernel | |||||
| */ | |||||
| void AssignCommonNodeOutputOffset(); | |||||
| /** | |||||
| * Assign output tensor memory offset of communication kernel | |||||
| */ | |||||
| void AssignCommunicationNodeOutputOffset(); | |||||
| /** | /** | ||||
| * Update input tensor's status of current kernel, and the status of membuf used by current kernel | * Update input tensor's status of current kernel, and the status of membuf used by current kernel | ||||
| */ | */ | ||||
| @@ -110,8 +118,10 @@ class BestFitMemReuse { | |||||
| void AddNewMembufPtr(KernelRefCount *tensor_desc, int flag); | void AddNewMembufPtr(KernelRefCount *tensor_desc, int flag); | ||||
| // Merge unused membuf | // Merge unused membuf | ||||
| void ReleaseMembuf(size_t tensor_index, int flag); | void ReleaseMembuf(size_t tensor_index, int flag); | ||||
| // Memory address alignment 512 | |||||
| size_t AlignMemorySize(size_t size) const; | |||||
| // Memory address alignment for common memory | |||||
| size_t AlignCommonMemorySize(size_t size) const; | |||||
| // Memory address alignment for communication used memory | |||||
| size_t AlignCommunicationMemorySize(size_t size) const; | |||||
| int GetRealIndex(size_t index, int flag = kDynamicMem) const; | int GetRealIndex(size_t index, int flag = kDynamicMem) const; | ||||
| size_t GetTensorIndex(int index) const; | size_t GetTensorIndex(int index) const; | ||||
| size_t GetWorkspaceIndex(int index) const; | size_t GetWorkspaceIndex(int index) const; | ||||
| @@ -153,6 +163,10 @@ class BestFitMemReuse { | |||||
| // kernel_front_map_, key: the kernel_def, value: kernels before this kernel_def | // kernel_front_map_, key: the kernel_def, value: kernels before this kernel_def | ||||
| std::map<KernelDefPtr, std::set<KernelDefPtr>> kernel_front_map_; | std::map<KernelDefPtr, std::set<KernelDefPtr>> kernel_front_map_; | ||||
| std::vector<std::vector<uint32_t>> stream_groups_; | std::vector<std::vector<uint32_t>> stream_groups_; | ||||
| size_t total_refoutput_size{0}; | |||||
| size_t total_comm_reuse_size{0}; | |||||
| size_t total_comm_output_reuse_size{0}; | |||||
| size_t total_comm_not_reuse_size{0}; | |||||
| }; | }; | ||||
| } // namespace memreuse | } // namespace memreuse | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -170,12 +170,14 @@ void MemReuseChecker::CheckMemReuseIR(const KernelRefCountPtrList &total_refs_li | |||||
| ofs << "all_tensor_refs:\n"; | ofs << "all_tensor_refs:\n"; | ||||
| ofs << "index:" | ofs << "index:" | ||||
| << "\tsize:" | << "\tsize:" | ||||
| << "\trefcount:\n"; | |||||
| << "\trefcount:" | |||||
| << "\ttype:\n"; | |||||
| for (auto &ref : total_refs_list) { | for (auto &ref : total_refs_list) { | ||||
| ofs << "%" << ref->index_ << "T" | ofs << "%" << ref->index_ << "T" | ||||
| << "\t" | << "\t" | ||||
| << "#" << ref->size_ << "S" | << "#" << ref->size_ << "S" | ||||
| << "\t" << ref->ref_count_ << "C" | << "\t" << ref->ref_count_ << "C" | ||||
| << "\t" << ref->type_ << "t" | |||||
| << "\n"; | << "\n"; | ||||
| } | } | ||||
| ofs << "kernel_def exc_order:\n"; | ofs << "kernel_def exc_order:\n"; | ||||
| @@ -241,7 +243,7 @@ bool MemReuseChecker::CheckGraphOutputAssigned(const session::KernelGraph *graph | |||||
| void MemReuseChecker::ExportMemOpIr(const KernelDef *def, std::ofstream &ofs, int def_idx) { | void MemReuseChecker::ExportMemOpIr(const KernelDef *def, std::ofstream &ofs, int def_idx) { | ||||
| auto scope_name = def->scope_full_name(); | auto scope_name = def->scope_full_name(); | ||||
| std::string split_name = GetSplitName(scope_name); | std::string split_name = GetSplitName(scope_name); | ||||
| ofs << "$" << def_idx << "\t" << split_name << "\t"; | |||||
| ofs << "$" << def_idx << "\t" << split_name << "\t" << static_cast<int>(def->type_) << "\t"; | |||||
| ofs << "inputs["; | ofs << "inputs["; | ||||
| for (auto &in : def->inputs_) { | for (auto &in : def->inputs_) { | ||||
| for (auto &in_ref : in.second) { | for (auto &in_ref : in.second) { | ||||
| @@ -95,6 +95,12 @@ uint8_t *AscendMemoryManager::MallocStaticMem(size_t size, bool communication_me | |||||
| } else { | } else { | ||||
| align_size = GetCommonAlignSize(size); | align_size = GetCommonAlignSize(size); | ||||
| } | } | ||||
| auto device_mem_pool_offset = AscendMemoryPool::GetInstance().device_mem_pool_offset(); | |||||
| MS_LOG(INFO) << "Malloc Memory: Static, total[" << device_mem_size_ << "] (dynamic[" << total_dynamic_size_ | |||||
| << "] memory pool[" << device_mem_pool_offset << "])" | |||||
| << " malloc [" << align_size << "]"; | |||||
| if (communication_mem) { | if (communication_mem) { | ||||
| // create protect area [kMemAlignSize -- data -- kMemAlignSize] | // create protect area [kMemAlignSize -- data -- kMemAlignSize] | ||||
| uint8_t *alloc_address = reinterpret_cast<uint8_t *>(AscendMemoryPool::GetInstance().AllocTensorMem(align_size)); | uint8_t *alloc_address = reinterpret_cast<uint8_t *>(AscendMemoryPool::GetInstance().AllocTensorMem(align_size)); | ||||
| @@ -111,12 +117,17 @@ uint8_t *AscendMemoryManager::MallocDynamicMem(size_t size, bool communication_m | |||||
| } else { | } else { | ||||
| align_size = GetCommonAlignSize(size); | align_size = GetCommonAlignSize(size); | ||||
| } | } | ||||
| auto device_mem_pool_offset = AscendMemoryPool::GetInstance().device_mem_pool_offset(); | |||||
| MS_LOG(INFO) << "Malloc Memory: Dynamic, total[" << device_mem_size_ << "] (dynamic[" << total_dynamic_size_ | |||||
| << "] memory pool[" << device_mem_pool_offset << "])" | |||||
| << " malloc [" << align_size << "]"; | |||||
| if (dynamic_mem_offset_ < align_size) { | if (dynamic_mem_offset_ < align_size) { | ||||
| MS_LOG(EXCEPTION) << "Out of memory!!! total[" << device_mem_size_ << "] (dynamic[" << total_dynamic_size_ | MS_LOG(EXCEPTION) << "Out of memory!!! total[" << device_mem_size_ << "] (dynamic[" << total_dynamic_size_ | ||||
| << "]) malloc [" << align_size << "] failed!"; | << "]) malloc [" << align_size << "] failed!"; | ||||
| } | } | ||||
| auto new_offset = dynamic_mem_offset_ - align_size; | auto new_offset = dynamic_mem_offset_ - align_size; | ||||
| auto device_mem_pool_offset = AscendMemoryPool::GetInstance().device_mem_pool_offset(); | |||||
| if (new_offset <= device_mem_pool_offset) { | if (new_offset <= device_mem_pool_offset) { | ||||
| MS_LOG(EXCEPTION) << "Out of memory!!! total[" << device_mem_size_ << "] (dynamic[" << total_dynamic_size_ | MS_LOG(EXCEPTION) << "Out of memory!!! total[" << device_mem_size_ << "] (dynamic[" << total_dynamic_size_ | ||||
| << "] memory pool[" << device_mem_pool_offset << "])" | << "] memory pool[" << device_mem_pool_offset << "])" | ||||
| @@ -399,7 +399,7 @@ void KernelRuntime::UpdateRefNodeOutputMem(const session::KernelGraph *graph) { | |||||
| } | } | ||||
| void KernelRuntime::AssignCommunicationNodeMem(int flag, const AnfNodePtr &node) { | void KernelRuntime::AssignCommunicationNodeMem(int flag, const AnfNodePtr &node) { | ||||
| AssignCommunicationNodeInputMem(node); | |||||
| AssignCommunicationNodeInputMem(flag, node); | |||||
| AssignCommunicationNodeOutputMem(flag, node); | AssignCommunicationNodeOutputMem(flag, node); | ||||
| } | } | ||||
| @@ -429,6 +429,11 @@ void KernelRuntime::AssignCommunicationNodeOutputMem(int flag, const AnfNodePtr | |||||
| total_size += mem_size; | total_size += mem_size; | ||||
| align_size_list.emplace_back(mem_size); | align_size_list.emplace_back(mem_size); | ||||
| } | } | ||||
| if (flag == kReuseDynamicMem) { | |||||
| // reuse communication op's all outputs' memory | |||||
| flag = kReuseDynamicCommMem; | |||||
| } | |||||
| uint8_t *output_ptr = mem_manager_->MallocOutputMem(node, 0, flag, total_size); | uint8_t *output_ptr = mem_manager_->MallocOutputMem(node, 0, flag, total_size); | ||||
| for (size_t j = 0; j < align_size_list.size(); ++j) { | for (size_t j = 0; j < align_size_list.size(); ++j) { | ||||
| std::string output_format = AnfAlgo::GetOutputFormat(node, j); | std::string output_format = AnfAlgo::GetOutputFormat(node, j); | ||||
| @@ -457,7 +462,7 @@ DeviceAddressPtr KernelRuntime::PreAssignCNodeMemory(const AnfNodePtr &anf_node, | |||||
| return address; | return address; | ||||
| } | } | ||||
| void KernelRuntime::AssignCommunicationNodeInputMem(const AnfNodePtr &node) { | |||||
| void KernelRuntime::AssignCommunicationNodeInputMem(int flag, const AnfNodePtr &node) { | |||||
| auto context_ptr = MsContext::GetInstance(); | auto context_ptr = MsContext::GetInstance(); | ||||
| MS_EXCEPTION_IF_NULL(context_ptr); | MS_EXCEPTION_IF_NULL(context_ptr); | ||||
| MS_EXCEPTION_IF_NULL(node); | MS_EXCEPTION_IF_NULL(node); | ||||
| @@ -478,7 +483,7 @@ void KernelRuntime::AssignCommunicationNodeInputMem(const AnfNodePtr &node) { | |||||
| total_size += mem_size; | total_size += mem_size; | ||||
| addr_size.emplace_back(address.get(), mem_size); | addr_size.emplace_back(address.get(), mem_size); | ||||
| } | } | ||||
| uint8_t *input_ptr = mem_manager_->MallocOutputMem(node, 0, kDynamicMem, total_size); | |||||
| uint8_t *input_ptr = mem_manager_->MallocOutputMem(node, 0, flag, total_size); | |||||
| for (const auto &iter : addr_size) { | for (const auto &iter : addr_size) { | ||||
| MS_EXCEPTION_IF_NULL(iter.first); | MS_EXCEPTION_IF_NULL(iter.first); | ||||
| iter.first->set_ptr(input_ptr); | iter.first->set_ptr(input_ptr); | ||||
| @@ -88,7 +88,7 @@ class KernelRuntime { | |||||
| void UpdateRefNodeOutputMem(const session::KernelGraph *graph); | void UpdateRefNodeOutputMem(const session::KernelGraph *graph); | ||||
| void AssignCommunicationNodeOutputMem(int flag, const AnfNodePtr &node); | void AssignCommunicationNodeOutputMem(int flag, const AnfNodePtr &node); | ||||
| void AssignCommunicationNodeInputMem(const AnfNodePtr &node); | |||||
| void AssignCommunicationNodeInputMem(int flag, const AnfNodePtr &node); | |||||
| void AssignCommunicationNodeMem(int flag, const AnfNodePtr &node); | void AssignCommunicationNodeMem(int flag, const AnfNodePtr &node); | ||||
| #ifdef ENABLE_DUMP_E2E | #ifdef ENABLE_DUMP_E2E | ||||
| bool SetDumpConf(); | bool SetDumpConf(); | ||||
| @@ -57,6 +57,9 @@ uint8_t *MemoryManager::MallocOutputMem(const AnfNodePtr &node, size_t index, in | |||||
| } | } | ||||
| if (flag == kStaticMem) { | if (flag == kStaticMem) { | ||||
| ptr = MallocStaticMem(size, communication_mem); | ptr = MallocStaticMem(size, communication_mem); | ||||
| } else if (flag == kReuseDynamicCommMem) { | |||||
| MS_EXCEPTION_IF_NULL(mem_reuse_util_ptr_); | |||||
| ptr = mem_reuse_util_ptr_->GetNodeOutputPtr(node, index); | |||||
| } else { | } else { | ||||
| ptr = MallocDynamicMem(size, communication_mem); | ptr = MallocDynamicMem(size, communication_mem); | ||||
| } | } | ||||
| @@ -25,6 +25,7 @@ namespace device { | |||||
| const int kStaticMem = 0; | const int kStaticMem = 0; | ||||
| const int kDynamicMem = 1; | const int kDynamicMem = 1; | ||||
| const int kReuseDynamicMem = 2; | const int kReuseDynamicMem = 2; | ||||
| const int kReuseDynamicCommMem = 3; | |||||
| const int kGetAllOuts = -1; | const int kGetAllOuts = -1; | ||||
| const uint64_t kMemAlignSize = 512; | const uint64_t kMemAlignSize = 512; | ||||
| using MemReuseUtilPtr = mindspore::memreuse::MemReuseUtilPtr; | using MemReuseUtilPtr = mindspore::memreuse::MemReuseUtilPtr; | ||||
| @@ -146,7 +146,7 @@ TEST_F(TestMemReuseAllocator, mem_reuse_allocator_split_membuf) { | |||||
| TEST_F(TestMemReuseAllocator, mem_reuse_allocator_align) { | TEST_F(TestMemReuseAllocator, mem_reuse_allocator_align) { | ||||
| auto best_fit_mem_reuse = std::make_shared<BestFitMemReuse>(); | auto best_fit_mem_reuse = std::make_shared<BestFitMemReuse>(); | ||||
| auto size = best_fit_mem_reuse->AlignMemorySize(510); | |||||
| auto size = best_fit_mem_reuse->AlignCommonMemorySize(510); | |||||
| ASSERT_EQ(size, 1024); | ASSERT_EQ(size, 1024); | ||||
| } | } | ||||
| } // namespace memreuse | } // namespace memreuse | ||||
| @@ -225,7 +225,6 @@ TEST_F(TestMemReuseWithPy, KernelRef) { | |||||
| ASSERT_EQ(kernel_ref_count_ptr->size_, 512); | ASSERT_EQ(kernel_ref_count_ptr->size_, 512); | ||||
| KernelDefPtr kernel_def_ptr = std::make_shared<KernelDef>(); | KernelDefPtr kernel_def_ptr = std::make_shared<KernelDef>(); | ||||
| ASSERT_NE(kernel_def_ptr, nullptr); | ASSERT_NE(kernel_def_ptr, nullptr); | ||||
| ASSERT_EQ(kernel_def_ptr->dirty, false); | |||||
| MembufPtr membuf_ptr = std::make_shared<Membuf>(); | MembufPtr membuf_ptr = std::make_shared<Membuf>(); | ||||
| ASSERT_NE(membuf_ptr, nullptr); | ASSERT_NE(membuf_ptr, nullptr); | ||||
| } | } | ||||