diff --git a/mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc b/mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc index f1962739cf..2cf7cff113 100644 --- a/mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc +++ b/mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc @@ -360,7 +360,7 @@ size_t AnfRuntimeAlgorithm::GetInputTensorNum(const AnfNodePtr &node) { MS_LOG(EXCEPTION) << "Cnode inputs size can't be zero" << " trace: " << trace::DumpSourceLines(node); } - // exclude intputs[0],which is value_node storing attr,inputs left are real input + // exclude inputs[0],which is value_node storing attr,inputs left are real input return input_num - 1; } @@ -1191,10 +1191,28 @@ FuncGraphPtr AnfRuntimeAlgorithm::GetValueNodeFuncGraph(const AnfNodePtr &node) std::vector AnfRuntimeAlgorithm::GetCallSwitchKernelGraph(const CNodePtr &cnode) { MS_EXCEPTION_IF_NULL(cnode); - if (!(AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimCall) || AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimSwitch))) { - MS_LOG(EXCEPTION) << "Node: " << cnode->DebugString() << "is not a call or switch node." + if (!(AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimCall) || AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimSwitch) || + AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimSwitchLayer))) { + MS_LOG(EXCEPTION) << "Node: " << cnode->DebugString() << "is not a call or switch or switch_layer node." << " trace: " << trace::DumpSourceLines(cnode); } + auto get_switch_kernel_graph = [cnode](size_t input_index) -> KernelGraphPtr { + auto partial = cnode->input(input_index); + MS_EXCEPTION_IF_NULL(partial); + if (IsValueNode(partial)) { + return GetValueNode(partial); + } + auto partial_cnode = partial->cast(); + MS_EXCEPTION_IF_NULL(partial_cnode); + auto graph_node = partial_cnode->input(kCallKernelGraphIndex); + MS_EXCEPTION_IF_NULL(graph_node); + auto graph_value_node = graph_node->cast(); + MS_EXCEPTION_IF_NULL(graph_value_node); + auto graph_value = graph_value_node->value(); + MS_EXCEPTION_IF_NULL(graph_value); + auto child_graph = graph_value->cast(); + return child_graph; + }; if (AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimCall)) { auto input1 = cnode->input(kCallKernelGraphIndex); MS_EXCEPTION_IF_NULL(input1); @@ -1204,25 +1222,15 @@ std::vector AnfRuntimeAlgorithm::GetCallSwitchKernelGraph(const MS_EXCEPTION_IF_NULL(kernel_graph); return {kernel_graph->cast()}; } else if (AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimSwitch)) { - auto get_switch_kernel_graph = [cnode](size_t input_index) -> KernelGraphPtr { - auto partial = cnode->input(input_index); - MS_EXCEPTION_IF_NULL(partial); - if (IsValueNode(partial)) { - return GetValueNode(partial); - } - auto partial_cnode = partial->cast(); - MS_EXCEPTION_IF_NULL(partial_cnode); - auto graph_node = partial_cnode->input(kCallKernelGraphIndex); - MS_EXCEPTION_IF_NULL(graph_node); - auto graph_value_node = graph_node->cast(); - MS_EXCEPTION_IF_NULL(graph_value_node); - auto graph_value = graph_value_node->value(); - MS_EXCEPTION_IF_NULL(graph_value); - auto child_graph = graph_value->cast(); - return child_graph; - }; return {get_switch_kernel_graph(kSwitchTrueKernelGraphIndex), get_switch_kernel_graph(kSwitchFalseKernelGraphIndex)}; + } else if (AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimSwitchLayer)) { + std::vector child_graphs; + for (size_t idx = kMakeTupleInSwitchLayerIndex; idx < cnode->inputs().size(); idx++) { + auto child_graph = get_switch_kernel_graph(idx); + child_graphs.emplace_back(child_graph); + } + return child_graphs; } return {}; } @@ -1627,7 +1635,7 @@ void AnfRuntimeAlgorithm::GetAllFatherRealNode(const AnfNodePtr &anf_node, std:: MS_EXCEPTION_IF_NULL(result); MS_EXCEPTION_IF_NULL(visited); if (visited->find(anf_node) != visited->end()) { - MS_LOG(INFO) << "Node:" << anf_node->fullname_with_scope() << " has alreday been visited"; + MS_LOG(INFO) << "Node:" << anf_node->fullname_with_scope() << " has already been visited"; return; } visited->insert(anf_node); diff --git a/mindspore/ccsrc/backend/session/ascend_control_parser.cc b/mindspore/ccsrc/backend/session/ascend_control_parser.cc index b02c39ac92..2b629b6626 100644 --- a/mindspore/ccsrc/backend/session/ascend_control_parser.cc +++ b/mindspore/ccsrc/backend/session/ascend_control_parser.cc @@ -156,7 +156,7 @@ static std::vector GetTargetLabelSetNodes(NotNull jump_node, for (auto label_id : target_label_list) { auto iter = label_id_to_label_set.find(label_id); if (iter == label_id_to_label_set.end()) { - MS_LOG(EXCEPTION) << "Connot find LabelSet node has label id " << label_id; + MS_LOG(EXCEPTION) << "Cannot find LabelSet node has label id " << label_id; } target_labelset_nodes.push_back(iter->second); } @@ -413,6 +413,16 @@ std::vector>> AscendControlPar const auto &[target_graph, args] = ParsePartial(NOT_NULL(*iter)); ret.emplace_back(target_graph, args); } + } else if (IsPrimitiveCNode(cnode.get(), prim::kPrimSwitchLayer)) { + const std::vector &switch_layer_inputs = cnode->inputs(); + if (switch_layer_inputs.size() <= kCNodeSwitchLayerBranch) { + MS_LOG(EXCEPTION) << "Switch layer node " << cnode->DebugString() << " has invalid inputs size " + << switch_layer_inputs.size(); + } + for (auto iter = switch_layer_inputs.begin() + kCNodeSwitchLayerBranch; iter != switch_layer_inputs.end(); ++iter) { + const auto &[target_graph, args] = ParsePartial(NOT_NULL(*iter)); + ret.emplace_back(target_graph, args); + } } else { MS_LOG(EXCEPTION) << "Unsupported call node: " << cnode->DebugString(5); } @@ -431,7 +441,8 @@ void AscendControlParser::ChildGraphDataAssign( const std::vector &nodes = kg->execution_order(); for (auto &node : nodes) { - if (!(IsPrimitiveCNode(node, prim::kPrimCall) || IsPrimitiveCNode(node, prim::kPrimSwitch))) { + if (!(IsPrimitiveCNode(node, prim::kPrimCall) || IsPrimitiveCNode(node, prim::kPrimSwitch) || + IsPrimitiveCNode(node, prim::kPrimSwitchLayer))) { continue; } @@ -647,12 +658,10 @@ void AscendControlParser::RecurseSwitchLayer(NotNull kg, NotNull MS_LOG(EXCEPTION) << "Inputs of apply node must more than " << kCNodeSwitchLayerLength; } - auto branch_tuple = cur_node->input(kCNodeSwitchLayerBranch); - MS_EXCEPTION_IF_NULL(branch_tuple); - if (!branch_tuple->isa()) { - MS_LOG(EXCEPTION) << branch_tuple->DebugString() << " is not a CNode"; + std::vector branch_partial; + for (size_t idx = kCNodeSwitchLayerBranch; idx < cur_node->inputs().size(); idx++) { + branch_partial.emplace_back(cur_node->input(idx)); } - const std::vector &branch_partial = utils::cast(branch_tuple)->inputs(); // 1 return label auto back_label = kg->NewCNode({std::make_shared(std::make_shared(kLabelSetOpName))}); // 2 add depend relationship @@ -673,16 +682,17 @@ void AscendControlParser::RecurseSwitchLayer(NotNull kg, NotNull // 3.1 branch kernel graph and args KernelGraphPtr branch_fg; std::vector origin_inputs; - std::tie(branch_fg, origin_inputs) = ParsePartial(NOT_NULL(origin_switch_inputs[i])); + std::tie(branch_fg, origin_inputs) = ParsePartial(NOT_NULL(origin_switch_inputs[i + kCNodeSwitchLayerBranch])); child_graphs.push_back(branch_fg); // 3.2 recurse sub graph CNodePtr branch_label = ProcessKernelGraph(NOT_NULL(branch_fg), cur_node, back_label, memo); new_switch_inputs.push_back(branch_label); AttachOriginalInputsToGraph(kg, origin_inputs); } - new_switch_inputs.insert(new_switch_inputs.end(), branch_partial.begin(), branch_partial.end()); cur_node->set_inputs(new_switch_inputs); - cur_node->set_abstract(nullptr); + cur_node->set_abstract(std::make_shared()); + // To adapt to the true and false branches of the switch, the sequence of the branches is reversed. + std::reverse(child_graphs.begin(), child_graphs.end()); AnfAlgo::SetNodeAttr(kAttrChildGraph, MakeValue>(child_graphs), cur_node.get()); MS_LOG(INFO) << "Succeed processing switch layer " << cur_node->DebugString(); } diff --git a/mindspore/ccsrc/backend/session/ascend_session.cc b/mindspore/ccsrc/backend/session/ascend_session.cc index 497f91fe37..452542091c 100644 --- a/mindspore/ccsrc/backend/session/ascend_session.cc +++ b/mindspore/ccsrc/backend/session/ascend_session.cc @@ -875,7 +875,7 @@ void AscendSession::BuildGraphImpl(GraphId graph_id) { // generate and load task info to device if it is sink mode Load(graph); } - // sync the inital const tensor to device + // sync the initial const tensor to device SyncInitialTenosrToDevice(); DumpAllGraphs({graph}); MS_LOG(INFO) << "End"; @@ -1634,7 +1634,8 @@ void AscendSession::CreateMultiBranchOutput(NotNull graph, NotNu std::map need_replace_list; auto node_list = GetCNodes(TopoSort(graph->get_return())); for (auto &node : node_list) { - if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimCall) || AnfAlgo::CheckPrimitiveType(node, prim::kPrimSwitch)) { + if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimCall) || AnfAlgo::CheckPrimitiveType(node, prim::kPrimSwitch) || + AnfAlgo::CheckPrimitiveType(node, prim::kPrimSwitchLayer)) { // create a parameter to store the output of multiple branch and set the parameter as the condition graph's output auto output_param = graph->TransTupleToMakeTuple(graph->NewParameter(node->abstract())); MS_EXCEPTION_IF_NULL(graph->MutableInputs()); diff --git a/mindspore/ccsrc/backend/session/kernel_graph.cc b/mindspore/ccsrc/backend/session/kernel_graph.cc index 13802a87c8..f438f5e0ed 100644 --- a/mindspore/ccsrc/backend/session/kernel_graph.cc +++ b/mindspore/ccsrc/backend/session/kernel_graph.cc @@ -1186,8 +1186,9 @@ bool KernelGraph::IsUniqueTargetInternalOutput(const AnfNodePtr &node, int outpu void KernelGraph::UpdateChildGraphOrder() { MS_LOG(INFO) << "Update " << ToString() << " child graph order."; SetExecOrderByDefault(); - auto call_nodes = FindNodeByPrimitive( - {std::make_shared(prim::kPrimCall->name()), std::make_shared(prim::kPrimSwitch->name())}); + auto call_nodes = FindNodeByPrimitive({std::make_shared(prim::kPrimCall->name()), + std::make_shared(prim::kPrimSwitch->name()), + std::make_shared(prim::kPrimSwitchLayer->name())}); std::vector> child_graph_order; for (auto &call_node : call_nodes) { MS_EXCEPTION_IF_NULL(call_node); diff --git a/mindspore/ccsrc/backend/session/session_basic.cc b/mindspore/ccsrc/backend/session/session_basic.cc index 1c5da17801..77008f3283 100644 --- a/mindspore/ccsrc/backend/session/session_basic.cc +++ b/mindspore/ccsrc/backend/session/session_basic.cc @@ -148,7 +148,7 @@ tensor::TensorPtr CreateCNodeOutputTensor(const session::KernelWithIndex &node_o } } tensor->set_padding_type(AnfAlgo::GetOutputReshapeType(node, output_index)); - // if in paynative mode,data only copyed to host when user want to print data + // if in pynative mode,data only copied to host when user want to print data auto ms_context = MsContext::GetInstance(); MS_EXCEPTION_IF_NULL(ms_context); if (ms_context->get_param(MS_CTX_EXECUTION_MODE) != kPynativeMode && @@ -499,10 +499,7 @@ std::vector SessionBasic::CreateParameterFromTuple(const AnfNodePtr auto graph_inputs = graph->MutableInputs(); MS_EXCEPTION_IF_NULL(graph_inputs); auto create_parameter = [&](const AbstractBasePtr &abstract) -> void { - auto parameter = graph->NewParameter(); - MS_EXCEPTION_IF_NULL(parameter); - parameter->set_abstract(abstract); - auto new_parameter = graph->NewParameter(parameter); + auto new_parameter = graph->NewParameter(abstract); parameters.push_back(new_parameter); valid_inputs->push_back(true); graph_inputs->push_back(new_parameter); @@ -662,7 +659,7 @@ CNodePtr SessionBasic::CreateNewCNode(const CNodePtr &cnode, KernelGraph *graph, return new_cnode; } -CNodePtr SessionBasic::CreateSwitchInput(const AnfNodePtr &node_input, KernelGraph *graph) { +CNodePtr SessionBasic::CreateSwitchInput(const CNodePtr &cnode, const AnfNodePtr &node_input, KernelGraph *graph) { MS_EXCEPTION_IF_NULL(node_input); MS_EXCEPTION_IF_NULL(graph); // switch input generalizes partial @@ -675,9 +672,11 @@ CNodePtr SessionBasic::CreateSwitchInput(const AnfNodePtr &node_input, KernelGra } else { KernelGraphPtr kernel_graph = NewKernelGraph(); MS_EXCEPTION_IF_NULL(kernel_graph); - auto parameter = CreateNewParameterFromCNode(graph->GetBackendAnfByFrontAnf(node_input), kernel_graph.get()); + auto parameter = CreateNewParameterFromCNode(cnode, kernel_graph.get()); + parameter->set_abstract(cnode->abstract()); auto primitive = NewValueNode(std::make_shared(prim::kPrimReturn->name())); auto return_node = kernel_graph->NewCNode({primitive, parameter}); + return_node->set_abstract(cnode->abstract()); kernel_graph->set_return(return_node); partial_inputs.emplace_back(std::make_shared(kernel_graph)); partial_inputs.emplace_back(graph->GetBackendAnfByFrontAnf(node_input)); @@ -722,10 +721,97 @@ std::vector SessionBasic::CreateCallSwitchInputs(const CNodePtr &cno return cnode_inputs; } +void SessionBasic::CreateCallNodeReturnFunction(const CNodePtr &cnode, const AnfNodePtr &real_input) { + MS_EXCEPTION_IF_NULL(cnode); + MS_EXCEPTION_IF_NULL(real_input); + if (!(AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimPartial))) { + MS_LOG(EXCEPTION) << "Node: " << cnode->DebugString() << "is not a partial node."; + } + auto partial_input = cnode->input(kFirstDataInputIndex); + KernelGraphPtr partial_kernel_graph = GetValueNode(partial_input); + MS_EXCEPTION_IF_NULL(partial_kernel_graph); + auto ret = partial_kernel_graph->get_return(); + MS_EXCEPTION_IF_NULL(ret); + auto return_input = ret->input(kFirstDataInputIndex); + // if kernel graph return node is a function + if (AnfAlgo::CheckPrimitiveType(return_input, prim::kPrimPartial)) { + std::vector call_inputs = { + partial_kernel_graph->NewValueNode(NewValueNode(std::make_shared(prim::kPrimCall->name())))}; + auto return_input_cnode = return_input->cast(); + + auto partial_inputs = return_input_cnode->inputs(); + call_inputs.insert(call_inputs.end(), partial_inputs.begin() + kFirstDataInputIndex, partial_inputs.end()); + auto parameter_for_input = CreateNewParameterFromCNode(real_input, partial_kernel_graph.get()); + call_inputs.emplace_back(parameter_for_input); + auto call_node = partial_kernel_graph->NewCNode(call_inputs); + // update abstract + KernelGraphPtr sub_partial_kernel_graph = GetValueNode(partial_inputs[kFirstDataInputIndex]); + auto ret_partial = sub_partial_kernel_graph->get_return(); + call_node->set_abstract(ret_partial->abstract()); + // update return input + ret->set_input(kFirstDataInputIndex, call_node); + } +} + +std::vector SessionBasic::CreateCallSwitchLayerInputs(const CNodePtr &cnode, KernelGraph *graph) { + MS_EXCEPTION_IF_NULL(cnode); + MS_EXCEPTION_IF_NULL(graph); + std::vector cnode_inputs = { + graph->NewValueNode(NewValueNode(std::make_shared(prim::kPrimCall->name())))}; + auto attr_input = cnode->input(kAnfPrimitiveIndex); + MS_EXCEPTION_IF_NULL(attr_input); + auto cnode_input = graph->GetBackendAnfByFrontAnf(attr_input); + auto switch_layer_cnode = cnode_input->cast(); + MS_EXCEPTION_IF_NULL(switch_layer_cnode); + std::vector switch_layer_inputs = {switch_layer_cnode->input(kAnfPrimitiveIndex), + switch_layer_cnode->input(kFirstDataInputIndex)}; + auto make_tuple_node = switch_layer_cnode->input(kMakeTupleInSwitchLayerIndex); + MS_EXCEPTION_IF_NULL(make_tuple_node); + auto node = make_tuple_node->cast(); + MS_EXCEPTION_IF_NULL(node); + auto make_tuple_inputs = node->inputs(); + // there is real input in call, should put it to make_tuple in switch_layer + auto real_input = cnode->input(kFirstDataInputIndex); + auto real_input_back = graph->GetBackendAnfByFrontAnf(real_input); + std::vector new_make_tuple_inputs = { + graph->NewValueNode(NewValueNode(std::make_shared(prim::kPrimMakeTuple->name())))}; + for (size_t idx = kFirstDataInputIndex; idx < make_tuple_inputs.size(); idx++) { + auto partial_idx = make_tuple_inputs[idx]; + MS_EXCEPTION_IF_NULL(cnode->abstract()); + // switch_layer node input is partial cnode + if (AnfAlgo::CheckPrimitiveType(partial_idx, prim::kPrimPartial)) { + auto partial_node = partial_idx->cast(); + MS_EXCEPTION_IF_NULL(partial_node); + // update kernel graph when switch_layer node return function + CreateCallNodeReturnFunction(partial_node, real_input_back); + + std::vector new_partial_inputs = partial_node->inputs(); + new_partial_inputs.emplace_back(real_input_back); + auto new_partial = graph->NewCNode(new_partial_inputs); + new_make_tuple_inputs.emplace_back(new_partial); + } + // switch_layer node input is kernel graph value node + if (IsValueNode(partial_idx)) { + // make_tuple inputs is KernelGraph + std::vector new_partial_inputs; + new_partial_inputs.emplace_back(NewValueNode(std::make_shared(prim::kPrimPartial->name()))); + new_partial_inputs.emplace_back(partial_idx); + new_partial_inputs.emplace_back(real_input_back); + auto new_partial = graph->NewCNode(new_partial_inputs); + new_make_tuple_inputs.emplace_back(new_partial); + } + } + auto new_make_tuple = graph->NewCNode(new_make_tuple_inputs); + switch_layer_inputs.emplace_back(new_make_tuple); + auto new_switch_layer = graph->NewCNode(switch_layer_inputs); + cnode_inputs.emplace_back(new_switch_layer); + return cnode_inputs; +} + std::vector SessionBasic::CreateSwitchOrPartialNode(const CNodePtr &cnode, KernelGraph *graph) { MS_EXCEPTION_IF_NULL(cnode); MS_EXCEPTION_IF_NULL(graph); - // create primitive of cnode:call(partial or switch) + // create primitive of cnode:call(partial or switch or switch_layer) std::vector cnode_inputs = { graph->NewValueNode(NewValueNode(std::make_shared(prim::kPrimCall->name())))}; auto attr_input = cnode->input(kAnfPrimitiveIndex); @@ -748,9 +834,11 @@ std::vector SessionBasic::CreateSwitchOrPartialNode(const CNodePtr & return cnode_inputs; } else if (AnfAlgo::CheckPrimitiveType(cnode_input, prim::kPrimSwitch)) { return CreateCallSwitchInputs(cnode, graph); + } else if (AnfAlgo::CheckPrimitiveType(cnode_input, prim::kPrimSwitchLayer)) { + return CreateCallSwitchLayerInputs(cnode, graph); } MS_LOG(ERROR) << "CNode:" << cnode->DebugString() << " input[0]" << cnode_input->DebugString() - << "must be partial or switch."; + << "must be partial or switch or switch_layer."; return {}; } @@ -788,7 +876,7 @@ void SessionBasic::CreateCNodeInputs(const CNodePtr &cnode, KernelGraph *graph, cnode_inputs->emplace_back(graph->GetBackendAnfByFrontAnf(cnode->input(kFirstDataInputIndex))); for (size_t index = kFirstBranchInSwitch; index < cnode->inputs().size(); index++) { auto node_input = cnode->input(index); - auto switch_input = CreateSwitchInput(node_input, graph); + auto switch_input = CreateSwitchInput(cnode, node_input, graph); cnode_inputs->emplace_back(switch_input); } } else { @@ -841,10 +929,17 @@ CNodePtr SessionBasic::CreateNewCNode(CNodePtr cnode, KernelGraph *graph) { // if the cnode is call switch, remove call if (new_cnode->inputs().size() > 1) { auto first_input = new_cnode->input(kFirstDataInputIndex); + MS_EXCEPTION_IF_NULL(first_input); if (AnfAlgo::CheckPrimitiveType(new_cnode, prim::kPrimCall) && AnfAlgo::CheckPrimitiveType(first_input, prim::kPrimSwitch)) { new_cnode = first_input->cast(); } + if (AnfAlgo::CheckPrimitiveType(new_cnode, prim::kPrimCall) && + AnfAlgo::CheckPrimitiveType(first_input, prim::kPrimSwitchLayer)) { + auto abstract = cnode->abstract(); + new_cnode = first_input->cast(); + new_cnode->set_abstract(abstract); + } } return new_cnode; @@ -1842,7 +1937,7 @@ void SessionBasic::AssignParamKey(const KernelGraphPtr &kernel_graph) { MS_EXCEPTION_IF_NULL(kernel_graph); // PS embeddingLookup cache check. if (ps::PsDataPrefetch::GetInstance().cache_enable()) { - MS_LOG(EXCEPTION) << "The other parameter cann't set ps mode when the embeddingLookup cache is enabled in " + MS_LOG(EXCEPTION) << "The other parameter can't set ps mode when the embeddingLookup cache is enabled in " "parameter server training mode."; } std::vector node_list = TopoSort(kernel_graph->get_return()); diff --git a/mindspore/ccsrc/backend/session/session_basic.h b/mindspore/ccsrc/backend/session/session_basic.h index 7f7170e035..1d0b0c565e 100644 --- a/mindspore/ccsrc/backend/session/session_basic.h +++ b/mindspore/ccsrc/backend/session/session_basic.h @@ -125,7 +125,7 @@ class SessionBasic : public std::enable_shared_from_this { #endif private: - CNodePtr CreateSwitchInput(const AnfNodePtr &node_input, KernelGraph *graph); + CNodePtr CreateSwitchInput(const CNodePtr &cnode, const AnfNodePtr &node_input, KernelGraph *graph); std::vector CreateSwitchOrPartialNode(const CNodePtr &cnode, KernelGraph *graph); std::vector CreateValueNode(const CNodePtr &cnode, KernelGraph *graph); void CreateCNodeInputs(const CNodePtr &cnode, KernelGraph *graph, std::vector *cnode_inputs); @@ -133,6 +133,8 @@ class SessionBasic : public std::enable_shared_from_this { void GetCNodeInfo(const CNodePtr &cnode, std::vector *cnode_inputs); void GetNewCNodeInputs(const CNodePtr &cnode, KernelGraph *graph, std::vector *cnode_inputs, std::unordered_map *other_graph_cnode); + std::vector CreateCallSwitchLayerInputs(const CNodePtr &cnode, KernelGraph *graph); + void CreateCallNodeReturnFunction(const CNodePtr &cnode, const AnfNodePtr &real_input); protected: friend class Executor; diff --git a/mindspore/ccsrc/utils/utils.h b/mindspore/ccsrc/utils/utils.h index 6540ab19d6..266d8221c6 100644 --- a/mindspore/ccsrc/utils/utils.h +++ b/mindspore/ccsrc/utils/utils.h @@ -407,6 +407,7 @@ constexpr auto kFirstBranchInSwitch = 2; constexpr auto kCallKernelGraphIndex = 1; constexpr auto kSwitchTrueKernelGraphIndex = 2; constexpr auto kSwitchFalseKernelGraphIndex = 3; +constexpr auto kMakeTupleInSwitchLayerIndex = 2; // index define of control depend constexpr auto kControlDependPriorIndex = 1; constexpr auto kControlDependBehindIndex = 2;