diff --git a/mindspore/ccsrc/backend/session/session_basic.cc b/mindspore/ccsrc/backend/session/session_basic.cc index 17c91f61a3..bd8da5a5e3 100644 --- a/mindspore/ccsrc/backend/session/session_basic.cc +++ b/mindspore/ccsrc/backend/session/session_basic.cc @@ -892,8 +892,8 @@ std::vector SessionBasic::CreateCallSwitchInputs(const CNodePtr &cno return cnode_inputs; } -void SessionBasic::CreateCallNodeReturnFunction(const CNodePtr &cnode, KernelGraph *graph, - const std::vector &real_inputs) { +void SessionBasic::ProcessNodeRetFunc(const CNodePtr &cnode, KernelGraph *graph, + const std::vector &real_inputs) { MS_EXCEPTION_IF_NULL(cnode); // func1 =switch(branch1, branch2) // func2 = func1(param1) @@ -997,7 +997,7 @@ std::vector SessionBasic::CreateCallSwitchLayerInputs(const CNodePtr MS_EXCEPTION_IF_NULL(ret); auto return_input = ret->input(kFirstDataInputIndex); if (AnfAlgo::CheckPrimitiveType(return_input, prim::kPrimPartial) || return_input->isa()) { - CreateCallNodeReturnFunction(cnode, partial_kernel_graph.get(), real_inputs); + ProcessNodeRetFunc(cnode, partial_kernel_graph.get(), real_inputs); } // partial node add input args new_partial_inputs.insert(new_partial_inputs.end(), real_inputs.begin(), real_inputs.end()); @@ -1006,7 +1006,11 @@ std::vector SessionBasic::CreateCallSwitchLayerInputs(const CNodePtr new_make_tuple_inputs.emplace_back(new_partial); } auto new_make_tuple = graph->NewCNode(new_make_tuple_inputs); - new_make_tuple->set_abstract(make_tuple_node->abstract()); + auto abstract = make_tuple_node->abstract(); + if (abstract == nullptr) { + abstract = std::make_shared(AbstractBasePtrList()); + } + new_make_tuple->set_abstract(abstract); switch_layer_inputs.emplace_back(new_make_tuple); auto new_switch_layer = graph->NewCNode(switch_layer_inputs); cnode_inputs.emplace_back(new_switch_layer); diff --git a/mindspore/ccsrc/backend/session/session_basic.h b/mindspore/ccsrc/backend/session/session_basic.h index e15f54dda0..945f249d01 100644 --- a/mindspore/ccsrc/backend/session/session_basic.h +++ b/mindspore/ccsrc/backend/session/session_basic.h @@ -155,8 +155,7 @@ class SessionBasic : public std::enable_shared_from_this { void GetNewCNodeInputs(const CNodePtr &cnode, KernelGraph *graph, std::vector *cnode_inputs, std::unordered_map *other_graph_cnode); std::vector CreateCallSwitchLayerInputs(const CNodePtr &cnode, KernelGraph *graph); - void CreateCallNodeReturnFunction(const CNodePtr &cnode, KernelGraph *graph, - const std::vector &real_inputs); + void ProcessNodeRetFunc(const CNodePtr &cnode, KernelGraph *graph, const std::vector &real_inputs); protected: friend class Executor;