| @@ -980,8 +980,9 @@ CNodePtr SessionBasic::ConstructOutput(const AnfNodePtrList &outputs, const std: | |||
| if (context_ptr->execution_mode() == kPynativeMode) { | |||
| return backend_anf; | |||
| } | |||
| auto front_real_kernel = AnfAlgo::VisitKernel(out, 0); | |||
| auto backend_real_kernel = AnfAlgo::VisitKernel(backend_anf, 0); | |||
| auto front_real_kernel_pair = AnfAlgo::VisitKernel(out, 0); | |||
| auto front_real_kernel = front_real_kernel_pair.first; | |||
| auto backend_real_kernel_pair = AnfAlgo::VisitKernel(backend_anf, 0); | |||
| MS_EXCEPTION_IF_NULL(out); | |||
| auto out_func_graph = out->func_graph(); | |||
| MS_EXCEPTION_IF_NULL(out_func_graph); | |||
| @@ -992,26 +993,47 @@ CNodePtr SessionBasic::ConstructOutput(const AnfNodePtrList &outputs, const std: | |||
| auto node_users = out_func_graph_manager->node_users(); | |||
| auto users = node_users[out]; | |||
| bool internal_output = true; | |||
| std::string kernel_target = GetCNodeTarget(front_real_kernel.first); | |||
| for (auto user : users) { | |||
| auto cnode = user.first->cast<CNodePtr>(); | |||
| if (cnode == nullptr) { | |||
| std::string kernel_target = GetCNodeTarget(front_real_kernel); | |||
| if (front_real_kernel != nullptr && front_real_kernel->isa<CNode>()) { | |||
| auto front_cnode = front_real_kernel->cast<CNodePtr>(); | |||
| if (front_cnode != nullptr) { | |||
| auto prim = front_cnode->input(kAnfPrimitiveIndex); | |||
| if (prim == nullptr || !prim->isa<ValueNode>()) { | |||
| internal_output = false; | |||
| } | |||
| } else { | |||
| internal_output = false; | |||
| break; | |||
| } | |||
| auto prim = cnode->input(kAnfPrimitiveIndex); | |||
| if (prim == nullptr || !prim->isa<ValueNode>()) { | |||
| } | |||
| if (internal_output && opt::IsNopNode(front_real_kernel)) { | |||
| auto pre_node_pair = AnfAlgo::GetPrevNodeOutput(front_real_kernel, 0); | |||
| auto pre_node_target = GetCNodeTarget(pre_node_pair.first); | |||
| if (pre_node_target != kernel_target) { | |||
| internal_output = false; | |||
| break; | |||
| } | |||
| if (!AnfAlgo::IsRealKernel(user.first) || kernel_target != GetCNodeTarget(user.first)) { | |||
| internal_output = false; | |||
| break; | |||
| } | |||
| if (internal_output) { | |||
| for (auto user : users) { | |||
| auto cnode = user.first->cast<CNodePtr>(); | |||
| if (cnode == nullptr) { | |||
| internal_output = false; | |||
| break; | |||
| } | |||
| auto prim = cnode->input(kAnfPrimitiveIndex); | |||
| if (prim == nullptr || !prim->isa<ValueNode>()) { | |||
| internal_output = false; | |||
| break; | |||
| } | |||
| if (!AnfAlgo::IsRealKernel(user.first) || kernel_target != GetCNodeTarget(user.first)) { | |||
| internal_output = false; | |||
| break; | |||
| } | |||
| } | |||
| } | |||
| if (internal_output) { | |||
| MS_LOG(INFO) << "Internal output1: " << out->DebugString() << "To " << backend_real_kernel.first->DebugString(); | |||
| graph->AddInternalOutput(out, backend_real_kernel.first); | |||
| MS_LOG(INFO) << "Internal output: " << out->DebugString() << "To " | |||
| << backend_real_kernel_pair.first->DebugString(); | |||
| graph->AddInternalOutput(out, backend_real_kernel_pair.first); | |||
| } | |||
| return backend_anf; | |||
| } | |||