diff --git a/mindspore/ccsrc/backend/session/ascend_session.cc b/mindspore/ccsrc/backend/session/ascend_session.cc index 066bc6682b..e718d03372 100644 --- a/mindspore/ccsrc/backend/session/ascend_session.cc +++ b/mindspore/ccsrc/backend/session/ascend_session.cc @@ -712,20 +712,6 @@ void AscendSession::SetSummaryNodes(KernelGraph *graph) { MS_LOG(DEBUG) << "Update summary end size: " << summary.size(); } -GraphId AscendSession::GetGraphIdByNode(const AnfNodePtr &front_anf) const { - for (const auto &graph_item : graphs_) { - auto graph = graph_item.second; - MS_EXCEPTION_IF_NULL(graph); - // if front_anf is a parameter,the backend parameter may have two - if (graph->GetBackendAnfByFrontAnf(front_anf) != nullptr) { - return graph_item.first; - } - } - MS_EXCEPTION_IF_NULL(front_anf); - MS_LOG(DEBUG) << "Front_anf " << front_anf->DebugString() << " is not exist in any graph"; - return kInvalidGraphId; -} - void AscendSession::MergeGraphExecOrder() { MS_LOG(INFO) << "Start!"; // merge graph order diff --git a/mindspore/ccsrc/backend/session/ascend_session.h b/mindspore/ccsrc/backend/session/ascend_session.h index 0738bbe163..cce006acce 100644 --- a/mindspore/ccsrc/backend/session/ascend_session.h +++ b/mindspore/ccsrc/backend/session/ascend_session.h @@ -62,8 +62,6 @@ class AscendSession : public SessionBasic { } } - // get graph id in child graphs by ME front anf node pointer - GraphId GetGraphIdByNode(const AnfNodePtr &front_anf) const override; // get graph id of final graph GraphId GetFinalRunGraph() const override { return final_graph_id_; } diff --git a/mindspore/ccsrc/backend/session/session_basic.cc b/mindspore/ccsrc/backend/session/session_basic.cc index a3f2880355..da25eb0b10 100644 --- a/mindspore/ccsrc/backend/session/session_basic.cc +++ b/mindspore/ccsrc/backend/session/session_basic.cc @@ -341,6 +341,20 @@ void SessionBasic::InitDevice(const std::string &device_name, uint32_t device_id executor_ = ExecutorManager::Instance().GetExecutor(device_name, device_id); } +GraphId SessionBasic::GetGraphIdByNode(const AnfNodePtr &front_anf) const { + for (const auto &graph_item : graphs_) { + auto graph = graph_item.second; + MS_EXCEPTION_IF_NULL(graph); + // if front_anf is a parameter,the backend parameter may have two + if (graph->GetBackendAnfByFrontAnf(front_anf) != nullptr) { + return graph_item.first; + } + } + MS_EXCEPTION_IF_NULL(front_anf); + MS_LOG(DEBUG) << "Front_anf " << front_anf->DebugString() << " is not exist in any graph"; + return kInvalidGraphId; +} + KernelGraphPtr SessionBasic::GetGraph(mindspore::GraphId graph_id) const { auto it = graphs_.find(graph_id); if (it == graphs_.end()) { @@ -1216,16 +1230,41 @@ std::vector ExtendNodeUsers(const FuncGraphManagerPtr &front_func_gr return result; } -void HandleInternalOutput(const AnfNodePtr &front_node, const AnfNodePtr &backend_node, +AnfNodePtr GetSupportedInternalNode(const AnfNodePtr &front_node) { + MS_EXCEPTION_IF_NULL(front_node); + if (!front_node->isa()) { + return nullptr; + } + if (AnfAlgo::IsRealKernel(front_node)) { + return front_node; + } + if (AnfAlgo::CheckPrimitiveType(front_node, prim::kPrimTupleGetItem)) { + return front_node; + } + if (AnfAlgo::CheckPrimitiveType(front_node, prim::kPrimDepend)) { + auto cnode = front_node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + auto &inputs = cnode->inputs(); + if (inputs.size() > 2) { + return GetSupportedInternalNode(inputs[1]); + } + } + return nullptr; +} + +void HandleInternalOutput(const AnfNodePtr &input_front_node, const AnfNodePtr &backend_node, const FuncGraphManagerPtr &front_func_graph_manager, const std::shared_ptr &backend_graph) { - // When init parameter from cnode of other graphs, the cnode will not be real kernel except for tuple_getitem. - if (!AnfAlgo::IsRealKernel(front_node) && !AnfAlgo::CheckPrimitiveType(front_node, prim::kPrimTupleGetItem)) { + auto front_node = GetSupportedInternalNode(input_front_node); + if (front_node == nullptr) { return; } auto front_real_kernel_pair = AnfAlgo::VisitKernel(front_node, 0); auto backend_real_kernel_pair = AnfAlgo::VisitKernel(backend_node, 0); - + auto backend_real_kernel = backend_real_kernel_pair.first; + if (backend_real_kernel == nullptr || !backend_real_kernel->isa()) { + return; + } auto front_real_kernel = front_real_kernel_pair.first; std::string kernel_target = GetCNodeTarget(front_real_kernel); bool internal_output = CNodeFirstInputIsPrimitive(front_real_kernel); @@ -1254,10 +1293,9 @@ void HandleInternalOutput(const AnfNodePtr &front_node, const AnfNodePtr &backen } } if (internal_output) { - MS_LOG(INFO) << "Internal output: " << front_node->DebugString() << " To " - << backend_real_kernel_pair.first->DebugString() << ", unique_target: " << unique_target; - backend_graph->AddInternalOutput(front_node, backend_real_kernel_pair.first, backend_real_kernel_pair.second, - unique_target); + MS_LOG(INFO) << "AddInternalOutput: " << front_node->DebugString() << " To " << backend_real_kernel->DebugString() + << ", unique_target: " << unique_target; + backend_graph->AddInternalOutput(front_node, backend_real_kernel, backend_real_kernel_pair.second, unique_target); } } } // namespace diff --git a/mindspore/ccsrc/backend/session/session_basic.h b/mindspore/ccsrc/backend/session/session_basic.h index 158fe6c0a8..9553d2fcd7 100644 --- a/mindspore/ccsrc/backend/session/session_basic.h +++ b/mindspore/ccsrc/backend/session/session_basic.h @@ -90,7 +90,7 @@ class SessionBasic : public std::enable_shared_from_this { CNodePtr CreateNewCNode(CNodePtr cnode, KernelGraph *graph); // get graph id in child graphs by ME front anf node pointer - virtual GraphId GetGraphIdByNode(const AnfNodePtr &) const { return kInvalidGraphId; } + virtual GraphId GetGraphIdByNode(const AnfNodePtr &) const; virtual GraphId GetFinalRunGraph() const { return kInvalidGraphId; } void CheckPSModeConsistence(const KernelGraphPtr &Kernel_graph); void AssignParamKey(const KernelGraphPtr &kernel_graph);