From ae675c5cf8cbc670a469d74359c2f542a3e56c34 Mon Sep 17 00:00:00 2001 From: kswang Date: Tue, 14 Apr 2020 14:15:44 +0800 Subject: [PATCH] fix nopnode output bug --- mindspore/ccsrc/device/kernel_runtime.cc | 2 +- mindspore/ccsrc/session/anf_runtime_algorithm.cc | 12 +++++++----- mindspore/ccsrc/session/anf_runtime_algorithm.h | 1 + 3 files changed, 9 insertions(+), 6 deletions(-) diff --git a/mindspore/ccsrc/device/kernel_runtime.cc b/mindspore/ccsrc/device/kernel_runtime.cc index eebc650347..e68ad22bbd 100644 --- a/mindspore/ccsrc/device/kernel_runtime.cc +++ b/mindspore/ccsrc/device/kernel_runtime.cc @@ -250,7 +250,7 @@ void KernelRuntime::AssignStaticMemoryOutput(const session::KernelGraph *graph) MS_EXCEPTION_IF_NULL(graph); auto nodes = AnfAlgo::GetAllOutput(graph->output(), {prim::kPrimTupleGetItem}); for (const auto &node : nodes) { - auto item_with_index = AnfAlgo::VisitKernelWithReturnType(node, 0); + auto item_with_index = AnfAlgo::VisitKernelWithReturnType(node, 0, true); MS_EXCEPTION_IF_NULL(item_with_index.first); if (!item_with_index.first->isa() || !AnfAlgo::IsRealKernel(item_with_index.first)) { continue; diff --git a/mindspore/ccsrc/session/anf_runtime_algorithm.cc b/mindspore/ccsrc/session/anf_runtime_algorithm.cc index 893c379a07..e355c7885d 100644 --- a/mindspore/ccsrc/session/anf_runtime_algorithm.cc +++ b/mindspore/ccsrc/session/anf_runtime_algorithm.cc @@ -84,6 +84,7 @@ KernelWithIndex AnfRuntimeAlgorithm::VisitKernel(const AnfNodePtr &anf_node, siz } KernelWithIndex AnfRuntimeAlgorithm::VisitKernelWithReturnType(const AnfNodePtr &anf_node, size_t index, + bool visit_nop_node, const std::vector &return_types) { MS_EXCEPTION_IF_NULL(anf_node); for (const auto &prim_type : return_types) { @@ -109,12 +110,13 @@ KernelWithIndex AnfRuntimeAlgorithm::VisitKernelWithReturnType(const AnfNodePtr auto value_node = input2->cast(); MS_EXCEPTION_IF_NULL(value_node); int item_idx = GetValue(value_node->value()); - return VisitKernelWithReturnType(cnode->input(kRealInputNodeIndexInTupleGetItem), IntToSize(item_idx)); + return VisitKernelWithReturnType(cnode->input(kRealInputNodeIndexInTupleGetItem), IntToSize(item_idx), + visit_nop_node); } else if (IsPrimitive(input0, prim::kPrimDepend) || IsPrimitive(input0, prim::kPrimControlDepend)) { - return VisitKernelWithReturnType(cnode->input(kRealInputIndexInDepend), 0); - } else if (opt::IsNopNode(cnode)) { + return VisitKernelWithReturnType(cnode->input(kRealInputIndexInDepend), 0, visit_nop_node); + } else if (opt::IsNopNode(cnode) && visit_nop_node) { if (cnode->inputs().size() == 2) { - return VisitKernelWithReturnType(cnode->input(1), 0); + return VisitKernelWithReturnType(cnode->input(1), 0, visit_nop_node); } else { MS_LOG(EXCEPTION) << cnode->DebugString() << "Invalid nop node"; } @@ -132,7 +134,7 @@ std::vector AnfRuntimeAlgorithm::GetAllOutput(const AnfNodePtr &node auto return_prim_type = return_types; // if visited make_tuple should return back return_prim_type.push_back(prim::kPrimMakeTuple); - auto item_with_index = AnfAlgo::VisitKernelWithReturnType(node, 0, return_prim_type); + auto item_with_index = AnfAlgo::VisitKernelWithReturnType(node, 0, false, return_prim_type); if (AnfAlgo::CheckPrimitiveType(item_with_index.first, prim::kPrimMakeTuple)) { MS_EXCEPTION_IF_NULL(item_with_index.first); auto make_tuple = item_with_index.first->cast(); diff --git a/mindspore/ccsrc/session/anf_runtime_algorithm.h b/mindspore/ccsrc/session/anf_runtime_algorithm.h index 1a1d471b84..233f86410c 100644 --- a/mindspore/ccsrc/session/anf_runtime_algorithm.h +++ b/mindspore/ccsrc/session/anf_runtime_algorithm.h @@ -41,6 +41,7 @@ class AnfRuntimeAlgorithm { // get input_anf_node's real kernel by recurse static KernelWithIndex VisitKernel(const AnfNodePtr &input_anf_node, size_t output_index); static KernelWithIndex VisitKernelWithReturnType(const AnfNodePtr &input_anf_node, size_t output_index, + bool visit_nop_node = false, const std::vector &return_types = { prim::kPrimMakeTuple}); static std::vector GetAllOutput(const AnfNodePtr &node,