Merge pull request !2754 from laiyongqiang/mem_opttags/v0.6.0-beta
| @@ -99,6 +99,11 @@ uint8_t *MemoryManager::MallocStaticMem(size_t size, bool communication_mem) { | |||
| } else { | |||
| align_size = GetCommonAlignSize(size); | |||
| } | |||
| MS_LOG(INFO) << "Malloc Memory for Static: total[" << device_mem_size_ << "](dynamic[" << total_dynamic_size_ | |||
| << "] static[" << total_static_size_ << "])" | |||
| << " malloc [" << align_size << "] communication_mem: " << communication_mem; | |||
| if (static_mem_offset_ < align_size) { | |||
| MS_LOG(EXCEPTION) << "Out of memory!!! total[" << device_mem_size_ << "](dynamic[" << total_dynamic_size_ | |||
| << "] static[" << total_static_size_ << "])" | |||
| @@ -126,6 +131,11 @@ uint8_t *MemoryManager::MallocDynamicMem(size_t size, bool communication_mem) { | |||
| } else { | |||
| align_size = GetCommonAlignSize(size); | |||
| } | |||
| MS_LOG(INFO) << "Malloc Memory for Dynamic: total[" << device_mem_size_ << "](dynamic[" << total_dynamic_size_ | |||
| << "] static[" << total_static_size_ << "])" | |||
| << " malloc [" << align_size << "] communication_mem: " << communication_mem; | |||
| uint64_t offset = dynamic_mem_offset_; | |||
| auto new_offset = dynamic_mem_offset_ + align_size; | |||
| if (new_offset > static_mem_offset_) { | |||
| @@ -329,22 +329,25 @@ void MemReuseUtil::SetSummaryNodesRefCount() { | |||
| return; | |||
| } | |||
| size_t total_summary_size = 0; | |||
| for (auto &node_item : summary_nodes) { | |||
| auto node = node_item.second.first; | |||
| size_t index = IntToSize(node_item.second.second); | |||
| MS_LOG(INFO) << "set summary node's ref count, node: " << node->fullname_with_scope() << " index: " << index; | |||
| if (kernel_output_refs_.find(node.get()) != kernel_output_refs_.end()) { | |||
| KernelRefCountPtr kernel_ref = kernel_output_refs_[node.get()][index]; | |||
| kernel_ref->ref_count_ = kMaxRefCount; | |||
| kernel_ref->ref_count_dynamic_use_ = kMaxRefCount; | |||
| total_summary_size += kernel_ref->size_; | |||
| MS_LOG(INFO) << "Set summary node's ref count, node: " << node->fullname_with_scope() << " index: " << index; | |||
| } else { | |||
| MS_LOG(WARNING) << "can't find summary node's kernel_def " << node->fullname_with_scope(); | |||
| MS_LOG(WARNING) << "Can't find summary node's kernel_def " << node->fullname_with_scope() << " index: " << index; | |||
| } | |||
| } | |||
| #ifdef MEM_REUSE_DEBUG | |||
| auto graph = *graph_; | |||
| MemReuseChecker::GetInstance().CheckMemReuseIR(total_refs_list_, kernel_def_ptr_list_, &graph); | |||
| #endif | |||
| MS_LOG(INFO) << "Special Tensor total size: SummaryNodes: " << total_summary_size; | |||
| } | |||
| void MemReuseUtil::SetGraphOutputRefCount() { | |||
| @@ -17,6 +17,9 @@ | |||
| #include "pre_activate/mem_reuse/mem_reuse_allocator.h" | |||
| #include "pre_activate/mem_reuse/mem_reuse.h" | |||
| #include "pre_activate/mem_reuse/mem_reuse_checker.h" | |||
| #ifdef ENABLE_D | |||
| #include "device/ascend/ascend_stream_assign.h" | |||
| #endif | |||
| namespace mindspore { | |||
| namespace memreuse { | |||
| @@ -34,6 +37,9 @@ void BestFitMemReuse::InitMemReuseInfo(const MemReuseUtil *mem_reuse_util_ptr) { | |||
| wk->size_ = AlignMemorySize(wk->size_); | |||
| wk->ref_count_ = 1; | |||
| } | |||
| #ifdef ENABLE_D | |||
| stream_groups_ = device::ascend::AscendStreamAssign::GetInstance().get_stream_group(); | |||
| #endif | |||
| } | |||
| void BestFitMemReuse::InitKernelDependence() { | |||
| @@ -63,21 +69,58 @@ void BestFitMemReuse::InitKernelDependence() { | |||
| } | |||
| } | |||
| bool BestFitMemReuse::IsUsable(const KernelDefPtr &kernel_curr, const KernelDefPtr &kernel_prev) { | |||
| bool BestFitMemReuse::IsUsable(const KernelDefPtr &kernel_curr, const MembufPtr &mem_buf) { | |||
| // determine whether the kernel_curr can reuse kernel_prev's output tensor membuf | |||
| MS_EXCEPTION_IF_NULL(kernel_curr); | |||
| MS_EXCEPTION_IF_NULL(mem_buf); | |||
| auto kernel_prev = mem_buf->used_kernel_; | |||
| MS_EXCEPTION_IF_NULL(kernel_prev); | |||
| auto curr_stream_id = kernel_curr->stream_id(); | |||
| auto prev_stream_id = kernel_prev->stream_id(); | |||
| if (curr_stream_id == prev_stream_id) { | |||
| mem_buf->type_ = IN_STREAM_REUSE; | |||
| return true; | |||
| } | |||
| bool reuse_between_streams = true; | |||
| for (auto &stream_group : stream_groups_) { | |||
| size_t cur_index = UINT32_MAX; | |||
| size_t prev_index = UINT32_MAX; | |||
| for (size_t index = 0; index < stream_group.size(); index++) { | |||
| if (curr_stream_id == stream_group[index]) { | |||
| cur_index = index; | |||
| continue; | |||
| } | |||
| if (prev_stream_id == stream_group[index]) { | |||
| prev_index = index; | |||
| continue; | |||
| } | |||
| } | |||
| if ((prev_index != UINT32_MAX) && (cur_index == UINT32_MAX || (prev_index > cur_index))) { | |||
| // previous stream and current stream are not in the same group can't be reused | |||
| // previous stream is behind current stream can't be reused | |||
| reuse_between_streams = false; | |||
| break; | |||
| } | |||
| } | |||
| if (reuse_between_streams) { | |||
| mem_buf->type_ = BETWEEN_STREAMS_REUSE; | |||
| return true; | |||
| } | |||
| auto iter = kernel_front_map_.find(kernel_curr); | |||
| if (iter == kernel_front_map_.end()) { | |||
| MS_LOG(EXCEPTION) << kernel_curr->scope_full_name() << " is not init."; | |||
| } | |||
| auto kernel_curr_front = iter->second; | |||
| return kernel_curr_front.count(kernel_prev); | |||
| auto depend_count = kernel_curr_front.count(kernel_prev); | |||
| if (depend_count) { | |||
| mem_buf->type_ = KERNEL_DEPENDENCE_REUSE; | |||
| return true; | |||
| } | |||
| return false; | |||
| } | |||
| void BestFitMemReuse::AssignNodeOutputOffset() { | |||
| @@ -135,7 +178,7 @@ std::map<size_t, size_t> BestFitMemReuse::GetReusableMembufMap(size_t tensor_siz | |||
| auto membuf = membuf_ptr_list_[i]; | |||
| auto index = i; | |||
| bool is_membuf_ok = membuf->status_ == kUnused && membuf->size_ >= tensor_size; | |||
| if (is_membuf_ok && IsUsable(current_kernel_, membuf->used_kernel_)) { | |||
| if (is_membuf_ok && IsUsable(current_kernel_, membuf)) { | |||
| (void)size_map.insert(std::make_pair(membuf->size_, index)); | |||
| break; | |||
| } | |||
| @@ -163,8 +206,8 @@ void BestFitMemReuse::SplitMembuf(const KernelRefCount *tensor_desc, size_t memb | |||
| auto bias = membuf->size_ - tensor_desc->size_; | |||
| membuf->size_ = tensor_desc->size_; | |||
| // to check if spilt membuf can be merge | |||
| auto new_membuf = | |||
| std::make_shared<Membuf>(kUnused, bias, membuf->offset_ + membuf->size_, kInvalidIndex, current_kernel_); | |||
| auto new_membuf = std::make_shared<Membuf>(kUnused, bias, membuf->offset_ + membuf->size_, kInvalidIndex, | |||
| membuf->type_, current_kernel_); | |||
| (void)membuf_ptr_list_.insert(membuf_ptr_list_.begin() + SizeToInt(membuf_index + 1), new_membuf); | |||
| } | |||
| @@ -176,7 +219,7 @@ void BestFitMemReuse::AddNewMembufPtr(KernelRefCount *tensor_desc, int flag) { | |||
| } | |||
| auto membuf_size = tensor_desc->size_; | |||
| auto real_index = GetRealIndex(IntToSize(tensor_desc->index_), flag); | |||
| auto membuf = std::make_shared<Membuf>(kReused, membuf_size, membuf_offset, real_index, current_kernel_); | |||
| auto membuf = std::make_shared<Membuf>(kReused, membuf_size, membuf_offset, real_index, NEW, current_kernel_); | |||
| membuf_ptr_list_.push_back(membuf); | |||
| tensor_desc->offset_ = membuf_offset; | |||
| } | |||
| @@ -242,7 +285,7 @@ void BestFitMemReuse::ReleaseMembuf(size_t tensor_index, int flag) { | |||
| auto membuf_next = (*next_iter); | |||
| MS_EXCEPTION_IF_NULL(membuf_next); | |||
| if (membuf_next->status_ == kUnused) { | |||
| bool is_merge = IsUsable(current_kernel_, membuf_next->used_kernel_); | |||
| bool is_merge = IsUsable(current_kernel_, membuf_next); | |||
| if (is_merge) { | |||
| membuf->size_ += membuf_next->size_; | |||
| (void)membuf_ptr_list_.erase(next_iter); | |||
| @@ -254,7 +297,7 @@ void BestFitMemReuse::ReleaseMembuf(size_t tensor_index, int flag) { | |||
| auto membuf_prev = (*prev_iter); | |||
| MS_EXCEPTION_IF_NULL(membuf_prev); | |||
| if (membuf_prev->status_ == kUnused) { | |||
| bool is_merge = IsUsable(current_kernel_, membuf_prev->used_kernel_); | |||
| bool is_merge = IsUsable(current_kernel_, membuf_prev); | |||
| if (is_merge) { | |||
| membuf->size_ += membuf_prev->size_; | |||
| membuf->offset_ = membuf_prev->offset_; | |||
| @@ -40,11 +40,12 @@ static constexpr int kDynamicMem = -1; | |||
| static constexpr int kWorkspaceMem = 1; | |||
| static constexpr size_t kTotalSize = 0; | |||
| enum Status { kUnused, kReused }; | |||
| enum MEMTYPE { NEW, IN_STREAM_REUSE, BETWEEN_STREAMS_REUSE, KERNEL_DEPENDENCE_REUSE }; | |||
| class Membuf { | |||
| public: | |||
| Membuf() = default; | |||
| Membuf(Status status, size_t size, size_t offset, int index, const KernelDefPtr &used_kernel) | |||
| : status_(status), size_(size), offset_(offset), index_(index), used_kernel_(used_kernel) {} | |||
| Membuf(Status status, size_t size, size_t offset, int index, MEMTYPE type, const KernelDefPtr &used_kernel) | |||
| : status_(status), size_(size), offset_(offset), index_(index), type_(type), used_kernel_(used_kernel) {} | |||
| ~Membuf() = default; | |||
| // Memory block status flags | |||
| Status status_ = kUnused; | |||
| @@ -52,6 +53,7 @@ class Membuf { | |||
| size_t offset_{0}; | |||
| // Store the tensor index stored in this memory block at a certain moment | |||
| int index_{0}; | |||
| MEMTYPE type_{NEW}; | |||
| KernelDefPtr used_kernel_; | |||
| }; | |||
| using MembufPtr = std::shared_ptr<Membuf>; | |||
| @@ -122,10 +124,10 @@ class BestFitMemReuse { | |||
| /** | |||
| * determine if the kernel_curr can reuse the output tensor add of kernel_prev | |||
| * @param kernel_curr, current kernel | |||
| * @param kernel_prev, the membuf used by this kernel | |||
| * @param mem_buf, the membuf | |||
| * @return bool | |||
| */ | |||
| bool IsUsable(const KernelDefPtr &kernel_curr, const KernelDefPtr &kernel_prev); | |||
| bool IsUsable(const KernelDefPtr &kernel_curr, const MembufPtr &mem_buf); | |||
| /** | |||
| * init the dependence of all kernels in the graph | |||
| */ | |||
| @@ -150,6 +152,7 @@ class BestFitMemReuse { | |||
| std::vector<MembufPtr> membuf_ptr_list_; | |||
| // kernel_front_map_, key: the kernel_def, value: kernels before this kernel_def | |||
| std::map<KernelDefPtr, std::set<KernelDefPtr>> kernel_front_map_; | |||
| std::vector<std::vector<uint32_t>> stream_groups_; | |||
| }; | |||
| } // namespace memreuse | |||
| } // namespace mindspore | |||
| @@ -413,7 +413,8 @@ void MemReuseChecker::CheckNormalIR(const session::KernelGraph *graph) { | |||
| void MemReuseChecker::SetMembuInfos(const KernelDef *op_def, const std::vector<MembufPtr> &membuf_ptr_list) { | |||
| std::vector<MembufPtr> curr_mem_infos; | |||
| for (const auto &mem : membuf_ptr_list) { | |||
| auto mem_checker = std::make_shared<Membuf>(mem->status_, mem->size_, mem->offset_, mem->index_, mem->used_kernel_); | |||
| auto mem_checker = | |||
| std::make_shared<Membuf>(mem->status_, mem->size_, mem->offset_, mem->index_, mem->type_, mem->used_kernel_); | |||
| curr_mem_infos.push_back(mem_checker); | |||
| } | |||
| membuf_all_infos_.push_back(curr_mem_infos); | |||
| @@ -427,7 +428,8 @@ void MemReuseChecker::SetAddNewMembuInfos(const KernelDef *op_def, const std::ve | |||
| std::vector<MembufPtr> add_new_curr_mem; | |||
| for (const auto &mem : membuf_ptr_list) { | |||
| auto mem_checker = std::make_shared<Membuf>(mem->status_, mem->size_, mem->offset_, mem->index_, mem->used_kernel_); | |||
| auto mem_checker = | |||
| std::make_shared<Membuf>(mem->status_, mem->size_, mem->offset_, mem->index_, mem->type_, mem->used_kernel_); | |||
| add_new_curr_mem.push_back(mem_checker); | |||
| } | |||
| add_new_mem_infos_.push_back(add_new_curr_mem); | |||
| @@ -451,6 +453,7 @@ void MemReuseChecker::ExportEachMembufInfo(std::ofstream &ofs) { | |||
| << "mem_size\t" | |||
| << "mem_head\t" | |||
| << "mem_tail\t" | |||
| << "mem_type\t" | |||
| << "used_kernel\n"; | |||
| size_t curr_used = 0; | |||
| size_t curr_allocated = 0; | |||
| @@ -461,8 +464,8 @@ void MemReuseChecker::ExportEachMembufInfo(std::ofstream &ofs) { | |||
| << "streamID[@" << membuf->used_kernel_->stream_id() << "]" | |||
| << "\t" | |||
| << "#" << static_cast<int>(membuf->status_) << "\t%" << membuf->index_ << "T" | |||
| << "\t" << membuf->size_ << "\t" << membuf->offset_ << "\t" << membuf->offset_ + membuf->size_ << "\t" | |||
| << GetSplitName(used_kernel) << "\n"; | |||
| << "\t" << membuf->size_ << "\t" << membuf->offset_ << "\t\t" << membuf->offset_ + membuf->size_ << "\t" | |||
| << "\t" << static_cast<int>(membuf->type_) << "\t" << GetSplitName(used_kernel) << "\n"; | |||
| if (membuf->status_ == kReused) { | |||
| curr_used += membuf->size_; | |||
| } | |||