diff --git a/mindspore/ccsrc/backend/optimizer/somas/somas.cc b/mindspore/ccsrc/backend/optimizer/somas/somas.cc index 90e75e9a56..80ce15b63c 100644 --- a/mindspore/ccsrc/backend/optimizer/somas/somas.cc +++ b/mindspore/ccsrc/backend/optimizer/somas/somas.cc @@ -400,11 +400,17 @@ bool Somas::InitSomasTensors(const session::KernelGraph *graph) { void Somas::InitSomasStreamAndNode(const session::KernelGraph *graph) { MS_EXCEPTION_IF_NULL(graph); + std::vector kernel_cnodes; streams_list_ = {}; nodes_list_ = {}; size_t node_index = 0; - auto kernel_cnodes = graph->execution_order(); - for (const auto &kernel : kernel_cnodes) { + if (graph->subgraph_multi_call()) { + kernel_cnodes = graph->mem_reuse_exec_order(); + } else { + kernel_cnodes = graph->execution_order(); + } + for (size_t i = 0; i < kernel_cnodes.size(); i++) { + auto kernel = kernel_cnodes[i]; SomasStreamPtr stream; auto stream_id = AnfAlgo::GetStreamId(kernel); auto it = find_if(streams_list_.begin(), streams_list_.end(), @@ -427,7 +433,8 @@ void Somas::InitSomasStreamAndNode(const session::KernelGraph *graph) { nodes_list_.push_back(node); stream->nodes_.push_back(node); auto key = kernel.get(); - nodes_map_[key] = node; + auto &nodes = nodes_map_[key]; + nodes.push_back(node); node_index++; } } @@ -438,7 +445,8 @@ void Somas::InitSomasOutputAndWorkspaceTensors(const session::KernelGraph *graph size_t tensor_index = 0; auto kernel_cnodes = graph->execution_order(); for (const auto &kernel : kernel_cnodes) { - auto node = nodes_map_[kernel.get()]; + auto nodes = nodes_map_[kernel.get()]; + auto node = nodes[0]; MS_EXCEPTION_IF_NULL(node); auto stream = node->GetStream(); MS_EXCEPTION_IF_NULL(stream); @@ -454,7 +462,7 @@ void Somas::InitSomasOutputAndWorkspaceTensors(const session::KernelGraph *graph // Set all output tensor lifelong to true. auto tensor = std::make_shared(output_tensor_index, node, stream, size, kLifeLongNone); tensor->lifetime_.start_ = node->GetId(); - tensor->lifetime_.end_ = node->GetId(); + tensor->lifetime_.end_ = (nodes.size() > 1) ? nodes.back()->GetId() : node->GetId(); tensor->type_ = kOutputOnly; if (AnfAlgo::OutputAddrExist(kernel, index)) { tensor->aligned_size_ = 0; @@ -463,8 +471,10 @@ void Somas::InitSomasOutputAndWorkspaceTensors(const session::KernelGraph *graph tensors_list_.push_back(tensor); tensors_map_[output_tensor_index] = tensor; stream->tensors_.push_back(tensor); - node->tensors_.insert(tensor); - node->output_tensors_.push_back(tensor); + std::for_each(nodes.begin(), nodes.end(), [tensor](auto &node) { + node->tensors_.insert(tensor); + node->output_tensors_.push_back(tensor); + }); index++; } @@ -477,15 +487,17 @@ void Somas::InitSomasOutputAndWorkspaceTensors(const session::KernelGraph *graph SomasTensorPtr tensor = std::make_shared(workspace_tensor_index, node, stream, size, kLifeLongNone); tensor->type_ = kWorkspace; tensor->lifetime_.start_ = node->GetId(); - tensor->lifetime_.end_ = node->GetId(); + tensor->lifetime_.end_ = (nodes.size() > 1) ? nodes.back()->GetId() : node->GetId(); if (AnfAlgo::WorkspaceAddrExist(kernel, index)) { tensor->aligned_size_ = 0; } tensors_list_.push_back(tensor); tensors_map_[workspace_tensor_index] = tensor; stream->tensors_.push_back(tensor); - node->tensors_.insert(tensor); - node->workspace_tensors_.push_back(tensor); + std::for_each(nodes.begin(), nodes.end(), [tensor](auto &node) { + node->tensors_.insert(tensor); + node->workspace_tensors_.push_back(tensor); + }); index++; } } @@ -505,7 +517,8 @@ void Somas::InitSomasInputTensors(const session::KernelGraph *graph) { } } void Somas::InitCommonNodeInputs(bool is_all_nop_node, const CNodePtr &kernel) { - auto node = nodes_map_[kernel.get()]; + auto nodes = nodes_map_[kernel.get()]; + auto node = nodes[0]; MS_EXCEPTION_IF_NULL(node); auto stream = node->GetStream(); MS_EXCEPTION_IF_NULL(stream); @@ -543,7 +556,7 @@ void Somas::InitCommonNodeInputs(bool is_all_nop_node, const CNodePtr &kernel) { MS_LOG(EXCEPTION) << "Kernel[" << kernel->fullname_with_scope() << "]'s input " << i << " [" << prenode_index.first->fullname_with_scope() << "] is not init."; } - auto pre_somas_node = iter->second; + auto pre_somas_node = iter->second.at(0); if (prenode_index.second > pre_somas_node->output_tensors_.size()) { MS_LOG(EXCEPTION) << "Output index " << prenode_index.second << " exceed input node [" << prenode_index.first->fullname_with_scope() << "]'s outputs size " @@ -551,15 +564,18 @@ void Somas::InitCommonNodeInputs(bool is_all_nop_node, const CNodePtr &kernel) { } auto input_somas_tensor = pre_somas_node->output_tensors_[prenode_index.second]; MS_EXCEPTION_IF_NULL(input_somas_tensor); - node->input_tensors_.push_back(input_somas_tensor); + std::for_each(nodes.begin(), nodes.end(), + [input_somas_tensor](auto &node) { node->input_tensors_.push_back(input_somas_tensor); }); real_input_index++; if (input_somas_tensor->type_ == kOutputOnly) { input_somas_tensor->type_ = kCommon; } - input_somas_tensor->destinations_.insert(node); input_somas_tensor->destinationStreams_.insert(stream); - if (input_somas_tensor->lifetime_.end_ < node->GetId()) { - input_somas_tensor->lifetime_.end_ = node->GetId(); + for (auto &repeat_node : nodes) { + input_somas_tensor->destinations_.insert(repeat_node); + if (input_somas_tensor->lifetime_.end_ < repeat_node->GetId()) { + input_somas_tensor->lifetime_.end_ = repeat_node->GetId(); + } } if (node != pre_somas_node) { @@ -574,7 +590,7 @@ void Somas::InitCommonNodeInputs(bool is_all_nop_node, const CNodePtr &kernel) { } void Somas::InitAtomicCleanInputs(bool enable_fusion_clear, const CNodePtr &kernel) { - auto node = nodes_map_[kernel.get()]; + auto node = nodes_map_[kernel.get()].at(0); MS_EXCEPTION_IF_NULL(node); auto stream = node->GetStream(); MS_EXCEPTION_IF_NULL(stream); @@ -588,7 +604,7 @@ void Somas::InitAtomicCleanInputs(bool enable_fusion_clear, const CNodePtr &kern MS_LOG(EXCEPTION) << "Kernel[" << kernel->fullname_with_scope() << "]'s input [" << pre_node->fullname_with_scope() << "] is not init."; } - auto pre_somas_node = iter->second; + auto pre_somas_node = iter->second.at(0); // set clean output tensors if (AnfAlgo::HasNodeAttr(kAttrAtomicOutputIndexs, pre_node)) { auto clean_output_indexs = AnfAlgo::GetNodeAttr>(pre_node, kAttrAtomicOutputIndexs); @@ -698,7 +714,8 @@ void Somas::GetNextOutputProcess(const session::KernelGraph *graph) { } auto iter = nodes_map_.find(kernel.get()); if (iter != nodes_map_.end()) { - auto getnext_output_tensors = iter->second->output_tensors_; + auto &node = iter->second.at(0); + auto getnext_output_tensors = node->output_tensors_; for (auto &tensor : getnext_output_tensors) { total_size += tensor->GetAlignedSize(); tensor->lifelong_value_ = kLifeLongGraphAll; @@ -720,7 +737,8 @@ void Somas::IndependentNodeOutputProcess(const session::KernelGraph *graph) { } auto iter = nodes_map_.find(kernel.get()); if (iter != nodes_map_.end()) { - auto semi_reuse_output_tensors = iter->second->output_tensors_; + auto &node = iter->second.at(0); + auto semi_reuse_output_tensors = node->output_tensors_; for (auto &tensor : semi_reuse_output_tensors) { total_size += tensor->GetAlignedSize(); tensor->lifelong_value_ = kLifeLongGraphAll; @@ -749,9 +767,9 @@ void Somas::SummaryInputProcess(const session::KernelGraph *graph) { size_t index = IntToSize(node_item.second.second); auto iter = nodes_map_.find(node.get()); if (iter != nodes_map_.end()) { - auto input_node = iter->second; + auto input_node = iter->second.at(0); if (index < input_node->output_tensors_.size()) { - auto tensor = iter->second->output_tensors_[index]; + auto tensor = input_node->output_tensors_[index]; tensor->lifelong_value_ = kLifeLongGraphAll; tensor->type_ = kSummaryInput; total_summary_size += tensor->GetAlignedSize(); @@ -789,7 +807,8 @@ void Somas::RefNodeProcess(const session::KernelGraph *graph) { if (graph->IsInRefOutputMap(out_pair)) { auto origin_pair = graph->GetRefCorrespondOutput(out_pair); MS_EXCEPTION_IF_NULL(origin_pair.first); - auto output_tensor = nodes_map_[kernel.get()]->output_tensors_[out_index]; + auto &node = nodes_map_[kernel.get()].at(0); + auto output_tensor = node->output_tensors_[out_index]; MS_EXCEPTION_IF_NULL(output_tensor); output_tensor->type_ = kRefNodeOutput; total_output_size += size; @@ -797,7 +816,8 @@ void Somas::RefNodeProcess(const session::KernelGraph *graph) { if (AnfAlgo::IsRealCNodeKernel(origin_pair.first)) { auto ori_node = origin_pair.first->cast(); auto ori_index = origin_pair.second; - auto input_tensor = nodes_map_[ori_node.get()]->output_tensors_[ori_index]; + auto &repeat_node = nodes_map_[ori_node.get()].at(0); + auto input_tensor = repeat_node->output_tensors_[ori_index]; MS_EXCEPTION_IF_NULL(input_tensor); input_tensor->type_ = kRefNodeInput; total_input_size += input_tensor->aligned_size_; @@ -821,7 +841,7 @@ void Somas::NonTaskSplitProcess(const session::KernelGraph *graph) { auto op_name = AnfAlgo::GetCNodeName(kernel); if ((op_name == kSplitOpName || op_name == kSplitVOpName) && AnfAlgo::HasNodeAttr(kAttrNonTask, kernel)) { std::vector refnode_input_output; - auto node = nodes_map_[kernel.get()]; + auto node = nodes_map_[kernel.get()].at(0); if (node->input_tensors_.size() == 0) { MS_LOG(EXCEPTION) << op_name << " has no input tensor, can not do split non_task process."; } @@ -852,7 +872,7 @@ void Somas::UnReuseNodeProcess(const session::KernelGraph *graph) { if (iter != full_name_list.end()) { MS_LOG(INFO) << "Set UnReuse Node in somas, Node:" << full_name; auto key = kernel.get(); - auto somas_node = nodes_map_[key]; + auto somas_node = nodes_map_[key].at(0); // input auto inputs = somas_node->input_tensors_; for (auto &input : inputs) { @@ -1749,11 +1769,12 @@ uint8_t *Somas::GetNodeOutputPtr(const AnfNodePtr &node, size_t index) const { auto iter = nodes_map_.find(key); uint8_t *ptr = nullptr; if (iter != nodes_map_.end()) { - if (index >= iter->second->output_tensors_.size()) { + auto &node = iter->second.at(0); + if (index >= node->output_tensors_.size()) { MS_LOG(EXCEPTION) << "index:[" << index << "] is larger than it's workspace size:[" - << iter->second->output_tensors_.size() << "]"; + << node->output_tensors_.size() << "]"; } - auto output_tensor = iter->second->output_tensors_[index]; + auto output_tensor = node->output_tensors_[index]; ptr = mem_base_addr_ + output_tensor->offset_; } else { MS_LOG(EXCEPTION) << "node [" << AnfAlgo::GetCNodeName(node) << "] don't exist in nodes_map"; @@ -1766,11 +1787,12 @@ uint8_t *Somas::GetNodeWorkSpacePtr(const AnfNodePtr &node, size_t index) const auto iter = nodes_map_.find(key); uint8_t *ptr = nullptr; if (iter != nodes_map_.end()) { - if (index >= iter->second->workspace_tensors_.size()) { + auto &node = iter->second.at(0); + if (index >= node->workspace_tensors_.size()) { MS_LOG(EXCEPTION) << "index:[" << index << "] is larger than it's workspace size:[" - << iter->second->workspace_tensors_.size() << "]"; + << node->workspace_tensors_.size() << "]"; } - auto workspace_tensor = iter->second->workspace_tensors_[index]; + auto workspace_tensor = node->workspace_tensors_[index]; ptr = mem_base_addr_ + workspace_tensor->offset_; } return ptr; diff --git a/mindspore/ccsrc/backend/optimizer/somas/somas.h b/mindspore/ccsrc/backend/optimizer/somas/somas.h index 4a04ab4323..84197a04f6 100644 --- a/mindspore/ccsrc/backend/optimizer/somas/somas.h +++ b/mindspore/ccsrc/backend/optimizer/somas/somas.h @@ -63,7 +63,7 @@ class Somas { std::string hash_id_; // Maps std::unordered_map tensors_map_; - std::map nodes_map_; + std::map> nodes_map_; std::map> parameters_map_; // Vectors diff --git a/mindspore/ccsrc/backend/session/ascend_auto_monad.cc b/mindspore/ccsrc/backend/session/ascend_auto_monad.cc index 181941e610..985d376cdf 100644 --- a/mindspore/ccsrc/backend/session/ascend_auto_monad.cc +++ b/mindspore/ccsrc/backend/session/ascend_auto_monad.cc @@ -296,6 +296,15 @@ class AscendAutoMonadContext : public BaseContext { // Set flag to indicate whether has already created an stack or not. void SetInitedStack(bool flag) { inited_stack_ = flag; } + // The graphs has recursion. + bool HasRecursiveCall() const { return has_recursive_call_; } + // The graphs has subgraph multi-call. + bool HasSubgraphMultiCall() const { return has_subgraph_multicall_; } + // set flag to indicate whether has recursion. + void SetRecursiveCall(bool flag) { has_recursive_call_ = flag; } + // set flag to indicate whether has multi-call. + void SetSubGraphMultiCall(bool flag) { has_subgraph_multicall_ = flag; } + // Map kernel_graph to its call info. OrderedMap call_info_map; @@ -311,6 +320,10 @@ class AscendAutoMonadContext : public BaseContext { // Create an stack for multi-call and non-tail recursion. bool inited_stack_ = false; + // The graphs has recursion or not. + bool has_recursive_call_ = false; + // The graphs has subgraph multi-call or not. + bool has_subgraph_multicall_ = false; }; // @@ -643,6 +656,11 @@ class AscendAutoMonadConverter { } // Handle recursive call. kernel_graph_->SetExecOrderByDefault(); + if (call_info_.recursive) { + const auto &nodes = kernel_graph_->execution_order(); + AnfAlgo::SetNodeAttr(kAttrRecursiveStart, prim::kValueOne, *nodes.begin()); + AnfAlgo::SetNodeAttr(kAttrRecursiveEnd, prim::kValueOne, *nodes.rbegin()); + } for (auto &call_site : call_info_.call_sites) { if (need_stackops_ && call_site.recursive) { MS_LOG(INFO) << "graph:" << kernel_graph_->ToString() << ", loop call_site:" << call_site.cnode->DebugString(); @@ -661,6 +679,7 @@ class AscendAutoMonadConverter { auto stack_destroy = StackDestroy(top_graph); AnfAlgo::KeepOrder(top_graph, *exec_order.rbegin(), stack_destroy); top_graph->SetExecOrderByDefault(); + context_.SetRecursiveCall(true); context_.SetInitedStack(true); } } @@ -812,6 +831,9 @@ class AscendAutoMonadConverter { // Create LabelGoto or LabelSwitch node. auto label_goto_switch = MakeLabelGotoSwitch(cnode, graphes, labels); call_site->conversion_cnode = label_goto_switch; + if (call_site->recursive) { + AnfAlgo::SetNodeAttr(kAttrRecursive, prim::kValueOne, label_goto_switch); + } // Setup return label and output if required. if (call_site->return_label != kNoLabel) { @@ -931,7 +953,11 @@ class AscendAutoMonadConverter { MS_EXCEPTION_IF_NULL(label_param); auto return_switch = LabelSwitch(label_param, return_labels); AnfAlgo::SetNodeAttr(kAttrReturn, prim::kValueOne, return_switch); + if (!call_info_.recursive) { + AnfAlgo::SetNodeAttr(kAttrMultiCallEnd, prim::kValueOne, return_switch); + } kernel_graph_->set_end_goto(return_switch); + context_.SetSubGraphMultiCall(true); } // Assign graph output to the output parameter. @@ -1650,6 +1676,8 @@ void AscendAutoMonad::Run() { CallInfoFinder::Run(&context); AscendAutoMonadConverter::Run(&context); kernel_graph_->set_label_num(context.CurrentLabel() + 1); + kernel_graph_->set_recursive_call(context.HasRecursiveCall()); + kernel_graph_->set_subgraph_multi_call(context.HasSubgraphMultiCall()); MS_LOG(DEBUG) << "Ascend auto-monad finish."; DumpGraphForDebug(kernel_graph_); } diff --git a/mindspore/ccsrc/backend/session/ascend_session.cc b/mindspore/ccsrc/backend/session/ascend_session.cc index 31201780e6..ebcc2f321d 100644 --- a/mindspore/ccsrc/backend/session/ascend_session.cc +++ b/mindspore/ccsrc/backend/session/ascend_session.cc @@ -1034,9 +1034,139 @@ void AscendSession::BuildDynamicKernel(const std::shared_ptr &kerne MS_LOG(INFO) << "Finish!"; } +static CNodePtr GetNextLabelSet(const std::vector &kernel_nodes, uint32_t index) { + uint32_t node_sizes = kernel_nodes.size(); + if (index >= node_sizes - 1) { + MS_LOG(EXCEPTION) << "there is no node after this node:" << kernel_nodes[index]->DebugString(); + } + auto kernel = kernel_nodes[index + 1]; + if (AnfAlgo::GetCNodeName(kernel) != kLabelSetOpName) { + MS_LOG(EXCEPTION) << "the node is not labelset follow labelgoto/labelswitch, node: " + << kernel_nodes[index]->DebugString(); + } + return kernel; +} + +static std::vector HandleRecursiveCall(const std::vector &kernel_cnodes, const uint32_t &back_label, + uint32_t *index, std::vector *back) { + MS_EXCEPTION_IF_NULL(index); + MS_EXCEPTION_IF_NULL(back); + std::vector front; + std::vector back_temp; + bool back_flag = false; + for (uint32_t i = *index; i < kernel_cnodes.size(); i++) { + if (!back_flag) { + front.emplace_back(kernel_cnodes[i]); + } else { + back->emplace_back(kernel_cnodes[i]); + } + if (AnfAlgo::HasNodeAttr(kAttrRecursiveEnd, kernel_cnodes[i])) { + *index = i; + back->insert(back->end(), back_temp.begin(), back_temp.end()); + return front; + } + if (AnfAlgo::HasNodeAttr(kAttrRecursive, kernel_cnodes[i])) { + back_flag = true; + if (AnfAlgo::IsLabelIndexInNode(kernel_cnodes[i], back_label)) { + continue; + } else { + auto temp = HandleRecursiveCall(kernel_cnodes, back_label, &(++i), &back_temp); + front.insert(front.end(), temp.begin(), temp.end()); + continue; + } + } + } + return front; +} + +static void UnfoldRecursiveExecOrder(KernelGraph *kernel_graph) { + MS_EXCEPTION_IF_NULL(kernel_graph); + if (!kernel_graph->recursive_call()) { + return; + } + auto kernel_cnodes = kernel_graph->mem_reuse_exec_order(); + std::vector mem_reuse_order; + mem_reuse_order.reserve(kernel_cnodes.size()); + for (uint32_t i = 0; i < kernel_cnodes.size(); i++) { + if (!AnfAlgo::HasNodeAttr(kAttrRecursiveStart, kernel_cnodes[i])) { + mem_reuse_order.emplace_back(kernel_cnodes[i]); + continue; + } + auto label_id = AnfAlgo::GetNodeAttr(kernel_cnodes[i], kAttrLabelIndex); + std::vector back; + auto front = HandleRecursiveCall(kernel_cnodes, label_id, &i, &back); + mem_reuse_order.insert(mem_reuse_order.end(), front.begin(), front.end()); + mem_reuse_order.insert(mem_reuse_order.end(), back.begin(), back.end()); + } + kernel_graph->set_mem_reuse_exec_order(mem_reuse_order); +} + +static void GetSubGraphExecOrder(const KernelGraph *kernel_graph, uint32_t index, const CNodePtr &back_node, + std::vector *mem_reuse_order) { + MS_EXCEPTION_IF_NULL(kernel_graph); + MS_EXCEPTION_IF_NULL(mem_reuse_order); + auto label_id = AnfAlgo::GetNodeAttr(back_node, kAttrLabelIndex); + auto kernel_cnodes = kernel_graph->execution_order(); + for (auto i = index; i < kernel_cnodes.size(); i++) { + mem_reuse_order->emplace_back(kernel_cnodes[i]); + if (AnfAlgo::IsLabelIndexInNode(kernel_cnodes[i], label_id)) { + return; + } + } +} + +void InitMemReuseExecOrder(KernelGraph *kernel_graph) { + MS_EXCEPTION_IF_NULL(kernel_graph); + if (!kernel_graph->subgraph_multi_call()) { + return; + } + std::unordered_map label_id_index_map; + auto kernel_cnodes = kernel_graph->execution_order(); + std::vector mem_reuse_order; + for (size_t i = 0; i < kernel_cnodes.size(); i++) { + mem_reuse_order.emplace_back(kernel_cnodes[i]); + if (AnfAlgo::CheckPrimitiveType(kernel_cnodes[i], prim::kPrimLabelSwitch) && + !AnfAlgo::HasNodeAttr(kAttrRecursive, kernel_cnodes[i]) && + !AnfAlgo::HasNodeAttr(kAttrReturn, kernel_cnodes[i])) { + auto label_list = AnfAlgo::GetNodeAttr>(kernel_cnodes[i], kAttrLabelSwitchList); + for (auto label_id : label_list) { + if (label_id_index_map.find(label_id) == label_id_index_map.end()) { + continue; + } + auto back_node = GetNextLabelSet(kernel_cnodes, i); + GetSubGraphExecOrder(kernel_graph, label_id_index_map[label_id], back_node, &mem_reuse_order); + } + continue; + } + if (AnfAlgo::CheckPrimitiveType(kernel_cnodes[i], prim::kPrimLabelGoto) && + !AnfAlgo::HasNodeAttr(kAttrRecursive, kernel_cnodes[i]) && + !AnfAlgo::HasNodeAttr(kAttrReturn, kernel_cnodes[i])) { + auto label_id = AnfAlgo::GetNodeAttr(kernel_cnodes[i], kAttrLabelIndex); + if (label_id_index_map.find(label_id) == label_id_index_map.end()) { + continue; + } + auto back_node = GetNextLabelSet(kernel_cnodes, i); + GetSubGraphExecOrder(kernel_graph, label_id_index_map[label_id], back_node, &mem_reuse_order); + continue; + } + if (AnfAlgo::CheckPrimitiveType(kernel_cnodes[i], prim::kPrimLabelSet) && + !AnfAlgo::HasNodeAttr(kAttrRecursive, kernel_cnodes[i])) { + auto label_id = AnfAlgo::GetNodeAttr(kernel_cnodes[i], kAttrLabelIndex); + if (label_id_index_map.find(label_id) != label_id_index_map.end()) { + MS_LOG(EXCEPTION) << "Two labelsets with same label id."; + } + label_id_index_map[label_id] = i; + continue; + } + } + kernel_graph->set_mem_reuse_exec_order(mem_reuse_order); + UnfoldRecursiveExecOrder(kernel_graph); +} + void AscendSession::MemoryAlloc(KernelGraph *kernel_graph) const { MS_LOG(INFO) << "Start!"; MS_EXCEPTION_IF_NULL(kernel_graph); + InitMemReuseExecOrder(kernel_graph); auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id_); MS_EXCEPTION_IF_NULL(runtime_instance); runtime_instance->AssignMemory(kernel_graph); diff --git a/mindspore/ccsrc/backend/session/kernel_graph.h b/mindspore/ccsrc/backend/session/kernel_graph.h index e81698f74f..b6766cd284 100644 --- a/mindspore/ccsrc/backend/session/kernel_graph.h +++ b/mindspore/ccsrc/backend/session/kernel_graph.h @@ -41,6 +41,7 @@ class KernelGraph : public FuncGraph { KernelGraph() : graph_id_(0), start_label_(nullptr), end_goto_(nullptr), current_epoch_(0), is_dynamic_shape_(false) { inputs_ = std::make_shared>(); execution_order_ = {}; + mem_reuse_exec_order_ = {}; executable_ = true; summary_node_exist_ = false; stream_distinction_label_ = kInvalidDistincLabel; @@ -51,6 +52,7 @@ class KernelGraph : public FuncGraph { inputs_ = graph.inputs_; child_graph_result_ = graph.child_graph_result_; execution_order_ = graph.execution_order_; + mem_reuse_exec_order_ = graph.mem_reuse_exec_order_; graph_id_ = graph.graph_id_; stream_distinction_label_ = graph.stream_distinction_label_; front_backend_anf_map_ = graph.front_backend_anf_map_; @@ -112,6 +114,9 @@ class KernelGraph : public FuncGraph { void set_execution_order(const std::vector &order) { execution_order_ = order; } void set_execution_order(std::vector &&order) { execution_order_ = std::move(order); } const std::vector &execution_order() const { return execution_order_; } + // Set new exec_order for mem_reuse + void set_mem_reuse_exec_order(const std::vector &order) { mem_reuse_exec_order_ = order; } + const std::vector &mem_reuse_exec_order() const { return mem_reuse_exec_order_; } void SetExecOrderByDefault(); uint32_t graph_id() const { return graph_id_; } void set_graph_id(uint32_t graph_id) { graph_id_ = graph_id; } @@ -278,6 +283,14 @@ class KernelGraph : public FuncGraph { uint32_t label_num() const { return label_num_; } void set_label_num(uint32_t num) { label_num_ = num; } + // The graphs has recursion. + bool recursive_call() const { return has_recursive_call_; } + // The graphs has subgraph multi-call. + bool subgraph_multi_call() const { return has_subgraph_multicall_; } + // set flag to indicate whether has recursion. + void set_recursive_call(bool flag) { has_recursive_call_ = flag; } + // set flag to indicate whether has multi-call. + void set_subgraph_multi_call(bool flag) { has_subgraph_multicall_ = flag; } private: // remove value node form graph @@ -307,6 +320,7 @@ class KernelGraph : public FuncGraph { std::shared_ptr> inputs_; std::vector child_graph_result_; std::vector execution_order_; + std::vector mem_reuse_exec_order_; // extra params and tensors for control flow std::vector> extra_param_tensor_; uint32_t graph_id_; @@ -360,6 +374,10 @@ class KernelGraph : public FuncGraph { bool has_optimizer_{false}; bool is_dynamic_shape_{false}; + // Indicate the graphs has recursion or multi-call or not as the root graph. + bool has_recursive_call_{false}; + bool has_subgraph_multicall_{false}; + // Number of labels. This is also the 'batch_num' for DavinciModel, // It should be 1 if no labels used for control flow. uint32_t label_num_ = 1; diff --git a/mindspore/ccsrc/utils/utils.h b/mindspore/ccsrc/utils/utils.h index 5ff9341482..4552c3233b 100644 --- a/mindspore/ccsrc/utils/utils.h +++ b/mindspore/ccsrc/utils/utils.h @@ -421,6 +421,10 @@ constexpr auto kAttrTopoSortRhsFirst = "topo_sort_rhs_first"; constexpr auto kAttrIgnoreSideEffect = "ignore_side_effect"; constexpr auto kAttrSwitchLayer = "switch_layer"; constexpr auto kAttrReturn = "return"; +constexpr auto kAttrRecursiveStart = "recursive_start"; +constexpr auto kAttrRecursiveEnd = "recursive_end"; +constexpr auto kAttrRecursive = "recursive"; +constexpr auto kAttrMultiCallEnd = "multicall_end"; // attr value constexpr auto kValueTargetSwitch = "target_switch";