From e22dbfac79282c8e26918166e3aa2de7cd04a01a Mon Sep 17 00:00:00 2001 From: kswang Date: Tue, 27 Oct 2020 16:20:17 +0800 Subject: [PATCH] extend nod euser --- .../ccsrc/backend/session/session_basic.cc | 28 +++++++++++++++---- 1 file changed, 23 insertions(+), 5 deletions(-) diff --git a/mindspore/ccsrc/backend/session/session_basic.cc b/mindspore/ccsrc/backend/session/session_basic.cc index 1467ff8c6f..3094f027fa 100644 --- a/mindspore/ccsrc/backend/session/session_basic.cc +++ b/mindspore/ccsrc/backend/session/session_basic.cc @@ -1186,6 +1186,25 @@ bool CNodeFirstInputIsPrimitive(const AnfNodePtr &node) { return true; } +std::vector ExtendNodeUsers(const FuncGraphManagerPtr &front_func_graph_manager, + const AnfNodePtr &front_node) { + auto node_users = front_func_graph_manager->node_users(); + auto users = node_users[front_node]; + std::vector result; + for (auto user : users) { + if (IsPrimitiveCNode(user.first, prim::kPrimControlDepend)) { + continue; + } + if (IsPrimitiveCNode(user.first, prim::kPrimDepend)) { + auto res = ExtendNodeUsers(front_func_graph_manager, user.first); + result.insert(result.end(), res.begin(), res.end()); + continue; + } + result.emplace_back(user.first); + } + return result; +} + void HandleInternalOutput(const AnfNodePtr &front_node, const AnfNodePtr &backend_node, const FuncGraphManagerPtr &front_func_graph_manager, const std::shared_ptr &backend_graph) { @@ -1193,8 +1212,6 @@ void HandleInternalOutput(const AnfNodePtr &front_node, const AnfNodePtr &backen if (!AnfAlgo::IsRealKernel(front_node) && !AnfAlgo::CheckPrimitiveType(front_node, prim::kPrimTupleGetItem)) { return; } - auto node_users = front_func_graph_manager->node_users(); - auto users = node_users[front_node]; auto front_real_kernel_pair = AnfAlgo::VisitKernel(front_node, 0); auto backend_real_kernel_pair = AnfAlgo::VisitKernel(backend_node, 0); @@ -1210,16 +1227,17 @@ void HandleInternalOutput(const AnfNodePtr &front_node, const AnfNodePtr &backen } } if (internal_output) { + auto users = ExtendNodeUsers(front_func_graph_manager, front_node); for (auto user : users) { - if (!CNodeFirstInputIsPrimitive(user.first)) { + if (!CNodeFirstInputIsPrimitive(user)) { internal_output = false; break; } - if (!AnfAlgo::IsRealKernel(user.first)) { + if (!AnfAlgo::IsRealKernel(user)) { internal_output = false; break; } - if (kernel_target != GetCNodeTarget(user.first)) { + if (kernel_target != GetCNodeTarget(user)) { unique_target = false; } }