| @@ -942,7 +942,6 @@ std::vector<KernelGraphPtr> AnfRuntimeAlgorithm::GetCallNodeKernelGraph(const CN | |||
| } else if (input1->isa<CNode>() && AnfAlgo::CheckPrimitiveType(input1, prim::kPrimSwitch)) { | |||
| auto switch_node = input1->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(switch_node); | |||
| MS_LOG(INFO) << "switch : " << switch_node->DebugString(); | |||
| auto get_switch_kernel_graph = [&](size_t input_index) -> KernelGraphPtr { | |||
| auto partial = switch_node->input(input_index); | |||
| MS_EXCEPTION_IF_NULL(partial); | |||
| @@ -950,7 +949,6 @@ std::vector<KernelGraphPtr> AnfRuntimeAlgorithm::GetCallNodeKernelGraph(const CN | |||
| MS_EXCEPTION_IF_NULL(partial_cnode); | |||
| auto graph_node = partial_cnode->input(1); | |||
| MS_EXCEPTION_IF_NULL(graph_node); | |||
| MS_LOG(INFO) << graph_node->DebugString(); | |||
| auto graph_value_node = graph_node->cast<ValueNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(graph_value_node); | |||
| auto graph_value = graph_value_node->value(); | |||
| @@ -976,5 +974,17 @@ bool AnfRuntimeAlgorithm::IsSwitchCall(const CNodePtr &call_node) { | |||
| } | |||
| MS_LOG(EXCEPTION) << "Unexpected input1 of call node,input1:" << input1->DebugString(); | |||
| } | |||
| bool AnfRuntimeAlgorithm::IsWhileTrueGraph(const KernelGraphPtr &child_graph) { | |||
| auto call_nodes = child_graph->FindNodeByPrimitive(prim::kPrimCall); | |||
| for (const auto &call_node : call_nodes) { | |||
| auto graphs = GetCallNodeKernelGraph(call_node); | |||
| if (graphs.size() == 1 && graphs[0] == child_graph->parent_graph()) { | |||
| return true; | |||
| } | |||
| } | |||
| return false; | |||
| } | |||
| } // namespace session | |||
| } // namespace mindspore | |||
| @@ -185,6 +185,7 @@ class AnfRuntimeAlgorithm { | |||
| static FuncGraphPtr GetValueNodeFuncGraph(const AnfNodePtr &node); | |||
| static std::vector<KernelGraphPtr> GetCallNodeKernelGraph(const CNodePtr &call_node); | |||
| static bool IsSwitchCall(const CNodePtr &call_node); | |||
| static bool IsWhileTrueGraph(const KernelGraphPtr &child_graph); | |||
| }; | |||
| } // namespace session | |||
| using AnfAlgo = session::AnfRuntimeAlgorithm; | |||
| @@ -18,6 +18,7 @@ | |||
| #include <map> | |||
| #include <tuple> | |||
| #include <set> | |||
| #include <list> | |||
| #include "operator/ops.h" | |||
| #include "ir/meta_tensor.h" | |||
| #include "ir/anf.h" | |||
| @@ -160,7 +161,7 @@ void ClearRunOpMemoryResource(const KernelGraphPtr &kernel_graph) { | |||
| std::vector<CNodePtr> GetCNodes(const std::vector<AnfNodePtr> &anf_nodes) { | |||
| std::vector<CNodePtr> cnodes = {}; | |||
| size_t i = 0; | |||
| for (auto anf : anf_nodes) { | |||
| for (const auto &anf : anf_nodes) { | |||
| MS_LOG(INFO) << "apply_list[" << i++ << "] = " << anf->DebugString(); | |||
| MS_EXCEPTION_IF_NULL(anf); | |||
| if (anf->isa<CNode>()) { | |||
| @@ -192,6 +193,8 @@ std::vector<std::vector<CNodePtr>> GetChildList(const KernelGraph &cur_graph, co | |||
| return ret; | |||
| } | |||
| // if a call has kernel input, it's a child graph split from ME, so these kernel input should be set into real input of | |||
| // graph.For example, call input = (prim,graph,kernel1,kernel2),then real_input = [kernel1,kernel2] | |||
| void UpdateRealInput(KernelGraph *graph) { | |||
| auto call_nodes = graph->FindNodeByPrimitive(prim::kPrimCall); | |||
| auto bind_call_partial_with_parameter = [&](const std::vector<AnfNodePtr> ¶meters, | |||
| @@ -239,6 +242,15 @@ void UpdateRealInput(KernelGraph *graph) { | |||
| } | |||
| } | |||
| } | |||
| void RecurseToUpdateCallRealInput(KernelGraph *graph) { | |||
| MS_EXCEPTION_IF_NULL(graph); | |||
| MS_LOG(INFO) << "start graph id:" << graph->graph_id(); | |||
| graph->UpdateCallRealInput(); | |||
| for (auto &child_graph : graph->child_graph_order()) { | |||
| RecurseToUpdateCallRealInput(child_graph.get()); | |||
| } | |||
| } | |||
| } // namespace | |||
| GraphId AscendSession::CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) { | |||
| @@ -254,7 +266,7 @@ GraphId AscendSession::CompileGraph(NotNull<FuncGraphPtr> func_graph) { | |||
| MS_LOG(INFO) << "start"; | |||
| auto graph = ConstructKernelGraph(func_graph); | |||
| // split switch | |||
| SplitGraph(graph); | |||
| SplitGraphs(graph); | |||
| // insert goto labels and label_sets | |||
| LinkChildGraphs(NOT_NULL(graph)); | |||
| // resource initialize | |||
| @@ -1366,8 +1378,8 @@ void AscendSession::SyncInitialTenosrToDevice() { | |||
| } | |||
| } | |||
| KernelGraphPtr AscendSession::SplitKernelGraph(const KernelGraphPtr &new_kernel_graph, | |||
| const std::vector<CNodePtr> &list) { | |||
| KernelGraphPtr AscendSession::ConstructSplitedGraph(const KernelGraphPtr &new_kernel_graph, | |||
| const std::vector<CNodePtr> &list) { | |||
| MS_EXCEPTION_IF_NULL(new_kernel_graph); | |||
| MS_LOG(INFO) << "start split kernel graph:" << new_kernel_graph->graph_id(); | |||
| // count the output of every anf node | |||
| @@ -1376,9 +1388,6 @@ KernelGraphPtr AscendSession::SplitKernelGraph(const KernelGraphPtr &new_kernel_ | |||
| for (auto &input : anf_node->inputs()) { | |||
| (void)has_output_nodes.insert(input); | |||
| } | |||
| if (AnfAlgo::CheckPrimitiveType(anf_node, prim::kPrimReturn)) { | |||
| new_kernel_graph->set_return(anf_node->cast<CNodePtr>()); | |||
| } | |||
| } | |||
| MS_LOG(INFO) << "Construct input of kernel graph:" << new_kernel_graph->graph_id(); | |||
| // create new parameter from cnode | |||
| @@ -1386,6 +1395,7 @@ KernelGraphPtr AscendSession::SplitKernelGraph(const KernelGraphPtr &new_kernel_ | |||
| auto cnode = anf_node->cast<CNodePtr>(); | |||
| for (size_t input_idx = 1; input_idx < cnode->inputs().size(); input_idx++) { | |||
| auto input = cnode->inputs()[input_idx]; | |||
| MS_EXCEPTION_IF_NULL(input); | |||
| if (!input->isa<CNode>()) { | |||
| cnode->set_input(input_idx, input); | |||
| continue; | |||
| @@ -1417,6 +1427,12 @@ KernelGraphPtr AscendSession::SplitKernelGraph(const KernelGraphPtr &new_kernel_ | |||
| return new_kernel_graph; | |||
| } | |||
| void AscendSession::SplitGraphs(const KernelGraphPtr &root_graph) { | |||
| SplitGraph(root_graph); | |||
| // replace the real input if the real input is a call | |||
| RecurseToUpdateCallRealInput(root_graph.get()); | |||
| } | |||
| void AscendSession::SplitGraph(const KernelGraphPtr &graph) { | |||
| MS_LOG(INFO) << "start,graph_id:" << graph->graph_id(); | |||
| MS_EXCEPTION_IF_NULL(graph); | |||
| @@ -1426,6 +1442,7 @@ void AscendSession::SplitGraph(const KernelGraphPtr &graph) { | |||
| // get child list from current graph | |||
| std::vector<std::vector<CNodePtr>> child_graph_lists = GetChildList(*graph, apply_list); | |||
| auto bind_new_call_to_new_graph = [&](std::vector<CNodePtr> child_graph_list) -> AnfNodePtr { | |||
| // if child graph list only has a call ,then return the exist call | |||
| if (child_graph_list.size() == 1 && AnfAlgo::CheckPrimitiveType(child_graph_list[0], prim::kPrimCall)) { | |||
| return child_graph_list[0]; | |||
| } | |||
| @@ -1440,22 +1457,22 @@ void AscendSession::SplitGraph(const KernelGraphPtr &graph) { | |||
| for (auto &child_graph_node : child_graph_list) { | |||
| AnfAlgo::SetGraphId(child_graph->graph_id(), child_graph_node.get()); | |||
| } | |||
| SplitKernelGraph(child_graph, child_graph_list); | |||
| ConstructSplitedGraph(child_graph, child_graph_list); | |||
| auto new_call = graph->NewCNode(new_call_input); | |||
| AnfAlgo::SetNodeAttr("graph id", MakeValue(graph->graph_id()), new_call); | |||
| return new_call; | |||
| }; | |||
| if (child_graph_lists.size() > 1) { | |||
| std::list<AnfNodePtr> depend_input = {}; | |||
| for (size_t call_index = 0; call_index < child_graph_lists.size(); call_index++) { | |||
| auto call_node = bind_new_call_to_new_graph(child_graph_lists[call_index]); | |||
| if (call_index == 0) { | |||
| auto new_return_primitive = | |||
| graph->NewValueNode(NewValueNode(std::make_shared<Primitive>(prim::kPrimReturn->name()))); | |||
| graph->set_return(graph->NewCNode({new_return_primitive, call_node})); | |||
| continue; | |||
| } | |||
| InsertDependToGraph(graph->graph_id(), call_node); | |||
| depend_input.push_front(call_node); | |||
| } | |||
| depend_input.push_front(graph->NewValueNode(NewValueNode(std::make_shared<Primitive>(prim::kPrimDepend->name())))); | |||
| auto depend = graph->NewCNode(std::vector<AnfNodePtr>(depend_input.begin(), depend_input.end())); | |||
| auto new_return_primitive = | |||
| graph->NewValueNode(NewValueNode(std::make_shared<Primitive>(prim::kPrimReturn->name()))); | |||
| graph->set_return(graph->NewCNode({new_return_primitive, depend})); | |||
| } | |||
| graph->UpdateChildGraphOrder(); | |||
| UpdateRealInput(graph.get()); | |||
| @@ -97,15 +97,16 @@ class AscendSession : public SessionBasic { | |||
| void SetFinalGraphOutput(const VectorRef &vec_output); | |||
| void SplitGraph(const KernelGraphPtr &graph); | |||
| // split graphs with recurse from root graph | |||
| void SplitGraphs(const KernelGraphPtr &root_graph); | |||
| void LinkChildGraphs(NotNull<KernelGraphPtr> graph); | |||
| void IRFusion(const KernelGraphPtr &graph) {} | |||
| void SelectKernelGraphKernel(const KernelGraph &graph) {} | |||
| void ConvertPredictModel(const KernelGraphPtr graph) {} | |||
| void HardwareOptimizeGraphs(const KernelGraphPtr graph) {} | |||
| void RootGraphExecutorValidate(KernelGraph *graph) {} | |||
| void RecurseUpdateAllChildGraohOrder(KernelGraph *root_graph); | |||
| KernelGraphPtr SplitKernelGraph(const KernelGraphPtr &new_kernel_graph, const std::vector<CNodePtr> &list); | |||
| KernelGraphPtr ConstructSplitedGraph(const KernelGraphPtr &new_kernel_graph, const std::vector<CNodePtr> &list); | |||
| void ChildGraphCommunicationDecrease(std::vector<std::vector<AnfNodePtr>> *anf_node_lists); | |||
| // merge execution order list of child graphs | |||
| @@ -39,16 +39,35 @@ void PushNoVisitedNode(const AnfNodePtr &node, std::queue<AnfNodePtr> *que, | |||
| MS_LOG(DEBUG) << "Push que:" << node->DebugString(); | |||
| } | |||
| } | |||
| std::vector<AnfNodePtr> GetCallRealOutputs(const AnfNodePtr &call_node) { | |||
| auto item_with_index = AnfAlgo::VisitKernelWithReturnType(call_node, 0); | |||
| MS_EXCEPTION_IF_NULL(item_with_index.first); | |||
| if (!AnfAlgo::CheckPrimitiveType(item_with_index.first, prim::kPrimCall)) { | |||
| return {item_with_index.first}; | |||
| } | |||
| std::vector<AnfNodePtr> real_inputs; | |||
| auto child_graphs = AnfAlgo::GetCallNodeKernelGraph(item_with_index.first->cast<CNodePtr>()); | |||
| for (const auto &child_graph : child_graphs) { | |||
| if (AnfAlgo::IsWhileTrueGraph(child_graph)) { | |||
| continue; | |||
| } | |||
| auto real_input = child_graph->output(); | |||
| auto child_real_inputs = GetCallRealOutputs(real_input); | |||
| std::copy(child_real_inputs.begin(), child_real_inputs.end(), std::back_inserter(real_inputs)); | |||
| } | |||
| return real_inputs; | |||
| } | |||
| } // namespace | |||
| std::vector<AnfNodePtr> KernelGraph::outputs() const { | |||
| MS_EXCEPTION_IF_NULL(output()); | |||
| if (IsPrimitiveCNode(output(), prim::kPrimMakeTuple)) { | |||
| auto graph_output = output(); | |||
| if (IsPrimitiveCNode(graph_output, prim::kPrimMakeTuple)) { | |||
| auto make_tuple = output()->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(make_tuple); | |||
| auto &inputs = make_tuple->inputs(); | |||
| return std::vector<AnfNodePtr>(inputs.begin() + 1, inputs.end()); | |||
| } | |||
| return std::vector<AnfNodePtr>(); | |||
| return std::vector<AnfNodePtr>(1, graph_output); | |||
| } | |||
| void KernelGraph::VisitNodeDescendants(const AnfNodePtr &node, std::queue<AnfNodePtr> *visit_queue, | |||
| @@ -587,6 +606,9 @@ void KernelGraph::UpdateExecuteKernelStreamLabel() { | |||
| void KernelGraph::UpdateChildGraphOrder() { | |||
| MS_LOG(INFO) << "graph id:" << graph_id_; | |||
| auto call_nodes = FindNodeByPrimitive(std::make_shared<Primitive>(prim::kPrimCall->name())); | |||
| for (auto &old_child_graph : child_graph_order_) { | |||
| old_child_graph->set_parent_graph(nullptr); | |||
| } | |||
| child_graph_order_.clear(); | |||
| for (auto &call_node : call_nodes) { | |||
| MS_EXCEPTION_IF_NULL(call_node); | |||
| @@ -640,6 +662,9 @@ std::set<AnfNodePtr> KernelGraph::GetRealInput(const AnfNodePtr ¶meter) { | |||
| } | |||
| void KernelGraph::SetRealInput(const AnfNodePtr ¶meter, const AnfNodePtr &arg) { | |||
| MS_EXCEPTION_IF_NULL(parameter); | |||
| MS_EXCEPTION_IF_NULL(arg); | |||
| MS_LOG(INFO) << "parameter: " << parameter->DebugString() << ", real input : " << arg->DebugString(); | |||
| MS_EXCEPTION_IF_NULL(parameter); | |||
| MS_EXCEPTION_IF_NULL(arg); | |||
| if (real_inputs_.find(parameter) == real_inputs_.end()) { | |||
| @@ -649,6 +674,41 @@ void KernelGraph::SetRealInput(const AnfNodePtr ¶meter, const AnfNodePtr &ar | |||
| (void)args.insert(arg); | |||
| } | |||
| void KernelGraph::UpdateCallRealInput() { | |||
| MS_LOG(INFO) << "Update graph id: " << graph_id_; | |||
| for (auto &it : real_inputs_) { | |||
| auto ¶meter = it.first; | |||
| MS_EXCEPTION_IF_NULL(parameter); | |||
| auto &real_inputs = it.second; | |||
| std::set<AnfNodePtr> new_real_inputs; | |||
| std::set<AnfNodePtr> erase_real_inputs; | |||
| for (auto &real_input : real_inputs) { | |||
| // if real input is a call node ,find the child graph output act as the new real input | |||
| auto item_with_index = AnfAlgo::VisitKernelWithReturnType(real_input, 0); | |||
| MS_EXCEPTION_IF_NULL(item_with_index.first); | |||
| if (AnfAlgo::CheckPrimitiveType(item_with_index.first, prim::kPrimCall)) { | |||
| MS_LOG(INFO) << "paramter: " << parameter->DebugString() | |||
| << " erase real input:" << item_with_index.first->DebugString(); | |||
| (void)erase_real_inputs.insert(item_with_index.first); | |||
| auto call_node_outputs = GetCallRealOutputs(item_with_index.first); | |||
| for (auto &call_node_output : call_node_outputs) { | |||
| MS_EXCEPTION_IF_NULL(call_node_output); | |||
| MS_LOG(INFO) << "paramter: " << parameter->DebugString() | |||
| << " insert real input:" << call_node_output->DebugString(); | |||
| (void)new_real_inputs.insert(call_node_output); | |||
| } | |||
| continue; | |||
| } | |||
| for (auto &erase_node : erase_real_inputs) { | |||
| (void)real_inputs.erase(erase_node); | |||
| } | |||
| for (auto &new_real_input : new_real_inputs) { | |||
| (void)real_inputs.insert(new_real_input); | |||
| } | |||
| } | |||
| } | |||
| } | |||
| std::string KernelGraph::ToString() const { return std::string("kernel_graph_").append(std::to_string(graph_id_)); } | |||
| } // namespace session | |||
| } // namespace mindspore | |||
| @@ -127,6 +127,8 @@ class KernelGraph : public FuncGraph { | |||
| void SetRealInput(const AnfNodePtr ¶meter, const AnfNodePtr &arg); | |||
| // used to dump ir | |||
| std::string ToString() const override; | |||
| // update the real input if the node is a call | |||
| void UpdateCallRealInput(); | |||
| void set_start_label(const CNodePtr &start_label) { start_label_ = start_label; } | |||
| CNodePtr get_start_label() { return start_label_; } | |||
| @@ -640,16 +640,6 @@ std::shared_ptr<KernelGraph> SessionBasic::ConstructKernelGraph(const FuncGraphP | |||
| MS_EXCEPTION_IF_NULL(func_graph_node); | |||
| auto sub_func_graph = AnfAlgo::GetValueNodeFuncGraph(func_graph_node); | |||
| ConstructKernelGraph(sub_func_graph); | |||
| } else if (prim->name() == kReturnOpName) { | |||
| std::vector<AnfNodePtr> outputs; | |||
| auto inputs = cnode->inputs(); | |||
| if (inputs.size() < 2) { | |||
| MS_LOG(EXCEPTION) << "CNode[return] must have two inputs at least, actual inputs size is " << inputs.size(); | |||
| } | |||
| (void)std::copy(inputs.begin() + 1, inputs.end(), std::back_inserter(outputs)); | |||
| // add a make_tuple before return as graph output | |||
| graph->set_output(ConstructOutput(outputs, graph)); | |||
| continue; | |||
| } | |||
| } | |||
| @@ -659,6 +649,9 @@ std::shared_ptr<KernelGraph> SessionBasic::ConstructKernelGraph(const FuncGraphP | |||
| new_cnode->set_abstract(cnode->abstract()); | |||
| new_cnode->set_scope(cnode->scope()); | |||
| graph->FrontBackendlMapAdd(node, new_cnode); | |||
| if (AnfAlgo::CheckPrimitiveType(new_cnode, prim::kPrimReturn)) { | |||
| graph->set_return(new_cnode); | |||
| } | |||
| } | |||
| } | |||