diff --git a/mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc b/mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc index 36083ecc6a..e14d5ba16a 100644 --- a/mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc +++ b/mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc @@ -48,6 +48,7 @@ using kernel::KernelModPtr; namespace { constexpr size_t kNopNodeInputSize = 2; constexpr size_t kNopNodeRealInputIndex = 1; +constexpr size_t kReturnDataIndex = 1; using PrimitiveSet = std::unordered_set; @@ -1919,5 +1920,16 @@ void AnfRuntimeAlgorithm::InferShape(const CNodePtr &node) { auto eval_result = opt::CppInferShape(primitive, args_spec_list); node->set_abstract(eval_result); } + +void AnfRuntimeAlgorithm::InsertMakeTupleForOutput(NotNull root_graph) { + auto return_node = root_graph->get_return(); + MS_EXCEPTION_IF_NULL(return_node); + if (return_node->size() <= kReturnDataIndex) { + return; + } + auto make_tuple = root_graph->NewCNode( + {NewValueNode(std::make_shared(prim::kPrimMakeTuple->name())), root_graph->output()}); + root_graph->set_output(make_tuple); +} } // namespace session } // namespace mindspore diff --git a/mindspore/ccsrc/backend/session/anf_runtime_algorithm.h b/mindspore/ccsrc/backend/session/anf_runtime_algorithm.h index 57eb49e8b4..ad36329985 100644 --- a/mindspore/ccsrc/backend/session/anf_runtime_algorithm.h +++ b/mindspore/ccsrc/backend/session/anf_runtime_algorithm.h @@ -266,6 +266,7 @@ class AnfRuntimeAlgorithm { // Find real input nodes. static void GetAllFatherRealNode(const AnfNodePtr &anf_node, std::vector *result, std::set *visited); + static void InsertMakeTupleForOutput(NotNull root_graph); }; } // namespace session using AnfAlgo = session::AnfRuntimeAlgorithm; diff --git a/mindspore/ccsrc/backend/session/ascend_session.cc b/mindspore/ccsrc/backend/session/ascend_session.cc index 3f6105eca1..79657959bd 100644 --- a/mindspore/ccsrc/backend/session/ascend_session.cc +++ b/mindspore/ccsrc/backend/session/ascend_session.cc @@ -85,7 +85,6 @@ static constexpr uint32_t kLabelSwitchLabelId = 2; namespace mindspore { namespace session { const size_t kInvalidIndex = SIZE_MAX; -constexpr size_t kReturnDataIndex = 1; constexpr char SR_TAG[] = "sr_tag"; constexpr char BACKWARD[] = "backward"; namespace { @@ -143,16 +142,6 @@ std::vector GetCNodes(const std::vector &anf_nodes) { } return cnodes; } -void InsertMakeTupleForOutput(NotNull root_graph) { - auto return_node = root_graph->get_return(); - MS_EXCEPTION_IF_NULL(return_node); - if (return_node->size() <= kReturnDataIndex) { - return; - } - auto make_tuple = root_graph->NewCNode( - {NewValueNode(std::make_shared(prim::kPrimMakeTuple->name())), root_graph->output()}); - root_graph->set_output(make_tuple); -} TensorPtr GetCNodeOutputStubTensor(const KernelWithIndex &kernel_with_index, const std::map &node_output_info, @@ -483,7 +472,7 @@ GraphId AscendSession::CompileGraphImpl(NotNull func_graph) { // empty graph dont entry to backend if (root_graph->execution_order().empty()) { MS_LOG(INFO) << root_graph->ToString() << " is empty graph."; - InsertMakeTupleForOutput(NOT_NULL(root_graph)); + AnfAlgo::InsertMakeTupleForOutput(NOT_NULL(root_graph)); root_graph->set_executable(false); InitRuntimeResource(); return root_graph->graph_id(); @@ -511,7 +500,7 @@ GraphId AscendSession::CompileGraphImpl(NotNull func_graph) { UpdateRefOutputMap(NOT_NULL(root_graph), NOT_NULL(&memo)); memo.clear(); // add make_tuple to the output graph - InsertMakeTupleForOutput(NOT_NULL(root_graph)); + AnfAlgo::InsertMakeTupleForOutput(NOT_NULL(root_graph)); // root root_graph valiate,include genearte execute order and so on RootGraphExecutorValidate(NOT_NULL(root_graph)); // dump graph before remove nop nodes diff --git a/mindspore/ccsrc/backend/session/gpu_session.cc b/mindspore/ccsrc/backend/session/gpu_session.cc index 0b19137fb6..b537bf63da 100644 --- a/mindspore/ccsrc/backend/session/gpu_session.cc +++ b/mindspore/ccsrc/backend/session/gpu_session.cc @@ -314,7 +314,9 @@ GraphId GPUSession::CompileGraphImpl(NotNull func_graph) { if (all_graphs.size() != 1) { MS_LOG(EXCEPTION) << "Gpu backend does not support multi-graph schedule. graph num" << all_graphs.size(); } - + // Insert maketuple graph output in case of multi-outputs. + // The ConvertTupleOutputToMaketuple pass will insert TupleGetItem. + AnfAlgo::InsertMakeTupleForOutput(NOT_NULL(root_graph)); opt::BackendCommonOptimization(root_graph); return CompileGraphImpl(root_graph); }