| @@ -1186,6 +1186,25 @@ bool CNodeFirstInputIsPrimitive(const AnfNodePtr &node) { | |||||
| return true; | return true; | ||||
| } | } | ||||
| std::vector<AnfNodePtr> ExtendNodeUsers(const FuncGraphManagerPtr &front_func_graph_manager, | |||||
| const AnfNodePtr &front_node) { | |||||
| auto node_users = front_func_graph_manager->node_users(); | |||||
| auto users = node_users[front_node]; | |||||
| std::vector<AnfNodePtr> result; | |||||
| for (auto user : users) { | |||||
| if (IsPrimitiveCNode(user.first, prim::kPrimControlDepend)) { | |||||
| continue; | |||||
| } | |||||
| if (IsPrimitiveCNode(user.first, prim::kPrimDepend)) { | |||||
| auto res = ExtendNodeUsers(front_func_graph_manager, user.first); | |||||
| result.insert(result.end(), res.begin(), res.end()); | |||||
| continue; | |||||
| } | |||||
| result.emplace_back(user.first); | |||||
| } | |||||
| return result; | |||||
| } | |||||
| void HandleInternalOutput(const AnfNodePtr &front_node, const AnfNodePtr &backend_node, | void HandleInternalOutput(const AnfNodePtr &front_node, const AnfNodePtr &backend_node, | ||||
| const FuncGraphManagerPtr &front_func_graph_manager, | const FuncGraphManagerPtr &front_func_graph_manager, | ||||
| const std::shared_ptr<KernelGraph> &backend_graph) { | const std::shared_ptr<KernelGraph> &backend_graph) { | ||||
| @@ -1193,8 +1212,6 @@ void HandleInternalOutput(const AnfNodePtr &front_node, const AnfNodePtr &backen | |||||
| if (!AnfAlgo::IsRealKernel(front_node) && !AnfAlgo::CheckPrimitiveType(front_node, prim::kPrimTupleGetItem)) { | if (!AnfAlgo::IsRealKernel(front_node) && !AnfAlgo::CheckPrimitiveType(front_node, prim::kPrimTupleGetItem)) { | ||||
| return; | return; | ||||
| } | } | ||||
| auto node_users = front_func_graph_manager->node_users(); | |||||
| auto users = node_users[front_node]; | |||||
| auto front_real_kernel_pair = AnfAlgo::VisitKernel(front_node, 0); | auto front_real_kernel_pair = AnfAlgo::VisitKernel(front_node, 0); | ||||
| auto backend_real_kernel_pair = AnfAlgo::VisitKernel(backend_node, 0); | auto backend_real_kernel_pair = AnfAlgo::VisitKernel(backend_node, 0); | ||||
| @@ -1210,16 +1227,17 @@ void HandleInternalOutput(const AnfNodePtr &front_node, const AnfNodePtr &backen | |||||
| } | } | ||||
| } | } | ||||
| if (internal_output) { | if (internal_output) { | ||||
| auto users = ExtendNodeUsers(front_func_graph_manager, front_node); | |||||
| for (auto user : users) { | for (auto user : users) { | ||||
| if (!CNodeFirstInputIsPrimitive(user.first)) { | |||||
| if (!CNodeFirstInputIsPrimitive(user)) { | |||||
| internal_output = false; | internal_output = false; | ||||
| break; | break; | ||||
| } | } | ||||
| if (!AnfAlgo::IsRealKernel(user.first)) { | |||||
| if (!AnfAlgo::IsRealKernel(user)) { | |||||
| internal_output = false; | internal_output = false; | ||||
| break; | break; | ||||
| } | } | ||||
| if (kernel_target != GetCNodeTarget(user.first)) { | |||||
| if (kernel_target != GetCNodeTarget(user)) { | |||||
| unique_target = false; | unique_target = false; | ||||
| } | } | ||||
| } | } | ||||