diff --git a/mindspore/ccsrc/device/ascend/ascend_label_assign.cc b/mindspore/ccsrc/device/ascend/ascend_label_assign.cc index 2973e5529c..9908b5d03d 100644 --- a/mindspore/ccsrc/device/ascend/ascend_label_assign.cc +++ b/mindspore/ccsrc/device/ascend/ascend_label_assign.cc @@ -26,7 +26,6 @@ static constexpr uint32_t kLabelSwitchLabelId = 2; namespace mindspore { namespace device { namespace ascend { - static void UpdateLabelGoto(NotNull node) { if (AnfAlgo::HasNodeAttr(kAttrLabelIndex, node)) { return; @@ -164,7 +163,6 @@ uint32_t AscendLabelAssign::GetLabelNum(NotNull gr uint32_t AscendLabelAssign::GetLabelNum(NotNull> graph) { return GetLabelNum(NOT_NULL(graph.get().get())); } - } // namespace ascend } // namespace device } // namespace mindspore diff --git a/mindspore/ccsrc/device/ascend/ascend_label_assign.h b/mindspore/ccsrc/device/ascend/ascend_label_assign.h index 743976fba1..98055576eb 100644 --- a/mindspore/ccsrc/device/ascend/ascend_label_assign.h +++ b/mindspore/ccsrc/device/ascend/ascend_label_assign.h @@ -25,7 +25,6 @@ namespace mindspore { namespace device { namespace ascend { - class AscendLabelAssign { public: static AscendLabelAssign &GetInstance() { diff --git a/mindspore/ccsrc/session/anf_runtime_algorithm.cc b/mindspore/ccsrc/session/anf_runtime_algorithm.cc index b0d9a96d47..ed20366acd 100644 --- a/mindspore/ccsrc/session/anf_runtime_algorithm.cc +++ b/mindspore/ccsrc/session/anf_runtime_algorithm.cc @@ -974,17 +974,5 @@ bool AnfRuntimeAlgorithm::IsSwitchCall(const CNodePtr &call_node) { } MS_LOG(EXCEPTION) << "Unexpected input1 of call node,input1:" << input1->DebugString(); } - -bool AnfRuntimeAlgorithm::IsWhileTrueGraph(const KernelGraphPtr &child_graph) { - auto call_nodes = child_graph->FindNodeByPrimitive(prim::kPrimCall); - for (const auto &call_node : call_nodes) { - auto graphs = GetCallNodeKernelGraph(call_node); - if (graphs.size() == 1 && graphs[0] == child_graph->parent_graph()) { - return true; - } - } - return false; -} - } // namespace session } // namespace mindspore diff --git a/mindspore/ccsrc/session/anf_runtime_algorithm.h b/mindspore/ccsrc/session/anf_runtime_algorithm.h index 950a897654..10ae5282e0 100644 --- a/mindspore/ccsrc/session/anf_runtime_algorithm.h +++ b/mindspore/ccsrc/session/anf_runtime_algorithm.h @@ -185,7 +185,6 @@ class AnfRuntimeAlgorithm { static FuncGraphPtr GetValueNodeFuncGraph(const AnfNodePtr &node); static std::vector GetCallNodeKernelGraph(const CNodePtr &call_node); static bool IsSwitchCall(const CNodePtr &call_node); - static bool IsWhileTrueGraph(const KernelGraphPtr &child_graph); }; } // namespace session using AnfAlgo = session::AnfRuntimeAlgorithm; diff --git a/mindspore/ccsrc/session/ascend_control_parser.cc b/mindspore/ccsrc/session/ascend_control_parser.cc index 4dad370a34..2853caa732 100644 --- a/mindspore/ccsrc/session/ascend_control_parser.cc +++ b/mindspore/ccsrc/session/ascend_control_parser.cc @@ -83,7 +83,7 @@ CNodePtr AscendControlParser::GetNextRealKernel(const std::vector &lis NotNull AscendControlParser::ProcessKernelGraph(NotNull kg, const CNodePtr &last_node, const CNodePtr &last_label, - NotNull *> memo) { + const NotNull *> memo) { MS_LOG(INFO) << "Start process KernelGraph " << kg->ToString(); // 1. recursive condition @@ -180,7 +180,7 @@ void AscendControlParser::LinkParentGraph(NotNull kg, const CNod } void AscendControlParser::RecurseCall(NotNull kg, NotNull cur_node, const CNodePtr &next_node, - NotNull *> memo) { + const NotNull *> memo) { MS_LOG(INFO) << "process call func " << cur_node->DebugString(); // 1 get kernel graph @@ -212,7 +212,7 @@ void AscendControlParser::RecurseCall(NotNull kg, NotNull kg, NotNull cur_node, - const CNodePtr &next_node, NotNull *> memo) { + const CNodePtr &next_node, const NotNull *> memo) { MS_LOG(INFO) << "process switch node " << cur_node->DebugString(); if (cur_node->size() < kCNodeSwitchLength) { @@ -249,7 +249,8 @@ void AscendControlParser::RecurseSwitch(NotNull kg, NotNull kg, NotNull cur_node, - const CNodePtr &next_node, NotNull *> memo) { + const CNodePtr &next_node, + const NotNull *> memo) { MS_LOG(INFO) << "process switch node " << cur_node->DebugString(); if (cur_node->size() < kCNodeSwitchLayerLength) { @@ -353,7 +354,7 @@ void AscendControlParser::ExecutorValidate(NotNull root_graph) { } std::vector AscendControlParser::RecurseGraph(NotNull graph, - NotNull *> memo) { + const NotNull *> memo) { MS_LOG(INFO) << "graph:" << graph->graph_id() << " start"; auto print_vector = [&](std::vector vec) -> void { MS_LOG(INFO) << "graph:" << graph->graph_id() << "execution order"; diff --git a/mindspore/ccsrc/session/ascend_control_parser.h b/mindspore/ccsrc/session/ascend_control_parser.h index 50f431bd32..bb1aee76af 100644 --- a/mindspore/ccsrc/session/ascend_control_parser.h +++ b/mindspore/ccsrc/session/ascend_control_parser.h @@ -40,13 +40,14 @@ class AscendControlParser { private: static NotNull ProcessKernelGraph(NotNull kg, const CNodePtr &last_node, - const CNodePtr &last_label, NotNull *> memo); + const CNodePtr &last_label, + const NotNull *> memo); static void RecurseCall(NotNull kg, NotNull cur_node, const CNodePtr &next_node, - NotNull *> memo); + const NotNull *> memo); static void RecurseSwitch(NotNull kg, NotNull cur_node, const CNodePtr &next_node, - NotNull *> memo); + const NotNull *> memo); static void RecurseSwitchLayer(NotNull kg, NotNull cur_node, const CNodePtr &next_node, - NotNull *> memo); + const NotNull *> memo); static void LinkParentGraph(NotNull kg, const CNodePtr &from_graph_call_node, const CNodePtr &last_label); @@ -63,7 +64,8 @@ class AscendControlParser { static std::vector GetLabelSwitchList(const CNodePtr &node); static bool CheckLabelIndex(uint32_t order_index, uint32_t label_index, const CNodePtr &cnode, NotNull graph); - static std::vector RecurseGraph(NotNull graph, NotNull *> memo); + static std::vector RecurseGraph(NotNull graph, + const NotNull *> memo); }; } // namespace session } // namespace mindspore diff --git a/mindspore/ccsrc/session/ascend_session.cc b/mindspore/ccsrc/session/ascend_session.cc index 0da6fc067e..bfe4a670fa 100644 --- a/mindspore/ccsrc/session/ascend_session.cc +++ b/mindspore/ccsrc/session/ascend_session.cc @@ -171,19 +171,35 @@ std::vector GetCNodes(const std::vector &anf_nodes) { return cnodes; } -std::vector> GetChildList(const KernelGraph &cur_graph, const std::vector &cnodes) { - size_t after_call_index = 0; +static std::vector> GetChildList(const std::vector &cnodes, + const std::set &cut_prims) { + size_t after_cut_index = 0; std::vector> ret; - for (size_t i = 0; i < cnodes.size(); i++) { - if (AnfAlgo::CheckPrimitiveType(cnodes[i], prim::kPrimCall) && !AnfAlgo::IsSwitchCall(cnodes[i])) { - auto call_kernel_graph = AnfAlgo::GetCallNodeKernelGraph(cnodes[i]); - auto prev_call_list = std::vector(cnodes.begin() + after_call_index, cnodes.begin() + i); - auto call_list = std::vector(1, cnodes[i]); - after_call_index = i + 1; - ret.push_back(prev_call_list); - ret.push_back(call_list); - } else if (AnfAlgo::CheckPrimitiveType(cnodes[i], prim::kPrimReturn)) { - ret.push_back(std::vector(cnodes.begin() + after_call_index, cnodes.end())); + for (size_t i = 0; i < cnodes.size(); ++i) { + bool is_cut_node = false; + for (auto &prim : cut_prims) { + if (AnfAlgo::CheckPrimitiveType(cnodes[i], prim)) { + is_cut_node = true; + break; + } + } + if (is_cut_node) { + // is call and not switch call,cut to 3 lists + if (!AnfAlgo::CheckPrimitiveType(cnodes[i], prim::kPrimCall)) { + // if is not a call,cut to 2 lists + ret.emplace_back(cnodes.begin() + after_cut_index, cnodes.begin() + i); + after_cut_index = i; + } else if (!AnfAlgo::IsSwitchCall(cnodes[i])) { + ret.emplace_back(cnodes.begin() + after_cut_index, cnodes.begin() + i); + ret.emplace_back(1, cnodes[i]); + after_cut_index = i + 1; + continue; + } + } + // get last child graph list + if (AnfAlgo::CheckPrimitiveType(cnodes[i], prim::kPrimReturn)) { + ret.emplace_back(cnodes.begin() + after_cut_index, cnodes.end()); + continue; } } return ret; @@ -191,7 +207,7 @@ std::vector> GetChildList(const KernelGraph &cur_graph, co // if a call has kernel input, it's a child graph split from ME, so these kernel input should be set into real input of // graph.For example, call input = (prim,graph,kernel1,kernel2),then real_input = [kernel1,kernel2] -static void UpdateRealInput(KernelGraph *graph) { +static void UpdateRealInput(NotNull graph) { auto call_nodes = graph->FindNodeByPrimitive(prim::kPrimCall); auto bind_call_arg_with_parameter = [&](const std::vector ¶meters, const std::vector &args, KernelGraph *child_graph) -> void { @@ -253,16 +269,17 @@ static void UpdateRealInput(KernelGraph *graph) { } } -void RecurseToUpdateCallRealInput(KernelGraph *graph) { - MS_EXCEPTION_IF_NULL(graph); +static void RecurseToUpdateCallRealInput(NotNull graph, + const NotNull *> memo) { + memo->insert(graph.get()); MS_LOG(INFO) << "start graph id:" << graph->graph_id(); for (auto &child_graph : graph->child_graph_order()) { - if (child_graph == graph->parent_graph()) { + if (memo->find(child_graph) != memo->end()) { MS_LOG(INFO) << "Child graph:" << child_graph->graph_id() << ",parent graph:" << graph->parent_graph()->graph_id(); continue; } - RecurseToUpdateCallRealInput(child_graph.get()); + RecurseToUpdateCallRealInput(NOT_NULL(child_graph), memo); } // this action should from bottom to top graph->UpdateCallRealInput(); @@ -282,7 +299,7 @@ GraphId AscendSession::CompileGraph(NotNull func_graph) { MS_LOG(INFO) << "start"; auto graph = ConstructKernelGraph(func_graph); // split switch - SplitGraphs(graph); + SplitGraphs(NOT_NULL(graph)); // insert goto labels and label_sets LinkChildGraphs(NOT_NULL(graph)); // resource initialize @@ -290,7 +307,8 @@ GraphId AscendSession::CompileGraph(NotNull func_graph) { // assign label AssignLabel(NOT_NULL(graph)); // recurse compile child graph - RecurseCompileGraph(graph); + std::set memo; + RecurseCompileGraph(NOT_NULL(graph), NOT_NULL(&memo)); // root graph valiate,include genearte execute order and so on RootGraphExecutorValidate(NOT_NULL(graph)); // adjust kernel @@ -1423,24 +1441,43 @@ std::vector AscendSession::ConstructSplitedGraph(const KernelGraphPt } MS_LOG(INFO) << "Construct input of kernel graph:" << new_kernel_graph->graph_id(); std::vector call_node_inputs; - auto graph_inputs = new_kernel_graph->MutableInputs(); - MS_EXCEPTION_IF_NULL(graph_inputs); + std::vector new_graph_inputs; // create new parameter from cnode for (auto &anf_node : list) { auto cnode = anf_node->cast(); for (size_t input_idx = 1; input_idx < cnode->inputs().size(); input_idx++) { auto input = cnode->inputs()[input_idx]; MS_EXCEPTION_IF_NULL(input); - if (input->isa()) { - graph_inputs->push_back(input); + AnfNodePtr new_parameter = nullptr; + // value node consider move to new graph + if (input->isa()) { + cnode->set_input(input_idx, input); + continue; + } else if (input->isa()) { + // parameter reuse and should attention mulptiple use of one parameter cnode->set_input(input_idx, input); + new_parameter = input; } else if (AnfAlgo::GetGraphId(input.get()) != new_kernel_graph->graph_id()) { - auto new_parameter = CreateNewParameterFromCNode(input, true, new_kernel_graph.get()); + // if is cnode and not in current child graph + new_parameter = CreateNewParameterFromCNode(input, true, new_kernel_graph.get()); cnode->set_input(input_idx, new_parameter); + } else { + // if is a cnode and in current graph + continue; + } + // if mulptiple use of one parameter or cnode, only set one parameter in graph inputs and one arg in call node + // args + if (std::find(call_node_inputs.begin(), call_node_inputs.end(), new_parameter) == call_node_inputs.end()) { + new_graph_inputs.push_back(new_parameter); + call_node_inputs.push_back(input); } - call_node_inputs.push_back(input); } } + // set graph inputs of new graph + auto graph_inputs = new_kernel_graph->MutableInputs(); + MS_EXCEPTION_IF_NULL(graph_inputs); + graph_inputs->clear(); + std::copy(new_graph_inputs.begin(), new_graph_inputs.end(), std::back_inserter(*graph_inputs)); MS_LOG(INFO) << "Construct output of kernel graph:" << new_kernel_graph->graph_id(); auto make_tuple_primitve = NewValueNode(std::make_shared(prim::kPrimMakeTuple->name())); std::vector make_tuple_inputs = {make_tuple_primitve}; @@ -1461,20 +1498,30 @@ std::vector AscendSession::ConstructSplitedGraph(const KernelGraphPt return call_node_inputs; } -void AscendSession::SplitGraphs(const KernelGraphPtr &root_graph) { - SplitGraph(root_graph); +void AscendSession::SplitGraphs(NotNull root_graph) { + std::set memo; + // if root graph output is a call node ,the root graph is condition graph of 'if' sentence + auto root_graph_output = AnfAlgo::VisitKernelWithReturnType(root_graph->output(), 0).first; + if (AnfAlgo::CheckPrimitiveType(root_graph_output, prim::kPrimCall)) { + SplitGraph(root_graph, {prim::kPrimReturn}); + for (auto &child_graph : root_graph->child_graph_order()) { + RecurseSplitGraph(NOT_NULL(child_graph), NOT_NULL(&memo)); + } + } else { + RecurseSplitGraph(root_graph, NOT_NULL(&memo)); + } + memo.clear(); // replace the real input if the real input is a call - RecurseToUpdateCallRealInput(root_graph.get()); + RecurseToUpdateCallRealInput(root_graph, NOT_NULL(&memo)); } -void AscendSession::SplitGraph(const KernelGraphPtr &graph) { +void AscendSession::SplitGraph(NotNull graph, const std::set &cut_prims) { MS_LOG(INFO) << "start,graph_id:" << graph->graph_id(); - MS_EXCEPTION_IF_NULL(graph); auto apply_list = GetCNodes(TopoSort(graph->get_return())); // update the root graph child graph order - AscendControlParser::UpdateChildGraphOrder(NOT_NULL(graph)); + AscendControlParser::UpdateChildGraphOrder(graph); // get child list from current graph - std::vector> child_graph_lists = GetChildList(*graph, apply_list); + std::vector> child_graph_lists = GetChildList(apply_list, cut_prims); auto bind_new_call_to_new_graph = [&](std::vector child_graph_list) -> AnfNodePtr { // if child graph list only has a call ,then return the exist call if (child_graph_list.size() == 1 && AnfAlgo::CheckPrimitiveType(child_graph_list[0], prim::kPrimCall)) { @@ -1521,20 +1568,22 @@ void AscendSession::SplitGraph(const KernelGraphPtr &graph) { pre_call_node = cur_call_node; cur_call_node = *iter; if (pre_call_node != nullptr && cur_call_node != nullptr) { - AscendControlParser::InsertControlDependToGraph(NOT_NULL(graph), NOT_NULL(cur_call_node), - NOT_NULL(pre_call_node)); + AscendControlParser::InsertControlDependToGraph(graph, NOT_NULL(cur_call_node), NOT_NULL(pre_call_node)); } } } - AscendControlParser::UpdateChildGraphOrder(NOT_NULL(graph)); - UpdateRealInput(graph.get()); - auto graph_name = std::string("./kernel-graph-").append(std::to_string(graph->graph_id())); - DumpIR(graph_name, graph); + AscendControlParser::UpdateChildGraphOrder(graph); + UpdateRealInput(graph); MS_LOG(INFO) << "split graph[" << graph->graph_id() << "] end"; // recurse to split child graph +} + +void AscendSession::RecurseSplitGraph(NotNull graph, const NotNull *> memo) { + memo->insert(graph.get()); + SplitGraph(graph, {prim::kPrimCall}); for (auto &child_graph : graph->child_graph_order()) { - if (child_graph != graph->parent_graph()) { - SplitGraph(child_graph); + if (memo->find(child_graph) == memo->end()) { + RecurseSplitGraph(NOT_NULL(child_graph), memo); } } } @@ -1545,13 +1594,14 @@ void AscendSession::RootGraphExecutorValidate(NotNull graph) { AscendControlParser::ExecutorValidate(graph); } -void AscendSession::RecurseCompileGraph(const KernelGraphPtr &graph) { +void AscendSession::RecurseCompileGraph(NotNull graph, const NotNull *> memo) { + memo->insert(graph.get()); CompileChildGraph(graph); for (auto child_graph : graph->child_graph_order()) { - if (child_graph == graph->parent_graph()) { + if (memo->find(child_graph) != memo->end()) { continue; } - RecurseCompileGraph(child_graph); + RecurseCompileGraph(NOT_NULL(child_graph), memo); } } diff --git a/mindspore/ccsrc/session/ascend_session.h b/mindspore/ccsrc/session/ascend_session.h index 773af22c65..529304714c 100755 --- a/mindspore/ccsrc/session/ascend_session.h +++ b/mindspore/ccsrc/session/ascend_session.h @@ -98,18 +98,15 @@ class AscendSession : public SessionBasic { void SetFinalGraphOutput(const ValuePtr &value); void SetFinalGraphOutput(const VectorRef &vec_output); - void SplitGraph(const KernelGraphPtr &graph); + void SplitGraph(NotNull graph, const std::set &cut_prims); // split graphs with recurse from root graph - void SplitGraphs(const KernelGraphPtr &root_graph); + void SplitGraphs(NotNull root_graph); void LinkChildGraphs(NotNull graph); - void IRFusion(const KernelGraphPtr &graph) {} - void SelectKernelGraphKernel(const KernelGraph &graph) {} - void ConvertPredictModel(const KernelGraphPtr graph) {} - void HardwareOptimizeGraphs(const KernelGraphPtr graph) {} void RootGraphExecutorValidate(NotNull graph); std::vector ConstructSplitedGraph(const KernelGraphPtr &new_kernel_graph, const std::vector &list); - void RecurseCompileGraph(const KernelGraphPtr &graph); + void RecurseCompileGraph(NotNull graph, const NotNull *> memo); + void RecurseSplitGraph(NotNull graph, const NotNull *> memo); // merge execution order list of child graphs void MergeGraphExecOrder(); diff --git a/mindspore/ccsrc/session/kernel_graph.cc b/mindspore/ccsrc/session/kernel_graph.cc index 9577fa5710..c1992b7cc0 100644 --- a/mindspore/ccsrc/session/kernel_graph.cc +++ b/mindspore/ccsrc/session/kernel_graph.cc @@ -50,7 +50,7 @@ std::vector GetCallRealOutputs(const AnfNodePtr &call_node) { std::vector real_inputs; auto child_graphs = AnfAlgo::GetCallNodeKernelGraph(item_with_index.first->cast()); for (const auto &child_graph : child_graphs) { - if (AnfAlgo::IsWhileTrueGraph(child_graph)) { + if (child_graph->get_output_null()) { continue; } auto real_input = child_graph->output(); @@ -592,7 +592,11 @@ void KernelGraph::ReplaceNode(const AnfNodePtr &old_anf_node, AnfNodePtr new_anf MS_EXCEPTION_IF_NULL(output_node.first); auto output_cnode = output_node.first->cast(); MS_EXCEPTION_IF_NULL(output_cnode); - const auto &output_node_inputs = output_cnode->inputs(); + auto &output_node_inputs = output_cnode->inputs(); + // don't replace node if it is a control edge => output_node.second == 0 + if (output_node.second == 0) { + continue; + } for (size_t i = 1; i < output_node_inputs.size(); i++) { if (output_node_inputs[i] == old_anf_node) { output_cnode->set_input(i, new_anf_node); @@ -686,10 +690,12 @@ std::set KernelGraph::GetRealInput(const AnfNodePtr ¶meter) { void KernelGraph::UpdateCallRealInput() { MS_LOG(INFO) << "Update graph id: " << graph_id_; + std::map> real_inputs_map; + std::vector> replace_list; for (auto &it : real_inputs_) { - auto ¶meter = it.first; + auto parameter = it.first; MS_EXCEPTION_IF_NULL(parameter); - auto &real_inputs = it.second; + auto real_inputs = it.second; std::vector new_real_inputs; std::set erase_real_inputs; for (auto &real_input : real_inputs) { @@ -711,10 +717,16 @@ void KernelGraph::UpdateCallRealInput() { << " insert real input:" << new_real_input->DebugString(); (void)real_inputs.insert(new_real_input); if (new_real_input->isa()) { - ReplaceNode(parameter, new_real_input); + replace_list.emplace_back(parameter, new_real_input); + parameter = new_real_input; } } + real_inputs_map[parameter] = real_inputs; + } + for (auto [parameter, arg] : replace_list) { + ReplaceNode(parameter, arg); } + real_inputs_ = real_inputs_map; } std::string KernelGraph::ToString() const { return std::string("kernel_graph_").append(std::to_string(graph_id_)); } diff --git a/mindspore/ccsrc/session/kernel_graph.h b/mindspore/ccsrc/session/kernel_graph.h index 02881eb162..98a007d1a1 100644 --- a/mindspore/ccsrc/session/kernel_graph.h +++ b/mindspore/ccsrc/session/kernel_graph.h @@ -36,7 +36,7 @@ namespace session { using AnfWithOutIndex = std::pair; class KernelGraph : public FuncGraph { public: - KernelGraph() : graph_id_(0) { + KernelGraph() : graph_id_(0), start_label_(nullptr), end_goto_(nullptr), null_output_(false) { inputs_ = std::make_shared>(); execution_order_ = {}; executable_ = true; @@ -134,6 +134,8 @@ class KernelGraph : public FuncGraph { CNodePtr get_start_label() { return start_label_; } void set_end_goto(const CNodePtr &end_goto) { end_goto_ = end_goto; } CNodePtr get_end_goto() { return end_goto_; } + bool get_output_null() { return null_output_; } + void set_output_null(bool is_output_null) { null_output_ = is_output_null; } private: // remove value node form graph @@ -188,6 +190,7 @@ class KernelGraph : public FuncGraph { CNodePtr start_label_; CNodePtr end_goto_; + bool null_output_; }; } // namespace session using KernelGraphPtr = std::shared_ptr; diff --git a/mindspore/ccsrc/session/session_basic.cc b/mindspore/ccsrc/session/session_basic.cc index a3f1ccbbb3..b47096a670 100644 --- a/mindspore/ccsrc/session/session_basic.cc +++ b/mindspore/ccsrc/session/session_basic.cc @@ -543,15 +543,12 @@ KernelGraphPtr SessionBasic::ConstructKernelGraph(const AnfNodePtrList &lst, con std::shared_ptr SessionBasic::ConstructKernelGraph(const FuncGraphPtr &func_graph) { MS_EXCEPTION_IF_NULL(func_graph); - if (front_backend_graph_map_.find(func_graph) != front_backend_graph_map_.end()) { - MS_LOG(INFO) << "FuncGraph: " << func_graph->ToString() << " has been transformed to KernelGraph."; - return front_backend_graph_map_[func_graph]; - } auto node_list = TopoSort(func_graph->get_return()); auto graph = NewKernelGraph(); front_backend_graph_map_[func_graph] = graph; MS_LOG(INFO) << "Create graph: " << graph->graph_id(); + bool is_trace_back = false; for (const auto &node : node_list) { MS_EXCEPTION_IF_NULL(node); MS_LOG(DEBUG) << "Start create new cnode, node = " << node->DebugString(); @@ -564,8 +561,14 @@ std::shared_ptr SessionBasic::ConstructKernelGraph(const FuncGraphP (void)CreateNewValueNode(node, graph.get()); } else { // if input is a ValueNode - auto child_graph = ConstructKernelGraph(AnfAlgo::GetValueNodeFuncGraph(node)); - auto new_value_node = CreateValueNodeKernelGraph(node, graph.get()); + FuncGraphPtr child_graph = AnfAlgo::GetValueNodeFuncGraph(node); + if (front_backend_graph_map_.find(child_graph) != front_backend_graph_map_.end()) { + MS_LOG(INFO) << "FuncGraph: " << child_graph->ToString() << " has been transformed to KernelGraph."; + is_trace_back = true; + } else { + (void)ConstructKernelGraph(child_graph); + } + (void)CreateValueNodeKernelGraph(node, graph.get()); } continue; } else { @@ -582,6 +585,8 @@ std::shared_ptr SessionBasic::ConstructKernelGraph(const FuncGraphP } } } + // if a graph jump back unconditionally, return op of this graph will never be executed, so output is null. + graph->set_output_null(is_trace_back); auto graph_inputs = graph->MutableInputs(); MS_EXCEPTION_IF_NULL(graph_inputs); graph_inputs->clear();