Signed-off-by: zhoufeng <zhoufeng54@huawei.com>tags/v0.5.0-beta
| @@ -26,7 +26,6 @@ static constexpr uint32_t kLabelSwitchLabelId = 2; | |||
| namespace mindspore { | |||
| namespace device { | |||
| namespace ascend { | |||
| static void UpdateLabelGoto(NotNull<CNodePtr> node) { | |||
| if (AnfAlgo::HasNodeAttr(kAttrLabelIndex, node)) { | |||
| return; | |||
| @@ -164,7 +163,6 @@ uint32_t AscendLabelAssign::GetLabelNum(NotNull<const session::KernelGraph *> gr | |||
| uint32_t AscendLabelAssign::GetLabelNum(NotNull<std::shared_ptr<session::KernelGraph>> graph) { | |||
| return GetLabelNum(NOT_NULL(graph.get().get())); | |||
| } | |||
| } // namespace ascend | |||
| } // namespace device | |||
| } // namespace mindspore | |||
| @@ -25,7 +25,6 @@ | |||
| namespace mindspore { | |||
| namespace device { | |||
| namespace ascend { | |||
| class AscendLabelAssign { | |||
| public: | |||
| static AscendLabelAssign &GetInstance() { | |||
| @@ -974,17 +974,5 @@ bool AnfRuntimeAlgorithm::IsSwitchCall(const CNodePtr &call_node) { | |||
| } | |||
| MS_LOG(EXCEPTION) << "Unexpected input1 of call node,input1:" << input1->DebugString(); | |||
| } | |||
| bool AnfRuntimeAlgorithm::IsWhileTrueGraph(const KernelGraphPtr &child_graph) { | |||
| auto call_nodes = child_graph->FindNodeByPrimitive(prim::kPrimCall); | |||
| for (const auto &call_node : call_nodes) { | |||
| auto graphs = GetCallNodeKernelGraph(call_node); | |||
| if (graphs.size() == 1 && graphs[0] == child_graph->parent_graph()) { | |||
| return true; | |||
| } | |||
| } | |||
| return false; | |||
| } | |||
| } // namespace session | |||
| } // namespace mindspore | |||
| @@ -185,7 +185,6 @@ class AnfRuntimeAlgorithm { | |||
| static FuncGraphPtr GetValueNodeFuncGraph(const AnfNodePtr &node); | |||
| static std::vector<KernelGraphPtr> GetCallNodeKernelGraph(const CNodePtr &call_node); | |||
| static bool IsSwitchCall(const CNodePtr &call_node); | |||
| static bool IsWhileTrueGraph(const KernelGraphPtr &child_graph); | |||
| }; | |||
| } // namespace session | |||
| using AnfAlgo = session::AnfRuntimeAlgorithm; | |||
| @@ -83,7 +83,7 @@ CNodePtr AscendControlParser::GetNextRealKernel(const std::vector<CNodePtr> &lis | |||
| NotNull<CNodePtr> AscendControlParser::ProcessKernelGraph(NotNull<KernelGraphPtr> kg, const CNodePtr &last_node, | |||
| const CNodePtr &last_label, | |||
| NotNull<std::set<KernelGraphPtr> *> memo) { | |||
| const NotNull<std::set<KernelGraphPtr> *> memo) { | |||
| MS_LOG(INFO) << "Start process KernelGraph " << kg->ToString(); | |||
| // 1. recursive condition | |||
| @@ -180,7 +180,7 @@ void AscendControlParser::LinkParentGraph(NotNull<KernelGraphPtr> kg, const CNod | |||
| } | |||
| void AscendControlParser::RecurseCall(NotNull<KernelGraphPtr> kg, NotNull<CNodePtr> cur_node, const CNodePtr &next_node, | |||
| NotNull<std::set<KernelGraphPtr> *> memo) { | |||
| const NotNull<std::set<KernelGraphPtr> *> memo) { | |||
| MS_LOG(INFO) << "process call func " << cur_node->DebugString(); | |||
| // 1 get kernel graph | |||
| @@ -212,7 +212,7 @@ void AscendControlParser::RecurseCall(NotNull<KernelGraphPtr> kg, NotNull<CNodeP | |||
| } | |||
| void AscendControlParser::RecurseSwitch(NotNull<KernelGraphPtr> kg, NotNull<CNodePtr> cur_node, | |||
| const CNodePtr &next_node, NotNull<std::set<KernelGraphPtr> *> memo) { | |||
| const CNodePtr &next_node, const NotNull<std::set<KernelGraphPtr> *> memo) { | |||
| MS_LOG(INFO) << "process switch node " << cur_node->DebugString(); | |||
| if (cur_node->size() < kCNodeSwitchLength) { | |||
| @@ -249,7 +249,8 @@ void AscendControlParser::RecurseSwitch(NotNull<KernelGraphPtr> kg, NotNull<CNod | |||
| } | |||
| void AscendControlParser::RecurseSwitchLayer(NotNull<KernelGraphPtr> kg, NotNull<CNodePtr> cur_node, | |||
| const CNodePtr &next_node, NotNull<std::set<KernelGraphPtr> *> memo) { | |||
| const CNodePtr &next_node, | |||
| const NotNull<std::set<KernelGraphPtr> *> memo) { | |||
| MS_LOG(INFO) << "process switch node " << cur_node->DebugString(); | |||
| if (cur_node->size() < kCNodeSwitchLayerLength) { | |||
| @@ -353,7 +354,7 @@ void AscendControlParser::ExecutorValidate(NotNull<KernelGraphPtr> root_graph) { | |||
| } | |||
| std::vector<CNodePtr> AscendControlParser::RecurseGraph(NotNull<KernelGraphPtr> graph, | |||
| NotNull<std::set<KernelGraphPtr> *> memo) { | |||
| const NotNull<std::set<KernelGraphPtr> *> memo) { | |||
| MS_LOG(INFO) << "graph:" << graph->graph_id() << " start"; | |||
| auto print_vector = [&](std::vector<CNodePtr> vec) -> void { | |||
| MS_LOG(INFO) << "graph:" << graph->graph_id() << "execution order"; | |||
| @@ -40,13 +40,14 @@ class AscendControlParser { | |||
| private: | |||
| static NotNull<CNodePtr> ProcessKernelGraph(NotNull<KernelGraphPtr> kg, const CNodePtr &last_node, | |||
| const CNodePtr &last_label, NotNull<std::set<KernelGraphPtr> *> memo); | |||
| const CNodePtr &last_label, | |||
| const NotNull<std::set<KernelGraphPtr> *> memo); | |||
| static void RecurseCall(NotNull<KernelGraphPtr> kg, NotNull<CNodePtr> cur_node, const CNodePtr &next_node, | |||
| NotNull<std::set<KernelGraphPtr> *> memo); | |||
| const NotNull<std::set<KernelGraphPtr> *> memo); | |||
| static void RecurseSwitch(NotNull<KernelGraphPtr> kg, NotNull<CNodePtr> cur_node, const CNodePtr &next_node, | |||
| NotNull<std::set<KernelGraphPtr> *> memo); | |||
| const NotNull<std::set<KernelGraphPtr> *> memo); | |||
| static void RecurseSwitchLayer(NotNull<KernelGraphPtr> kg, NotNull<CNodePtr> cur_node, const CNodePtr &next_node, | |||
| NotNull<std::set<KernelGraphPtr> *> memo); | |||
| const NotNull<std::set<KernelGraphPtr> *> memo); | |||
| static void LinkParentGraph(NotNull<KernelGraphPtr> kg, const CNodePtr &from_graph_call_node, | |||
| const CNodePtr &last_label); | |||
| @@ -63,7 +64,8 @@ class AscendControlParser { | |||
| static std::vector<uint32_t> GetLabelSwitchList(const CNodePtr &node); | |||
| static bool CheckLabelIndex(uint32_t order_index, uint32_t label_index, const CNodePtr &cnode, | |||
| NotNull<KernelGraphPtr> graph); | |||
| static std::vector<CNodePtr> RecurseGraph(NotNull<KernelGraphPtr> graph, NotNull<std::set<KernelGraphPtr> *> memo); | |||
| static std::vector<CNodePtr> RecurseGraph(NotNull<KernelGraphPtr> graph, | |||
| const NotNull<std::set<KernelGraphPtr> *> memo); | |||
| }; | |||
| } // namespace session | |||
| } // namespace mindspore | |||
| @@ -171,19 +171,35 @@ std::vector<CNodePtr> GetCNodes(const std::vector<AnfNodePtr> &anf_nodes) { | |||
| return cnodes; | |||
| } | |||
| std::vector<std::vector<CNodePtr>> GetChildList(const KernelGraph &cur_graph, const std::vector<CNodePtr> &cnodes) { | |||
| size_t after_call_index = 0; | |||
| static std::vector<std::vector<CNodePtr>> GetChildList(const std::vector<CNodePtr> &cnodes, | |||
| const std::set<PrimitivePtr> &cut_prims) { | |||
| size_t after_cut_index = 0; | |||
| std::vector<std::vector<CNodePtr>> ret; | |||
| for (size_t i = 0; i < cnodes.size(); i++) { | |||
| if (AnfAlgo::CheckPrimitiveType(cnodes[i], prim::kPrimCall) && !AnfAlgo::IsSwitchCall(cnodes[i])) { | |||
| auto call_kernel_graph = AnfAlgo::GetCallNodeKernelGraph(cnodes[i]); | |||
| auto prev_call_list = std::vector<CNodePtr>(cnodes.begin() + after_call_index, cnodes.begin() + i); | |||
| auto call_list = std::vector<CNodePtr>(1, cnodes[i]); | |||
| after_call_index = i + 1; | |||
| ret.push_back(prev_call_list); | |||
| ret.push_back(call_list); | |||
| } else if (AnfAlgo::CheckPrimitiveType(cnodes[i], prim::kPrimReturn)) { | |||
| ret.push_back(std::vector<CNodePtr>(cnodes.begin() + after_call_index, cnodes.end())); | |||
| for (size_t i = 0; i < cnodes.size(); ++i) { | |||
| bool is_cut_node = false; | |||
| for (auto &prim : cut_prims) { | |||
| if (AnfAlgo::CheckPrimitiveType(cnodes[i], prim)) { | |||
| is_cut_node = true; | |||
| break; | |||
| } | |||
| } | |||
| if (is_cut_node) { | |||
| // is call and not switch call,cut to 3 lists | |||
| if (!AnfAlgo::CheckPrimitiveType(cnodes[i], prim::kPrimCall)) { | |||
| // if is not a call,cut to 2 lists | |||
| ret.emplace_back(cnodes.begin() + after_cut_index, cnodes.begin() + i); | |||
| after_cut_index = i; | |||
| } else if (!AnfAlgo::IsSwitchCall(cnodes[i])) { | |||
| ret.emplace_back(cnodes.begin() + after_cut_index, cnodes.begin() + i); | |||
| ret.emplace_back(1, cnodes[i]); | |||
| after_cut_index = i + 1; | |||
| continue; | |||
| } | |||
| } | |||
| // get last child graph list | |||
| if (AnfAlgo::CheckPrimitiveType(cnodes[i], prim::kPrimReturn)) { | |||
| ret.emplace_back(cnodes.begin() + after_cut_index, cnodes.end()); | |||
| continue; | |||
| } | |||
| } | |||
| return ret; | |||
| @@ -191,7 +207,7 @@ std::vector<std::vector<CNodePtr>> GetChildList(const KernelGraph &cur_graph, co | |||
| // if a call has kernel input, it's a child graph split from ME, so these kernel input should be set into real input of | |||
| // graph.For example, call input = (prim,graph,kernel1,kernel2),then real_input = [kernel1,kernel2] | |||
| static void UpdateRealInput(KernelGraph *graph) { | |||
| static void UpdateRealInput(NotNull<KernelGraphPtr> graph) { | |||
| auto call_nodes = graph->FindNodeByPrimitive(prim::kPrimCall); | |||
| auto bind_call_arg_with_parameter = [&](const std::vector<AnfNodePtr> ¶meters, | |||
| const std::vector<AnfNodePtr> &args, KernelGraph *child_graph) -> void { | |||
| @@ -253,16 +269,17 @@ static void UpdateRealInput(KernelGraph *graph) { | |||
| } | |||
| } | |||
| void RecurseToUpdateCallRealInput(KernelGraph *graph) { | |||
| MS_EXCEPTION_IF_NULL(graph); | |||
| static void RecurseToUpdateCallRealInput(NotNull<KernelGraphPtr> graph, | |||
| const NotNull<std::set<KernelGraphPtr> *> memo) { | |||
| memo->insert(graph.get()); | |||
| MS_LOG(INFO) << "start graph id:" << graph->graph_id(); | |||
| for (auto &child_graph : graph->child_graph_order()) { | |||
| if (child_graph == graph->parent_graph()) { | |||
| if (memo->find(child_graph) != memo->end()) { | |||
| MS_LOG(INFO) << "Child graph:" << child_graph->graph_id() | |||
| << ",parent graph:" << graph->parent_graph()->graph_id(); | |||
| continue; | |||
| } | |||
| RecurseToUpdateCallRealInput(child_graph.get()); | |||
| RecurseToUpdateCallRealInput(NOT_NULL(child_graph), memo); | |||
| } | |||
| // this action should from bottom to top | |||
| graph->UpdateCallRealInput(); | |||
| @@ -282,7 +299,7 @@ GraphId AscendSession::CompileGraph(NotNull<FuncGraphPtr> func_graph) { | |||
| MS_LOG(INFO) << "start"; | |||
| auto graph = ConstructKernelGraph(func_graph); | |||
| // split switch | |||
| SplitGraphs(graph); | |||
| SplitGraphs(NOT_NULL(graph)); | |||
| // insert goto labels and label_sets | |||
| LinkChildGraphs(NOT_NULL(graph)); | |||
| // resource initialize | |||
| @@ -290,7 +307,8 @@ GraphId AscendSession::CompileGraph(NotNull<FuncGraphPtr> func_graph) { | |||
| // assign label | |||
| AssignLabel(NOT_NULL(graph)); | |||
| // recurse compile child graph | |||
| RecurseCompileGraph(graph); | |||
| std::set<KernelGraphPtr> memo; | |||
| RecurseCompileGraph(NOT_NULL(graph), NOT_NULL(&memo)); | |||
| // root graph valiate,include genearte execute order and so on | |||
| RootGraphExecutorValidate(NOT_NULL(graph)); | |||
| // adjust kernel | |||
| @@ -1423,24 +1441,43 @@ std::vector<AnfNodePtr> AscendSession::ConstructSplitedGraph(const KernelGraphPt | |||
| } | |||
| MS_LOG(INFO) << "Construct input of kernel graph:" << new_kernel_graph->graph_id(); | |||
| std::vector<AnfNodePtr> call_node_inputs; | |||
| auto graph_inputs = new_kernel_graph->MutableInputs(); | |||
| MS_EXCEPTION_IF_NULL(graph_inputs); | |||
| std::vector<AnfNodePtr> new_graph_inputs; | |||
| // create new parameter from cnode | |||
| for (auto &anf_node : list) { | |||
| auto cnode = anf_node->cast<CNodePtr>(); | |||
| for (size_t input_idx = 1; input_idx < cnode->inputs().size(); input_idx++) { | |||
| auto input = cnode->inputs()[input_idx]; | |||
| MS_EXCEPTION_IF_NULL(input); | |||
| if (input->isa<Parameter>()) { | |||
| graph_inputs->push_back(input); | |||
| AnfNodePtr new_parameter = nullptr; | |||
| // value node consider move to new graph | |||
| if (input->isa<ValueNode>()) { | |||
| cnode->set_input(input_idx, input); | |||
| continue; | |||
| } else if (input->isa<Parameter>()) { | |||
| // parameter reuse and should attention mulptiple use of one parameter | |||
| cnode->set_input(input_idx, input); | |||
| new_parameter = input; | |||
| } else if (AnfAlgo::GetGraphId(input.get()) != new_kernel_graph->graph_id()) { | |||
| auto new_parameter = CreateNewParameterFromCNode(input, true, new_kernel_graph.get()); | |||
| // if is cnode and not in current child graph | |||
| new_parameter = CreateNewParameterFromCNode(input, true, new_kernel_graph.get()); | |||
| cnode->set_input(input_idx, new_parameter); | |||
| } else { | |||
| // if is a cnode and in current graph | |||
| continue; | |||
| } | |||
| // if mulptiple use of one parameter or cnode, only set one parameter in graph inputs and one arg in call node | |||
| // args | |||
| if (std::find(call_node_inputs.begin(), call_node_inputs.end(), new_parameter) == call_node_inputs.end()) { | |||
| new_graph_inputs.push_back(new_parameter); | |||
| call_node_inputs.push_back(input); | |||
| } | |||
| call_node_inputs.push_back(input); | |||
| } | |||
| } | |||
| // set graph inputs of new graph | |||
| auto graph_inputs = new_kernel_graph->MutableInputs(); | |||
| MS_EXCEPTION_IF_NULL(graph_inputs); | |||
| graph_inputs->clear(); | |||
| std::copy(new_graph_inputs.begin(), new_graph_inputs.end(), std::back_inserter(*graph_inputs)); | |||
| MS_LOG(INFO) << "Construct output of kernel graph:" << new_kernel_graph->graph_id(); | |||
| auto make_tuple_primitve = NewValueNode(std::make_shared<Primitive>(prim::kPrimMakeTuple->name())); | |||
| std::vector<AnfNodePtr> make_tuple_inputs = {make_tuple_primitve}; | |||
| @@ -1461,20 +1498,30 @@ std::vector<AnfNodePtr> AscendSession::ConstructSplitedGraph(const KernelGraphPt | |||
| return call_node_inputs; | |||
| } | |||
| void AscendSession::SplitGraphs(const KernelGraphPtr &root_graph) { | |||
| SplitGraph(root_graph); | |||
| void AscendSession::SplitGraphs(NotNull<KernelGraphPtr> root_graph) { | |||
| std::set<KernelGraphPtr> memo; | |||
| // if root graph output is a call node ,the root graph is condition graph of 'if' sentence | |||
| auto root_graph_output = AnfAlgo::VisitKernelWithReturnType(root_graph->output(), 0).first; | |||
| if (AnfAlgo::CheckPrimitiveType(root_graph_output, prim::kPrimCall)) { | |||
| SplitGraph(root_graph, {prim::kPrimReturn}); | |||
| for (auto &child_graph : root_graph->child_graph_order()) { | |||
| RecurseSplitGraph(NOT_NULL(child_graph), NOT_NULL(&memo)); | |||
| } | |||
| } else { | |||
| RecurseSplitGraph(root_graph, NOT_NULL(&memo)); | |||
| } | |||
| memo.clear(); | |||
| // replace the real input if the real input is a call | |||
| RecurseToUpdateCallRealInput(root_graph.get()); | |||
| RecurseToUpdateCallRealInput(root_graph, NOT_NULL(&memo)); | |||
| } | |||
| void AscendSession::SplitGraph(const KernelGraphPtr &graph) { | |||
| void AscendSession::SplitGraph(NotNull<KernelGraphPtr> graph, const std::set<PrimitivePtr> &cut_prims) { | |||
| MS_LOG(INFO) << "start,graph_id:" << graph->graph_id(); | |||
| MS_EXCEPTION_IF_NULL(graph); | |||
| auto apply_list = GetCNodes(TopoSort(graph->get_return())); | |||
| // update the root graph child graph order | |||
| AscendControlParser::UpdateChildGraphOrder(NOT_NULL(graph)); | |||
| AscendControlParser::UpdateChildGraphOrder(graph); | |||
| // get child list from current graph | |||
| std::vector<std::vector<CNodePtr>> child_graph_lists = GetChildList(*graph, apply_list); | |||
| std::vector<std::vector<CNodePtr>> child_graph_lists = GetChildList(apply_list, cut_prims); | |||
| auto bind_new_call_to_new_graph = [&](std::vector<CNodePtr> child_graph_list) -> AnfNodePtr { | |||
| // if child graph list only has a call ,then return the exist call | |||
| if (child_graph_list.size() == 1 && AnfAlgo::CheckPrimitiveType(child_graph_list[0], prim::kPrimCall)) { | |||
| @@ -1521,20 +1568,22 @@ void AscendSession::SplitGraph(const KernelGraphPtr &graph) { | |||
| pre_call_node = cur_call_node; | |||
| cur_call_node = *iter; | |||
| if (pre_call_node != nullptr && cur_call_node != nullptr) { | |||
| AscendControlParser::InsertControlDependToGraph(NOT_NULL(graph), NOT_NULL(cur_call_node), | |||
| NOT_NULL(pre_call_node)); | |||
| AscendControlParser::InsertControlDependToGraph(graph, NOT_NULL(cur_call_node), NOT_NULL(pre_call_node)); | |||
| } | |||
| } | |||
| } | |||
| AscendControlParser::UpdateChildGraphOrder(NOT_NULL(graph)); | |||
| UpdateRealInput(graph.get()); | |||
| auto graph_name = std::string("./kernel-graph-").append(std::to_string(graph->graph_id())); | |||
| DumpIR(graph_name, graph); | |||
| AscendControlParser::UpdateChildGraphOrder(graph); | |||
| UpdateRealInput(graph); | |||
| MS_LOG(INFO) << "split graph[" << graph->graph_id() << "] end"; | |||
| // recurse to split child graph | |||
| } | |||
| void AscendSession::RecurseSplitGraph(NotNull<KernelGraphPtr> graph, const NotNull<std::set<KernelGraphPtr> *> memo) { | |||
| memo->insert(graph.get()); | |||
| SplitGraph(graph, {prim::kPrimCall}); | |||
| for (auto &child_graph : graph->child_graph_order()) { | |||
| if (child_graph != graph->parent_graph()) { | |||
| SplitGraph(child_graph); | |||
| if (memo->find(child_graph) == memo->end()) { | |||
| RecurseSplitGraph(NOT_NULL(child_graph), memo); | |||
| } | |||
| } | |||
| } | |||
| @@ -1545,13 +1594,14 @@ void AscendSession::RootGraphExecutorValidate(NotNull<KernelGraphPtr> graph) { | |||
| AscendControlParser::ExecutorValidate(graph); | |||
| } | |||
| void AscendSession::RecurseCompileGraph(const KernelGraphPtr &graph) { | |||
| void AscendSession::RecurseCompileGraph(NotNull<KernelGraphPtr> graph, const NotNull<std::set<KernelGraphPtr> *> memo) { | |||
| memo->insert(graph.get()); | |||
| CompileChildGraph(graph); | |||
| for (auto child_graph : graph->child_graph_order()) { | |||
| if (child_graph == graph->parent_graph()) { | |||
| if (memo->find(child_graph) != memo->end()) { | |||
| continue; | |||
| } | |||
| RecurseCompileGraph(child_graph); | |||
| RecurseCompileGraph(NOT_NULL(child_graph), memo); | |||
| } | |||
| } | |||
| @@ -98,18 +98,15 @@ class AscendSession : public SessionBasic { | |||
| void SetFinalGraphOutput(const ValuePtr &value); | |||
| void SetFinalGraphOutput(const VectorRef &vec_output); | |||
| void SplitGraph(const KernelGraphPtr &graph); | |||
| void SplitGraph(NotNull<KernelGraphPtr> graph, const std::set<PrimitivePtr> &cut_prims); | |||
| // split graphs with recurse from root graph | |||
| void SplitGraphs(const KernelGraphPtr &root_graph); | |||
| void SplitGraphs(NotNull<KernelGraphPtr> root_graph); | |||
| void LinkChildGraphs(NotNull<KernelGraphPtr> graph); | |||
| void IRFusion(const KernelGraphPtr &graph) {} | |||
| void SelectKernelGraphKernel(const KernelGraph &graph) {} | |||
| void ConvertPredictModel(const KernelGraphPtr graph) {} | |||
| void HardwareOptimizeGraphs(const KernelGraphPtr graph) {} | |||
| void RootGraphExecutorValidate(NotNull<KernelGraphPtr> graph); | |||
| std::vector<AnfNodePtr> ConstructSplitedGraph(const KernelGraphPtr &new_kernel_graph, | |||
| const std::vector<CNodePtr> &list); | |||
| void RecurseCompileGraph(const KernelGraphPtr &graph); | |||
| void RecurseCompileGraph(NotNull<KernelGraphPtr> graph, const NotNull<std::set<KernelGraphPtr> *> memo); | |||
| void RecurseSplitGraph(NotNull<KernelGraphPtr> graph, const NotNull<std::set<KernelGraphPtr> *> memo); | |||
| // merge execution order list of child graphs | |||
| void MergeGraphExecOrder(); | |||
| @@ -50,7 +50,7 @@ std::vector<AnfNodePtr> GetCallRealOutputs(const AnfNodePtr &call_node) { | |||
| std::vector<AnfNodePtr> real_inputs; | |||
| auto child_graphs = AnfAlgo::GetCallNodeKernelGraph(item_with_index.first->cast<CNodePtr>()); | |||
| for (const auto &child_graph : child_graphs) { | |||
| if (AnfAlgo::IsWhileTrueGraph(child_graph)) { | |||
| if (child_graph->get_output_null()) { | |||
| continue; | |||
| } | |||
| auto real_input = child_graph->output(); | |||
| @@ -592,7 +592,11 @@ void KernelGraph::ReplaceNode(const AnfNodePtr &old_anf_node, AnfNodePtr new_anf | |||
| MS_EXCEPTION_IF_NULL(output_node.first); | |||
| auto output_cnode = output_node.first->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(output_cnode); | |||
| const auto &output_node_inputs = output_cnode->inputs(); | |||
| auto &output_node_inputs = output_cnode->inputs(); | |||
| // don't replace node if it is a control edge => output_node.second == 0 | |||
| if (output_node.second == 0) { | |||
| continue; | |||
| } | |||
| for (size_t i = 1; i < output_node_inputs.size(); i++) { | |||
| if (output_node_inputs[i] == old_anf_node) { | |||
| output_cnode->set_input(i, new_anf_node); | |||
| @@ -686,10 +690,12 @@ std::set<AnfNodePtr> KernelGraph::GetRealInput(const AnfNodePtr ¶meter) { | |||
| void KernelGraph::UpdateCallRealInput() { | |||
| MS_LOG(INFO) << "Update graph id: " << graph_id_; | |||
| std::map<AnfNodePtr, std::set<AnfNodePtr>> real_inputs_map; | |||
| std::vector<std::pair<AnfNodePtr, AnfNodePtr>> replace_list; | |||
| for (auto &it : real_inputs_) { | |||
| auto ¶meter = it.first; | |||
| auto parameter = it.first; | |||
| MS_EXCEPTION_IF_NULL(parameter); | |||
| auto &real_inputs = it.second; | |||
| auto real_inputs = it.second; | |||
| std::vector<AnfNodePtr> new_real_inputs; | |||
| std::set<AnfNodePtr> erase_real_inputs; | |||
| for (auto &real_input : real_inputs) { | |||
| @@ -711,10 +717,16 @@ void KernelGraph::UpdateCallRealInput() { | |||
| << " insert real input:" << new_real_input->DebugString(); | |||
| (void)real_inputs.insert(new_real_input); | |||
| if (new_real_input->isa<Parameter>()) { | |||
| ReplaceNode(parameter, new_real_input); | |||
| replace_list.emplace_back(parameter, new_real_input); | |||
| parameter = new_real_input; | |||
| } | |||
| } | |||
| real_inputs_map[parameter] = real_inputs; | |||
| } | |||
| for (auto [parameter, arg] : replace_list) { | |||
| ReplaceNode(parameter, arg); | |||
| } | |||
| real_inputs_ = real_inputs_map; | |||
| } | |||
| std::string KernelGraph::ToString() const { return std::string("kernel_graph_").append(std::to_string(graph_id_)); } | |||
| @@ -36,7 +36,7 @@ namespace session { | |||
| using AnfWithOutIndex = std::pair<AnfNodePtr, size_t>; | |||
| class KernelGraph : public FuncGraph { | |||
| public: | |||
| KernelGraph() : graph_id_(0) { | |||
| KernelGraph() : graph_id_(0), start_label_(nullptr), end_goto_(nullptr), null_output_(false) { | |||
| inputs_ = std::make_shared<std::vector<AnfNodePtr>>(); | |||
| execution_order_ = {}; | |||
| executable_ = true; | |||
| @@ -134,6 +134,8 @@ class KernelGraph : public FuncGraph { | |||
| CNodePtr get_start_label() { return start_label_; } | |||
| void set_end_goto(const CNodePtr &end_goto) { end_goto_ = end_goto; } | |||
| CNodePtr get_end_goto() { return end_goto_; } | |||
| bool get_output_null() { return null_output_; } | |||
| void set_output_null(bool is_output_null) { null_output_ = is_output_null; } | |||
| private: | |||
| // remove value node form graph | |||
| @@ -188,6 +190,7 @@ class KernelGraph : public FuncGraph { | |||
| CNodePtr start_label_; | |||
| CNodePtr end_goto_; | |||
| bool null_output_; | |||
| }; | |||
| } // namespace session | |||
| using KernelGraphPtr = std::shared_ptr<session::KernelGraph>; | |||
| @@ -543,15 +543,12 @@ KernelGraphPtr SessionBasic::ConstructKernelGraph(const AnfNodePtrList &lst, con | |||
| std::shared_ptr<KernelGraph> SessionBasic::ConstructKernelGraph(const FuncGraphPtr &func_graph) { | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| if (front_backend_graph_map_.find(func_graph) != front_backend_graph_map_.end()) { | |||
| MS_LOG(INFO) << "FuncGraph: " << func_graph->ToString() << " has been transformed to KernelGraph."; | |||
| return front_backend_graph_map_[func_graph]; | |||
| } | |||
| auto node_list = TopoSort(func_graph->get_return()); | |||
| auto graph = NewKernelGraph(); | |||
| front_backend_graph_map_[func_graph] = graph; | |||
| MS_LOG(INFO) << "Create graph: " << graph->graph_id(); | |||
| bool is_trace_back = false; | |||
| for (const auto &node : node_list) { | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| MS_LOG(DEBUG) << "Start create new cnode, node = " << node->DebugString(); | |||
| @@ -564,8 +561,14 @@ std::shared_ptr<KernelGraph> SessionBasic::ConstructKernelGraph(const FuncGraphP | |||
| (void)CreateNewValueNode(node, graph.get()); | |||
| } else { | |||
| // if input is a ValueNode<FuncGraph> | |||
| auto child_graph = ConstructKernelGraph(AnfAlgo::GetValueNodeFuncGraph(node)); | |||
| auto new_value_node = CreateValueNodeKernelGraph(node, graph.get()); | |||
| FuncGraphPtr child_graph = AnfAlgo::GetValueNodeFuncGraph(node); | |||
| if (front_backend_graph_map_.find(child_graph) != front_backend_graph_map_.end()) { | |||
| MS_LOG(INFO) << "FuncGraph: " << child_graph->ToString() << " has been transformed to KernelGraph."; | |||
| is_trace_back = true; | |||
| } else { | |||
| (void)ConstructKernelGraph(child_graph); | |||
| } | |||
| (void)CreateValueNodeKernelGraph(node, graph.get()); | |||
| } | |||
| continue; | |||
| } else { | |||
| @@ -582,6 +585,8 @@ std::shared_ptr<KernelGraph> SessionBasic::ConstructKernelGraph(const FuncGraphP | |||
| } | |||
| } | |||
| } | |||
| // if a graph jump back unconditionally, return op of this graph will never be executed, so output is null. | |||
| graph->set_output_null(is_trace_back); | |||
| auto graph_inputs = graph->MutableInputs(); | |||
| MS_EXCEPTION_IF_NULL(graph_inputs); | |||
| graph_inputs->clear(); | |||