diff --git a/mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc b/mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc index 34471537db..8a212be1f0 100644 --- a/mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc +++ b/mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc @@ -1031,31 +1031,29 @@ FuncGraphPtr AnfRuntimeAlgorithm::GetValueNodeFuncGraph(const AnfNodePtr &node) return func_graph; } -std::vector AnfRuntimeAlgorithm::GetCallNodeKernelGraph(const CNodePtr &call_node) { - MS_EXCEPTION_IF_NULL(call_node); - if (!AnfAlgo::CheckPrimitiveType(call_node, std::make_shared("call"))) { - MS_LOG(EXCEPTION) << "Anf node: " << call_node->DebugString() << "is not a call node."; +std::vector AnfRuntimeAlgorithm::GetCallSwitchKernelGraph(const CNodePtr &cnode) { + MS_EXCEPTION_IF_NULL(cnode); + if (!(AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimCall) || AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimSwitch))) { + MS_LOG(EXCEPTION) << "Node: " << cnode->DebugString() << "is not a call or switch node."; } - auto input1 = call_node->input(1); - MS_EXCEPTION_IF_NULL(input1); - if (input1->isa()) { + if (AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimCall)) { + auto input1 = cnode->input(kCallKernelGraphIndex); + MS_EXCEPTION_IF_NULL(input1); auto value_node = input1->cast(); MS_EXCEPTION_IF_NULL(value_node); auto kernel_graph = value_node->value(); MS_EXCEPTION_IF_NULL(kernel_graph); return {kernel_graph->cast()}; - } else if (input1->isa() && AnfAlgo::CheckPrimitiveType(input1, prim::kPrimSwitch)) { - auto switch_node = input1->cast(); - MS_EXCEPTION_IF_NULL(switch_node); - auto get_switch_kernel_graph = [switch_node](size_t input_index) -> KernelGraphPtr { - auto partial = switch_node->input(input_index); + } else if (AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimSwitch)) { + auto get_switch_kernel_graph = [cnode](size_t input_index) -> KernelGraphPtr { + auto partial = cnode->input(input_index); MS_EXCEPTION_IF_NULL(partial); if (IsValueNode(partial)) { return GetValueNode(partial); } auto partial_cnode = partial->cast(); MS_EXCEPTION_IF_NULL(partial_cnode); - auto graph_node = partial_cnode->input(1); + auto graph_node = partial_cnode->input(kCallKernelGraphIndex); MS_EXCEPTION_IF_NULL(graph_node); auto graph_value_node = graph_node->cast(); MS_EXCEPTION_IF_NULL(graph_value_node); @@ -1064,7 +1062,8 @@ std::vector AnfRuntimeAlgorithm::GetCallNodeKernelGraph(const CN auto child_graph = graph_value->cast(); return child_graph; }; - return {get_switch_kernel_graph(2), get_switch_kernel_graph(3)}; + return {get_switch_kernel_graph(kSwitchTrueKernelGraphIndex), + get_switch_kernel_graph(kSwitchFalseKernelGraphIndex)}; } return {}; } diff --git a/mindspore/ccsrc/backend/session/anf_runtime_algorithm.h b/mindspore/ccsrc/backend/session/anf_runtime_algorithm.h index c08819e2dc..89a5ecbce3 100644 --- a/mindspore/ccsrc/backend/session/anf_runtime_algorithm.h +++ b/mindspore/ccsrc/backend/session/anf_runtime_algorithm.h @@ -201,7 +201,7 @@ class AnfRuntimeAlgorithm { static bool IsCommunicationOp(const AnfNodePtr &node); static bool IsGetNext(const NotNull &node); static FuncGraphPtr GetValueNodeFuncGraph(const AnfNodePtr &node); - static std::vector GetCallNodeKernelGraph(const CNodePtr &call_node); + static std::vector GetCallSwitchKernelGraph(const CNodePtr &cnode); static bool IsSwitchCall(const CNodePtr &call_node); static bool IsScalarInput(const CNodePtr &cnode, size_t index); static bool IsScalarOutput(const CNodePtr &cnode, size_t index); diff --git a/mindspore/ccsrc/backend/session/ascend_control_parser.cc b/mindspore/ccsrc/backend/session/ascend_control_parser.cc index 81516d9481..e1d0a02a80 100644 --- a/mindspore/ccsrc/backend/session/ascend_control_parser.cc +++ b/mindspore/ccsrc/backend/session/ascend_control_parser.cc @@ -361,27 +361,22 @@ void AscendControlParser::ExecutorValidate(NotNull root_graph) { } } -std::vector>> AscendControlParser::ParseCallNode( - NotNull call_node) { +std::vector>> AscendControlParser::ParseCallSwitchNode( + NotNull cnode) { std::vector>> ret; - if (!IsPrimitiveCNode(call_node.get(), prim::kPrimCall)) { - MS_LOG(EXCEPTION) << "Node " << call_node->DebugString() << " is not a call node."; - } - if (call_node->size() <= kCNodeCallArg) { - MS_LOG(EXCEPTION) << "Node " << call_node->DebugString() << " has invalid inputs size " << call_node->size(); - } - const std::vector &call_node_inputs = call_node->inputs(); - auto call_arg = call_node_inputs[kCNodeCallArg]; - MS_EXCEPTION_IF_NULL(call_arg); - if (IsValueNode(call_arg)) { + + if (IsPrimitiveCNode(cnode.get(), prim::kPrimCall)) { + if (cnode->size() <= kCNodeCallArg) { + MS_LOG(EXCEPTION) << "Call node " << cnode->DebugString() << " has invalid inputs size " << cnode->size(); + } + auto call_arg = cnode->input(kCNodeCallArg); + MS_EXCEPTION_IF_NULL(call_arg); ret.emplace_back(GetValueNode(call_arg), - std::vector(call_node_inputs.begin() + kCNodeCallArg + 1, call_node_inputs.end())); - } else if (IsPrimitiveCNode(call_arg, prim::kPrimSwitch)) { - auto switch_cnode = call_arg->cast(); - MS_EXCEPTION_IF_NULL(switch_cnode); - const std::vector &switch_inputs = switch_cnode->inputs(); - if (switch_inputs.size() <= kCNodeSwitchCond) { - MS_LOG(EXCEPTION) << "Node " << switch_cnode->DebugString() << " has invalid inputs size " + std::vector(cnode->inputs().begin() + kCNodeCallArg + 1, cnode->inputs().end())); + } else if (IsPrimitiveCNode(cnode.get(), prim::kPrimSwitch)) { + const std::vector &switch_inputs = cnode->inputs(); + if (switch_inputs.size() < kCNodeSwitchLength) { + MS_LOG(EXCEPTION) << "Switch node " << cnode->DebugString() << " has invalid inputs size " << switch_inputs.size(); } for (auto iter = switch_inputs.begin() + kCNodeSwitchCond + 1; iter != switch_inputs.end(); ++iter) { @@ -389,7 +384,7 @@ std::vector>> AscendControlPar ret.emplace_back(target_graph, args); } } else { - MS_LOG(EXCEPTION) << "Unsupport call node: " << call_node->DebugString(5); + MS_LOG(EXCEPTION) << "Unsupport call node: " << cnode->DebugString(5); } return ret; } @@ -406,11 +401,11 @@ void AscendControlParser::ChildGraphDataAssign( const std::vector &nodes = kg->execution_order(); for (auto &node : nodes) { - if (!IsPrimitiveCNode(node, prim::kPrimCall)) { + if (!(IsPrimitiveCNode(node, prim::kPrimCall) || IsPrimitiveCNode(node, prim::kPrimSwitch))) { continue; } - auto child_graph_list = ParseCallNode(NOT_NULL(node)); + auto child_graph_list = ParseCallSwitchNode(NOT_NULL(node)); for (auto &[child_graph, args] : child_graph_list) { MS_EXCEPTION_IF_NULL(child_graph); const std::vector ¶ms = child_graph->inputs(); @@ -425,7 +420,6 @@ void AscendControlParser::ChildGraphDataAssign( link_list->emplace_back(args[i], params[i]); continue; } - InsertMultipleAssignToGraph(kg, node, NOT_NULL(args[i]), NOT_NULL(params[i])); } } @@ -475,30 +469,20 @@ NotNull AscendControlParser::ProcessKernelGraph(NotNullsize() < kCNodePrim + 1) { - MS_LOG(EXCEPTION) << "Inputs of apply node is empty"; - } - AnfNodePtr fn = cnode->input(kAnfPrimitiveIndex); - if (!IsPrimitive(fn, prim::kPrimCall) || cnode->size() < kCNodeCallArg + 1) { - MS_LOG(DEBUG) << "Continue node " << cnode->DebugString(); + if (!(AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimCall) || + AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimSwitch) || + AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimSwitchLayer))) { continue; } - AnfNodePtr arg = cnode->input(kFirstDataInputIndex); - MS_EXCEPTION_IF_NULL(arg); - if (IsValueNode(arg)) { + + if (IsPrimitiveCNode(cnode, prim::kPrimCall)) { RecurseCall(kg, NOT_NULL(cnode), GetNextRealKernel(nodes, i + 1), memo); - } else if (!arg->isa()) { - MS_LOG(EXCEPTION) << "Unknown type call node " << cnode->DebugString(); - } else if (IsPrimitiveCNode(arg->cast(), prim::kPrimSwitch)) { - auto arg_cnode = arg->cast(); - MS_EXCEPTION_IF_NULL(arg_cnode); - cnode->set_inputs(arg_cnode->inputs()); + } else if (IsPrimitiveCNode(cnode, prim::kPrimSwitch)) { RecurseSwitch(kg, NOT_NULL(cnode), GetNextRealKernel(nodes, i + 1), memo); - } else if (IsPrimitiveCNode(arg->cast(), prim::kPrimSwitchLayer)) { - auto arg_cnode = arg->cast(); - MS_EXCEPTION_IF_NULL(arg_cnode); - cnode->set_inputs(arg_cnode->inputs()); + } else if (IsPrimitiveCNode(cnode, prim::kPrimSwitchLayer)) { RecurseSwitchLayer(kg, NOT_NULL(cnode), GetNextRealKernel(nodes, i + 1), memo); + } else { + MS_LOG(EXCEPTION) << "Unexpected node: " << cnode->DebugString(); } } kg->SetExecOrderByDefault(); @@ -699,31 +683,22 @@ void AscendControlParser::InsertMultipleAssignToGraph(NotNull fr continue; } const auto &from_graph_exe_order = from_graph->execution_order(); - std::vector real_exe_order(from_graph_exe_order.size()); - size_t real_exe_order_size = 0; - std::copy_if(from_graph_exe_order.begin(), from_graph_exe_order.end(), real_exe_order.begin(), - [&real_exe_order_size](const CNodePtr &node) -> bool { - return (IsPrimitiveCNode(node, prim::kPrimSwitch) || IsPrimitiveCNode(node, prim::kPrimPartial)) - ? false - : (++real_exe_order_size, true); - }); - real_exe_order.resize(real_exe_order_size); if (jump_node == nullptr) { - if (!real_exe_order.empty()) { - InsertControlDependToGraph(from_graph, NOT_NULL(*(real_exe_order.rbegin())), NOT_NULL(assign_node)); + if (!from_graph_exe_order.empty()) { + InsertControlDependToGraph(from_graph, NOT_NULL(*(from_graph_exe_order.rbegin())), NOT_NULL(assign_node)); } else { InsertDependToGraph(from_graph, NOT_NULL(assign_node)); } continue; } - auto jump_node_iter = std::find(real_exe_order.begin(), real_exe_order.end(), jump_node); - if (jump_node_iter == real_exe_order.end()) { + auto jump_node_iter = std::find(from_graph_exe_order.begin(), from_graph_exe_order.end(), jump_node); + if (jump_node_iter == from_graph_exe_order.end()) { MS_LOG(EXCEPTION) << "Cannot find jump node " << jump_node->DebugString() << " in graph " << from_graph->ToString(); } // insert assign between jump_node -1 and jump_node - if (jump_node_iter != real_exe_order.begin()) { + if (jump_node_iter != from_graph_exe_order.begin()) { InsertControlDependToGraph(from_graph, NOT_NULL(*(jump_node_iter - 1)), NOT_NULL(assign_node)); } InsertControlDependToGraph(from_graph, NOT_NULL(assign_node), NOT_NULL(jump_node)); @@ -772,6 +747,7 @@ std::vector AscendControlParser::RecurseGraph(NotNull std::vector execution_order; uint32_t child_order_index = 0; for (auto &node : cnodes) { + uint32_t child_graph_index = 0; execution_order.push_back(node); if (node == graph->get_end_goto()) { continue; @@ -779,7 +755,7 @@ std::vector AscendControlParser::RecurseGraph(NotNull if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimLabelSwitch)) { std::vector label_switch_list = AnfAlgo::GetNodeAttr>(node, kAttrLabelSwitchList); for (auto iter = label_switch_list.rbegin(); iter != label_switch_list.rend(); ++iter) { - if (!CheckLabelIndex(child_order_index, *iter, node, graph)) { + if (!CheckLabelIndex(child_graph_index++, *iter, node)) { MS_LOG(EXCEPTION) << "Check label index fail"; } if (child_order_index >= graph->child_graph_order().size()) { @@ -791,9 +767,12 @@ std::vector AscendControlParser::RecurseGraph(NotNull } } else if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimLabelGoto)) { uint32_t label_index = AnfAlgo::GetNodeAttr(node, kAttrLabelIndex); - if (!CheckLabelIndex(child_order_index, label_index, node, graph)) { + if (!CheckLabelIndex(child_graph_index, label_index, node)) { MS_LOG(EXCEPTION) << "Check label index fail"; } + if (child_order_index >= graph->child_graph_order().size()) { + MS_LOG(EXCEPTION) << "Index out of range:" << graph->child_graph_order().size(); + } auto child_graph = graph->child_graph_order()[child_order_index++]; auto child_execution_order = RecurseGraph(NOT_NULL(child_graph), memo); execution_order.insert(execution_order.end(), child_execution_order.begin(), child_execution_order.end()); @@ -804,15 +783,14 @@ std::vector AscendControlParser::RecurseGraph(NotNull return execution_order; } -bool AscendControlParser::CheckLabelIndex(uint32_t order_index, uint32_t label_index, const CNodePtr &cur_label, - NotNull graph) { - const std::vector> &child_graph_order = graph->child_graph_order(); +bool AscendControlParser::CheckLabelIndex(uint32_t index, uint32_t label_index, const CNodePtr &cur_label) { + auto child_graphs = AnfAlgo::GetNodeAttr>(cur_label, kAttrChildGraph); // check index and child order size - if (child_graph_order.size() <= IntToSize(order_index)) { - MS_LOG(EXCEPTION) << "Child graph order is wrong, graph " << graph->ToString() << " child graph size " - << child_graph_order.size() << " goto index " << order_index; + if (child_graphs.size() <= IntToSize(index)) { + MS_LOG(EXCEPTION) << "Child graph index is wrong, current node " << cur_label->ToString() << " child graph size " + << child_graphs.size() << " goto index " << index; } - auto child_graph = child_graph_order[order_index]; + auto child_graph = child_graphs[index]; MS_EXCEPTION_IF_NULL(child_graph); // get start_label_set_index of child graph @@ -822,7 +800,7 @@ bool AscendControlParser::CheckLabelIndex(uint32_t order_index, uint32_t label_i MS_EXCEPTION_IF_NULL(cur_label); MS_EXCEPTION_IF_NULL(start_label_set); MS_LOG(WARNING) << cur_label->DebugString() << " index " << label_index << " but " << start_label_set->DebugString() - << " index " << start_label_set_index << " current child graph order : " << order_index; + << " index " << start_label_set_index; return false; } else { return true; diff --git a/mindspore/ccsrc/backend/session/ascend_control_parser.h b/mindspore/ccsrc/backend/session/ascend_control_parser.h index 3c370fa500..ae64bd8d3a 100644 --- a/mindspore/ccsrc/backend/session/ascend_control_parser.h +++ b/mindspore/ccsrc/backend/session/ascend_control_parser.h @@ -64,13 +64,13 @@ class AscendControlParser { const CNodePtr &last_label); static AnfNodePtr InsertAssignToGraph(NotNull kg, NotNull from, NotNull to); - static std::vector>> ParseCallNode(NotNull call_node); + static std::vector>> ParseCallSwitchNode( + NotNull call_node); static std::tuple> ParsePartial(NotNull node); static void AttachChildGraphToReturnNode(NotNull graph, const NotNull *> memo); // root graph order - static bool CheckLabelIndex(uint32_t order_index, uint32_t label_index, const CNodePtr &cnode, - NotNull graph); + static bool CheckLabelIndex(uint32_t index, uint32_t label_index, const CNodePtr &cnode); static std::vector RecurseGraph(NotNull graph, const NotNull *> memo); }; diff --git a/mindspore/ccsrc/backend/session/ascend_session.cc b/mindspore/ccsrc/backend/session/ascend_session.cc index 59fc846759..84fc142516 100644 --- a/mindspore/ccsrc/backend/session/ascend_session.cc +++ b/mindspore/ccsrc/backend/session/ascend_session.cc @@ -885,7 +885,7 @@ void AscendSession::CreateMultiBranchOutput(NotNull graph, NotNu std::map need_replace_list; auto node_list = GetCNodes(TopoSort(graph->get_return())); for (auto &node : node_list) { - if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimCall)) { + if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimCall) || AnfAlgo::CheckPrimitiveType(node, prim::kPrimSwitch)) { // create a parameter to store the output of multiple branch and set the parameter as the condition graph's output auto output_param = graph->TransTupleToMakeTuple(graph->NewParameter(node->abstract())); MS_EXCEPTION_IF_NULL(graph->MutableInputs()); @@ -898,7 +898,7 @@ void AscendSession::CreateMultiBranchOutput(NotNull graph, NotNu MS_LOG(INFO) << "Create parameter " << output_param->DebugString() << " for call node " << node->DebugString() << ", depend node is " << depend->DebugString(); // insert assign in order to transfer child graph output to parameter - auto child_graphs = AnfAlgo::GetCallNodeKernelGraph(node); + auto child_graphs = AnfAlgo::GetCallSwitchKernelGraph(node); for (auto &child_graph : child_graphs) { MS_EXCEPTION_IF_NULL(child_graph); // If graph has no output, the graph is the true graph of while and will call condition graph, no need insert diff --git a/mindspore/ccsrc/backend/session/kernel_graph.cc b/mindspore/ccsrc/backend/session/kernel_graph.cc index 6b49b4b878..5f7c37fa18 100644 --- a/mindspore/ccsrc/backend/session/kernel_graph.cc +++ b/mindspore/ccsrc/backend/session/kernel_graph.cc @@ -67,7 +67,7 @@ std::vector GetCallRealOutputs(const AnfNodePtr &call_node) { return {node}; } std::vector real_inputs; - auto child_graphs = AnfAlgo::GetCallNodeKernelGraph(node->cast()); + auto child_graphs = AnfAlgo::GetCallSwitchKernelGraph(node->cast()); for (const auto &child_graph : child_graphs) { if (child_graph->get_output_null()) { continue; @@ -931,6 +931,18 @@ std::vector KernelGraph::FindNodeByPrimitive(const PrimitivePtr &primi return result; } +std::vector KernelGraph::FindNodeByPrimitive(const std::vector &primitive_list) const { + std::vector result; + for (const auto &anf : execution_order_) { + for (const auto &primitive : primitive_list) { + if (AnfAlgo::CheckPrimitiveType(anf, primitive) && AnfAlgo::GetGraphId(anf.get()) == graph_id_) { + result.push_back(anf->cast()); + } + } + } + return result; +} + void KernelGraph::PrintGraphExecuteOrder() const { MS_LOG(INFO) << "Graph:" << graph_id_ << "execution order"; for (size_t i = 0; i < execution_order_.size(); i++) { @@ -1078,11 +1090,12 @@ bool KernelGraph::IsUniqueTargetInternalOutput(const AnfNodePtr &node, int outpu void KernelGraph::UpdateChildGraphOrder() { MS_LOG(INFO) << "Update " << ToString() << " child graph order."; SetExecOrderByDefault(); - auto call_nodes = FindNodeByPrimitive(std::make_shared(prim::kPrimCall->name())); + auto call_nodes = FindNodeByPrimitive( + {std::make_shared(prim::kPrimCall->name()), std::make_shared(prim::kPrimSwitch->name())}); std::vector child_graph_order; for (auto &call_node : call_nodes) { MS_EXCEPTION_IF_NULL(call_node); - auto call_child_graphs = AnfAlgo::GetCallNodeKernelGraph(call_node->cast()); + auto call_child_graphs = AnfAlgo::GetCallSwitchKernelGraph(call_node->cast()); for (const auto &child_graph : call_child_graphs) { MS_EXCEPTION_IF_NULL(child_graph); if (child_graph != parent_graph_) { diff --git a/mindspore/ccsrc/backend/session/kernel_graph.h b/mindspore/ccsrc/backend/session/kernel_graph.h index 047c21ea20..6f22aff3e2 100644 --- a/mindspore/ccsrc/backend/session/kernel_graph.h +++ b/mindspore/ccsrc/backend/session/kernel_graph.h @@ -131,6 +131,7 @@ class KernelGraph : public FuncGraph { void set_parent_graph(const std::shared_ptr &parent_graph) { parent_graph_ = parent_graph; } // find anf node in graph std::vector FindNodeByPrimitive(const PrimitivePtr &primitive) const; + std::vector FindNodeByPrimitive(const std::vector &primitive_list) const; // used to dump ir std::string ToString() const override; diff --git a/mindspore/ccsrc/backend/session/session_basic.cc b/mindspore/ccsrc/backend/session/session_basic.cc index b8e2e40df3..e32bb4baef 100644 --- a/mindspore/ccsrc/backend/session/session_basic.cc +++ b/mindspore/ccsrc/backend/session/session_basic.cc @@ -547,45 +547,26 @@ CNodePtr SessionBasic::CreateSwitchInput(const AnfNodePtr &node_input, KernelGra MS_EXCEPTION_IF_NULL(node_input); MS_EXCEPTION_IF_NULL(graph); // switch input generalizes partial - if (AnfAlgo::CheckPrimitiveType(node_input, prim::kPrimPartial) || - AnfAlgo::CheckPrimitiveType(node_input, prim::kPrimCall)) { - return node_input->cast(); - } - if (node_input->isa()) { - MS_LOG(EXCEPTION) << "If switch input is " << node_input->DebugString() << ", it mast be partial or call."; - } std::vector partial_inputs = {NewValueNode(std::make_shared(prim::kPrimPartial->name()))}; - if (node_input->isa() && IsValueNode(node_input)) { - partial_inputs.emplace_back(node_input); - auto partial_node = graph->NewCNode(partial_inputs); - return partial_node; + if (AnfAlgo::CheckPrimitiveType(node_input, prim::kPrimPartial)) { + auto partial_node = graph->GetBackendAnfByFrontAnf(node_input); + return partial_node->cast(); + } else if (node_input->isa() && IsValueNode(node_input)) { + partial_inputs.emplace_back(graph->GetBackendAnfByFrontAnf(node_input)); + } else { + KernelGraphPtr kernel_graph = NewKernelGraph(); + MS_EXCEPTION_IF_NULL(kernel_graph); + auto parameter = CreateNewParameterFromCNode(graph->GetBackendAnfByFrontAnf(node_input), true, kernel_graph.get()); + auto primitive = NewValueNode(std::make_shared(prim::kPrimReturn->name())); + auto return_node = kernel_graph->NewCNode({primitive, parameter}); + kernel_graph->set_return(return_node); + partial_inputs.emplace_back(std::make_shared(kernel_graph)); + partial_inputs.emplace_back(graph->GetBackendAnfByFrontAnf(node_input)); } - KernelGraphPtr kernel_graph = NewKernelGraph(); - MS_EXCEPTION_IF_NULL(kernel_graph); - kernel_graph->set_output(graph->GetBackendAnfByFrontAnf(node_input)); - partial_inputs.emplace_back(std::make_shared(kernel_graph)); auto partial_node = graph->NewCNode(partial_inputs); return partial_node; } -CNodePtr SessionBasic::HandleSwitchInputs(const AnfNodePtr &anf_node, KernelGraph *graph) { - MS_EXCEPTION_IF_NULL(anf_node); - MS_EXCEPTION_IF_NULL(graph); - auto node = anf_node->cast(); - MS_EXCEPTION_IF_NULL(node); - if (node->inputs().size() < kSwitchInputSize) { - MS_LOG(EXCEPTION) << "Switch input size less than " << kSwitchInputSize; - } - auto primitive = NewValueNode(std::make_shared(prim::kPrimSwitch->name())); - std::vector switch_inputs = {primitive, node->input(1)}; - for (size_t index = 2; index < node->inputs().size(); index++) { - auto input = CreateSwitchInput(node->input(index), graph); - switch_inputs.emplace_back(input); - } - auto switch_node = graph->NewCNode(switch_inputs); - return switch_node; -} - std::vector SessionBasic::CreateSwitchOrPartialNode(const CNodePtr &cnode, KernelGraph *graph) { MS_EXCEPTION_IF_NULL(cnode); MS_EXCEPTION_IF_NULL(graph); @@ -611,14 +592,33 @@ std::vector SessionBasic::CreateSwitchOrPartialNode(const CNodePtr & }); return cnode_inputs; } else if (AnfAlgo::CheckPrimitiveType(cnode_input, prim::kPrimSwitch)) { - auto switch_node = HandleSwitchInputs(cnode_input, graph); + auto switch_cnode = cnode_input->cast(); + MS_EXCEPTION_IF_NULL(switch_cnode); + std::vector switch_inputs = {switch_cnode->input(kAnfPrimitiveIndex), + switch_cnode->input(kFirstDataInputIndex)}; + for (size_t index = kFirstBranchInSwitch; index < switch_cnode->inputs().size(); index++) { + auto node = switch_cnode->input(index); + // there is real input in call, should put it to true and false branch in switch + if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimPartial)) { + auto partial_node = node->cast(); + MS_EXCEPTION_IF_NULL(partial_node); + std::vector partial_inputs = partial_node->inputs(); + partial_inputs.emplace_back(graph->GetBackendAnfByFrontAnf(cnode->input(kFirstDataInputIndex))); + auto new_partial = graph->NewCNode(partial_inputs); + switch_inputs.emplace_back(new_partial); + } + } + if (switch_inputs.size() < kSwitchInputSize) { + MS_LOG(EXCEPTION) << "Switch inputs size: " << switch_inputs.size() << "less than " << kSwitchInputSize; + } + auto switch_node = graph->NewCNode(switch_inputs); cnode_inputs.emplace_back(switch_node); return cnode_inputs; } MS_LOG(EXCEPTION) << "CNode input[0] must be partial or switch."; } -CNodePtr SessionBasic::CreateNewCNode(const CNodePtr &cnode, KernelGraph *graph) { +CNodePtr SessionBasic::CreateNewCNode(CNodePtr cnode, KernelGraph *graph) { MS_EXCEPTION_IF_NULL(cnode); MS_EXCEPTION_IF_NULL(graph); std::vector cnode_inputs; @@ -642,7 +642,22 @@ CNodePtr SessionBasic::CreateNewCNode(const CNodePtr &cnode, KernelGraph *graph) } } } else if (attr_input->isa()) { - cnode_inputs = CreateSwitchOrPartialNode(cnode, graph); + auto cnode_input = graph->GetBackendAnfByFrontAnf(attr_input); + if (cnode->inputs().size() < 2 && AnfAlgo::CheckPrimitiveType(cnode_input, prim::kPrimSwitch)) { + auto switch_cnode = cnode_input->cast(); + MS_EXCEPTION_IF_NULL(switch_cnode); + cnode_inputs = switch_cnode->inputs(); + } else { + cnode_inputs = CreateSwitchOrPartialNode(cnode, graph); + } + } else if (AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimSwitch)) { + cnode_inputs = {graph->GetBackendAnfByFrontAnf(cnode->input(kAnfPrimitiveIndex)), + graph->GetBackendAnfByFrontAnf(cnode->input(kFirstDataInputIndex))}; + for (size_t index = kFirstBranchInSwitch; index < cnode->inputs().size(); index++) { + auto node_input = cnode->input(index); + auto switch_input = CreateSwitchInput(node_input, graph); + cnode_inputs.emplace_back(switch_input); + } } else { // get primitive of old node auto prim = AnfAlgo::GetCNodePrimitive(cnode); @@ -651,21 +666,33 @@ CNodePtr SessionBasic::CreateNewCNode(const CNodePtr &cnode, KernelGraph *graph) cnode_inputs = {graph->NewValueNode(NewValueNode(std::make_shared(*prim)))}; } - for (size_t input_idx = 1; input_idx < cnode->inputs().size(); input_idx++) { - auto anf = cnode->input(input_idx); - MS_EXCEPTION_IF_NULL(anf); - // anf has been created before - if (graph->GetBackendAnfByFrontAnf(anf) != nullptr) { - cnode_inputs.emplace_back(graph->GetBackendAnfByFrontAnf(anf)); - continue; - } else if (IsValueNode(anf)) { - continue; + if (!AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimSwitch)) { + for (size_t input_idx = kFirstDataInputIndex; input_idx < cnode->inputs().size(); input_idx++) { + auto anf = cnode->input(input_idx); + MS_EXCEPTION_IF_NULL(anf); + // anf has been created before + if (graph->GetBackendAnfByFrontAnf(anf) != nullptr) { + cnode_inputs.emplace_back(graph->GetBackendAnfByFrontAnf(anf)); + continue; + } else if (IsValueNode(anf)) { + continue; + } + MS_LOG(EXCEPTION) << "Unexpected input[" << anf->DebugString() << "]"; } - MS_LOG(EXCEPTION) << "Unexpected input[" << anf->DebugString() << "]"; } TraceManager::DebugTrace(std::make_shared(cnode->debug_info())); auto new_cnode = graph->NewCNode(cnode_inputs); TraceManager::EndTrace(); + + // if the cnode is call switch, remove call + if (new_cnode->inputs().size() > 1) { + auto first_input = new_cnode->input(kFirstDataInputIndex); + if (AnfAlgo::CheckPrimitiveType(new_cnode, prim::kPrimCall) && + AnfAlgo::CheckPrimitiveType(first_input, prim::kPrimSwitch)) { + new_cnode = first_input->cast(); + } + } + return new_cnode; } diff --git a/mindspore/ccsrc/backend/session/session_basic.h b/mindspore/ccsrc/backend/session/session_basic.h old mode 100755 new mode 100644 index f20c27473e..d3e94473a4 --- a/mindspore/ccsrc/backend/session/session_basic.h +++ b/mindspore/ccsrc/backend/session/session_basic.h @@ -86,11 +86,7 @@ class SessionBasic { CNodePtr CreateNewCNode(const CNodePtr &cnode, bool valid_input, KernelGraph *graph, bool *from_other_graph, std::unordered_map *other_graph_cnode); - CNodePtr CreateNewCNode(const CNodePtr &cnode, KernelGraph *graph); - - CNodePtr CreateSwitchInput(const AnfNodePtr &node_input, KernelGraph *graph); - CNodePtr HandleSwitchInputs(const AnfNodePtr &anf_node, KernelGraph *graph); - std::vector CreateSwitchOrPartialNode(const CNodePtr &cnode, KernelGraph *graph); + CNodePtr CreateNewCNode(CNodePtr cnode, KernelGraph *graph); // get graph id in child graphs by ME front anf node pointer virtual GraphId GetGraphIdByNode(const AnfNodePtr &) const { return kInvalidGraphId; } @@ -112,6 +108,10 @@ class SessionBasic { } #endif + private: + CNodePtr CreateSwitchInput(const AnfNodePtr &node_input, KernelGraph *graph); + std::vector CreateSwitchOrPartialNode(const CNodePtr &cnode, KernelGraph *graph); + protected: virtual void SetSummaryNodes(KernelGraph *graph); // Get graph by graph id ,if not exist return null ptr diff --git a/mindspore/ccsrc/utils/utils.h b/mindspore/ccsrc/utils/utils.h index c8c9fc2290..ba47f607cf 100644 --- a/mindspore/ccsrc/utils/utils.h +++ b/mindspore/ccsrc/utils/utils.h @@ -277,11 +277,14 @@ const int kValueNodeTensorMask = 2; // define special index in special node constexpr auto kAnfPrimitiveIndex = 0; constexpr auto kFirstDataInputIndex = 1; -constexpr auto kAnfPartialFuncGraphIndex = 1; constexpr auto kRealInputNodeIndexInTupleGetItem = 1; constexpr auto kInputNodeOutputIndexInTupleGetItem = 2; constexpr auto kTupleGetItemInputSize = 3; constexpr auto kSwitchInputSize = 4; +constexpr auto kFirstBranchInSwitch = 2; +constexpr auto kCallKernelGraphIndex = 1; +constexpr auto kSwitchTrueKernelGraphIndex = 2; +constexpr auto kSwitchFalseKernelGraphIndex = 3; // index define of control depend constexpr auto kControlDependPriorIndex = 1; constexpr auto kControlDependBehindIndex = 2;