From b24943d496ab46df595596c6651a2afa49d03d71 Mon Sep 17 00:00:00 2001 From: wenchunjiang Date: Fri, 17 Jul 2020 15:39:08 +0800 Subject: [PATCH] adapte to remove inline merge me commit for remove inline deal witch multiple cases of switch in ConstructKernelGraph deal with switch and call cases in ConstructKernelGraph fix bug and rebase master ConstructKernelGraph adapte to remove inline fix InsertMultipleAssignToGraph bug add graph input to new graph which is created for switch input replace CreateNewParameterFromCNode to NewParameter in order to set new parameter's abstract and kernel_info avoids create a new switch repeatedly when the cnode is a call switch without real input null pointer check update frontend code Revert "update frontend code" This reverts commit ce1f600d1e9b4b47d9b81122f981bbbe505dd250. update frontend code PR_2948 fix bug of CheckLabalIndex handle switch_layer in ConstructKernelGraph add attr for assign node to avoid erasing by cse pass cherry-pick ms commit[59b35f690ddcc94ff35a4f4eaf3816121b32235b]:temporary avoid list getitem problem rebase master Revert "cherry-pick ms commit[59b35f690ddcc94ff35a4f4eaf3816121b32235b]:temporary avoid list getitem problem" This reverts commit 74c258f94260ca0769a1ef69c6ef8e831c301dbf. Revert "handle switch_layer in ConstructKernelGraph" This reverts commit cb5367f02d69facbca8d39e9234c501608aee27f. Revert "update frontend code PR_2948" This reverts commit 234ac583400a96a8ddd641f7a722e1ccd5e056c6. Revert "merge me commit for remove inline" This reverts commit 55c0ebd42b6699c7686f5ce585e745f87dd42280. fix diff after rebase master doing remove inline in me overwrite FindNodePrimitive Revert "doing remove inline in me" This reverts commit b42e893125bc624d323e855ac6ae615333c06e65. --- .../backend/session/anf_runtime_algorithm.cc | 27 ++-- .../backend/session/anf_runtime_algorithm.h | 2 +- .../backend/session/ascend_control_parser.cc | 110 +++++++--------- .../backend/session/ascend_control_parser.h | 6 +- .../ccsrc/backend/session/ascend_session.cc | 4 +- .../ccsrc/backend/session/kernel_graph.cc | 19 ++- .../ccsrc/backend/session/kernel_graph.h | 1 + .../ccsrc/backend/session/session_basic.cc | 119 +++++++++++------- .../ccsrc/backend/session/session_basic.h | 10 +- mindspore/ccsrc/utils/utils.h | 5 +- 10 files changed, 162 insertions(+), 141 deletions(-) mode change 100755 => 100644 mindspore/ccsrc/backend/session/session_basic.h diff --git a/mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc b/mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc index 34471537db..8a212be1f0 100644 --- a/mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc +++ b/mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc @@ -1031,31 +1031,29 @@ FuncGraphPtr AnfRuntimeAlgorithm::GetValueNodeFuncGraph(const AnfNodePtr &node) return func_graph; } -std::vector AnfRuntimeAlgorithm::GetCallNodeKernelGraph(const CNodePtr &call_node) { - MS_EXCEPTION_IF_NULL(call_node); - if (!AnfAlgo::CheckPrimitiveType(call_node, std::make_shared("call"))) { - MS_LOG(EXCEPTION) << "Anf node: " << call_node->DebugString() << "is not a call node."; +std::vector AnfRuntimeAlgorithm::GetCallSwitchKernelGraph(const CNodePtr &cnode) { + MS_EXCEPTION_IF_NULL(cnode); + if (!(AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimCall) || AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimSwitch))) { + MS_LOG(EXCEPTION) << "Node: " << cnode->DebugString() << "is not a call or switch node."; } - auto input1 = call_node->input(1); - MS_EXCEPTION_IF_NULL(input1); - if (input1->isa()) { + if (AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimCall)) { + auto input1 = cnode->input(kCallKernelGraphIndex); + MS_EXCEPTION_IF_NULL(input1); auto value_node = input1->cast(); MS_EXCEPTION_IF_NULL(value_node); auto kernel_graph = value_node->value(); MS_EXCEPTION_IF_NULL(kernel_graph); return {kernel_graph->cast()}; - } else if (input1->isa() && AnfAlgo::CheckPrimitiveType(input1, prim::kPrimSwitch)) { - auto switch_node = input1->cast(); - MS_EXCEPTION_IF_NULL(switch_node); - auto get_switch_kernel_graph = [switch_node](size_t input_index) -> KernelGraphPtr { - auto partial = switch_node->input(input_index); + } else if (AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimSwitch)) { + auto get_switch_kernel_graph = [cnode](size_t input_index) -> KernelGraphPtr { + auto partial = cnode->input(input_index); MS_EXCEPTION_IF_NULL(partial); if (IsValueNode(partial)) { return GetValueNode(partial); } auto partial_cnode = partial->cast(); MS_EXCEPTION_IF_NULL(partial_cnode); - auto graph_node = partial_cnode->input(1); + auto graph_node = partial_cnode->input(kCallKernelGraphIndex); MS_EXCEPTION_IF_NULL(graph_node); auto graph_value_node = graph_node->cast(); MS_EXCEPTION_IF_NULL(graph_value_node); @@ -1064,7 +1062,8 @@ std::vector AnfRuntimeAlgorithm::GetCallNodeKernelGraph(const CN auto child_graph = graph_value->cast(); return child_graph; }; - return {get_switch_kernel_graph(2), get_switch_kernel_graph(3)}; + return {get_switch_kernel_graph(kSwitchTrueKernelGraphIndex), + get_switch_kernel_graph(kSwitchFalseKernelGraphIndex)}; } return {}; } diff --git a/mindspore/ccsrc/backend/session/anf_runtime_algorithm.h b/mindspore/ccsrc/backend/session/anf_runtime_algorithm.h index c08819e2dc..89a5ecbce3 100644 --- a/mindspore/ccsrc/backend/session/anf_runtime_algorithm.h +++ b/mindspore/ccsrc/backend/session/anf_runtime_algorithm.h @@ -201,7 +201,7 @@ class AnfRuntimeAlgorithm { static bool IsCommunicationOp(const AnfNodePtr &node); static bool IsGetNext(const NotNull &node); static FuncGraphPtr GetValueNodeFuncGraph(const AnfNodePtr &node); - static std::vector GetCallNodeKernelGraph(const CNodePtr &call_node); + static std::vector GetCallSwitchKernelGraph(const CNodePtr &cnode); static bool IsSwitchCall(const CNodePtr &call_node); static bool IsScalarInput(const CNodePtr &cnode, size_t index); static bool IsScalarOutput(const CNodePtr &cnode, size_t index); diff --git a/mindspore/ccsrc/backend/session/ascend_control_parser.cc b/mindspore/ccsrc/backend/session/ascend_control_parser.cc index 81516d9481..e1d0a02a80 100644 --- a/mindspore/ccsrc/backend/session/ascend_control_parser.cc +++ b/mindspore/ccsrc/backend/session/ascend_control_parser.cc @@ -361,27 +361,22 @@ void AscendControlParser::ExecutorValidate(NotNull root_graph) { } } -std::vector>> AscendControlParser::ParseCallNode( - NotNull call_node) { +std::vector>> AscendControlParser::ParseCallSwitchNode( + NotNull cnode) { std::vector>> ret; - if (!IsPrimitiveCNode(call_node.get(), prim::kPrimCall)) { - MS_LOG(EXCEPTION) << "Node " << call_node->DebugString() << " is not a call node."; - } - if (call_node->size() <= kCNodeCallArg) { - MS_LOG(EXCEPTION) << "Node " << call_node->DebugString() << " has invalid inputs size " << call_node->size(); - } - const std::vector &call_node_inputs = call_node->inputs(); - auto call_arg = call_node_inputs[kCNodeCallArg]; - MS_EXCEPTION_IF_NULL(call_arg); - if (IsValueNode(call_arg)) { + + if (IsPrimitiveCNode(cnode.get(), prim::kPrimCall)) { + if (cnode->size() <= kCNodeCallArg) { + MS_LOG(EXCEPTION) << "Call node " << cnode->DebugString() << " has invalid inputs size " << cnode->size(); + } + auto call_arg = cnode->input(kCNodeCallArg); + MS_EXCEPTION_IF_NULL(call_arg); ret.emplace_back(GetValueNode(call_arg), - std::vector(call_node_inputs.begin() + kCNodeCallArg + 1, call_node_inputs.end())); - } else if (IsPrimitiveCNode(call_arg, prim::kPrimSwitch)) { - auto switch_cnode = call_arg->cast(); - MS_EXCEPTION_IF_NULL(switch_cnode); - const std::vector &switch_inputs = switch_cnode->inputs(); - if (switch_inputs.size() <= kCNodeSwitchCond) { - MS_LOG(EXCEPTION) << "Node " << switch_cnode->DebugString() << " has invalid inputs size " + std::vector(cnode->inputs().begin() + kCNodeCallArg + 1, cnode->inputs().end())); + } else if (IsPrimitiveCNode(cnode.get(), prim::kPrimSwitch)) { + const std::vector &switch_inputs = cnode->inputs(); + if (switch_inputs.size() < kCNodeSwitchLength) { + MS_LOG(EXCEPTION) << "Switch node " << cnode->DebugString() << " has invalid inputs size " << switch_inputs.size(); } for (auto iter = switch_inputs.begin() + kCNodeSwitchCond + 1; iter != switch_inputs.end(); ++iter) { @@ -389,7 +384,7 @@ std::vector>> AscendControlPar ret.emplace_back(target_graph, args); } } else { - MS_LOG(EXCEPTION) << "Unsupport call node: " << call_node->DebugString(5); + MS_LOG(EXCEPTION) << "Unsupport call node: " << cnode->DebugString(5); } return ret; } @@ -406,11 +401,11 @@ void AscendControlParser::ChildGraphDataAssign( const std::vector &nodes = kg->execution_order(); for (auto &node : nodes) { - if (!IsPrimitiveCNode(node, prim::kPrimCall)) { + if (!(IsPrimitiveCNode(node, prim::kPrimCall) || IsPrimitiveCNode(node, prim::kPrimSwitch))) { continue; } - auto child_graph_list = ParseCallNode(NOT_NULL(node)); + auto child_graph_list = ParseCallSwitchNode(NOT_NULL(node)); for (auto &[child_graph, args] : child_graph_list) { MS_EXCEPTION_IF_NULL(child_graph); const std::vector ¶ms = child_graph->inputs(); @@ -425,7 +420,6 @@ void AscendControlParser::ChildGraphDataAssign( link_list->emplace_back(args[i], params[i]); continue; } - InsertMultipleAssignToGraph(kg, node, NOT_NULL(args[i]), NOT_NULL(params[i])); } } @@ -475,30 +469,20 @@ NotNull AscendControlParser::ProcessKernelGraph(NotNullsize() < kCNodePrim + 1) { - MS_LOG(EXCEPTION) << "Inputs of apply node is empty"; - } - AnfNodePtr fn = cnode->input(kAnfPrimitiveIndex); - if (!IsPrimitive(fn, prim::kPrimCall) || cnode->size() < kCNodeCallArg + 1) { - MS_LOG(DEBUG) << "Continue node " << cnode->DebugString(); + if (!(AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimCall) || + AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimSwitch) || + AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimSwitchLayer))) { continue; } - AnfNodePtr arg = cnode->input(kFirstDataInputIndex); - MS_EXCEPTION_IF_NULL(arg); - if (IsValueNode(arg)) { + + if (IsPrimitiveCNode(cnode, prim::kPrimCall)) { RecurseCall(kg, NOT_NULL(cnode), GetNextRealKernel(nodes, i + 1), memo); - } else if (!arg->isa()) { - MS_LOG(EXCEPTION) << "Unknown type call node " << cnode->DebugString(); - } else if (IsPrimitiveCNode(arg->cast(), prim::kPrimSwitch)) { - auto arg_cnode = arg->cast(); - MS_EXCEPTION_IF_NULL(arg_cnode); - cnode->set_inputs(arg_cnode->inputs()); + } else if (IsPrimitiveCNode(cnode, prim::kPrimSwitch)) { RecurseSwitch(kg, NOT_NULL(cnode), GetNextRealKernel(nodes, i + 1), memo); - } else if (IsPrimitiveCNode(arg->cast(), prim::kPrimSwitchLayer)) { - auto arg_cnode = arg->cast(); - MS_EXCEPTION_IF_NULL(arg_cnode); - cnode->set_inputs(arg_cnode->inputs()); + } else if (IsPrimitiveCNode(cnode, prim::kPrimSwitchLayer)) { RecurseSwitchLayer(kg, NOT_NULL(cnode), GetNextRealKernel(nodes, i + 1), memo); + } else { + MS_LOG(EXCEPTION) << "Unexpected node: " << cnode->DebugString(); } } kg->SetExecOrderByDefault(); @@ -699,31 +683,22 @@ void AscendControlParser::InsertMultipleAssignToGraph(NotNull fr continue; } const auto &from_graph_exe_order = from_graph->execution_order(); - std::vector real_exe_order(from_graph_exe_order.size()); - size_t real_exe_order_size = 0; - std::copy_if(from_graph_exe_order.begin(), from_graph_exe_order.end(), real_exe_order.begin(), - [&real_exe_order_size](const CNodePtr &node) -> bool { - return (IsPrimitiveCNode(node, prim::kPrimSwitch) || IsPrimitiveCNode(node, prim::kPrimPartial)) - ? false - : (++real_exe_order_size, true); - }); - real_exe_order.resize(real_exe_order_size); if (jump_node == nullptr) { - if (!real_exe_order.empty()) { - InsertControlDependToGraph(from_graph, NOT_NULL(*(real_exe_order.rbegin())), NOT_NULL(assign_node)); + if (!from_graph_exe_order.empty()) { + InsertControlDependToGraph(from_graph, NOT_NULL(*(from_graph_exe_order.rbegin())), NOT_NULL(assign_node)); } else { InsertDependToGraph(from_graph, NOT_NULL(assign_node)); } continue; } - auto jump_node_iter = std::find(real_exe_order.begin(), real_exe_order.end(), jump_node); - if (jump_node_iter == real_exe_order.end()) { + auto jump_node_iter = std::find(from_graph_exe_order.begin(), from_graph_exe_order.end(), jump_node); + if (jump_node_iter == from_graph_exe_order.end()) { MS_LOG(EXCEPTION) << "Cannot find jump node " << jump_node->DebugString() << " in graph " << from_graph->ToString(); } // insert assign between jump_node -1 and jump_node - if (jump_node_iter != real_exe_order.begin()) { + if (jump_node_iter != from_graph_exe_order.begin()) { InsertControlDependToGraph(from_graph, NOT_NULL(*(jump_node_iter - 1)), NOT_NULL(assign_node)); } InsertControlDependToGraph(from_graph, NOT_NULL(assign_node), NOT_NULL(jump_node)); @@ -772,6 +747,7 @@ std::vector AscendControlParser::RecurseGraph(NotNull std::vector execution_order; uint32_t child_order_index = 0; for (auto &node : cnodes) { + uint32_t child_graph_index = 0; execution_order.push_back(node); if (node == graph->get_end_goto()) { continue; @@ -779,7 +755,7 @@ std::vector AscendControlParser::RecurseGraph(NotNull if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimLabelSwitch)) { std::vector label_switch_list = AnfAlgo::GetNodeAttr>(node, kAttrLabelSwitchList); for (auto iter = label_switch_list.rbegin(); iter != label_switch_list.rend(); ++iter) { - if (!CheckLabelIndex(child_order_index, *iter, node, graph)) { + if (!CheckLabelIndex(child_graph_index++, *iter, node)) { MS_LOG(EXCEPTION) << "Check label index fail"; } if (child_order_index >= graph->child_graph_order().size()) { @@ -791,9 +767,12 @@ std::vector AscendControlParser::RecurseGraph(NotNull } } else if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimLabelGoto)) { uint32_t label_index = AnfAlgo::GetNodeAttr(node, kAttrLabelIndex); - if (!CheckLabelIndex(child_order_index, label_index, node, graph)) { + if (!CheckLabelIndex(child_graph_index, label_index, node)) { MS_LOG(EXCEPTION) << "Check label index fail"; } + if (child_order_index >= graph->child_graph_order().size()) { + MS_LOG(EXCEPTION) << "Index out of range:" << graph->child_graph_order().size(); + } auto child_graph = graph->child_graph_order()[child_order_index++]; auto child_execution_order = RecurseGraph(NOT_NULL(child_graph), memo); execution_order.insert(execution_order.end(), child_execution_order.begin(), child_execution_order.end()); @@ -804,15 +783,14 @@ std::vector AscendControlParser::RecurseGraph(NotNull return execution_order; } -bool AscendControlParser::CheckLabelIndex(uint32_t order_index, uint32_t label_index, const CNodePtr &cur_label, - NotNull graph) { - const std::vector> &child_graph_order = graph->child_graph_order(); +bool AscendControlParser::CheckLabelIndex(uint32_t index, uint32_t label_index, const CNodePtr &cur_label) { + auto child_graphs = AnfAlgo::GetNodeAttr>(cur_label, kAttrChildGraph); // check index and child order size - if (child_graph_order.size() <= IntToSize(order_index)) { - MS_LOG(EXCEPTION) << "Child graph order is wrong, graph " << graph->ToString() << " child graph size " - << child_graph_order.size() << " goto index " << order_index; + if (child_graphs.size() <= IntToSize(index)) { + MS_LOG(EXCEPTION) << "Child graph index is wrong, current node " << cur_label->ToString() << " child graph size " + << child_graphs.size() << " goto index " << index; } - auto child_graph = child_graph_order[order_index]; + auto child_graph = child_graphs[index]; MS_EXCEPTION_IF_NULL(child_graph); // get start_label_set_index of child graph @@ -822,7 +800,7 @@ bool AscendControlParser::CheckLabelIndex(uint32_t order_index, uint32_t label_i MS_EXCEPTION_IF_NULL(cur_label); MS_EXCEPTION_IF_NULL(start_label_set); MS_LOG(WARNING) << cur_label->DebugString() << " index " << label_index << " but " << start_label_set->DebugString() - << " index " << start_label_set_index << " current child graph order : " << order_index; + << " index " << start_label_set_index; return false; } else { return true; diff --git a/mindspore/ccsrc/backend/session/ascend_control_parser.h b/mindspore/ccsrc/backend/session/ascend_control_parser.h index 3c370fa500..ae64bd8d3a 100644 --- a/mindspore/ccsrc/backend/session/ascend_control_parser.h +++ b/mindspore/ccsrc/backend/session/ascend_control_parser.h @@ -64,13 +64,13 @@ class AscendControlParser { const CNodePtr &last_label); static AnfNodePtr InsertAssignToGraph(NotNull kg, NotNull from, NotNull to); - static std::vector>> ParseCallNode(NotNull call_node); + static std::vector>> ParseCallSwitchNode( + NotNull call_node); static std::tuple> ParsePartial(NotNull node); static void AttachChildGraphToReturnNode(NotNull graph, const NotNull *> memo); // root graph order - static bool CheckLabelIndex(uint32_t order_index, uint32_t label_index, const CNodePtr &cnode, - NotNull graph); + static bool CheckLabelIndex(uint32_t index, uint32_t label_index, const CNodePtr &cnode); static std::vector RecurseGraph(NotNull graph, const NotNull *> memo); }; diff --git a/mindspore/ccsrc/backend/session/ascend_session.cc b/mindspore/ccsrc/backend/session/ascend_session.cc index 59fc846759..84fc142516 100644 --- a/mindspore/ccsrc/backend/session/ascend_session.cc +++ b/mindspore/ccsrc/backend/session/ascend_session.cc @@ -885,7 +885,7 @@ void AscendSession::CreateMultiBranchOutput(NotNull graph, NotNu std::map need_replace_list; auto node_list = GetCNodes(TopoSort(graph->get_return())); for (auto &node : node_list) { - if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimCall)) { + if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimCall) || AnfAlgo::CheckPrimitiveType(node, prim::kPrimSwitch)) { // create a parameter to store the output of multiple branch and set the parameter as the condition graph's output auto output_param = graph->TransTupleToMakeTuple(graph->NewParameter(node->abstract())); MS_EXCEPTION_IF_NULL(graph->MutableInputs()); @@ -898,7 +898,7 @@ void AscendSession::CreateMultiBranchOutput(NotNull graph, NotNu MS_LOG(INFO) << "Create parameter " << output_param->DebugString() << " for call node " << node->DebugString() << ", depend node is " << depend->DebugString(); // insert assign in order to transfer child graph output to parameter - auto child_graphs = AnfAlgo::GetCallNodeKernelGraph(node); + auto child_graphs = AnfAlgo::GetCallSwitchKernelGraph(node); for (auto &child_graph : child_graphs) { MS_EXCEPTION_IF_NULL(child_graph); // If graph has no output, the graph is the true graph of while and will call condition graph, no need insert diff --git a/mindspore/ccsrc/backend/session/kernel_graph.cc b/mindspore/ccsrc/backend/session/kernel_graph.cc index 6b49b4b878..5f7c37fa18 100644 --- a/mindspore/ccsrc/backend/session/kernel_graph.cc +++ b/mindspore/ccsrc/backend/session/kernel_graph.cc @@ -67,7 +67,7 @@ std::vector GetCallRealOutputs(const AnfNodePtr &call_node) { return {node}; } std::vector real_inputs; - auto child_graphs = AnfAlgo::GetCallNodeKernelGraph(node->cast()); + auto child_graphs = AnfAlgo::GetCallSwitchKernelGraph(node->cast()); for (const auto &child_graph : child_graphs) { if (child_graph->get_output_null()) { continue; @@ -931,6 +931,18 @@ std::vector KernelGraph::FindNodeByPrimitive(const PrimitivePtr &primi return result; } +std::vector KernelGraph::FindNodeByPrimitive(const std::vector &primitive_list) const { + std::vector result; + for (const auto &anf : execution_order_) { + for (const auto &primitive : primitive_list) { + if (AnfAlgo::CheckPrimitiveType(anf, primitive) && AnfAlgo::GetGraphId(anf.get()) == graph_id_) { + result.push_back(anf->cast()); + } + } + } + return result; +} + void KernelGraph::PrintGraphExecuteOrder() const { MS_LOG(INFO) << "Graph:" << graph_id_ << "execution order"; for (size_t i = 0; i < execution_order_.size(); i++) { @@ -1078,11 +1090,12 @@ bool KernelGraph::IsUniqueTargetInternalOutput(const AnfNodePtr &node, int outpu void KernelGraph::UpdateChildGraphOrder() { MS_LOG(INFO) << "Update " << ToString() << " child graph order."; SetExecOrderByDefault(); - auto call_nodes = FindNodeByPrimitive(std::make_shared(prim::kPrimCall->name())); + auto call_nodes = FindNodeByPrimitive( + {std::make_shared(prim::kPrimCall->name()), std::make_shared(prim::kPrimSwitch->name())}); std::vector child_graph_order; for (auto &call_node : call_nodes) { MS_EXCEPTION_IF_NULL(call_node); - auto call_child_graphs = AnfAlgo::GetCallNodeKernelGraph(call_node->cast()); + auto call_child_graphs = AnfAlgo::GetCallSwitchKernelGraph(call_node->cast()); for (const auto &child_graph : call_child_graphs) { MS_EXCEPTION_IF_NULL(child_graph); if (child_graph != parent_graph_) { diff --git a/mindspore/ccsrc/backend/session/kernel_graph.h b/mindspore/ccsrc/backend/session/kernel_graph.h index 047c21ea20..6f22aff3e2 100644 --- a/mindspore/ccsrc/backend/session/kernel_graph.h +++ b/mindspore/ccsrc/backend/session/kernel_graph.h @@ -131,6 +131,7 @@ class KernelGraph : public FuncGraph { void set_parent_graph(const std::shared_ptr &parent_graph) { parent_graph_ = parent_graph; } // find anf node in graph std::vector FindNodeByPrimitive(const PrimitivePtr &primitive) const; + std::vector FindNodeByPrimitive(const std::vector &primitive_list) const; // used to dump ir std::string ToString() const override; diff --git a/mindspore/ccsrc/backend/session/session_basic.cc b/mindspore/ccsrc/backend/session/session_basic.cc index b8e2e40df3..e32bb4baef 100644 --- a/mindspore/ccsrc/backend/session/session_basic.cc +++ b/mindspore/ccsrc/backend/session/session_basic.cc @@ -547,45 +547,26 @@ CNodePtr SessionBasic::CreateSwitchInput(const AnfNodePtr &node_input, KernelGra MS_EXCEPTION_IF_NULL(node_input); MS_EXCEPTION_IF_NULL(graph); // switch input generalizes partial - if (AnfAlgo::CheckPrimitiveType(node_input, prim::kPrimPartial) || - AnfAlgo::CheckPrimitiveType(node_input, prim::kPrimCall)) { - return node_input->cast(); - } - if (node_input->isa()) { - MS_LOG(EXCEPTION) << "If switch input is " << node_input->DebugString() << ", it mast be partial or call."; - } std::vector partial_inputs = {NewValueNode(std::make_shared(prim::kPrimPartial->name()))}; - if (node_input->isa() && IsValueNode(node_input)) { - partial_inputs.emplace_back(node_input); - auto partial_node = graph->NewCNode(partial_inputs); - return partial_node; + if (AnfAlgo::CheckPrimitiveType(node_input, prim::kPrimPartial)) { + auto partial_node = graph->GetBackendAnfByFrontAnf(node_input); + return partial_node->cast(); + } else if (node_input->isa() && IsValueNode(node_input)) { + partial_inputs.emplace_back(graph->GetBackendAnfByFrontAnf(node_input)); + } else { + KernelGraphPtr kernel_graph = NewKernelGraph(); + MS_EXCEPTION_IF_NULL(kernel_graph); + auto parameter = CreateNewParameterFromCNode(graph->GetBackendAnfByFrontAnf(node_input), true, kernel_graph.get()); + auto primitive = NewValueNode(std::make_shared(prim::kPrimReturn->name())); + auto return_node = kernel_graph->NewCNode({primitive, parameter}); + kernel_graph->set_return(return_node); + partial_inputs.emplace_back(std::make_shared(kernel_graph)); + partial_inputs.emplace_back(graph->GetBackendAnfByFrontAnf(node_input)); } - KernelGraphPtr kernel_graph = NewKernelGraph(); - MS_EXCEPTION_IF_NULL(kernel_graph); - kernel_graph->set_output(graph->GetBackendAnfByFrontAnf(node_input)); - partial_inputs.emplace_back(std::make_shared(kernel_graph)); auto partial_node = graph->NewCNode(partial_inputs); return partial_node; } -CNodePtr SessionBasic::HandleSwitchInputs(const AnfNodePtr &anf_node, KernelGraph *graph) { - MS_EXCEPTION_IF_NULL(anf_node); - MS_EXCEPTION_IF_NULL(graph); - auto node = anf_node->cast(); - MS_EXCEPTION_IF_NULL(node); - if (node->inputs().size() < kSwitchInputSize) { - MS_LOG(EXCEPTION) << "Switch input size less than " << kSwitchInputSize; - } - auto primitive = NewValueNode(std::make_shared(prim::kPrimSwitch->name())); - std::vector switch_inputs = {primitive, node->input(1)}; - for (size_t index = 2; index < node->inputs().size(); index++) { - auto input = CreateSwitchInput(node->input(index), graph); - switch_inputs.emplace_back(input); - } - auto switch_node = graph->NewCNode(switch_inputs); - return switch_node; -} - std::vector SessionBasic::CreateSwitchOrPartialNode(const CNodePtr &cnode, KernelGraph *graph) { MS_EXCEPTION_IF_NULL(cnode); MS_EXCEPTION_IF_NULL(graph); @@ -611,14 +592,33 @@ std::vector SessionBasic::CreateSwitchOrPartialNode(const CNodePtr & }); return cnode_inputs; } else if (AnfAlgo::CheckPrimitiveType(cnode_input, prim::kPrimSwitch)) { - auto switch_node = HandleSwitchInputs(cnode_input, graph); + auto switch_cnode = cnode_input->cast(); + MS_EXCEPTION_IF_NULL(switch_cnode); + std::vector switch_inputs = {switch_cnode->input(kAnfPrimitiveIndex), + switch_cnode->input(kFirstDataInputIndex)}; + for (size_t index = kFirstBranchInSwitch; index < switch_cnode->inputs().size(); index++) { + auto node = switch_cnode->input(index); + // there is real input in call, should put it to true and false branch in switch + if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimPartial)) { + auto partial_node = node->cast(); + MS_EXCEPTION_IF_NULL(partial_node); + std::vector partial_inputs = partial_node->inputs(); + partial_inputs.emplace_back(graph->GetBackendAnfByFrontAnf(cnode->input(kFirstDataInputIndex))); + auto new_partial = graph->NewCNode(partial_inputs); + switch_inputs.emplace_back(new_partial); + } + } + if (switch_inputs.size() < kSwitchInputSize) { + MS_LOG(EXCEPTION) << "Switch inputs size: " << switch_inputs.size() << "less than " << kSwitchInputSize; + } + auto switch_node = graph->NewCNode(switch_inputs); cnode_inputs.emplace_back(switch_node); return cnode_inputs; } MS_LOG(EXCEPTION) << "CNode input[0] must be partial or switch."; } -CNodePtr SessionBasic::CreateNewCNode(const CNodePtr &cnode, KernelGraph *graph) { +CNodePtr SessionBasic::CreateNewCNode(CNodePtr cnode, KernelGraph *graph) { MS_EXCEPTION_IF_NULL(cnode); MS_EXCEPTION_IF_NULL(graph); std::vector cnode_inputs; @@ -642,7 +642,22 @@ CNodePtr SessionBasic::CreateNewCNode(const CNodePtr &cnode, KernelGraph *graph) } } } else if (attr_input->isa()) { - cnode_inputs = CreateSwitchOrPartialNode(cnode, graph); + auto cnode_input = graph->GetBackendAnfByFrontAnf(attr_input); + if (cnode->inputs().size() < 2 && AnfAlgo::CheckPrimitiveType(cnode_input, prim::kPrimSwitch)) { + auto switch_cnode = cnode_input->cast(); + MS_EXCEPTION_IF_NULL(switch_cnode); + cnode_inputs = switch_cnode->inputs(); + } else { + cnode_inputs = CreateSwitchOrPartialNode(cnode, graph); + } + } else if (AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimSwitch)) { + cnode_inputs = {graph->GetBackendAnfByFrontAnf(cnode->input(kAnfPrimitiveIndex)), + graph->GetBackendAnfByFrontAnf(cnode->input(kFirstDataInputIndex))}; + for (size_t index = kFirstBranchInSwitch; index < cnode->inputs().size(); index++) { + auto node_input = cnode->input(index); + auto switch_input = CreateSwitchInput(node_input, graph); + cnode_inputs.emplace_back(switch_input); + } } else { // get primitive of old node auto prim = AnfAlgo::GetCNodePrimitive(cnode); @@ -651,21 +666,33 @@ CNodePtr SessionBasic::CreateNewCNode(const CNodePtr &cnode, KernelGraph *graph) cnode_inputs = {graph->NewValueNode(NewValueNode(std::make_shared(*prim)))}; } - for (size_t input_idx = 1; input_idx < cnode->inputs().size(); input_idx++) { - auto anf = cnode->input(input_idx); - MS_EXCEPTION_IF_NULL(anf); - // anf has been created before - if (graph->GetBackendAnfByFrontAnf(anf) != nullptr) { - cnode_inputs.emplace_back(graph->GetBackendAnfByFrontAnf(anf)); - continue; - } else if (IsValueNode(anf)) { - continue; + if (!AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimSwitch)) { + for (size_t input_idx = kFirstDataInputIndex; input_idx < cnode->inputs().size(); input_idx++) { + auto anf = cnode->input(input_idx); + MS_EXCEPTION_IF_NULL(anf); + // anf has been created before + if (graph->GetBackendAnfByFrontAnf(anf) != nullptr) { + cnode_inputs.emplace_back(graph->GetBackendAnfByFrontAnf(anf)); + continue; + } else if (IsValueNode(anf)) { + continue; + } + MS_LOG(EXCEPTION) << "Unexpected input[" << anf->DebugString() << "]"; } - MS_LOG(EXCEPTION) << "Unexpected input[" << anf->DebugString() << "]"; } TraceManager::DebugTrace(std::make_shared(cnode->debug_info())); auto new_cnode = graph->NewCNode(cnode_inputs); TraceManager::EndTrace(); + + // if the cnode is call switch, remove call + if (new_cnode->inputs().size() > 1) { + auto first_input = new_cnode->input(kFirstDataInputIndex); + if (AnfAlgo::CheckPrimitiveType(new_cnode, prim::kPrimCall) && + AnfAlgo::CheckPrimitiveType(first_input, prim::kPrimSwitch)) { + new_cnode = first_input->cast(); + } + } + return new_cnode; } diff --git a/mindspore/ccsrc/backend/session/session_basic.h b/mindspore/ccsrc/backend/session/session_basic.h old mode 100755 new mode 100644 index f20c27473e..d3e94473a4 --- a/mindspore/ccsrc/backend/session/session_basic.h +++ b/mindspore/ccsrc/backend/session/session_basic.h @@ -86,11 +86,7 @@ class SessionBasic { CNodePtr CreateNewCNode(const CNodePtr &cnode, bool valid_input, KernelGraph *graph, bool *from_other_graph, std::unordered_map *other_graph_cnode); - CNodePtr CreateNewCNode(const CNodePtr &cnode, KernelGraph *graph); - - CNodePtr CreateSwitchInput(const AnfNodePtr &node_input, KernelGraph *graph); - CNodePtr HandleSwitchInputs(const AnfNodePtr &anf_node, KernelGraph *graph); - std::vector CreateSwitchOrPartialNode(const CNodePtr &cnode, KernelGraph *graph); + CNodePtr CreateNewCNode(CNodePtr cnode, KernelGraph *graph); // get graph id in child graphs by ME front anf node pointer virtual GraphId GetGraphIdByNode(const AnfNodePtr &) const { return kInvalidGraphId; } @@ -112,6 +108,10 @@ class SessionBasic { } #endif + private: + CNodePtr CreateSwitchInput(const AnfNodePtr &node_input, KernelGraph *graph); + std::vector CreateSwitchOrPartialNode(const CNodePtr &cnode, KernelGraph *graph); + protected: virtual void SetSummaryNodes(KernelGraph *graph); // Get graph by graph id ,if not exist return null ptr diff --git a/mindspore/ccsrc/utils/utils.h b/mindspore/ccsrc/utils/utils.h index c8c9fc2290..ba47f607cf 100644 --- a/mindspore/ccsrc/utils/utils.h +++ b/mindspore/ccsrc/utils/utils.h @@ -277,11 +277,14 @@ const int kValueNodeTensorMask = 2; // define special index in special node constexpr auto kAnfPrimitiveIndex = 0; constexpr auto kFirstDataInputIndex = 1; -constexpr auto kAnfPartialFuncGraphIndex = 1; constexpr auto kRealInputNodeIndexInTupleGetItem = 1; constexpr auto kInputNodeOutputIndexInTupleGetItem = 2; constexpr auto kTupleGetItemInputSize = 3; constexpr auto kSwitchInputSize = 4; +constexpr auto kFirstBranchInSwitch = 2; +constexpr auto kCallKernelGraphIndex = 1; +constexpr auto kSwitchTrueKernelGraphIndex = 2; +constexpr auto kSwitchFalseKernelGraphIndex = 3; // index define of control depend constexpr auto kControlDependPriorIndex = 1; constexpr auto kControlDependBehindIndex = 2;