Signed-off-by: zhoufeng <zhoufeng54@huawei.com>tags/v0.5.0-beta
| @@ -26,7 +26,6 @@ static constexpr uint32_t kLabelSwitchLabelId = 2; | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace device { | namespace device { | ||||
| namespace ascend { | namespace ascend { | ||||
| static void UpdateLabelGoto(NotNull<CNodePtr> node) { | static void UpdateLabelGoto(NotNull<CNodePtr> node) { | ||||
| if (AnfAlgo::HasNodeAttr(kAttrLabelIndex, node)) { | if (AnfAlgo::HasNodeAttr(kAttrLabelIndex, node)) { | ||||
| return; | return; | ||||
| @@ -164,7 +163,6 @@ uint32_t AscendLabelAssign::GetLabelNum(NotNull<const session::KernelGraph *> gr | |||||
| uint32_t AscendLabelAssign::GetLabelNum(NotNull<std::shared_ptr<session::KernelGraph>> graph) { | uint32_t AscendLabelAssign::GetLabelNum(NotNull<std::shared_ptr<session::KernelGraph>> graph) { | ||||
| return GetLabelNum(NOT_NULL(graph.get().get())); | return GetLabelNum(NOT_NULL(graph.get().get())); | ||||
| } | } | ||||
| } // namespace ascend | } // namespace ascend | ||||
| } // namespace device | } // namespace device | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -25,7 +25,6 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace device { | namespace device { | ||||
| namespace ascend { | namespace ascend { | ||||
| class AscendLabelAssign { | class AscendLabelAssign { | ||||
| public: | public: | ||||
| static AscendLabelAssign &GetInstance() { | static AscendLabelAssign &GetInstance() { | ||||
| @@ -974,17 +974,5 @@ bool AnfRuntimeAlgorithm::IsSwitchCall(const CNodePtr &call_node) { | |||||
| } | } | ||||
| MS_LOG(EXCEPTION) << "Unexpected input1 of call node,input1:" << input1->DebugString(); | MS_LOG(EXCEPTION) << "Unexpected input1 of call node,input1:" << input1->DebugString(); | ||||
| } | } | ||||
| bool AnfRuntimeAlgorithm::IsWhileTrueGraph(const KernelGraphPtr &child_graph) { | |||||
| auto call_nodes = child_graph->FindNodeByPrimitive(prim::kPrimCall); | |||||
| for (const auto &call_node : call_nodes) { | |||||
| auto graphs = GetCallNodeKernelGraph(call_node); | |||||
| if (graphs.size() == 1 && graphs[0] == child_graph->parent_graph()) { | |||||
| return true; | |||||
| } | |||||
| } | |||||
| return false; | |||||
| } | |||||
| } // namespace session | } // namespace session | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -185,7 +185,6 @@ class AnfRuntimeAlgorithm { | |||||
| static FuncGraphPtr GetValueNodeFuncGraph(const AnfNodePtr &node); | static FuncGraphPtr GetValueNodeFuncGraph(const AnfNodePtr &node); | ||||
| static std::vector<KernelGraphPtr> GetCallNodeKernelGraph(const CNodePtr &call_node); | static std::vector<KernelGraphPtr> GetCallNodeKernelGraph(const CNodePtr &call_node); | ||||
| static bool IsSwitchCall(const CNodePtr &call_node); | static bool IsSwitchCall(const CNodePtr &call_node); | ||||
| static bool IsWhileTrueGraph(const KernelGraphPtr &child_graph); | |||||
| }; | }; | ||||
| } // namespace session | } // namespace session | ||||
| using AnfAlgo = session::AnfRuntimeAlgorithm; | using AnfAlgo = session::AnfRuntimeAlgorithm; | ||||
| @@ -83,7 +83,7 @@ CNodePtr AscendControlParser::GetNextRealKernel(const std::vector<CNodePtr> &lis | |||||
| NotNull<CNodePtr> AscendControlParser::ProcessKernelGraph(NotNull<KernelGraphPtr> kg, const CNodePtr &last_node, | NotNull<CNodePtr> AscendControlParser::ProcessKernelGraph(NotNull<KernelGraphPtr> kg, const CNodePtr &last_node, | ||||
| const CNodePtr &last_label, | const CNodePtr &last_label, | ||||
| NotNull<std::set<KernelGraphPtr> *> memo) { | |||||
| const NotNull<std::set<KernelGraphPtr> *> memo) { | |||||
| MS_LOG(INFO) << "Start process KernelGraph " << kg->ToString(); | MS_LOG(INFO) << "Start process KernelGraph " << kg->ToString(); | ||||
| // 1. recursive condition | // 1. recursive condition | ||||
| @@ -180,7 +180,7 @@ void AscendControlParser::LinkParentGraph(NotNull<KernelGraphPtr> kg, const CNod | |||||
| } | } | ||||
| void AscendControlParser::RecurseCall(NotNull<KernelGraphPtr> kg, NotNull<CNodePtr> cur_node, const CNodePtr &next_node, | void AscendControlParser::RecurseCall(NotNull<KernelGraphPtr> kg, NotNull<CNodePtr> cur_node, const CNodePtr &next_node, | ||||
| NotNull<std::set<KernelGraphPtr> *> memo) { | |||||
| const NotNull<std::set<KernelGraphPtr> *> memo) { | |||||
| MS_LOG(INFO) << "process call func " << cur_node->DebugString(); | MS_LOG(INFO) << "process call func " << cur_node->DebugString(); | ||||
| // 1 get kernel graph | // 1 get kernel graph | ||||
| @@ -212,7 +212,7 @@ void AscendControlParser::RecurseCall(NotNull<KernelGraphPtr> kg, NotNull<CNodeP | |||||
| } | } | ||||
| void AscendControlParser::RecurseSwitch(NotNull<KernelGraphPtr> kg, NotNull<CNodePtr> cur_node, | void AscendControlParser::RecurseSwitch(NotNull<KernelGraphPtr> kg, NotNull<CNodePtr> cur_node, | ||||
| const CNodePtr &next_node, NotNull<std::set<KernelGraphPtr> *> memo) { | |||||
| const CNodePtr &next_node, const NotNull<std::set<KernelGraphPtr> *> memo) { | |||||
| MS_LOG(INFO) << "process switch node " << cur_node->DebugString(); | MS_LOG(INFO) << "process switch node " << cur_node->DebugString(); | ||||
| if (cur_node->size() < kCNodeSwitchLength) { | if (cur_node->size() < kCNodeSwitchLength) { | ||||
| @@ -249,7 +249,8 @@ void AscendControlParser::RecurseSwitch(NotNull<KernelGraphPtr> kg, NotNull<CNod | |||||
| } | } | ||||
| void AscendControlParser::RecurseSwitchLayer(NotNull<KernelGraphPtr> kg, NotNull<CNodePtr> cur_node, | void AscendControlParser::RecurseSwitchLayer(NotNull<KernelGraphPtr> kg, NotNull<CNodePtr> cur_node, | ||||
| const CNodePtr &next_node, NotNull<std::set<KernelGraphPtr> *> memo) { | |||||
| const CNodePtr &next_node, | |||||
| const NotNull<std::set<KernelGraphPtr> *> memo) { | |||||
| MS_LOG(INFO) << "process switch node " << cur_node->DebugString(); | MS_LOG(INFO) << "process switch node " << cur_node->DebugString(); | ||||
| if (cur_node->size() < kCNodeSwitchLayerLength) { | if (cur_node->size() < kCNodeSwitchLayerLength) { | ||||
| @@ -353,7 +354,7 @@ void AscendControlParser::ExecutorValidate(NotNull<KernelGraphPtr> root_graph) { | |||||
| } | } | ||||
| std::vector<CNodePtr> AscendControlParser::RecurseGraph(NotNull<KernelGraphPtr> graph, | std::vector<CNodePtr> AscendControlParser::RecurseGraph(NotNull<KernelGraphPtr> graph, | ||||
| NotNull<std::set<KernelGraphPtr> *> memo) { | |||||
| const NotNull<std::set<KernelGraphPtr> *> memo) { | |||||
| MS_LOG(INFO) << "graph:" << graph->graph_id() << " start"; | MS_LOG(INFO) << "graph:" << graph->graph_id() << " start"; | ||||
| auto print_vector = [&](std::vector<CNodePtr> vec) -> void { | auto print_vector = [&](std::vector<CNodePtr> vec) -> void { | ||||
| MS_LOG(INFO) << "graph:" << graph->graph_id() << "execution order"; | MS_LOG(INFO) << "graph:" << graph->graph_id() << "execution order"; | ||||
| @@ -40,13 +40,14 @@ class AscendControlParser { | |||||
| private: | private: | ||||
| static NotNull<CNodePtr> ProcessKernelGraph(NotNull<KernelGraphPtr> kg, const CNodePtr &last_node, | static NotNull<CNodePtr> ProcessKernelGraph(NotNull<KernelGraphPtr> kg, const CNodePtr &last_node, | ||||
| const CNodePtr &last_label, NotNull<std::set<KernelGraphPtr> *> memo); | |||||
| const CNodePtr &last_label, | |||||
| const NotNull<std::set<KernelGraphPtr> *> memo); | |||||
| static void RecurseCall(NotNull<KernelGraphPtr> kg, NotNull<CNodePtr> cur_node, const CNodePtr &next_node, | static void RecurseCall(NotNull<KernelGraphPtr> kg, NotNull<CNodePtr> cur_node, const CNodePtr &next_node, | ||||
| NotNull<std::set<KernelGraphPtr> *> memo); | |||||
| const NotNull<std::set<KernelGraphPtr> *> memo); | |||||
| static void RecurseSwitch(NotNull<KernelGraphPtr> kg, NotNull<CNodePtr> cur_node, const CNodePtr &next_node, | static void RecurseSwitch(NotNull<KernelGraphPtr> kg, NotNull<CNodePtr> cur_node, const CNodePtr &next_node, | ||||
| NotNull<std::set<KernelGraphPtr> *> memo); | |||||
| const NotNull<std::set<KernelGraphPtr> *> memo); | |||||
| static void RecurseSwitchLayer(NotNull<KernelGraphPtr> kg, NotNull<CNodePtr> cur_node, const CNodePtr &next_node, | static void RecurseSwitchLayer(NotNull<KernelGraphPtr> kg, NotNull<CNodePtr> cur_node, const CNodePtr &next_node, | ||||
| NotNull<std::set<KernelGraphPtr> *> memo); | |||||
| const NotNull<std::set<KernelGraphPtr> *> memo); | |||||
| static void LinkParentGraph(NotNull<KernelGraphPtr> kg, const CNodePtr &from_graph_call_node, | static void LinkParentGraph(NotNull<KernelGraphPtr> kg, const CNodePtr &from_graph_call_node, | ||||
| const CNodePtr &last_label); | const CNodePtr &last_label); | ||||
| @@ -63,7 +64,8 @@ class AscendControlParser { | |||||
| static std::vector<uint32_t> GetLabelSwitchList(const CNodePtr &node); | static std::vector<uint32_t> GetLabelSwitchList(const CNodePtr &node); | ||||
| static bool CheckLabelIndex(uint32_t order_index, uint32_t label_index, const CNodePtr &cnode, | static bool CheckLabelIndex(uint32_t order_index, uint32_t label_index, const CNodePtr &cnode, | ||||
| NotNull<KernelGraphPtr> graph); | NotNull<KernelGraphPtr> graph); | ||||
| static std::vector<CNodePtr> RecurseGraph(NotNull<KernelGraphPtr> graph, NotNull<std::set<KernelGraphPtr> *> memo); | |||||
| static std::vector<CNodePtr> RecurseGraph(NotNull<KernelGraphPtr> graph, | |||||
| const NotNull<std::set<KernelGraphPtr> *> memo); | |||||
| }; | }; | ||||
| } // namespace session | } // namespace session | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -171,19 +171,35 @@ std::vector<CNodePtr> GetCNodes(const std::vector<AnfNodePtr> &anf_nodes) { | |||||
| return cnodes; | return cnodes; | ||||
| } | } | ||||
| std::vector<std::vector<CNodePtr>> GetChildList(const KernelGraph &cur_graph, const std::vector<CNodePtr> &cnodes) { | |||||
| size_t after_call_index = 0; | |||||
| static std::vector<std::vector<CNodePtr>> GetChildList(const std::vector<CNodePtr> &cnodes, | |||||
| const std::set<PrimitivePtr> &cut_prims) { | |||||
| size_t after_cut_index = 0; | |||||
| std::vector<std::vector<CNodePtr>> ret; | std::vector<std::vector<CNodePtr>> ret; | ||||
| for (size_t i = 0; i < cnodes.size(); i++) { | |||||
| if (AnfAlgo::CheckPrimitiveType(cnodes[i], prim::kPrimCall) && !AnfAlgo::IsSwitchCall(cnodes[i])) { | |||||
| auto call_kernel_graph = AnfAlgo::GetCallNodeKernelGraph(cnodes[i]); | |||||
| auto prev_call_list = std::vector<CNodePtr>(cnodes.begin() + after_call_index, cnodes.begin() + i); | |||||
| auto call_list = std::vector<CNodePtr>(1, cnodes[i]); | |||||
| after_call_index = i + 1; | |||||
| ret.push_back(prev_call_list); | |||||
| ret.push_back(call_list); | |||||
| } else if (AnfAlgo::CheckPrimitiveType(cnodes[i], prim::kPrimReturn)) { | |||||
| ret.push_back(std::vector<CNodePtr>(cnodes.begin() + after_call_index, cnodes.end())); | |||||
| for (size_t i = 0; i < cnodes.size(); ++i) { | |||||
| bool is_cut_node = false; | |||||
| for (auto &prim : cut_prims) { | |||||
| if (AnfAlgo::CheckPrimitiveType(cnodes[i], prim)) { | |||||
| is_cut_node = true; | |||||
| break; | |||||
| } | |||||
| } | |||||
| if (is_cut_node) { | |||||
| // is call and not switch call,cut to 3 lists | |||||
| if (!AnfAlgo::CheckPrimitiveType(cnodes[i], prim::kPrimCall)) { | |||||
| // if is not a call,cut to 2 lists | |||||
| ret.emplace_back(cnodes.begin() + after_cut_index, cnodes.begin() + i); | |||||
| after_cut_index = i; | |||||
| } else if (!AnfAlgo::IsSwitchCall(cnodes[i])) { | |||||
| ret.emplace_back(cnodes.begin() + after_cut_index, cnodes.begin() + i); | |||||
| ret.emplace_back(1, cnodes[i]); | |||||
| after_cut_index = i + 1; | |||||
| continue; | |||||
| } | |||||
| } | |||||
| // get last child graph list | |||||
| if (AnfAlgo::CheckPrimitiveType(cnodes[i], prim::kPrimReturn)) { | |||||
| ret.emplace_back(cnodes.begin() + after_cut_index, cnodes.end()); | |||||
| continue; | |||||
| } | } | ||||
| } | } | ||||
| return ret; | return ret; | ||||
| @@ -191,7 +207,7 @@ std::vector<std::vector<CNodePtr>> GetChildList(const KernelGraph &cur_graph, co | |||||
| // if a call has kernel input, it's a child graph split from ME, so these kernel input should be set into real input of | // if a call has kernel input, it's a child graph split from ME, so these kernel input should be set into real input of | ||||
| // graph.For example, call input = (prim,graph,kernel1,kernel2),then real_input = [kernel1,kernel2] | // graph.For example, call input = (prim,graph,kernel1,kernel2),then real_input = [kernel1,kernel2] | ||||
| static void UpdateRealInput(KernelGraph *graph) { | |||||
| static void UpdateRealInput(NotNull<KernelGraphPtr> graph) { | |||||
| auto call_nodes = graph->FindNodeByPrimitive(prim::kPrimCall); | auto call_nodes = graph->FindNodeByPrimitive(prim::kPrimCall); | ||||
| auto bind_call_arg_with_parameter = [&](const std::vector<AnfNodePtr> ¶meters, | auto bind_call_arg_with_parameter = [&](const std::vector<AnfNodePtr> ¶meters, | ||||
| const std::vector<AnfNodePtr> &args, KernelGraph *child_graph) -> void { | const std::vector<AnfNodePtr> &args, KernelGraph *child_graph) -> void { | ||||
| @@ -253,16 +269,17 @@ static void UpdateRealInput(KernelGraph *graph) { | |||||
| } | } | ||||
| } | } | ||||
| void RecurseToUpdateCallRealInput(KernelGraph *graph) { | |||||
| MS_EXCEPTION_IF_NULL(graph); | |||||
| static void RecurseToUpdateCallRealInput(NotNull<KernelGraphPtr> graph, | |||||
| const NotNull<std::set<KernelGraphPtr> *> memo) { | |||||
| memo->insert(graph.get()); | |||||
| MS_LOG(INFO) << "start graph id:" << graph->graph_id(); | MS_LOG(INFO) << "start graph id:" << graph->graph_id(); | ||||
| for (auto &child_graph : graph->child_graph_order()) { | for (auto &child_graph : graph->child_graph_order()) { | ||||
| if (child_graph == graph->parent_graph()) { | |||||
| if (memo->find(child_graph) != memo->end()) { | |||||
| MS_LOG(INFO) << "Child graph:" << child_graph->graph_id() | MS_LOG(INFO) << "Child graph:" << child_graph->graph_id() | ||||
| << ",parent graph:" << graph->parent_graph()->graph_id(); | << ",parent graph:" << graph->parent_graph()->graph_id(); | ||||
| continue; | continue; | ||||
| } | } | ||||
| RecurseToUpdateCallRealInput(child_graph.get()); | |||||
| RecurseToUpdateCallRealInput(NOT_NULL(child_graph), memo); | |||||
| } | } | ||||
| // this action should from bottom to top | // this action should from bottom to top | ||||
| graph->UpdateCallRealInput(); | graph->UpdateCallRealInput(); | ||||
| @@ -282,7 +299,7 @@ GraphId AscendSession::CompileGraph(NotNull<FuncGraphPtr> func_graph) { | |||||
| MS_LOG(INFO) << "start"; | MS_LOG(INFO) << "start"; | ||||
| auto graph = ConstructKernelGraph(func_graph); | auto graph = ConstructKernelGraph(func_graph); | ||||
| // split switch | // split switch | ||||
| SplitGraphs(graph); | |||||
| SplitGraphs(NOT_NULL(graph)); | |||||
| // insert goto labels and label_sets | // insert goto labels and label_sets | ||||
| LinkChildGraphs(NOT_NULL(graph)); | LinkChildGraphs(NOT_NULL(graph)); | ||||
| // resource initialize | // resource initialize | ||||
| @@ -290,7 +307,8 @@ GraphId AscendSession::CompileGraph(NotNull<FuncGraphPtr> func_graph) { | |||||
| // assign label | // assign label | ||||
| AssignLabel(NOT_NULL(graph)); | AssignLabel(NOT_NULL(graph)); | ||||
| // recurse compile child graph | // recurse compile child graph | ||||
| RecurseCompileGraph(graph); | |||||
| std::set<KernelGraphPtr> memo; | |||||
| RecurseCompileGraph(NOT_NULL(graph), NOT_NULL(&memo)); | |||||
| // root graph valiate,include genearte execute order and so on | // root graph valiate,include genearte execute order and so on | ||||
| RootGraphExecutorValidate(NOT_NULL(graph)); | RootGraphExecutorValidate(NOT_NULL(graph)); | ||||
| // adjust kernel | // adjust kernel | ||||
| @@ -1423,24 +1441,43 @@ std::vector<AnfNodePtr> AscendSession::ConstructSplitedGraph(const KernelGraphPt | |||||
| } | } | ||||
| MS_LOG(INFO) << "Construct input of kernel graph:" << new_kernel_graph->graph_id(); | MS_LOG(INFO) << "Construct input of kernel graph:" << new_kernel_graph->graph_id(); | ||||
| std::vector<AnfNodePtr> call_node_inputs; | std::vector<AnfNodePtr> call_node_inputs; | ||||
| auto graph_inputs = new_kernel_graph->MutableInputs(); | |||||
| MS_EXCEPTION_IF_NULL(graph_inputs); | |||||
| std::vector<AnfNodePtr> new_graph_inputs; | |||||
| // create new parameter from cnode | // create new parameter from cnode | ||||
| for (auto &anf_node : list) { | for (auto &anf_node : list) { | ||||
| auto cnode = anf_node->cast<CNodePtr>(); | auto cnode = anf_node->cast<CNodePtr>(); | ||||
| for (size_t input_idx = 1; input_idx < cnode->inputs().size(); input_idx++) { | for (size_t input_idx = 1; input_idx < cnode->inputs().size(); input_idx++) { | ||||
| auto input = cnode->inputs()[input_idx]; | auto input = cnode->inputs()[input_idx]; | ||||
| MS_EXCEPTION_IF_NULL(input); | MS_EXCEPTION_IF_NULL(input); | ||||
| if (input->isa<Parameter>()) { | |||||
| graph_inputs->push_back(input); | |||||
| AnfNodePtr new_parameter = nullptr; | |||||
| // value node consider move to new graph | |||||
| if (input->isa<ValueNode>()) { | |||||
| cnode->set_input(input_idx, input); | |||||
| continue; | |||||
| } else if (input->isa<Parameter>()) { | |||||
| // parameter reuse and should attention mulptiple use of one parameter | |||||
| cnode->set_input(input_idx, input); | cnode->set_input(input_idx, input); | ||||
| new_parameter = input; | |||||
| } else if (AnfAlgo::GetGraphId(input.get()) != new_kernel_graph->graph_id()) { | } else if (AnfAlgo::GetGraphId(input.get()) != new_kernel_graph->graph_id()) { | ||||
| auto new_parameter = CreateNewParameterFromCNode(input, true, new_kernel_graph.get()); | |||||
| // if is cnode and not in current child graph | |||||
| new_parameter = CreateNewParameterFromCNode(input, true, new_kernel_graph.get()); | |||||
| cnode->set_input(input_idx, new_parameter); | cnode->set_input(input_idx, new_parameter); | ||||
| } else { | |||||
| // if is a cnode and in current graph | |||||
| continue; | |||||
| } | |||||
| // if mulptiple use of one parameter or cnode, only set one parameter in graph inputs and one arg in call node | |||||
| // args | |||||
| if (std::find(call_node_inputs.begin(), call_node_inputs.end(), new_parameter) == call_node_inputs.end()) { | |||||
| new_graph_inputs.push_back(new_parameter); | |||||
| call_node_inputs.push_back(input); | |||||
| } | } | ||||
| call_node_inputs.push_back(input); | |||||
| } | } | ||||
| } | } | ||||
| // set graph inputs of new graph | |||||
| auto graph_inputs = new_kernel_graph->MutableInputs(); | |||||
| MS_EXCEPTION_IF_NULL(graph_inputs); | |||||
| graph_inputs->clear(); | |||||
| std::copy(new_graph_inputs.begin(), new_graph_inputs.end(), std::back_inserter(*graph_inputs)); | |||||
| MS_LOG(INFO) << "Construct output of kernel graph:" << new_kernel_graph->graph_id(); | MS_LOG(INFO) << "Construct output of kernel graph:" << new_kernel_graph->graph_id(); | ||||
| auto make_tuple_primitve = NewValueNode(std::make_shared<Primitive>(prim::kPrimMakeTuple->name())); | auto make_tuple_primitve = NewValueNode(std::make_shared<Primitive>(prim::kPrimMakeTuple->name())); | ||||
| std::vector<AnfNodePtr> make_tuple_inputs = {make_tuple_primitve}; | std::vector<AnfNodePtr> make_tuple_inputs = {make_tuple_primitve}; | ||||
| @@ -1461,20 +1498,30 @@ std::vector<AnfNodePtr> AscendSession::ConstructSplitedGraph(const KernelGraphPt | |||||
| return call_node_inputs; | return call_node_inputs; | ||||
| } | } | ||||
| void AscendSession::SplitGraphs(const KernelGraphPtr &root_graph) { | |||||
| SplitGraph(root_graph); | |||||
| void AscendSession::SplitGraphs(NotNull<KernelGraphPtr> root_graph) { | |||||
| std::set<KernelGraphPtr> memo; | |||||
| // if root graph output is a call node ,the root graph is condition graph of 'if' sentence | |||||
| auto root_graph_output = AnfAlgo::VisitKernelWithReturnType(root_graph->output(), 0).first; | |||||
| if (AnfAlgo::CheckPrimitiveType(root_graph_output, prim::kPrimCall)) { | |||||
| SplitGraph(root_graph, {prim::kPrimReturn}); | |||||
| for (auto &child_graph : root_graph->child_graph_order()) { | |||||
| RecurseSplitGraph(NOT_NULL(child_graph), NOT_NULL(&memo)); | |||||
| } | |||||
| } else { | |||||
| RecurseSplitGraph(root_graph, NOT_NULL(&memo)); | |||||
| } | |||||
| memo.clear(); | |||||
| // replace the real input if the real input is a call | // replace the real input if the real input is a call | ||||
| RecurseToUpdateCallRealInput(root_graph.get()); | |||||
| RecurseToUpdateCallRealInput(root_graph, NOT_NULL(&memo)); | |||||
| } | } | ||||
| void AscendSession::SplitGraph(const KernelGraphPtr &graph) { | |||||
| void AscendSession::SplitGraph(NotNull<KernelGraphPtr> graph, const std::set<PrimitivePtr> &cut_prims) { | |||||
| MS_LOG(INFO) << "start,graph_id:" << graph->graph_id(); | MS_LOG(INFO) << "start,graph_id:" << graph->graph_id(); | ||||
| MS_EXCEPTION_IF_NULL(graph); | |||||
| auto apply_list = GetCNodes(TopoSort(graph->get_return())); | auto apply_list = GetCNodes(TopoSort(graph->get_return())); | ||||
| // update the root graph child graph order | // update the root graph child graph order | ||||
| AscendControlParser::UpdateChildGraphOrder(NOT_NULL(graph)); | |||||
| AscendControlParser::UpdateChildGraphOrder(graph); | |||||
| // get child list from current graph | // get child list from current graph | ||||
| std::vector<std::vector<CNodePtr>> child_graph_lists = GetChildList(*graph, apply_list); | |||||
| std::vector<std::vector<CNodePtr>> child_graph_lists = GetChildList(apply_list, cut_prims); | |||||
| auto bind_new_call_to_new_graph = [&](std::vector<CNodePtr> child_graph_list) -> AnfNodePtr { | auto bind_new_call_to_new_graph = [&](std::vector<CNodePtr> child_graph_list) -> AnfNodePtr { | ||||
| // if child graph list only has a call ,then return the exist call | // if child graph list only has a call ,then return the exist call | ||||
| if (child_graph_list.size() == 1 && AnfAlgo::CheckPrimitiveType(child_graph_list[0], prim::kPrimCall)) { | if (child_graph_list.size() == 1 && AnfAlgo::CheckPrimitiveType(child_graph_list[0], prim::kPrimCall)) { | ||||
| @@ -1521,20 +1568,22 @@ void AscendSession::SplitGraph(const KernelGraphPtr &graph) { | |||||
| pre_call_node = cur_call_node; | pre_call_node = cur_call_node; | ||||
| cur_call_node = *iter; | cur_call_node = *iter; | ||||
| if (pre_call_node != nullptr && cur_call_node != nullptr) { | if (pre_call_node != nullptr && cur_call_node != nullptr) { | ||||
| AscendControlParser::InsertControlDependToGraph(NOT_NULL(graph), NOT_NULL(cur_call_node), | |||||
| NOT_NULL(pre_call_node)); | |||||
| AscendControlParser::InsertControlDependToGraph(graph, NOT_NULL(cur_call_node), NOT_NULL(pre_call_node)); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| AscendControlParser::UpdateChildGraphOrder(NOT_NULL(graph)); | |||||
| UpdateRealInput(graph.get()); | |||||
| auto graph_name = std::string("./kernel-graph-").append(std::to_string(graph->graph_id())); | |||||
| DumpIR(graph_name, graph); | |||||
| AscendControlParser::UpdateChildGraphOrder(graph); | |||||
| UpdateRealInput(graph); | |||||
| MS_LOG(INFO) << "split graph[" << graph->graph_id() << "] end"; | MS_LOG(INFO) << "split graph[" << graph->graph_id() << "] end"; | ||||
| // recurse to split child graph | // recurse to split child graph | ||||
| } | |||||
| void AscendSession::RecurseSplitGraph(NotNull<KernelGraphPtr> graph, const NotNull<std::set<KernelGraphPtr> *> memo) { | |||||
| memo->insert(graph.get()); | |||||
| SplitGraph(graph, {prim::kPrimCall}); | |||||
| for (auto &child_graph : graph->child_graph_order()) { | for (auto &child_graph : graph->child_graph_order()) { | ||||
| if (child_graph != graph->parent_graph()) { | |||||
| SplitGraph(child_graph); | |||||
| if (memo->find(child_graph) == memo->end()) { | |||||
| RecurseSplitGraph(NOT_NULL(child_graph), memo); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -1545,13 +1594,14 @@ void AscendSession::RootGraphExecutorValidate(NotNull<KernelGraphPtr> graph) { | |||||
| AscendControlParser::ExecutorValidate(graph); | AscendControlParser::ExecutorValidate(graph); | ||||
| } | } | ||||
| void AscendSession::RecurseCompileGraph(const KernelGraphPtr &graph) { | |||||
| void AscendSession::RecurseCompileGraph(NotNull<KernelGraphPtr> graph, const NotNull<std::set<KernelGraphPtr> *> memo) { | |||||
| memo->insert(graph.get()); | |||||
| CompileChildGraph(graph); | CompileChildGraph(graph); | ||||
| for (auto child_graph : graph->child_graph_order()) { | for (auto child_graph : graph->child_graph_order()) { | ||||
| if (child_graph == graph->parent_graph()) { | |||||
| if (memo->find(child_graph) != memo->end()) { | |||||
| continue; | continue; | ||||
| } | } | ||||
| RecurseCompileGraph(child_graph); | |||||
| RecurseCompileGraph(NOT_NULL(child_graph), memo); | |||||
| } | } | ||||
| } | } | ||||
| @@ -98,18 +98,15 @@ class AscendSession : public SessionBasic { | |||||
| void SetFinalGraphOutput(const ValuePtr &value); | void SetFinalGraphOutput(const ValuePtr &value); | ||||
| void SetFinalGraphOutput(const VectorRef &vec_output); | void SetFinalGraphOutput(const VectorRef &vec_output); | ||||
| void SplitGraph(const KernelGraphPtr &graph); | |||||
| void SplitGraph(NotNull<KernelGraphPtr> graph, const std::set<PrimitivePtr> &cut_prims); | |||||
| // split graphs with recurse from root graph | // split graphs with recurse from root graph | ||||
| void SplitGraphs(const KernelGraphPtr &root_graph); | |||||
| void SplitGraphs(NotNull<KernelGraphPtr> root_graph); | |||||
| void LinkChildGraphs(NotNull<KernelGraphPtr> graph); | void LinkChildGraphs(NotNull<KernelGraphPtr> graph); | ||||
| void IRFusion(const KernelGraphPtr &graph) {} | |||||
| void SelectKernelGraphKernel(const KernelGraph &graph) {} | |||||
| void ConvertPredictModel(const KernelGraphPtr graph) {} | |||||
| void HardwareOptimizeGraphs(const KernelGraphPtr graph) {} | |||||
| void RootGraphExecutorValidate(NotNull<KernelGraphPtr> graph); | void RootGraphExecutorValidate(NotNull<KernelGraphPtr> graph); | ||||
| std::vector<AnfNodePtr> ConstructSplitedGraph(const KernelGraphPtr &new_kernel_graph, | std::vector<AnfNodePtr> ConstructSplitedGraph(const KernelGraphPtr &new_kernel_graph, | ||||
| const std::vector<CNodePtr> &list); | const std::vector<CNodePtr> &list); | ||||
| void RecurseCompileGraph(const KernelGraphPtr &graph); | |||||
| void RecurseCompileGraph(NotNull<KernelGraphPtr> graph, const NotNull<std::set<KernelGraphPtr> *> memo); | |||||
| void RecurseSplitGraph(NotNull<KernelGraphPtr> graph, const NotNull<std::set<KernelGraphPtr> *> memo); | |||||
| // merge execution order list of child graphs | // merge execution order list of child graphs | ||||
| void MergeGraphExecOrder(); | void MergeGraphExecOrder(); | ||||
| @@ -50,7 +50,7 @@ std::vector<AnfNodePtr> GetCallRealOutputs(const AnfNodePtr &call_node) { | |||||
| std::vector<AnfNodePtr> real_inputs; | std::vector<AnfNodePtr> real_inputs; | ||||
| auto child_graphs = AnfAlgo::GetCallNodeKernelGraph(item_with_index.first->cast<CNodePtr>()); | auto child_graphs = AnfAlgo::GetCallNodeKernelGraph(item_with_index.first->cast<CNodePtr>()); | ||||
| for (const auto &child_graph : child_graphs) { | for (const auto &child_graph : child_graphs) { | ||||
| if (AnfAlgo::IsWhileTrueGraph(child_graph)) { | |||||
| if (child_graph->get_output_null()) { | |||||
| continue; | continue; | ||||
| } | } | ||||
| auto real_input = child_graph->output(); | auto real_input = child_graph->output(); | ||||
| @@ -592,7 +592,11 @@ void KernelGraph::ReplaceNode(const AnfNodePtr &old_anf_node, AnfNodePtr new_anf | |||||
| MS_EXCEPTION_IF_NULL(output_node.first); | MS_EXCEPTION_IF_NULL(output_node.first); | ||||
| auto output_cnode = output_node.first->cast<CNodePtr>(); | auto output_cnode = output_node.first->cast<CNodePtr>(); | ||||
| MS_EXCEPTION_IF_NULL(output_cnode); | MS_EXCEPTION_IF_NULL(output_cnode); | ||||
| const auto &output_node_inputs = output_cnode->inputs(); | |||||
| auto &output_node_inputs = output_cnode->inputs(); | |||||
| // don't replace node if it is a control edge => output_node.second == 0 | |||||
| if (output_node.second == 0) { | |||||
| continue; | |||||
| } | |||||
| for (size_t i = 1; i < output_node_inputs.size(); i++) { | for (size_t i = 1; i < output_node_inputs.size(); i++) { | ||||
| if (output_node_inputs[i] == old_anf_node) { | if (output_node_inputs[i] == old_anf_node) { | ||||
| output_cnode->set_input(i, new_anf_node); | output_cnode->set_input(i, new_anf_node); | ||||
| @@ -686,10 +690,12 @@ std::set<AnfNodePtr> KernelGraph::GetRealInput(const AnfNodePtr ¶meter) { | |||||
| void KernelGraph::UpdateCallRealInput() { | void KernelGraph::UpdateCallRealInput() { | ||||
| MS_LOG(INFO) << "Update graph id: " << graph_id_; | MS_LOG(INFO) << "Update graph id: " << graph_id_; | ||||
| std::map<AnfNodePtr, std::set<AnfNodePtr>> real_inputs_map; | |||||
| std::vector<std::pair<AnfNodePtr, AnfNodePtr>> replace_list; | |||||
| for (auto &it : real_inputs_) { | for (auto &it : real_inputs_) { | ||||
| auto ¶meter = it.first; | |||||
| auto parameter = it.first; | |||||
| MS_EXCEPTION_IF_NULL(parameter); | MS_EXCEPTION_IF_NULL(parameter); | ||||
| auto &real_inputs = it.second; | |||||
| auto real_inputs = it.second; | |||||
| std::vector<AnfNodePtr> new_real_inputs; | std::vector<AnfNodePtr> new_real_inputs; | ||||
| std::set<AnfNodePtr> erase_real_inputs; | std::set<AnfNodePtr> erase_real_inputs; | ||||
| for (auto &real_input : real_inputs) { | for (auto &real_input : real_inputs) { | ||||
| @@ -711,10 +717,16 @@ void KernelGraph::UpdateCallRealInput() { | |||||
| << " insert real input:" << new_real_input->DebugString(); | << " insert real input:" << new_real_input->DebugString(); | ||||
| (void)real_inputs.insert(new_real_input); | (void)real_inputs.insert(new_real_input); | ||||
| if (new_real_input->isa<Parameter>()) { | if (new_real_input->isa<Parameter>()) { | ||||
| ReplaceNode(parameter, new_real_input); | |||||
| replace_list.emplace_back(parameter, new_real_input); | |||||
| parameter = new_real_input; | |||||
| } | } | ||||
| } | } | ||||
| real_inputs_map[parameter] = real_inputs; | |||||
| } | |||||
| for (auto [parameter, arg] : replace_list) { | |||||
| ReplaceNode(parameter, arg); | |||||
| } | } | ||||
| real_inputs_ = real_inputs_map; | |||||
| } | } | ||||
| std::string KernelGraph::ToString() const { return std::string("kernel_graph_").append(std::to_string(graph_id_)); } | std::string KernelGraph::ToString() const { return std::string("kernel_graph_").append(std::to_string(graph_id_)); } | ||||
| @@ -36,7 +36,7 @@ namespace session { | |||||
| using AnfWithOutIndex = std::pair<AnfNodePtr, size_t>; | using AnfWithOutIndex = std::pair<AnfNodePtr, size_t>; | ||||
| class KernelGraph : public FuncGraph { | class KernelGraph : public FuncGraph { | ||||
| public: | public: | ||||
| KernelGraph() : graph_id_(0) { | |||||
| KernelGraph() : graph_id_(0), start_label_(nullptr), end_goto_(nullptr), null_output_(false) { | |||||
| inputs_ = std::make_shared<std::vector<AnfNodePtr>>(); | inputs_ = std::make_shared<std::vector<AnfNodePtr>>(); | ||||
| execution_order_ = {}; | execution_order_ = {}; | ||||
| executable_ = true; | executable_ = true; | ||||
| @@ -134,6 +134,8 @@ class KernelGraph : public FuncGraph { | |||||
| CNodePtr get_start_label() { return start_label_; } | CNodePtr get_start_label() { return start_label_; } | ||||
| void set_end_goto(const CNodePtr &end_goto) { end_goto_ = end_goto; } | void set_end_goto(const CNodePtr &end_goto) { end_goto_ = end_goto; } | ||||
| CNodePtr get_end_goto() { return end_goto_; } | CNodePtr get_end_goto() { return end_goto_; } | ||||
| bool get_output_null() { return null_output_; } | |||||
| void set_output_null(bool is_output_null) { null_output_ = is_output_null; } | |||||
| private: | private: | ||||
| // remove value node form graph | // remove value node form graph | ||||
| @@ -188,6 +190,7 @@ class KernelGraph : public FuncGraph { | |||||
| CNodePtr start_label_; | CNodePtr start_label_; | ||||
| CNodePtr end_goto_; | CNodePtr end_goto_; | ||||
| bool null_output_; | |||||
| }; | }; | ||||
| } // namespace session | } // namespace session | ||||
| using KernelGraphPtr = std::shared_ptr<session::KernelGraph>; | using KernelGraphPtr = std::shared_ptr<session::KernelGraph>; | ||||
| @@ -543,15 +543,12 @@ KernelGraphPtr SessionBasic::ConstructKernelGraph(const AnfNodePtrList &lst, con | |||||
| std::shared_ptr<KernelGraph> SessionBasic::ConstructKernelGraph(const FuncGraphPtr &func_graph) { | std::shared_ptr<KernelGraph> SessionBasic::ConstructKernelGraph(const FuncGraphPtr &func_graph) { | ||||
| MS_EXCEPTION_IF_NULL(func_graph); | MS_EXCEPTION_IF_NULL(func_graph); | ||||
| if (front_backend_graph_map_.find(func_graph) != front_backend_graph_map_.end()) { | |||||
| MS_LOG(INFO) << "FuncGraph: " << func_graph->ToString() << " has been transformed to KernelGraph."; | |||||
| return front_backend_graph_map_[func_graph]; | |||||
| } | |||||
| auto node_list = TopoSort(func_graph->get_return()); | auto node_list = TopoSort(func_graph->get_return()); | ||||
| auto graph = NewKernelGraph(); | auto graph = NewKernelGraph(); | ||||
| front_backend_graph_map_[func_graph] = graph; | front_backend_graph_map_[func_graph] = graph; | ||||
| MS_LOG(INFO) << "Create graph: " << graph->graph_id(); | MS_LOG(INFO) << "Create graph: " << graph->graph_id(); | ||||
| bool is_trace_back = false; | |||||
| for (const auto &node : node_list) { | for (const auto &node : node_list) { | ||||
| MS_EXCEPTION_IF_NULL(node); | MS_EXCEPTION_IF_NULL(node); | ||||
| MS_LOG(DEBUG) << "Start create new cnode, node = " << node->DebugString(); | MS_LOG(DEBUG) << "Start create new cnode, node = " << node->DebugString(); | ||||
| @@ -564,8 +561,14 @@ std::shared_ptr<KernelGraph> SessionBasic::ConstructKernelGraph(const FuncGraphP | |||||
| (void)CreateNewValueNode(node, graph.get()); | (void)CreateNewValueNode(node, graph.get()); | ||||
| } else { | } else { | ||||
| // if input is a ValueNode<FuncGraph> | // if input is a ValueNode<FuncGraph> | ||||
| auto child_graph = ConstructKernelGraph(AnfAlgo::GetValueNodeFuncGraph(node)); | |||||
| auto new_value_node = CreateValueNodeKernelGraph(node, graph.get()); | |||||
| FuncGraphPtr child_graph = AnfAlgo::GetValueNodeFuncGraph(node); | |||||
| if (front_backend_graph_map_.find(child_graph) != front_backend_graph_map_.end()) { | |||||
| MS_LOG(INFO) << "FuncGraph: " << child_graph->ToString() << " has been transformed to KernelGraph."; | |||||
| is_trace_back = true; | |||||
| } else { | |||||
| (void)ConstructKernelGraph(child_graph); | |||||
| } | |||||
| (void)CreateValueNodeKernelGraph(node, graph.get()); | |||||
| } | } | ||||
| continue; | continue; | ||||
| } else { | } else { | ||||
| @@ -582,6 +585,8 @@ std::shared_ptr<KernelGraph> SessionBasic::ConstructKernelGraph(const FuncGraphP | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| // if a graph jump back unconditionally, return op of this graph will never be executed, so output is null. | |||||
| graph->set_output_null(is_trace_back); | |||||
| auto graph_inputs = graph->MutableInputs(); | auto graph_inputs = graph->MutableInputs(); | ||||
| MS_EXCEPTION_IF_NULL(graph_inputs); | MS_EXCEPTION_IF_NULL(graph_inputs); | ||||
| graph_inputs->clear(); | graph_inputs->clear(); | ||||