Browse Source

fix nopnode output bug

tags/v0.2.0-alpha
kswang 5 years ago
parent
commit
ae675c5cf8
3 changed files with 9 additions and 6 deletions
  1. +1
    -1
      mindspore/ccsrc/device/kernel_runtime.cc
  2. +7
    -5
      mindspore/ccsrc/session/anf_runtime_algorithm.cc
  3. +1
    -0
      mindspore/ccsrc/session/anf_runtime_algorithm.h

+ 1
- 1
mindspore/ccsrc/device/kernel_runtime.cc View File

@@ -250,7 +250,7 @@ void KernelRuntime::AssignStaticMemoryOutput(const session::KernelGraph *graph)
MS_EXCEPTION_IF_NULL(graph);
auto nodes = AnfAlgo::GetAllOutput(graph->output(), {prim::kPrimTupleGetItem});
for (const auto &node : nodes) {
auto item_with_index = AnfAlgo::VisitKernelWithReturnType(node, 0);
auto item_with_index = AnfAlgo::VisitKernelWithReturnType(node, 0, true);
MS_EXCEPTION_IF_NULL(item_with_index.first);
if (!item_with_index.first->isa<CNode>() || !AnfAlgo::IsRealKernel(item_with_index.first)) {
continue;


+ 7
- 5
mindspore/ccsrc/session/anf_runtime_algorithm.cc View File

@@ -84,6 +84,7 @@ KernelWithIndex AnfRuntimeAlgorithm::VisitKernel(const AnfNodePtr &anf_node, siz
}

KernelWithIndex AnfRuntimeAlgorithm::VisitKernelWithReturnType(const AnfNodePtr &anf_node, size_t index,
bool visit_nop_node,
const std::vector<PrimitivePtr> &return_types) {
MS_EXCEPTION_IF_NULL(anf_node);
for (const auto &prim_type : return_types) {
@@ -109,12 +110,13 @@ KernelWithIndex AnfRuntimeAlgorithm::VisitKernelWithReturnType(const AnfNodePtr
auto value_node = input2->cast<ValueNodePtr>();
MS_EXCEPTION_IF_NULL(value_node);
int item_idx = GetValue<int>(value_node->value());
return VisitKernelWithReturnType(cnode->input(kRealInputNodeIndexInTupleGetItem), IntToSize(item_idx));
return VisitKernelWithReturnType(cnode->input(kRealInputNodeIndexInTupleGetItem), IntToSize(item_idx),
visit_nop_node);
} else if (IsPrimitive(input0, prim::kPrimDepend) || IsPrimitive(input0, prim::kPrimControlDepend)) {
return VisitKernelWithReturnType(cnode->input(kRealInputIndexInDepend), 0);
} else if (opt::IsNopNode(cnode)) {
return VisitKernelWithReturnType(cnode->input(kRealInputIndexInDepend), 0, visit_nop_node);
} else if (opt::IsNopNode(cnode) && visit_nop_node) {
if (cnode->inputs().size() == 2) {
return VisitKernelWithReturnType(cnode->input(1), 0);
return VisitKernelWithReturnType(cnode->input(1), 0, visit_nop_node);
} else {
MS_LOG(EXCEPTION) << cnode->DebugString() << "Invalid nop node";
}
@@ -132,7 +134,7 @@ std::vector<AnfNodePtr> AnfRuntimeAlgorithm::GetAllOutput(const AnfNodePtr &node
auto return_prim_type = return_types;
// if visited make_tuple should return back
return_prim_type.push_back(prim::kPrimMakeTuple);
auto item_with_index = AnfAlgo::VisitKernelWithReturnType(node, 0, return_prim_type);
auto item_with_index = AnfAlgo::VisitKernelWithReturnType(node, 0, false, return_prim_type);
if (AnfAlgo::CheckPrimitiveType(item_with_index.first, prim::kPrimMakeTuple)) {
MS_EXCEPTION_IF_NULL(item_with_index.first);
auto make_tuple = item_with_index.first->cast<CNodePtr>();


+ 1
- 0
mindspore/ccsrc/session/anf_runtime_algorithm.h View File

@@ -41,6 +41,7 @@ class AnfRuntimeAlgorithm {
// get input_anf_node's real kernel by recurse
static KernelWithIndex VisitKernel(const AnfNodePtr &input_anf_node, size_t output_index);
static KernelWithIndex VisitKernelWithReturnType(const AnfNodePtr &input_anf_node, size_t output_index,
bool visit_nop_node = false,
const std::vector<PrimitivePtr> &return_types = {
prim::kPrimMakeTuple});
static std::vector<AnfNodePtr> GetAllOutput(const AnfNodePtr &node,


Loading…
Cancel
Save