From: @liangzelang Reviewed-by: @zhoufeng54,@kisnwang Signed-off-by: @kisnwangpull/15356/MERGE
| @@ -400,11 +400,17 @@ bool Somas::InitSomasTensors(const session::KernelGraph *graph) { | |||
| void Somas::InitSomasStreamAndNode(const session::KernelGraph *graph) { | |||
| MS_EXCEPTION_IF_NULL(graph); | |||
| std::vector<CNodePtr> kernel_cnodes; | |||
| streams_list_ = {}; | |||
| nodes_list_ = {}; | |||
| size_t node_index = 0; | |||
| auto kernel_cnodes = graph->execution_order(); | |||
| for (const auto &kernel : kernel_cnodes) { | |||
| if (graph->subgraph_multi_call()) { | |||
| kernel_cnodes = graph->mem_reuse_exec_order(); | |||
| } else { | |||
| kernel_cnodes = graph->execution_order(); | |||
| } | |||
| for (size_t i = 0; i < kernel_cnodes.size(); i++) { | |||
| auto kernel = kernel_cnodes[i]; | |||
| SomasStreamPtr stream; | |||
| auto stream_id = AnfAlgo::GetStreamId(kernel); | |||
| auto it = find_if(streams_list_.begin(), streams_list_.end(), | |||
| @@ -427,7 +433,8 @@ void Somas::InitSomasStreamAndNode(const session::KernelGraph *graph) { | |||
| nodes_list_.push_back(node); | |||
| stream->nodes_.push_back(node); | |||
| auto key = kernel.get(); | |||
| nodes_map_[key] = node; | |||
| auto &nodes = nodes_map_[key]; | |||
| nodes.push_back(node); | |||
| node_index++; | |||
| } | |||
| } | |||
| @@ -438,7 +445,8 @@ void Somas::InitSomasOutputAndWorkspaceTensors(const session::KernelGraph *graph | |||
| size_t tensor_index = 0; | |||
| auto kernel_cnodes = graph->execution_order(); | |||
| for (const auto &kernel : kernel_cnodes) { | |||
| auto node = nodes_map_[kernel.get()]; | |||
| auto nodes = nodes_map_[kernel.get()]; | |||
| auto node = nodes[0]; | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| auto stream = node->GetStream(); | |||
| MS_EXCEPTION_IF_NULL(stream); | |||
| @@ -454,7 +462,7 @@ void Somas::InitSomasOutputAndWorkspaceTensors(const session::KernelGraph *graph | |||
| // Set all output tensor lifelong to true. | |||
| auto tensor = std::make_shared<SomasTensor>(output_tensor_index, node, stream, size, kLifeLongNone); | |||
| tensor->lifetime_.start_ = node->GetId(); | |||
| tensor->lifetime_.end_ = node->GetId(); | |||
| tensor->lifetime_.end_ = (nodes.size() > 1) ? nodes.back()->GetId() : node->GetId(); | |||
| tensor->type_ = kOutputOnly; | |||
| if (AnfAlgo::OutputAddrExist(kernel, index)) { | |||
| tensor->aligned_size_ = 0; | |||
| @@ -463,8 +471,10 @@ void Somas::InitSomasOutputAndWorkspaceTensors(const session::KernelGraph *graph | |||
| tensors_list_.push_back(tensor); | |||
| tensors_map_[output_tensor_index] = tensor; | |||
| stream->tensors_.push_back(tensor); | |||
| node->tensors_.insert(tensor); | |||
| node->output_tensors_.push_back(tensor); | |||
| std::for_each(nodes.begin(), nodes.end(), [tensor](auto &node) { | |||
| node->tensors_.insert(tensor); | |||
| node->output_tensors_.push_back(tensor); | |||
| }); | |||
| index++; | |||
| } | |||
| @@ -477,15 +487,17 @@ void Somas::InitSomasOutputAndWorkspaceTensors(const session::KernelGraph *graph | |||
| SomasTensorPtr tensor = std::make_shared<SomasTensor>(workspace_tensor_index, node, stream, size, kLifeLongNone); | |||
| tensor->type_ = kWorkspace; | |||
| tensor->lifetime_.start_ = node->GetId(); | |||
| tensor->lifetime_.end_ = node->GetId(); | |||
| tensor->lifetime_.end_ = (nodes.size() > 1) ? nodes.back()->GetId() : node->GetId(); | |||
| if (AnfAlgo::WorkspaceAddrExist(kernel, index)) { | |||
| tensor->aligned_size_ = 0; | |||
| } | |||
| tensors_list_.push_back(tensor); | |||
| tensors_map_[workspace_tensor_index] = tensor; | |||
| stream->tensors_.push_back(tensor); | |||
| node->tensors_.insert(tensor); | |||
| node->workspace_tensors_.push_back(tensor); | |||
| std::for_each(nodes.begin(), nodes.end(), [tensor](auto &node) { | |||
| node->tensors_.insert(tensor); | |||
| node->workspace_tensors_.push_back(tensor); | |||
| }); | |||
| index++; | |||
| } | |||
| } | |||
| @@ -505,7 +517,8 @@ void Somas::InitSomasInputTensors(const session::KernelGraph *graph) { | |||
| } | |||
| } | |||
| void Somas::InitCommonNodeInputs(bool is_all_nop_node, const CNodePtr &kernel) { | |||
| auto node = nodes_map_[kernel.get()]; | |||
| auto nodes = nodes_map_[kernel.get()]; | |||
| auto node = nodes[0]; | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| auto stream = node->GetStream(); | |||
| MS_EXCEPTION_IF_NULL(stream); | |||
| @@ -543,7 +556,7 @@ void Somas::InitCommonNodeInputs(bool is_all_nop_node, const CNodePtr &kernel) { | |||
| MS_LOG(EXCEPTION) << "Kernel[" << kernel->fullname_with_scope() << "]'s input " << i << " [" | |||
| << prenode_index.first->fullname_with_scope() << "] is not init."; | |||
| } | |||
| auto pre_somas_node = iter->second; | |||
| auto pre_somas_node = iter->second.at(0); | |||
| if (prenode_index.second > pre_somas_node->output_tensors_.size()) { | |||
| MS_LOG(EXCEPTION) << "Output index " << prenode_index.second << " exceed input node [" | |||
| << prenode_index.first->fullname_with_scope() << "]'s outputs size " | |||
| @@ -551,15 +564,18 @@ void Somas::InitCommonNodeInputs(bool is_all_nop_node, const CNodePtr &kernel) { | |||
| } | |||
| auto input_somas_tensor = pre_somas_node->output_tensors_[prenode_index.second]; | |||
| MS_EXCEPTION_IF_NULL(input_somas_tensor); | |||
| node->input_tensors_.push_back(input_somas_tensor); | |||
| std::for_each(nodes.begin(), nodes.end(), | |||
| [input_somas_tensor](auto &node) { node->input_tensors_.push_back(input_somas_tensor); }); | |||
| real_input_index++; | |||
| if (input_somas_tensor->type_ == kOutputOnly) { | |||
| input_somas_tensor->type_ = kCommon; | |||
| } | |||
| input_somas_tensor->destinations_.insert(node); | |||
| input_somas_tensor->destinationStreams_.insert(stream); | |||
| if (input_somas_tensor->lifetime_.end_ < node->GetId()) { | |||
| input_somas_tensor->lifetime_.end_ = node->GetId(); | |||
| for (auto &repeat_node : nodes) { | |||
| input_somas_tensor->destinations_.insert(repeat_node); | |||
| if (input_somas_tensor->lifetime_.end_ < repeat_node->GetId()) { | |||
| input_somas_tensor->lifetime_.end_ = repeat_node->GetId(); | |||
| } | |||
| } | |||
| if (node != pre_somas_node) { | |||
| @@ -574,7 +590,7 @@ void Somas::InitCommonNodeInputs(bool is_all_nop_node, const CNodePtr &kernel) { | |||
| } | |||
| void Somas::InitAtomicCleanInputs(bool enable_fusion_clear, const CNodePtr &kernel) { | |||
| auto node = nodes_map_[kernel.get()]; | |||
| auto node = nodes_map_[kernel.get()].at(0); | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| auto stream = node->GetStream(); | |||
| MS_EXCEPTION_IF_NULL(stream); | |||
| @@ -588,7 +604,7 @@ void Somas::InitAtomicCleanInputs(bool enable_fusion_clear, const CNodePtr &kern | |||
| MS_LOG(EXCEPTION) << "Kernel[" << kernel->fullname_with_scope() << "]'s input [" | |||
| << pre_node->fullname_with_scope() << "] is not init."; | |||
| } | |||
| auto pre_somas_node = iter->second; | |||
| auto pre_somas_node = iter->second.at(0); | |||
| // set clean output tensors | |||
| if (AnfAlgo::HasNodeAttr(kAttrAtomicOutputIndexs, pre_node)) { | |||
| auto clean_output_indexs = AnfAlgo::GetNodeAttr<std::vector<size_t>>(pre_node, kAttrAtomicOutputIndexs); | |||
| @@ -698,7 +714,8 @@ void Somas::GetNextOutputProcess(const session::KernelGraph *graph) { | |||
| } | |||
| auto iter = nodes_map_.find(kernel.get()); | |||
| if (iter != nodes_map_.end()) { | |||
| auto getnext_output_tensors = iter->second->output_tensors_; | |||
| auto &node = iter->second.at(0); | |||
| auto getnext_output_tensors = node->output_tensors_; | |||
| for (auto &tensor : getnext_output_tensors) { | |||
| total_size += tensor->GetAlignedSize(); | |||
| tensor->lifelong_value_ = kLifeLongGraphAll; | |||
| @@ -720,7 +737,8 @@ void Somas::IndependentNodeOutputProcess(const session::KernelGraph *graph) { | |||
| } | |||
| auto iter = nodes_map_.find(kernel.get()); | |||
| if (iter != nodes_map_.end()) { | |||
| auto semi_reuse_output_tensors = iter->second->output_tensors_; | |||
| auto &node = iter->second.at(0); | |||
| auto semi_reuse_output_tensors = node->output_tensors_; | |||
| for (auto &tensor : semi_reuse_output_tensors) { | |||
| total_size += tensor->GetAlignedSize(); | |||
| tensor->lifelong_value_ = kLifeLongGraphAll; | |||
| @@ -749,9 +767,9 @@ void Somas::SummaryInputProcess(const session::KernelGraph *graph) { | |||
| size_t index = IntToSize(node_item.second.second); | |||
| auto iter = nodes_map_.find(node.get()); | |||
| if (iter != nodes_map_.end()) { | |||
| auto input_node = iter->second; | |||
| auto input_node = iter->second.at(0); | |||
| if (index < input_node->output_tensors_.size()) { | |||
| auto tensor = iter->second->output_tensors_[index]; | |||
| auto tensor = input_node->output_tensors_[index]; | |||
| tensor->lifelong_value_ = kLifeLongGraphAll; | |||
| tensor->type_ = kSummaryInput; | |||
| total_summary_size += tensor->GetAlignedSize(); | |||
| @@ -789,7 +807,8 @@ void Somas::RefNodeProcess(const session::KernelGraph *graph) { | |||
| if (graph->IsInRefOutputMap(out_pair)) { | |||
| auto origin_pair = graph->GetRefCorrespondOutput(out_pair); | |||
| MS_EXCEPTION_IF_NULL(origin_pair.first); | |||
| auto output_tensor = nodes_map_[kernel.get()]->output_tensors_[out_index]; | |||
| auto &node = nodes_map_[kernel.get()].at(0); | |||
| auto output_tensor = node->output_tensors_[out_index]; | |||
| MS_EXCEPTION_IF_NULL(output_tensor); | |||
| output_tensor->type_ = kRefNodeOutput; | |||
| total_output_size += size; | |||
| @@ -797,7 +816,8 @@ void Somas::RefNodeProcess(const session::KernelGraph *graph) { | |||
| if (AnfAlgo::IsRealCNodeKernel(origin_pair.first)) { | |||
| auto ori_node = origin_pair.first->cast<CNodePtr>(); | |||
| auto ori_index = origin_pair.second; | |||
| auto input_tensor = nodes_map_[ori_node.get()]->output_tensors_[ori_index]; | |||
| auto &repeat_node = nodes_map_[ori_node.get()].at(0); | |||
| auto input_tensor = repeat_node->output_tensors_[ori_index]; | |||
| MS_EXCEPTION_IF_NULL(input_tensor); | |||
| input_tensor->type_ = kRefNodeInput; | |||
| total_input_size += input_tensor->aligned_size_; | |||
| @@ -821,7 +841,7 @@ void Somas::NonTaskSplitProcess(const session::KernelGraph *graph) { | |||
| auto op_name = AnfAlgo::GetCNodeName(kernel); | |||
| if ((op_name == kSplitOpName || op_name == kSplitVOpName) && AnfAlgo::HasNodeAttr(kAttrNonTask, kernel)) { | |||
| std::vector<size_t> refnode_input_output; | |||
| auto node = nodes_map_[kernel.get()]; | |||
| auto node = nodes_map_[kernel.get()].at(0); | |||
| if (node->input_tensors_.size() == 0) { | |||
| MS_LOG(EXCEPTION) << op_name << " has no input tensor, can not do split non_task process."; | |||
| } | |||
| @@ -852,7 +872,7 @@ void Somas::UnReuseNodeProcess(const session::KernelGraph *graph) { | |||
| if (iter != full_name_list.end()) { | |||
| MS_LOG(INFO) << "Set UnReuse Node in somas, Node:" << full_name; | |||
| auto key = kernel.get(); | |||
| auto somas_node = nodes_map_[key]; | |||
| auto somas_node = nodes_map_[key].at(0); | |||
| // input | |||
| auto inputs = somas_node->input_tensors_; | |||
| for (auto &input : inputs) { | |||
| @@ -1749,11 +1769,12 @@ uint8_t *Somas::GetNodeOutputPtr(const AnfNodePtr &node, size_t index) const { | |||
| auto iter = nodes_map_.find(key); | |||
| uint8_t *ptr = nullptr; | |||
| if (iter != nodes_map_.end()) { | |||
| if (index >= iter->second->output_tensors_.size()) { | |||
| auto &node = iter->second.at(0); | |||
| if (index >= node->output_tensors_.size()) { | |||
| MS_LOG(EXCEPTION) << "index:[" << index << "] is larger than it's workspace size:[" | |||
| << iter->second->output_tensors_.size() << "]"; | |||
| << node->output_tensors_.size() << "]"; | |||
| } | |||
| auto output_tensor = iter->second->output_tensors_[index]; | |||
| auto output_tensor = node->output_tensors_[index]; | |||
| ptr = mem_base_addr_ + output_tensor->offset_; | |||
| } else { | |||
| MS_LOG(EXCEPTION) << "node [" << AnfAlgo::GetCNodeName(node) << "] don't exist in nodes_map"; | |||
| @@ -1766,11 +1787,12 @@ uint8_t *Somas::GetNodeWorkSpacePtr(const AnfNodePtr &node, size_t index) const | |||
| auto iter = nodes_map_.find(key); | |||
| uint8_t *ptr = nullptr; | |||
| if (iter != nodes_map_.end()) { | |||
| if (index >= iter->second->workspace_tensors_.size()) { | |||
| auto &node = iter->second.at(0); | |||
| if (index >= node->workspace_tensors_.size()) { | |||
| MS_LOG(EXCEPTION) << "index:[" << index << "] is larger than it's workspace size:[" | |||
| << iter->second->workspace_tensors_.size() << "]"; | |||
| << node->workspace_tensors_.size() << "]"; | |||
| } | |||
| auto workspace_tensor = iter->second->workspace_tensors_[index]; | |||
| auto workspace_tensor = node->workspace_tensors_[index]; | |||
| ptr = mem_base_addr_ + workspace_tensor->offset_; | |||
| } | |||
| return ptr; | |||
| @@ -63,7 +63,7 @@ class Somas { | |||
| std::string hash_id_; | |||
| // Maps | |||
| std::unordered_map<size_t, SomasTensorPtr> tensors_map_; | |||
| std::map<void *, SomasNodePtr> nodes_map_; | |||
| std::map<void *, std::vector<SomasNodePtr>> nodes_map_; | |||
| std::map<void *, vector<SomasParameterPtr>> parameters_map_; | |||
| // Vectors | |||
| @@ -296,6 +296,15 @@ class AscendAutoMonadContext : public BaseContext { | |||
| // Set flag to indicate whether has already created an stack or not. | |||
| void SetInitedStack(bool flag) { inited_stack_ = flag; } | |||
| // The graphs has recursion. | |||
| bool HasRecursiveCall() const { return has_recursive_call_; } | |||
| // The graphs has subgraph multi-call. | |||
| bool HasSubgraphMultiCall() const { return has_subgraph_multicall_; } | |||
| // set flag to indicate whether has recursion. | |||
| void SetRecursiveCall(bool flag) { has_recursive_call_ = flag; } | |||
| // set flag to indicate whether has multi-call. | |||
| void SetSubGraphMultiCall(bool flag) { has_subgraph_multicall_ = flag; } | |||
| // Map kernel_graph to its call info. | |||
| OrderedMap<KernelGraphPtr, CallInfo> call_info_map; | |||
| @@ -311,6 +320,10 @@ class AscendAutoMonadContext : public BaseContext { | |||
| // Create an stack for multi-call and non-tail recursion. | |||
| bool inited_stack_ = false; | |||
| // The graphs has recursion or not. | |||
| bool has_recursive_call_ = false; | |||
| // The graphs has subgraph multi-call or not. | |||
| bool has_subgraph_multicall_ = false; | |||
| }; | |||
| // | |||
| @@ -643,6 +656,11 @@ class AscendAutoMonadConverter { | |||
| } | |||
| // Handle recursive call. | |||
| kernel_graph_->SetExecOrderByDefault(); | |||
| if (call_info_.recursive) { | |||
| const auto &nodes = kernel_graph_->execution_order(); | |||
| AnfAlgo::SetNodeAttr(kAttrRecursiveStart, prim::kValueOne, *nodes.begin()); | |||
| AnfAlgo::SetNodeAttr(kAttrRecursiveEnd, prim::kValueOne, *nodes.rbegin()); | |||
| } | |||
| for (auto &call_site : call_info_.call_sites) { | |||
| if (need_stackops_ && call_site.recursive) { | |||
| MS_LOG(INFO) << "graph:" << kernel_graph_->ToString() << ", loop call_site:" << call_site.cnode->DebugString(); | |||
| @@ -661,6 +679,7 @@ class AscendAutoMonadConverter { | |||
| auto stack_destroy = StackDestroy(top_graph); | |||
| AnfAlgo::KeepOrder(top_graph, *exec_order.rbegin(), stack_destroy); | |||
| top_graph->SetExecOrderByDefault(); | |||
| context_.SetRecursiveCall(true); | |||
| context_.SetInitedStack(true); | |||
| } | |||
| } | |||
| @@ -812,6 +831,9 @@ class AscendAutoMonadConverter { | |||
| // Create LabelGoto or LabelSwitch node. | |||
| auto label_goto_switch = MakeLabelGotoSwitch(cnode, graphes, labels); | |||
| call_site->conversion_cnode = label_goto_switch; | |||
| if (call_site->recursive) { | |||
| AnfAlgo::SetNodeAttr(kAttrRecursive, prim::kValueOne, label_goto_switch); | |||
| } | |||
| // Setup return label and output if required. | |||
| if (call_site->return_label != kNoLabel) { | |||
| @@ -931,7 +953,11 @@ class AscendAutoMonadConverter { | |||
| MS_EXCEPTION_IF_NULL(label_param); | |||
| auto return_switch = LabelSwitch(label_param, return_labels); | |||
| AnfAlgo::SetNodeAttr(kAttrReturn, prim::kValueOne, return_switch); | |||
| if (!call_info_.recursive) { | |||
| AnfAlgo::SetNodeAttr(kAttrMultiCallEnd, prim::kValueOne, return_switch); | |||
| } | |||
| kernel_graph_->set_end_goto(return_switch); | |||
| context_.SetSubGraphMultiCall(true); | |||
| } | |||
| // Assign graph output to the output parameter. | |||
| @@ -1650,6 +1676,8 @@ void AscendAutoMonad::Run() { | |||
| CallInfoFinder::Run(&context); | |||
| AscendAutoMonadConverter::Run(&context); | |||
| kernel_graph_->set_label_num(context.CurrentLabel() + 1); | |||
| kernel_graph_->set_recursive_call(context.HasRecursiveCall()); | |||
| kernel_graph_->set_subgraph_multi_call(context.HasSubgraphMultiCall()); | |||
| MS_LOG(DEBUG) << "Ascend auto-monad finish."; | |||
| DumpGraphForDebug(kernel_graph_); | |||
| } | |||
| @@ -1034,9 +1034,139 @@ void AscendSession::BuildDynamicKernel(const std::shared_ptr<KernelGraph> &kerne | |||
| MS_LOG(INFO) << "Finish!"; | |||
| } | |||
| static CNodePtr GetNextLabelSet(const std::vector<CNodePtr> &kernel_nodes, uint32_t index) { | |||
| uint32_t node_sizes = kernel_nodes.size(); | |||
| if (index >= node_sizes - 1) { | |||
| MS_LOG(EXCEPTION) << "there is no node after this node:" << kernel_nodes[index]->DebugString(); | |||
| } | |||
| auto kernel = kernel_nodes[index + 1]; | |||
| if (AnfAlgo::GetCNodeName(kernel) != kLabelSetOpName) { | |||
| MS_LOG(EXCEPTION) << "the node is not labelset follow labelgoto/labelswitch, node: " | |||
| << kernel_nodes[index]->DebugString(); | |||
| } | |||
| return kernel; | |||
| } | |||
| static std::vector<CNodePtr> HandleRecursiveCall(const std::vector<CNodePtr> &kernel_cnodes, const uint32_t &back_label, | |||
| uint32_t *index, std::vector<CNodePtr> *back) { | |||
| MS_EXCEPTION_IF_NULL(index); | |||
| MS_EXCEPTION_IF_NULL(back); | |||
| std::vector<CNodePtr> front; | |||
| std::vector<CNodePtr> back_temp; | |||
| bool back_flag = false; | |||
| for (uint32_t i = *index; i < kernel_cnodes.size(); i++) { | |||
| if (!back_flag) { | |||
| front.emplace_back(kernel_cnodes[i]); | |||
| } else { | |||
| back->emplace_back(kernel_cnodes[i]); | |||
| } | |||
| if (AnfAlgo::HasNodeAttr(kAttrRecursiveEnd, kernel_cnodes[i])) { | |||
| *index = i; | |||
| back->insert(back->end(), back_temp.begin(), back_temp.end()); | |||
| return front; | |||
| } | |||
| if (AnfAlgo::HasNodeAttr(kAttrRecursive, kernel_cnodes[i])) { | |||
| back_flag = true; | |||
| if (AnfAlgo::IsLabelIndexInNode(kernel_cnodes[i], back_label)) { | |||
| continue; | |||
| } else { | |||
| auto temp = HandleRecursiveCall(kernel_cnodes, back_label, &(++i), &back_temp); | |||
| front.insert(front.end(), temp.begin(), temp.end()); | |||
| continue; | |||
| } | |||
| } | |||
| } | |||
| return front; | |||
| } | |||
| static void UnfoldRecursiveExecOrder(KernelGraph *kernel_graph) { | |||
| MS_EXCEPTION_IF_NULL(kernel_graph); | |||
| if (!kernel_graph->recursive_call()) { | |||
| return; | |||
| } | |||
| auto kernel_cnodes = kernel_graph->mem_reuse_exec_order(); | |||
| std::vector<CNodePtr> mem_reuse_order; | |||
| mem_reuse_order.reserve(kernel_cnodes.size()); | |||
| for (uint32_t i = 0; i < kernel_cnodes.size(); i++) { | |||
| if (!AnfAlgo::HasNodeAttr(kAttrRecursiveStart, kernel_cnodes[i])) { | |||
| mem_reuse_order.emplace_back(kernel_cnodes[i]); | |||
| continue; | |||
| } | |||
| auto label_id = AnfAlgo::GetNodeAttr<uint32_t>(kernel_cnodes[i], kAttrLabelIndex); | |||
| std::vector<CNodePtr> back; | |||
| auto front = HandleRecursiveCall(kernel_cnodes, label_id, &i, &back); | |||
| mem_reuse_order.insert(mem_reuse_order.end(), front.begin(), front.end()); | |||
| mem_reuse_order.insert(mem_reuse_order.end(), back.begin(), back.end()); | |||
| } | |||
| kernel_graph->set_mem_reuse_exec_order(mem_reuse_order); | |||
| } | |||
| static void GetSubGraphExecOrder(const KernelGraph *kernel_graph, uint32_t index, const CNodePtr &back_node, | |||
| std::vector<CNodePtr> *mem_reuse_order) { | |||
| MS_EXCEPTION_IF_NULL(kernel_graph); | |||
| MS_EXCEPTION_IF_NULL(mem_reuse_order); | |||
| auto label_id = AnfAlgo::GetNodeAttr<uint32_t>(back_node, kAttrLabelIndex); | |||
| auto kernel_cnodes = kernel_graph->execution_order(); | |||
| for (auto i = index; i < kernel_cnodes.size(); i++) { | |||
| mem_reuse_order->emplace_back(kernel_cnodes[i]); | |||
| if (AnfAlgo::IsLabelIndexInNode(kernel_cnodes[i], label_id)) { | |||
| return; | |||
| } | |||
| } | |||
| } | |||
| void InitMemReuseExecOrder(KernelGraph *kernel_graph) { | |||
| MS_EXCEPTION_IF_NULL(kernel_graph); | |||
| if (!kernel_graph->subgraph_multi_call()) { | |||
| return; | |||
| } | |||
| std::unordered_map<uint32_t, uint32_t> label_id_index_map; | |||
| auto kernel_cnodes = kernel_graph->execution_order(); | |||
| std::vector<CNodePtr> mem_reuse_order; | |||
| for (size_t i = 0; i < kernel_cnodes.size(); i++) { | |||
| mem_reuse_order.emplace_back(kernel_cnodes[i]); | |||
| if (AnfAlgo::CheckPrimitiveType(kernel_cnodes[i], prim::kPrimLabelSwitch) && | |||
| !AnfAlgo::HasNodeAttr(kAttrRecursive, kernel_cnodes[i]) && | |||
| !AnfAlgo::HasNodeAttr(kAttrReturn, kernel_cnodes[i])) { | |||
| auto label_list = AnfAlgo::GetNodeAttr<std::vector<uint32_t>>(kernel_cnodes[i], kAttrLabelSwitchList); | |||
| for (auto label_id : label_list) { | |||
| if (label_id_index_map.find(label_id) == label_id_index_map.end()) { | |||
| continue; | |||
| } | |||
| auto back_node = GetNextLabelSet(kernel_cnodes, i); | |||
| GetSubGraphExecOrder(kernel_graph, label_id_index_map[label_id], back_node, &mem_reuse_order); | |||
| } | |||
| continue; | |||
| } | |||
| if (AnfAlgo::CheckPrimitiveType(kernel_cnodes[i], prim::kPrimLabelGoto) && | |||
| !AnfAlgo::HasNodeAttr(kAttrRecursive, kernel_cnodes[i]) && | |||
| !AnfAlgo::HasNodeAttr(kAttrReturn, kernel_cnodes[i])) { | |||
| auto label_id = AnfAlgo::GetNodeAttr<uint32_t>(kernel_cnodes[i], kAttrLabelIndex); | |||
| if (label_id_index_map.find(label_id) == label_id_index_map.end()) { | |||
| continue; | |||
| } | |||
| auto back_node = GetNextLabelSet(kernel_cnodes, i); | |||
| GetSubGraphExecOrder(kernel_graph, label_id_index_map[label_id], back_node, &mem_reuse_order); | |||
| continue; | |||
| } | |||
| if (AnfAlgo::CheckPrimitiveType(kernel_cnodes[i], prim::kPrimLabelSet) && | |||
| !AnfAlgo::HasNodeAttr(kAttrRecursive, kernel_cnodes[i])) { | |||
| auto label_id = AnfAlgo::GetNodeAttr<uint32_t>(kernel_cnodes[i], kAttrLabelIndex); | |||
| if (label_id_index_map.find(label_id) != label_id_index_map.end()) { | |||
| MS_LOG(EXCEPTION) << "Two labelsets with same label id."; | |||
| } | |||
| label_id_index_map[label_id] = i; | |||
| continue; | |||
| } | |||
| } | |||
| kernel_graph->set_mem_reuse_exec_order(mem_reuse_order); | |||
| UnfoldRecursiveExecOrder(kernel_graph); | |||
| } | |||
| void AscendSession::MemoryAlloc(KernelGraph *kernel_graph) const { | |||
| MS_LOG(INFO) << "Start!"; | |||
| MS_EXCEPTION_IF_NULL(kernel_graph); | |||
| InitMemReuseExecOrder(kernel_graph); | |||
| auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id_); | |||
| MS_EXCEPTION_IF_NULL(runtime_instance); | |||
| runtime_instance->AssignMemory(kernel_graph); | |||
| @@ -41,6 +41,7 @@ class KernelGraph : public FuncGraph { | |||
| KernelGraph() : graph_id_(0), start_label_(nullptr), end_goto_(nullptr), current_epoch_(0), is_dynamic_shape_(false) { | |||
| inputs_ = std::make_shared<std::vector<AnfNodePtr>>(); | |||
| execution_order_ = {}; | |||
| mem_reuse_exec_order_ = {}; | |||
| executable_ = true; | |||
| summary_node_exist_ = false; | |||
| stream_distinction_label_ = kInvalidDistincLabel; | |||
| @@ -51,6 +52,7 @@ class KernelGraph : public FuncGraph { | |||
| inputs_ = graph.inputs_; | |||
| child_graph_result_ = graph.child_graph_result_; | |||
| execution_order_ = graph.execution_order_; | |||
| mem_reuse_exec_order_ = graph.mem_reuse_exec_order_; | |||
| graph_id_ = graph.graph_id_; | |||
| stream_distinction_label_ = graph.stream_distinction_label_; | |||
| front_backend_anf_map_ = graph.front_backend_anf_map_; | |||
| @@ -112,6 +114,9 @@ class KernelGraph : public FuncGraph { | |||
| void set_execution_order(const std::vector<CNodePtr> &order) { execution_order_ = order; } | |||
| void set_execution_order(std::vector<CNodePtr> &&order) { execution_order_ = std::move(order); } | |||
| const std::vector<CNodePtr> &execution_order() const { return execution_order_; } | |||
| // Set new exec_order for mem_reuse | |||
| void set_mem_reuse_exec_order(const std::vector<CNodePtr> &order) { mem_reuse_exec_order_ = order; } | |||
| const std::vector<CNodePtr> &mem_reuse_exec_order() const { return mem_reuse_exec_order_; } | |||
| void SetExecOrderByDefault(); | |||
| uint32_t graph_id() const { return graph_id_; } | |||
| void set_graph_id(uint32_t graph_id) { graph_id_ = graph_id; } | |||
| @@ -278,6 +283,14 @@ class KernelGraph : public FuncGraph { | |||
| uint32_t label_num() const { return label_num_; } | |||
| void set_label_num(uint32_t num) { label_num_ = num; } | |||
| // The graphs has recursion. | |||
| bool recursive_call() const { return has_recursive_call_; } | |||
| // The graphs has subgraph multi-call. | |||
| bool subgraph_multi_call() const { return has_subgraph_multicall_; } | |||
| // set flag to indicate whether has recursion. | |||
| void set_recursive_call(bool flag) { has_recursive_call_ = flag; } | |||
| // set flag to indicate whether has multi-call. | |||
| void set_subgraph_multi_call(bool flag) { has_subgraph_multicall_ = flag; } | |||
| private: | |||
| // remove value node form graph | |||
| @@ -307,6 +320,7 @@ class KernelGraph : public FuncGraph { | |||
| std::shared_ptr<std::vector<AnfNodePtr>> inputs_; | |||
| std::vector<AnfNodePtr> child_graph_result_; | |||
| std::vector<CNodePtr> execution_order_; | |||
| std::vector<CNodePtr> mem_reuse_exec_order_; | |||
| // extra params and tensors for control flow | |||
| std::vector<std::pair<ParameterPtr, tensor::TensorPtr>> extra_param_tensor_; | |||
| uint32_t graph_id_; | |||
| @@ -360,6 +374,10 @@ class KernelGraph : public FuncGraph { | |||
| bool has_optimizer_{false}; | |||
| bool is_dynamic_shape_{false}; | |||
| // Indicate the graphs has recursion or multi-call or not as the root graph. | |||
| bool has_recursive_call_{false}; | |||
| bool has_subgraph_multicall_{false}; | |||
| // Number of labels. This is also the 'batch_num' for DavinciModel, | |||
| // It should be 1 if no labels used for control flow. | |||
| uint32_t label_num_ = 1; | |||
| @@ -421,6 +421,10 @@ constexpr auto kAttrTopoSortRhsFirst = "topo_sort_rhs_first"; | |||
| constexpr auto kAttrIgnoreSideEffect = "ignore_side_effect"; | |||
| constexpr auto kAttrSwitchLayer = "switch_layer"; | |||
| constexpr auto kAttrReturn = "return"; | |||
| constexpr auto kAttrRecursiveStart = "recursive_start"; | |||
| constexpr auto kAttrRecursiveEnd = "recursive_end"; | |||
| constexpr auto kAttrRecursive = "recursive"; | |||
| constexpr auto kAttrMultiCallEnd = "multicall_end"; | |||
| // attr value | |||
| constexpr auto kValueTargetSwitch = "target_switch"; | |||