Browse Source

add make_tuple before reture as graph outputs in ConstructKernelGraph

tags/v0.3.0-alpha
wenchunjiang 5 years ago
parent
commit
245ab3199b
2 changed files with 11 additions and 5 deletions
  1. +10
    -5
      mindspore/ccsrc/session/session_basic.cc
  2. +1
    -0
      mindspore/ccsrc/utils/utils.h

+ 10
- 5
mindspore/ccsrc/session/session_basic.cc View File

@@ -646,6 +646,16 @@ std::shared_ptr<KernelGraph> SessionBasic::ConstructKernelGraph(const FuncGraphP
MS_EXCEPTION_IF_NULL(func_graph_node);
auto sub_func_graph = AnfAlgo::GetValueNodeFuncGraph(func_graph_node);
ConstructKernelGraph(sub_func_graph);
} else if (prim->name() == kReturnOpName) {
std::vector<AnfNodePtr> outputs;
auto inputs = cnode->inputs();
if (inputs.size() < 2) {
MS_LOG(EXCEPTION) << "CNode[return] must have two inputs at least, actual inputs size is " << inputs.size();
}
(void)std::copy(inputs.begin() + 1, inputs.end(), std::back_inserter(outputs));
// add a make_tuple before return as graph output
graph->set_output(ConstructOutput(outputs, graph));
continue;
}
}

@@ -655,11 +665,6 @@ std::shared_ptr<KernelGraph> SessionBasic::ConstructKernelGraph(const FuncGraphP
new_cnode->set_abstract(cnode->abstract());
new_cnode->set_scope(cnode->scope());
graph->FrontBackendlMapAdd(node, new_cnode);

// set original return to kernel_graph
if (IsPrimitive(new_cnode->input(kAnfPrimitiveIndex), prim::kPrimReturn)) {
graph->set_return(new_cnode);
}
}
}



+ 1
- 0
mindspore/ccsrc/utils/utils.h View File

@@ -144,6 +144,7 @@ constexpr auto kBNInferGradOpName = "BNInferGrad";
constexpr auto kCallOpName = "call";
constexpr auto kPartialOpName = "partial";
constexpr auto kSwitchOpName = "switch";
constexpr auto kReturnOpName = "return";
constexpr auto kLarsV2OpName = "LarsV2";
constexpr auto kLarsV2UpdateOpName = "LarsV2Update";
constexpr auto kSquareSumAllOpName = "SquareSumAll";


Loading…
Cancel
Save