Merge pull request !933 from limingqi107/mastertags/v0.3.0-alpha
| @@ -184,6 +184,10 @@ void GPUKernelRuntime::InitKernelOutputAddress(const session::KernelGraph *graph | |||||
| bool GPUKernelRuntime::LaunchKernelDynamic(const session::KernelGraph *graph) { | bool GPUKernelRuntime::LaunchKernelDynamic(const session::KernelGraph *graph) { | ||||
| MS_EXCEPTION_IF_NULL(graph); | MS_EXCEPTION_IF_NULL(graph); | ||||
| auto graph_id = graph->graph_id(); | auto graph_id = graph->graph_id(); | ||||
| auto mem_reuse_util_ptr = mem_reuse_util_map_[graph_id]; | |||||
| MS_EXCEPTION_IF_NULL(mem_reuse_util_ptr); | |||||
| // Reset the reference count. | |||||
| mem_reuse_util_ptr->ResetDynamicUsedRefCount(); | |||||
| // The inputs and outputs memory of communication kernel need be continuous, so separate processing. | // The inputs and outputs memory of communication kernel need be continuous, so separate processing. | ||||
| AllocCommunicationOpDynamicRes(graph); | AllocCommunicationOpDynamicRes(graph); | ||||
| @@ -360,16 +364,13 @@ void GPUKernelRuntime::FreeKernelDynamicRes(const mindspore::AnfNodePtr &kernel, | |||||
| if (kernel_ref_count_ptr == nullptr) { | if (kernel_ref_count_ptr == nullptr) { | ||||
| continue; | continue; | ||||
| } | } | ||||
| // Can't free the output of graph. | |||||
| if (kernel_ref_count_ptr->ref_count_dynamic_use_ == memreuse::kMaxRefCount) { | |||||
| continue; | |||||
| } | |||||
| kernel_ref_count_ptr->ref_count_dynamic_use_--; | kernel_ref_count_ptr->ref_count_dynamic_use_--; | ||||
| if (kernel_ref_count_ptr->ref_count_dynamic_use_ < 0) { | |||||
| MS_LOG(EXCEPTION) << "Check dynamic reference count failed."; | |||||
| } | |||||
| if (kernel_ref_count_ptr->ref_count_dynamic_use_ == 0) { | if (kernel_ref_count_ptr->ref_count_dynamic_use_ == 0) { | ||||
| auto device_address = AnfAlgo::GetPrevNodeMutableOutputAddr(kernel, i); | auto device_address = AnfAlgo::GetPrevNodeMutableOutputAddr(kernel, i); | ||||
| mem_manager_->FreeMemFromMemPool(device_address); | mem_manager_->FreeMemFromMemPool(device_address); | ||||
| // Reset the reference count. | |||||
| kernel_ref_count_ptr->ref_count_dynamic_use_ = kernel_ref_count_ptr->ref_count_; | |||||
| } | } | ||||
| } | } | ||||
| // Free the output of kernel, if output has no reference. | // Free the output of kernel, if output has no reference. | ||||
| @@ -288,6 +288,14 @@ void MemReuseUtil::SetGraphOutputRefCount() { | |||||
| #endif | #endif | ||||
| } | } | ||||
| void MemReuseUtil::ResetDynamicUsedRefCount() { | |||||
| for (auto iter = kernel_output_refs_.begin(); iter != kernel_output_refs_.end(); ++iter) { | |||||
| for (auto &ref_count : iter->second) { | |||||
| ref_count->ref_count_dynamic_use_ = ref_count->ref_count_; | |||||
| } | |||||
| } | |||||
| } | |||||
| void MemReuseUtil::SetAllInfo(KernelGraph *graph) { | void MemReuseUtil::SetAllInfo(KernelGraph *graph) { | ||||
| if (!InitDynamicKernelRef(graph)) { | if (!InitDynamicKernelRef(graph)) { | ||||
| MS_LOG(EXCEPTION) << "Init ReuseAssignDynamicMemory Fault"; | MS_LOG(EXCEPTION) << "Init ReuseAssignDynamicMemory Fault"; | ||||
| @@ -64,6 +64,8 @@ class MemReuseUtil { | |||||
| void SetReuseRefCount(); | void SetReuseRefCount(); | ||||
| // Set the reference count of graph output specially. | // Set the reference count of graph output specially. | ||||
| void SetGraphOutputRefCount(); | void SetGraphOutputRefCount(); | ||||
| // Reset the dynamic used reference count by ref_count_. | |||||
| void ResetDynamicUsedRefCount(); | |||||
| KernelRefCountPtr GetRef(const AnfNodePtr &node, int output_idx); | KernelRefCountPtr GetRef(const AnfNodePtr &node, int output_idx); | ||||
| KernelRefCountPtr GetKernelInputRef(const CNodePtr &kernel, size_t input_idx); | KernelRefCountPtr GetKernelInputRef(const CNodePtr &kernel, size_t input_idx); | ||||
| @@ -161,7 +161,8 @@ void MemReuseChecker::CheckMemReuseIR(const KernelRefCountPtrList &total_refs_li | |||||
| total_ori_value_size_ = CalculOriValue(graph); | total_ori_value_size_ = CalculOriValue(graph); | ||||
| total_ori_dy_size_ = CalculOriDy(graph); | total_ori_dy_size_ = CalculOriDy(graph); | ||||
| total_ori_wkspace_size_ = CalculOriWk(graph); | total_ori_wkspace_size_ = CalculOriWk(graph); | ||||
| std::string filename = "./memreuse.ir"; | |||||
| std::string graph_id = std::to_string(graph->graph_id()); | |||||
| std::string filename = "./memreuse_" + graph_id + ".ir"; | |||||
| std::ofstream ofs(filename); | std::ofstream ofs(filename); | ||||
| if (!ofs.is_open()) { | if (!ofs.is_open()) { | ||||
| MS_LOG(ERROR) << "Open file [" << filename << "] failed!"; | MS_LOG(ERROR) << "Open file [" << filename << "] failed!"; | ||||