From 6e99eab0887fe21216f0e8dac69064a66c66ff94 Mon Sep 17 00:00:00 2001 From: yujianfeng Date: Tue, 11 Aug 2020 21:18:22 +0800 Subject: [PATCH] Erase the nodes without primitive value node input from the internal outputs --- mindspore/ccsrc/backend/session/session_basic.cc | 14 ++++---------- 1 file changed, 4 insertions(+), 10 deletions(-) diff --git a/mindspore/ccsrc/backend/session/session_basic.cc b/mindspore/ccsrc/backend/session/session_basic.cc index a357b6f712..26a53b45a5 100644 --- a/mindspore/ccsrc/backend/session/session_basic.cc +++ b/mindspore/ccsrc/backend/session/session_basic.cc @@ -1007,7 +1007,7 @@ void SessionBasic::Summary(KernelGraph *graph) { } namespace { -bool CNodePrimIsValueNode(const AnfNodePtr &node) { +bool CNodeFirstInputIsPrimitive(const AnfNodePtr &node) { if (node == nullptr) { return false; } @@ -1016,7 +1016,7 @@ bool CNodePrimIsValueNode(const AnfNodePtr &node) { return false; } auto prim = cnode->input(kAnfPrimitiveIndex); - if (prim == nullptr || !prim->isa()) { + if (prim == nullptr || !IsValueNode(prim)) { return false; } return true; @@ -1032,7 +1032,7 @@ void HandleInternalOutput(const AnfNodePtr &front_node, const AnfNodePtr &backen auto front_real_kernel = front_real_kernel_pair.first; std::string kernel_target = GetCNodeTarget(front_real_kernel); - bool internal_output = CNodePrimIsValueNode(front_real_kernel); + bool internal_output = CNodeFirstInputIsPrimitive(front_real_kernel); bool unique_target = true; if (internal_output && opt::IsNopNode(front_real_kernel)) { auto pre_node_pair = AnfAlgo::GetPrevNodeOutput(front_real_kernel, 0); @@ -1043,13 +1043,7 @@ void HandleInternalOutput(const AnfNodePtr &front_node, const AnfNodePtr &backen } if (internal_output) { for (auto user : users) { - auto cnode = user.first->cast(); - if (cnode == nullptr) { - internal_output = false; - break; - } - auto prim = cnode->input(kAnfPrimitiveIndex); - if (prim == nullptr || !prim->isa()) { + if (!CNodeFirstInputIsPrimitive(user.first)) { internal_output = false; break; }