From: @liangzelang Reviewed-by: @zhoufeng54,@kisnwang Signed-off-by: @kisnwangpull/15356/MERGE
| @@ -400,11 +400,17 @@ bool Somas::InitSomasTensors(const session::KernelGraph *graph) { | |||||
| void Somas::InitSomasStreamAndNode(const session::KernelGraph *graph) { | void Somas::InitSomasStreamAndNode(const session::KernelGraph *graph) { | ||||
| MS_EXCEPTION_IF_NULL(graph); | MS_EXCEPTION_IF_NULL(graph); | ||||
| std::vector<CNodePtr> kernel_cnodes; | |||||
| streams_list_ = {}; | streams_list_ = {}; | ||||
| nodes_list_ = {}; | nodes_list_ = {}; | ||||
| size_t node_index = 0; | size_t node_index = 0; | ||||
| auto kernel_cnodes = graph->execution_order(); | |||||
| for (const auto &kernel : kernel_cnodes) { | |||||
| if (graph->subgraph_multi_call()) { | |||||
| kernel_cnodes = graph->mem_reuse_exec_order(); | |||||
| } else { | |||||
| kernel_cnodes = graph->execution_order(); | |||||
| } | |||||
| for (size_t i = 0; i < kernel_cnodes.size(); i++) { | |||||
| auto kernel = kernel_cnodes[i]; | |||||
| SomasStreamPtr stream; | SomasStreamPtr stream; | ||||
| auto stream_id = AnfAlgo::GetStreamId(kernel); | auto stream_id = AnfAlgo::GetStreamId(kernel); | ||||
| auto it = find_if(streams_list_.begin(), streams_list_.end(), | auto it = find_if(streams_list_.begin(), streams_list_.end(), | ||||
| @@ -427,7 +433,8 @@ void Somas::InitSomasStreamAndNode(const session::KernelGraph *graph) { | |||||
| nodes_list_.push_back(node); | nodes_list_.push_back(node); | ||||
| stream->nodes_.push_back(node); | stream->nodes_.push_back(node); | ||||
| auto key = kernel.get(); | auto key = kernel.get(); | ||||
| nodes_map_[key] = node; | |||||
| auto &nodes = nodes_map_[key]; | |||||
| nodes.push_back(node); | |||||
| node_index++; | node_index++; | ||||
| } | } | ||||
| } | } | ||||
| @@ -438,7 +445,8 @@ void Somas::InitSomasOutputAndWorkspaceTensors(const session::KernelGraph *graph | |||||
| size_t tensor_index = 0; | size_t tensor_index = 0; | ||||
| auto kernel_cnodes = graph->execution_order(); | auto kernel_cnodes = graph->execution_order(); | ||||
| for (const auto &kernel : kernel_cnodes) { | for (const auto &kernel : kernel_cnodes) { | ||||
| auto node = nodes_map_[kernel.get()]; | |||||
| auto nodes = nodes_map_[kernel.get()]; | |||||
| auto node = nodes[0]; | |||||
| MS_EXCEPTION_IF_NULL(node); | MS_EXCEPTION_IF_NULL(node); | ||||
| auto stream = node->GetStream(); | auto stream = node->GetStream(); | ||||
| MS_EXCEPTION_IF_NULL(stream); | MS_EXCEPTION_IF_NULL(stream); | ||||
| @@ -454,7 +462,7 @@ void Somas::InitSomasOutputAndWorkspaceTensors(const session::KernelGraph *graph | |||||
| // Set all output tensor lifelong to true. | // Set all output tensor lifelong to true. | ||||
| auto tensor = std::make_shared<SomasTensor>(output_tensor_index, node, stream, size, kLifeLongNone); | auto tensor = std::make_shared<SomasTensor>(output_tensor_index, node, stream, size, kLifeLongNone); | ||||
| tensor->lifetime_.start_ = node->GetId(); | tensor->lifetime_.start_ = node->GetId(); | ||||
| tensor->lifetime_.end_ = node->GetId(); | |||||
| tensor->lifetime_.end_ = (nodes.size() > 1) ? nodes.back()->GetId() : node->GetId(); | |||||
| tensor->type_ = kOutputOnly; | tensor->type_ = kOutputOnly; | ||||
| if (AnfAlgo::OutputAddrExist(kernel, index)) { | if (AnfAlgo::OutputAddrExist(kernel, index)) { | ||||
| tensor->aligned_size_ = 0; | tensor->aligned_size_ = 0; | ||||
| @@ -463,8 +471,10 @@ void Somas::InitSomasOutputAndWorkspaceTensors(const session::KernelGraph *graph | |||||
| tensors_list_.push_back(tensor); | tensors_list_.push_back(tensor); | ||||
| tensors_map_[output_tensor_index] = tensor; | tensors_map_[output_tensor_index] = tensor; | ||||
| stream->tensors_.push_back(tensor); | stream->tensors_.push_back(tensor); | ||||
| node->tensors_.insert(tensor); | |||||
| node->output_tensors_.push_back(tensor); | |||||
| std::for_each(nodes.begin(), nodes.end(), [tensor](auto &node) { | |||||
| node->tensors_.insert(tensor); | |||||
| node->output_tensors_.push_back(tensor); | |||||
| }); | |||||
| index++; | index++; | ||||
| } | } | ||||
| @@ -477,15 +487,17 @@ void Somas::InitSomasOutputAndWorkspaceTensors(const session::KernelGraph *graph | |||||
| SomasTensorPtr tensor = std::make_shared<SomasTensor>(workspace_tensor_index, node, stream, size, kLifeLongNone); | SomasTensorPtr tensor = std::make_shared<SomasTensor>(workspace_tensor_index, node, stream, size, kLifeLongNone); | ||||
| tensor->type_ = kWorkspace; | tensor->type_ = kWorkspace; | ||||
| tensor->lifetime_.start_ = node->GetId(); | tensor->lifetime_.start_ = node->GetId(); | ||||
| tensor->lifetime_.end_ = node->GetId(); | |||||
| tensor->lifetime_.end_ = (nodes.size() > 1) ? nodes.back()->GetId() : node->GetId(); | |||||
| if (AnfAlgo::WorkspaceAddrExist(kernel, index)) { | if (AnfAlgo::WorkspaceAddrExist(kernel, index)) { | ||||
| tensor->aligned_size_ = 0; | tensor->aligned_size_ = 0; | ||||
| } | } | ||||
| tensors_list_.push_back(tensor); | tensors_list_.push_back(tensor); | ||||
| tensors_map_[workspace_tensor_index] = tensor; | tensors_map_[workspace_tensor_index] = tensor; | ||||
| stream->tensors_.push_back(tensor); | stream->tensors_.push_back(tensor); | ||||
| node->tensors_.insert(tensor); | |||||
| node->workspace_tensors_.push_back(tensor); | |||||
| std::for_each(nodes.begin(), nodes.end(), [tensor](auto &node) { | |||||
| node->tensors_.insert(tensor); | |||||
| node->workspace_tensors_.push_back(tensor); | |||||
| }); | |||||
| index++; | index++; | ||||
| } | } | ||||
| } | } | ||||
| @@ -505,7 +517,8 @@ void Somas::InitSomasInputTensors(const session::KernelGraph *graph) { | |||||
| } | } | ||||
| } | } | ||||
| void Somas::InitCommonNodeInputs(bool is_all_nop_node, const CNodePtr &kernel) { | void Somas::InitCommonNodeInputs(bool is_all_nop_node, const CNodePtr &kernel) { | ||||
| auto node = nodes_map_[kernel.get()]; | |||||
| auto nodes = nodes_map_[kernel.get()]; | |||||
| auto node = nodes[0]; | |||||
| MS_EXCEPTION_IF_NULL(node); | MS_EXCEPTION_IF_NULL(node); | ||||
| auto stream = node->GetStream(); | auto stream = node->GetStream(); | ||||
| MS_EXCEPTION_IF_NULL(stream); | MS_EXCEPTION_IF_NULL(stream); | ||||
| @@ -543,7 +556,7 @@ void Somas::InitCommonNodeInputs(bool is_all_nop_node, const CNodePtr &kernel) { | |||||
| MS_LOG(EXCEPTION) << "Kernel[" << kernel->fullname_with_scope() << "]'s input " << i << " [" | MS_LOG(EXCEPTION) << "Kernel[" << kernel->fullname_with_scope() << "]'s input " << i << " [" | ||||
| << prenode_index.first->fullname_with_scope() << "] is not init."; | << prenode_index.first->fullname_with_scope() << "] is not init."; | ||||
| } | } | ||||
| auto pre_somas_node = iter->second; | |||||
| auto pre_somas_node = iter->second.at(0); | |||||
| if (prenode_index.second > pre_somas_node->output_tensors_.size()) { | if (prenode_index.second > pre_somas_node->output_tensors_.size()) { | ||||
| MS_LOG(EXCEPTION) << "Output index " << prenode_index.second << " exceed input node [" | MS_LOG(EXCEPTION) << "Output index " << prenode_index.second << " exceed input node [" | ||||
| << prenode_index.first->fullname_with_scope() << "]'s outputs size " | << prenode_index.first->fullname_with_scope() << "]'s outputs size " | ||||
| @@ -551,15 +564,18 @@ void Somas::InitCommonNodeInputs(bool is_all_nop_node, const CNodePtr &kernel) { | |||||
| } | } | ||||
| auto input_somas_tensor = pre_somas_node->output_tensors_[prenode_index.second]; | auto input_somas_tensor = pre_somas_node->output_tensors_[prenode_index.second]; | ||||
| MS_EXCEPTION_IF_NULL(input_somas_tensor); | MS_EXCEPTION_IF_NULL(input_somas_tensor); | ||||
| node->input_tensors_.push_back(input_somas_tensor); | |||||
| std::for_each(nodes.begin(), nodes.end(), | |||||
| [input_somas_tensor](auto &node) { node->input_tensors_.push_back(input_somas_tensor); }); | |||||
| real_input_index++; | real_input_index++; | ||||
| if (input_somas_tensor->type_ == kOutputOnly) { | if (input_somas_tensor->type_ == kOutputOnly) { | ||||
| input_somas_tensor->type_ = kCommon; | input_somas_tensor->type_ = kCommon; | ||||
| } | } | ||||
| input_somas_tensor->destinations_.insert(node); | |||||
| input_somas_tensor->destinationStreams_.insert(stream); | input_somas_tensor->destinationStreams_.insert(stream); | ||||
| if (input_somas_tensor->lifetime_.end_ < node->GetId()) { | |||||
| input_somas_tensor->lifetime_.end_ = node->GetId(); | |||||
| for (auto &repeat_node : nodes) { | |||||
| input_somas_tensor->destinations_.insert(repeat_node); | |||||
| if (input_somas_tensor->lifetime_.end_ < repeat_node->GetId()) { | |||||
| input_somas_tensor->lifetime_.end_ = repeat_node->GetId(); | |||||
| } | |||||
| } | } | ||||
| if (node != pre_somas_node) { | if (node != pre_somas_node) { | ||||
| @@ -574,7 +590,7 @@ void Somas::InitCommonNodeInputs(bool is_all_nop_node, const CNodePtr &kernel) { | |||||
| } | } | ||||
| void Somas::InitAtomicCleanInputs(bool enable_fusion_clear, const CNodePtr &kernel) { | void Somas::InitAtomicCleanInputs(bool enable_fusion_clear, const CNodePtr &kernel) { | ||||
| auto node = nodes_map_[kernel.get()]; | |||||
| auto node = nodes_map_[kernel.get()].at(0); | |||||
| MS_EXCEPTION_IF_NULL(node); | MS_EXCEPTION_IF_NULL(node); | ||||
| auto stream = node->GetStream(); | auto stream = node->GetStream(); | ||||
| MS_EXCEPTION_IF_NULL(stream); | MS_EXCEPTION_IF_NULL(stream); | ||||
| @@ -588,7 +604,7 @@ void Somas::InitAtomicCleanInputs(bool enable_fusion_clear, const CNodePtr &kern | |||||
| MS_LOG(EXCEPTION) << "Kernel[" << kernel->fullname_with_scope() << "]'s input [" | MS_LOG(EXCEPTION) << "Kernel[" << kernel->fullname_with_scope() << "]'s input [" | ||||
| << pre_node->fullname_with_scope() << "] is not init."; | << pre_node->fullname_with_scope() << "] is not init."; | ||||
| } | } | ||||
| auto pre_somas_node = iter->second; | |||||
| auto pre_somas_node = iter->second.at(0); | |||||
| // set clean output tensors | // set clean output tensors | ||||
| if (AnfAlgo::HasNodeAttr(kAttrAtomicOutputIndexs, pre_node)) { | if (AnfAlgo::HasNodeAttr(kAttrAtomicOutputIndexs, pre_node)) { | ||||
| auto clean_output_indexs = AnfAlgo::GetNodeAttr<std::vector<size_t>>(pre_node, kAttrAtomicOutputIndexs); | auto clean_output_indexs = AnfAlgo::GetNodeAttr<std::vector<size_t>>(pre_node, kAttrAtomicOutputIndexs); | ||||
| @@ -698,7 +714,8 @@ void Somas::GetNextOutputProcess(const session::KernelGraph *graph) { | |||||
| } | } | ||||
| auto iter = nodes_map_.find(kernel.get()); | auto iter = nodes_map_.find(kernel.get()); | ||||
| if (iter != nodes_map_.end()) { | if (iter != nodes_map_.end()) { | ||||
| auto getnext_output_tensors = iter->second->output_tensors_; | |||||
| auto &node = iter->second.at(0); | |||||
| auto getnext_output_tensors = node->output_tensors_; | |||||
| for (auto &tensor : getnext_output_tensors) { | for (auto &tensor : getnext_output_tensors) { | ||||
| total_size += tensor->GetAlignedSize(); | total_size += tensor->GetAlignedSize(); | ||||
| tensor->lifelong_value_ = kLifeLongGraphAll; | tensor->lifelong_value_ = kLifeLongGraphAll; | ||||
| @@ -720,7 +737,8 @@ void Somas::IndependentNodeOutputProcess(const session::KernelGraph *graph) { | |||||
| } | } | ||||
| auto iter = nodes_map_.find(kernel.get()); | auto iter = nodes_map_.find(kernel.get()); | ||||
| if (iter != nodes_map_.end()) { | if (iter != nodes_map_.end()) { | ||||
| auto semi_reuse_output_tensors = iter->second->output_tensors_; | |||||
| auto &node = iter->second.at(0); | |||||
| auto semi_reuse_output_tensors = node->output_tensors_; | |||||
| for (auto &tensor : semi_reuse_output_tensors) { | for (auto &tensor : semi_reuse_output_tensors) { | ||||
| total_size += tensor->GetAlignedSize(); | total_size += tensor->GetAlignedSize(); | ||||
| tensor->lifelong_value_ = kLifeLongGraphAll; | tensor->lifelong_value_ = kLifeLongGraphAll; | ||||
| @@ -749,9 +767,9 @@ void Somas::SummaryInputProcess(const session::KernelGraph *graph) { | |||||
| size_t index = IntToSize(node_item.second.second); | size_t index = IntToSize(node_item.second.second); | ||||
| auto iter = nodes_map_.find(node.get()); | auto iter = nodes_map_.find(node.get()); | ||||
| if (iter != nodes_map_.end()) { | if (iter != nodes_map_.end()) { | ||||
| auto input_node = iter->second; | |||||
| auto input_node = iter->second.at(0); | |||||
| if (index < input_node->output_tensors_.size()) { | if (index < input_node->output_tensors_.size()) { | ||||
| auto tensor = iter->second->output_tensors_[index]; | |||||
| auto tensor = input_node->output_tensors_[index]; | |||||
| tensor->lifelong_value_ = kLifeLongGraphAll; | tensor->lifelong_value_ = kLifeLongGraphAll; | ||||
| tensor->type_ = kSummaryInput; | tensor->type_ = kSummaryInput; | ||||
| total_summary_size += tensor->GetAlignedSize(); | total_summary_size += tensor->GetAlignedSize(); | ||||
| @@ -789,7 +807,8 @@ void Somas::RefNodeProcess(const session::KernelGraph *graph) { | |||||
| if (graph->IsInRefOutputMap(out_pair)) { | if (graph->IsInRefOutputMap(out_pair)) { | ||||
| auto origin_pair = graph->GetRefCorrespondOutput(out_pair); | auto origin_pair = graph->GetRefCorrespondOutput(out_pair); | ||||
| MS_EXCEPTION_IF_NULL(origin_pair.first); | MS_EXCEPTION_IF_NULL(origin_pair.first); | ||||
| auto output_tensor = nodes_map_[kernel.get()]->output_tensors_[out_index]; | |||||
| auto &node = nodes_map_[kernel.get()].at(0); | |||||
| auto output_tensor = node->output_tensors_[out_index]; | |||||
| MS_EXCEPTION_IF_NULL(output_tensor); | MS_EXCEPTION_IF_NULL(output_tensor); | ||||
| output_tensor->type_ = kRefNodeOutput; | output_tensor->type_ = kRefNodeOutput; | ||||
| total_output_size += size; | total_output_size += size; | ||||
| @@ -797,7 +816,8 @@ void Somas::RefNodeProcess(const session::KernelGraph *graph) { | |||||
| if (AnfAlgo::IsRealCNodeKernel(origin_pair.first)) { | if (AnfAlgo::IsRealCNodeKernel(origin_pair.first)) { | ||||
| auto ori_node = origin_pair.first->cast<CNodePtr>(); | auto ori_node = origin_pair.first->cast<CNodePtr>(); | ||||
| auto ori_index = origin_pair.second; | auto ori_index = origin_pair.second; | ||||
| auto input_tensor = nodes_map_[ori_node.get()]->output_tensors_[ori_index]; | |||||
| auto &repeat_node = nodes_map_[ori_node.get()].at(0); | |||||
| auto input_tensor = repeat_node->output_tensors_[ori_index]; | |||||
| MS_EXCEPTION_IF_NULL(input_tensor); | MS_EXCEPTION_IF_NULL(input_tensor); | ||||
| input_tensor->type_ = kRefNodeInput; | input_tensor->type_ = kRefNodeInput; | ||||
| total_input_size += input_tensor->aligned_size_; | total_input_size += input_tensor->aligned_size_; | ||||
| @@ -821,7 +841,7 @@ void Somas::NonTaskSplitProcess(const session::KernelGraph *graph) { | |||||
| auto op_name = AnfAlgo::GetCNodeName(kernel); | auto op_name = AnfAlgo::GetCNodeName(kernel); | ||||
| if ((op_name == kSplitOpName || op_name == kSplitVOpName) && AnfAlgo::HasNodeAttr(kAttrNonTask, kernel)) { | if ((op_name == kSplitOpName || op_name == kSplitVOpName) && AnfAlgo::HasNodeAttr(kAttrNonTask, kernel)) { | ||||
| std::vector<size_t> refnode_input_output; | std::vector<size_t> refnode_input_output; | ||||
| auto node = nodes_map_[kernel.get()]; | |||||
| auto node = nodes_map_[kernel.get()].at(0); | |||||
| if (node->input_tensors_.size() == 0) { | if (node->input_tensors_.size() == 0) { | ||||
| MS_LOG(EXCEPTION) << op_name << " has no input tensor, can not do split non_task process."; | MS_LOG(EXCEPTION) << op_name << " has no input tensor, can not do split non_task process."; | ||||
| } | } | ||||
| @@ -852,7 +872,7 @@ void Somas::UnReuseNodeProcess(const session::KernelGraph *graph) { | |||||
| if (iter != full_name_list.end()) { | if (iter != full_name_list.end()) { | ||||
| MS_LOG(INFO) << "Set UnReuse Node in somas, Node:" << full_name; | MS_LOG(INFO) << "Set UnReuse Node in somas, Node:" << full_name; | ||||
| auto key = kernel.get(); | auto key = kernel.get(); | ||||
| auto somas_node = nodes_map_[key]; | |||||
| auto somas_node = nodes_map_[key].at(0); | |||||
| // input | // input | ||||
| auto inputs = somas_node->input_tensors_; | auto inputs = somas_node->input_tensors_; | ||||
| for (auto &input : inputs) { | for (auto &input : inputs) { | ||||
| @@ -1749,11 +1769,12 @@ uint8_t *Somas::GetNodeOutputPtr(const AnfNodePtr &node, size_t index) const { | |||||
| auto iter = nodes_map_.find(key); | auto iter = nodes_map_.find(key); | ||||
| uint8_t *ptr = nullptr; | uint8_t *ptr = nullptr; | ||||
| if (iter != nodes_map_.end()) { | if (iter != nodes_map_.end()) { | ||||
| if (index >= iter->second->output_tensors_.size()) { | |||||
| auto &node = iter->second.at(0); | |||||
| if (index >= node->output_tensors_.size()) { | |||||
| MS_LOG(EXCEPTION) << "index:[" << index << "] is larger than it's workspace size:[" | MS_LOG(EXCEPTION) << "index:[" << index << "] is larger than it's workspace size:[" | ||||
| << iter->second->output_tensors_.size() << "]"; | |||||
| << node->output_tensors_.size() << "]"; | |||||
| } | } | ||||
| auto output_tensor = iter->second->output_tensors_[index]; | |||||
| auto output_tensor = node->output_tensors_[index]; | |||||
| ptr = mem_base_addr_ + output_tensor->offset_; | ptr = mem_base_addr_ + output_tensor->offset_; | ||||
| } else { | } else { | ||||
| MS_LOG(EXCEPTION) << "node [" << AnfAlgo::GetCNodeName(node) << "] don't exist in nodes_map"; | MS_LOG(EXCEPTION) << "node [" << AnfAlgo::GetCNodeName(node) << "] don't exist in nodes_map"; | ||||
| @@ -1766,11 +1787,12 @@ uint8_t *Somas::GetNodeWorkSpacePtr(const AnfNodePtr &node, size_t index) const | |||||
| auto iter = nodes_map_.find(key); | auto iter = nodes_map_.find(key); | ||||
| uint8_t *ptr = nullptr; | uint8_t *ptr = nullptr; | ||||
| if (iter != nodes_map_.end()) { | if (iter != nodes_map_.end()) { | ||||
| if (index >= iter->second->workspace_tensors_.size()) { | |||||
| auto &node = iter->second.at(0); | |||||
| if (index >= node->workspace_tensors_.size()) { | |||||
| MS_LOG(EXCEPTION) << "index:[" << index << "] is larger than it's workspace size:[" | MS_LOG(EXCEPTION) << "index:[" << index << "] is larger than it's workspace size:[" | ||||
| << iter->second->workspace_tensors_.size() << "]"; | |||||
| << node->workspace_tensors_.size() << "]"; | |||||
| } | } | ||||
| auto workspace_tensor = iter->second->workspace_tensors_[index]; | |||||
| auto workspace_tensor = node->workspace_tensors_[index]; | |||||
| ptr = mem_base_addr_ + workspace_tensor->offset_; | ptr = mem_base_addr_ + workspace_tensor->offset_; | ||||
| } | } | ||||
| return ptr; | return ptr; | ||||
| @@ -63,7 +63,7 @@ class Somas { | |||||
| std::string hash_id_; | std::string hash_id_; | ||||
| // Maps | // Maps | ||||
| std::unordered_map<size_t, SomasTensorPtr> tensors_map_; | std::unordered_map<size_t, SomasTensorPtr> tensors_map_; | ||||
| std::map<void *, SomasNodePtr> nodes_map_; | |||||
| std::map<void *, std::vector<SomasNodePtr>> nodes_map_; | |||||
| std::map<void *, vector<SomasParameterPtr>> parameters_map_; | std::map<void *, vector<SomasParameterPtr>> parameters_map_; | ||||
| // Vectors | // Vectors | ||||
| @@ -296,6 +296,15 @@ class AscendAutoMonadContext : public BaseContext { | |||||
| // Set flag to indicate whether has already created an stack or not. | // Set flag to indicate whether has already created an stack or not. | ||||
| void SetInitedStack(bool flag) { inited_stack_ = flag; } | void SetInitedStack(bool flag) { inited_stack_ = flag; } | ||||
| // The graphs has recursion. | |||||
| bool HasRecursiveCall() const { return has_recursive_call_; } | |||||
| // The graphs has subgraph multi-call. | |||||
| bool HasSubgraphMultiCall() const { return has_subgraph_multicall_; } | |||||
| // set flag to indicate whether has recursion. | |||||
| void SetRecursiveCall(bool flag) { has_recursive_call_ = flag; } | |||||
| // set flag to indicate whether has multi-call. | |||||
| void SetSubGraphMultiCall(bool flag) { has_subgraph_multicall_ = flag; } | |||||
| // Map kernel_graph to its call info. | // Map kernel_graph to its call info. | ||||
| OrderedMap<KernelGraphPtr, CallInfo> call_info_map; | OrderedMap<KernelGraphPtr, CallInfo> call_info_map; | ||||
| @@ -311,6 +320,10 @@ class AscendAutoMonadContext : public BaseContext { | |||||
| // Create an stack for multi-call and non-tail recursion. | // Create an stack for multi-call and non-tail recursion. | ||||
| bool inited_stack_ = false; | bool inited_stack_ = false; | ||||
| // The graphs has recursion or not. | |||||
| bool has_recursive_call_ = false; | |||||
| // The graphs has subgraph multi-call or not. | |||||
| bool has_subgraph_multicall_ = false; | |||||
| }; | }; | ||||
| // | // | ||||
| @@ -643,6 +656,11 @@ class AscendAutoMonadConverter { | |||||
| } | } | ||||
| // Handle recursive call. | // Handle recursive call. | ||||
| kernel_graph_->SetExecOrderByDefault(); | kernel_graph_->SetExecOrderByDefault(); | ||||
| if (call_info_.recursive) { | |||||
| const auto &nodes = kernel_graph_->execution_order(); | |||||
| AnfAlgo::SetNodeAttr(kAttrRecursiveStart, prim::kValueOne, *nodes.begin()); | |||||
| AnfAlgo::SetNodeAttr(kAttrRecursiveEnd, prim::kValueOne, *nodes.rbegin()); | |||||
| } | |||||
| for (auto &call_site : call_info_.call_sites) { | for (auto &call_site : call_info_.call_sites) { | ||||
| if (need_stackops_ && call_site.recursive) { | if (need_stackops_ && call_site.recursive) { | ||||
| MS_LOG(INFO) << "graph:" << kernel_graph_->ToString() << ", loop call_site:" << call_site.cnode->DebugString(); | MS_LOG(INFO) << "graph:" << kernel_graph_->ToString() << ", loop call_site:" << call_site.cnode->DebugString(); | ||||
| @@ -661,6 +679,7 @@ class AscendAutoMonadConverter { | |||||
| auto stack_destroy = StackDestroy(top_graph); | auto stack_destroy = StackDestroy(top_graph); | ||||
| AnfAlgo::KeepOrder(top_graph, *exec_order.rbegin(), stack_destroy); | AnfAlgo::KeepOrder(top_graph, *exec_order.rbegin(), stack_destroy); | ||||
| top_graph->SetExecOrderByDefault(); | top_graph->SetExecOrderByDefault(); | ||||
| context_.SetRecursiveCall(true); | |||||
| context_.SetInitedStack(true); | context_.SetInitedStack(true); | ||||
| } | } | ||||
| } | } | ||||
| @@ -812,6 +831,9 @@ class AscendAutoMonadConverter { | |||||
| // Create LabelGoto or LabelSwitch node. | // Create LabelGoto or LabelSwitch node. | ||||
| auto label_goto_switch = MakeLabelGotoSwitch(cnode, graphes, labels); | auto label_goto_switch = MakeLabelGotoSwitch(cnode, graphes, labels); | ||||
| call_site->conversion_cnode = label_goto_switch; | call_site->conversion_cnode = label_goto_switch; | ||||
| if (call_site->recursive) { | |||||
| AnfAlgo::SetNodeAttr(kAttrRecursive, prim::kValueOne, label_goto_switch); | |||||
| } | |||||
| // Setup return label and output if required. | // Setup return label and output if required. | ||||
| if (call_site->return_label != kNoLabel) { | if (call_site->return_label != kNoLabel) { | ||||
| @@ -931,7 +953,11 @@ class AscendAutoMonadConverter { | |||||
| MS_EXCEPTION_IF_NULL(label_param); | MS_EXCEPTION_IF_NULL(label_param); | ||||
| auto return_switch = LabelSwitch(label_param, return_labels); | auto return_switch = LabelSwitch(label_param, return_labels); | ||||
| AnfAlgo::SetNodeAttr(kAttrReturn, prim::kValueOne, return_switch); | AnfAlgo::SetNodeAttr(kAttrReturn, prim::kValueOne, return_switch); | ||||
| if (!call_info_.recursive) { | |||||
| AnfAlgo::SetNodeAttr(kAttrMultiCallEnd, prim::kValueOne, return_switch); | |||||
| } | |||||
| kernel_graph_->set_end_goto(return_switch); | kernel_graph_->set_end_goto(return_switch); | ||||
| context_.SetSubGraphMultiCall(true); | |||||
| } | } | ||||
| // Assign graph output to the output parameter. | // Assign graph output to the output parameter. | ||||
| @@ -1650,6 +1676,8 @@ void AscendAutoMonad::Run() { | |||||
| CallInfoFinder::Run(&context); | CallInfoFinder::Run(&context); | ||||
| AscendAutoMonadConverter::Run(&context); | AscendAutoMonadConverter::Run(&context); | ||||
| kernel_graph_->set_label_num(context.CurrentLabel() + 1); | kernel_graph_->set_label_num(context.CurrentLabel() + 1); | ||||
| kernel_graph_->set_recursive_call(context.HasRecursiveCall()); | |||||
| kernel_graph_->set_subgraph_multi_call(context.HasSubgraphMultiCall()); | |||||
| MS_LOG(DEBUG) << "Ascend auto-monad finish."; | MS_LOG(DEBUG) << "Ascend auto-monad finish."; | ||||
| DumpGraphForDebug(kernel_graph_); | DumpGraphForDebug(kernel_graph_); | ||||
| } | } | ||||
| @@ -1034,9 +1034,139 @@ void AscendSession::BuildDynamicKernel(const std::shared_ptr<KernelGraph> &kerne | |||||
| MS_LOG(INFO) << "Finish!"; | MS_LOG(INFO) << "Finish!"; | ||||
| } | } | ||||
| static CNodePtr GetNextLabelSet(const std::vector<CNodePtr> &kernel_nodes, uint32_t index) { | |||||
| uint32_t node_sizes = kernel_nodes.size(); | |||||
| if (index >= node_sizes - 1) { | |||||
| MS_LOG(EXCEPTION) << "there is no node after this node:" << kernel_nodes[index]->DebugString(); | |||||
| } | |||||
| auto kernel = kernel_nodes[index + 1]; | |||||
| if (AnfAlgo::GetCNodeName(kernel) != kLabelSetOpName) { | |||||
| MS_LOG(EXCEPTION) << "the node is not labelset follow labelgoto/labelswitch, node: " | |||||
| << kernel_nodes[index]->DebugString(); | |||||
| } | |||||
| return kernel; | |||||
| } | |||||
| static std::vector<CNodePtr> HandleRecursiveCall(const std::vector<CNodePtr> &kernel_cnodes, const uint32_t &back_label, | |||||
| uint32_t *index, std::vector<CNodePtr> *back) { | |||||
| MS_EXCEPTION_IF_NULL(index); | |||||
| MS_EXCEPTION_IF_NULL(back); | |||||
| std::vector<CNodePtr> front; | |||||
| std::vector<CNodePtr> back_temp; | |||||
| bool back_flag = false; | |||||
| for (uint32_t i = *index; i < kernel_cnodes.size(); i++) { | |||||
| if (!back_flag) { | |||||
| front.emplace_back(kernel_cnodes[i]); | |||||
| } else { | |||||
| back->emplace_back(kernel_cnodes[i]); | |||||
| } | |||||
| if (AnfAlgo::HasNodeAttr(kAttrRecursiveEnd, kernel_cnodes[i])) { | |||||
| *index = i; | |||||
| back->insert(back->end(), back_temp.begin(), back_temp.end()); | |||||
| return front; | |||||
| } | |||||
| if (AnfAlgo::HasNodeAttr(kAttrRecursive, kernel_cnodes[i])) { | |||||
| back_flag = true; | |||||
| if (AnfAlgo::IsLabelIndexInNode(kernel_cnodes[i], back_label)) { | |||||
| continue; | |||||
| } else { | |||||
| auto temp = HandleRecursiveCall(kernel_cnodes, back_label, &(++i), &back_temp); | |||||
| front.insert(front.end(), temp.begin(), temp.end()); | |||||
| continue; | |||||
| } | |||||
| } | |||||
| } | |||||
| return front; | |||||
| } | |||||
| static void UnfoldRecursiveExecOrder(KernelGraph *kernel_graph) { | |||||
| MS_EXCEPTION_IF_NULL(kernel_graph); | |||||
| if (!kernel_graph->recursive_call()) { | |||||
| return; | |||||
| } | |||||
| auto kernel_cnodes = kernel_graph->mem_reuse_exec_order(); | |||||
| std::vector<CNodePtr> mem_reuse_order; | |||||
| mem_reuse_order.reserve(kernel_cnodes.size()); | |||||
| for (uint32_t i = 0; i < kernel_cnodes.size(); i++) { | |||||
| if (!AnfAlgo::HasNodeAttr(kAttrRecursiveStart, kernel_cnodes[i])) { | |||||
| mem_reuse_order.emplace_back(kernel_cnodes[i]); | |||||
| continue; | |||||
| } | |||||
| auto label_id = AnfAlgo::GetNodeAttr<uint32_t>(kernel_cnodes[i], kAttrLabelIndex); | |||||
| std::vector<CNodePtr> back; | |||||
| auto front = HandleRecursiveCall(kernel_cnodes, label_id, &i, &back); | |||||
| mem_reuse_order.insert(mem_reuse_order.end(), front.begin(), front.end()); | |||||
| mem_reuse_order.insert(mem_reuse_order.end(), back.begin(), back.end()); | |||||
| } | |||||
| kernel_graph->set_mem_reuse_exec_order(mem_reuse_order); | |||||
| } | |||||
| static void GetSubGraphExecOrder(const KernelGraph *kernel_graph, uint32_t index, const CNodePtr &back_node, | |||||
| std::vector<CNodePtr> *mem_reuse_order) { | |||||
| MS_EXCEPTION_IF_NULL(kernel_graph); | |||||
| MS_EXCEPTION_IF_NULL(mem_reuse_order); | |||||
| auto label_id = AnfAlgo::GetNodeAttr<uint32_t>(back_node, kAttrLabelIndex); | |||||
| auto kernel_cnodes = kernel_graph->execution_order(); | |||||
| for (auto i = index; i < kernel_cnodes.size(); i++) { | |||||
| mem_reuse_order->emplace_back(kernel_cnodes[i]); | |||||
| if (AnfAlgo::IsLabelIndexInNode(kernel_cnodes[i], label_id)) { | |||||
| return; | |||||
| } | |||||
| } | |||||
| } | |||||
| void InitMemReuseExecOrder(KernelGraph *kernel_graph) { | |||||
| MS_EXCEPTION_IF_NULL(kernel_graph); | |||||
| if (!kernel_graph->subgraph_multi_call()) { | |||||
| return; | |||||
| } | |||||
| std::unordered_map<uint32_t, uint32_t> label_id_index_map; | |||||
| auto kernel_cnodes = kernel_graph->execution_order(); | |||||
| std::vector<CNodePtr> mem_reuse_order; | |||||
| for (size_t i = 0; i < kernel_cnodes.size(); i++) { | |||||
| mem_reuse_order.emplace_back(kernel_cnodes[i]); | |||||
| if (AnfAlgo::CheckPrimitiveType(kernel_cnodes[i], prim::kPrimLabelSwitch) && | |||||
| !AnfAlgo::HasNodeAttr(kAttrRecursive, kernel_cnodes[i]) && | |||||
| !AnfAlgo::HasNodeAttr(kAttrReturn, kernel_cnodes[i])) { | |||||
| auto label_list = AnfAlgo::GetNodeAttr<std::vector<uint32_t>>(kernel_cnodes[i], kAttrLabelSwitchList); | |||||
| for (auto label_id : label_list) { | |||||
| if (label_id_index_map.find(label_id) == label_id_index_map.end()) { | |||||
| continue; | |||||
| } | |||||
| auto back_node = GetNextLabelSet(kernel_cnodes, i); | |||||
| GetSubGraphExecOrder(kernel_graph, label_id_index_map[label_id], back_node, &mem_reuse_order); | |||||
| } | |||||
| continue; | |||||
| } | |||||
| if (AnfAlgo::CheckPrimitiveType(kernel_cnodes[i], prim::kPrimLabelGoto) && | |||||
| !AnfAlgo::HasNodeAttr(kAttrRecursive, kernel_cnodes[i]) && | |||||
| !AnfAlgo::HasNodeAttr(kAttrReturn, kernel_cnodes[i])) { | |||||
| auto label_id = AnfAlgo::GetNodeAttr<uint32_t>(kernel_cnodes[i], kAttrLabelIndex); | |||||
| if (label_id_index_map.find(label_id) == label_id_index_map.end()) { | |||||
| continue; | |||||
| } | |||||
| auto back_node = GetNextLabelSet(kernel_cnodes, i); | |||||
| GetSubGraphExecOrder(kernel_graph, label_id_index_map[label_id], back_node, &mem_reuse_order); | |||||
| continue; | |||||
| } | |||||
| if (AnfAlgo::CheckPrimitiveType(kernel_cnodes[i], prim::kPrimLabelSet) && | |||||
| !AnfAlgo::HasNodeAttr(kAttrRecursive, kernel_cnodes[i])) { | |||||
| auto label_id = AnfAlgo::GetNodeAttr<uint32_t>(kernel_cnodes[i], kAttrLabelIndex); | |||||
| if (label_id_index_map.find(label_id) != label_id_index_map.end()) { | |||||
| MS_LOG(EXCEPTION) << "Two labelsets with same label id."; | |||||
| } | |||||
| label_id_index_map[label_id] = i; | |||||
| continue; | |||||
| } | |||||
| } | |||||
| kernel_graph->set_mem_reuse_exec_order(mem_reuse_order); | |||||
| UnfoldRecursiveExecOrder(kernel_graph); | |||||
| } | |||||
| void AscendSession::MemoryAlloc(KernelGraph *kernel_graph) const { | void AscendSession::MemoryAlloc(KernelGraph *kernel_graph) const { | ||||
| MS_LOG(INFO) << "Start!"; | MS_LOG(INFO) << "Start!"; | ||||
| MS_EXCEPTION_IF_NULL(kernel_graph); | MS_EXCEPTION_IF_NULL(kernel_graph); | ||||
| InitMemReuseExecOrder(kernel_graph); | |||||
| auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id_); | auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id_); | ||||
| MS_EXCEPTION_IF_NULL(runtime_instance); | MS_EXCEPTION_IF_NULL(runtime_instance); | ||||
| runtime_instance->AssignMemory(kernel_graph); | runtime_instance->AssignMemory(kernel_graph); | ||||
| @@ -41,6 +41,7 @@ class KernelGraph : public FuncGraph { | |||||
| KernelGraph() : graph_id_(0), start_label_(nullptr), end_goto_(nullptr), current_epoch_(0), is_dynamic_shape_(false) { | KernelGraph() : graph_id_(0), start_label_(nullptr), end_goto_(nullptr), current_epoch_(0), is_dynamic_shape_(false) { | ||||
| inputs_ = std::make_shared<std::vector<AnfNodePtr>>(); | inputs_ = std::make_shared<std::vector<AnfNodePtr>>(); | ||||
| execution_order_ = {}; | execution_order_ = {}; | ||||
| mem_reuse_exec_order_ = {}; | |||||
| executable_ = true; | executable_ = true; | ||||
| summary_node_exist_ = false; | summary_node_exist_ = false; | ||||
| stream_distinction_label_ = kInvalidDistincLabel; | stream_distinction_label_ = kInvalidDistincLabel; | ||||
| @@ -51,6 +52,7 @@ class KernelGraph : public FuncGraph { | |||||
| inputs_ = graph.inputs_; | inputs_ = graph.inputs_; | ||||
| child_graph_result_ = graph.child_graph_result_; | child_graph_result_ = graph.child_graph_result_; | ||||
| execution_order_ = graph.execution_order_; | execution_order_ = graph.execution_order_; | ||||
| mem_reuse_exec_order_ = graph.mem_reuse_exec_order_; | |||||
| graph_id_ = graph.graph_id_; | graph_id_ = graph.graph_id_; | ||||
| stream_distinction_label_ = graph.stream_distinction_label_; | stream_distinction_label_ = graph.stream_distinction_label_; | ||||
| front_backend_anf_map_ = graph.front_backend_anf_map_; | front_backend_anf_map_ = graph.front_backend_anf_map_; | ||||
| @@ -112,6 +114,9 @@ class KernelGraph : public FuncGraph { | |||||
| void set_execution_order(const std::vector<CNodePtr> &order) { execution_order_ = order; } | void set_execution_order(const std::vector<CNodePtr> &order) { execution_order_ = order; } | ||||
| void set_execution_order(std::vector<CNodePtr> &&order) { execution_order_ = std::move(order); } | void set_execution_order(std::vector<CNodePtr> &&order) { execution_order_ = std::move(order); } | ||||
| const std::vector<CNodePtr> &execution_order() const { return execution_order_; } | const std::vector<CNodePtr> &execution_order() const { return execution_order_; } | ||||
| // Set new exec_order for mem_reuse | |||||
| void set_mem_reuse_exec_order(const std::vector<CNodePtr> &order) { mem_reuse_exec_order_ = order; } | |||||
| const std::vector<CNodePtr> &mem_reuse_exec_order() const { return mem_reuse_exec_order_; } | |||||
| void SetExecOrderByDefault(); | void SetExecOrderByDefault(); | ||||
| uint32_t graph_id() const { return graph_id_; } | uint32_t graph_id() const { return graph_id_; } | ||||
| void set_graph_id(uint32_t graph_id) { graph_id_ = graph_id; } | void set_graph_id(uint32_t graph_id) { graph_id_ = graph_id; } | ||||
| @@ -278,6 +283,14 @@ class KernelGraph : public FuncGraph { | |||||
| uint32_t label_num() const { return label_num_; } | uint32_t label_num() const { return label_num_; } | ||||
| void set_label_num(uint32_t num) { label_num_ = num; } | void set_label_num(uint32_t num) { label_num_ = num; } | ||||
| // The graphs has recursion. | |||||
| bool recursive_call() const { return has_recursive_call_; } | |||||
| // The graphs has subgraph multi-call. | |||||
| bool subgraph_multi_call() const { return has_subgraph_multicall_; } | |||||
| // set flag to indicate whether has recursion. | |||||
| void set_recursive_call(bool flag) { has_recursive_call_ = flag; } | |||||
| // set flag to indicate whether has multi-call. | |||||
| void set_subgraph_multi_call(bool flag) { has_subgraph_multicall_ = flag; } | |||||
| private: | private: | ||||
| // remove value node form graph | // remove value node form graph | ||||
| @@ -307,6 +320,7 @@ class KernelGraph : public FuncGraph { | |||||
| std::shared_ptr<std::vector<AnfNodePtr>> inputs_; | std::shared_ptr<std::vector<AnfNodePtr>> inputs_; | ||||
| std::vector<AnfNodePtr> child_graph_result_; | std::vector<AnfNodePtr> child_graph_result_; | ||||
| std::vector<CNodePtr> execution_order_; | std::vector<CNodePtr> execution_order_; | ||||
| std::vector<CNodePtr> mem_reuse_exec_order_; | |||||
| // extra params and tensors for control flow | // extra params and tensors for control flow | ||||
| std::vector<std::pair<ParameterPtr, tensor::TensorPtr>> extra_param_tensor_; | std::vector<std::pair<ParameterPtr, tensor::TensorPtr>> extra_param_tensor_; | ||||
| uint32_t graph_id_; | uint32_t graph_id_; | ||||
| @@ -360,6 +374,10 @@ class KernelGraph : public FuncGraph { | |||||
| bool has_optimizer_{false}; | bool has_optimizer_{false}; | ||||
| bool is_dynamic_shape_{false}; | bool is_dynamic_shape_{false}; | ||||
| // Indicate the graphs has recursion or multi-call or not as the root graph. | |||||
| bool has_recursive_call_{false}; | |||||
| bool has_subgraph_multicall_{false}; | |||||
| // Number of labels. This is also the 'batch_num' for DavinciModel, | // Number of labels. This is also the 'batch_num' for DavinciModel, | ||||
| // It should be 1 if no labels used for control flow. | // It should be 1 if no labels used for control flow. | ||||
| uint32_t label_num_ = 1; | uint32_t label_num_ = 1; | ||||
| @@ -421,6 +421,10 @@ constexpr auto kAttrTopoSortRhsFirst = "topo_sort_rhs_first"; | |||||
| constexpr auto kAttrIgnoreSideEffect = "ignore_side_effect"; | constexpr auto kAttrIgnoreSideEffect = "ignore_side_effect"; | ||||
| constexpr auto kAttrSwitchLayer = "switch_layer"; | constexpr auto kAttrSwitchLayer = "switch_layer"; | ||||
| constexpr auto kAttrReturn = "return"; | constexpr auto kAttrReturn = "return"; | ||||
| constexpr auto kAttrRecursiveStart = "recursive_start"; | |||||
| constexpr auto kAttrRecursiveEnd = "recursive_end"; | |||||
| constexpr auto kAttrRecursive = "recursive"; | |||||
| constexpr auto kAttrMultiCallEnd = "multicall_end"; | |||||
| // attr value | // attr value | ||||
| constexpr auto kValueTargetSwitch = "target_switch"; | constexpr auto kValueTargetSwitch = "target_switch"; | ||||