Browse Source

!14832 Add abstract for maketuple

From: @youui
Reviewed-by: @liangchenghui,@zhoufeng54
Signed-off-by: @liangchenghui
pull/14832/MERGE
mindspore-ci-bot Gitee 4 years ago
parent
commit
9545e96ee9
2 changed files with 9 additions and 6 deletions
  1. +8
    -4
      mindspore/ccsrc/backend/session/session_basic.cc
  2. +1
    -2
      mindspore/ccsrc/backend/session/session_basic.h

+ 8
- 4
mindspore/ccsrc/backend/session/session_basic.cc View File

@@ -892,8 +892,8 @@ std::vector<AnfNodePtr> SessionBasic::CreateCallSwitchInputs(const CNodePtr &cno
return cnode_inputs;
}

void SessionBasic::CreateCallNodeReturnFunction(const CNodePtr &cnode, KernelGraph *graph,
const std::vector<AnfNodePtr> &real_inputs) {
void SessionBasic::ProcessNodeRetFunc(const CNodePtr &cnode, KernelGraph *graph,
const std::vector<AnfNodePtr> &real_inputs) {
MS_EXCEPTION_IF_NULL(cnode);
// func1 =switch(branch1, branch2)
// func2 = func1(param1)
@@ -997,7 +997,7 @@ std::vector<AnfNodePtr> SessionBasic::CreateCallSwitchLayerInputs(const CNodePtr
MS_EXCEPTION_IF_NULL(ret);
auto return_input = ret->input(kFirstDataInputIndex);
if (AnfAlgo::CheckPrimitiveType(return_input, prim::kPrimPartial) || return_input->isa<ValueNode>()) {
CreateCallNodeReturnFunction(cnode, partial_kernel_graph.get(), real_inputs);
ProcessNodeRetFunc(cnode, partial_kernel_graph.get(), real_inputs);
}
// partial node add input args
new_partial_inputs.insert(new_partial_inputs.end(), real_inputs.begin(), real_inputs.end());
@@ -1006,7 +1006,11 @@ std::vector<AnfNodePtr> SessionBasic::CreateCallSwitchLayerInputs(const CNodePtr
new_make_tuple_inputs.emplace_back(new_partial);
}
auto new_make_tuple = graph->NewCNode(new_make_tuple_inputs);
new_make_tuple->set_abstract(make_tuple_node->abstract());
auto abstract = make_tuple_node->abstract();
if (abstract == nullptr) {
abstract = std::make_shared<abstract::AbstractTuple>(AbstractBasePtrList());
}
new_make_tuple->set_abstract(abstract);
switch_layer_inputs.emplace_back(new_make_tuple);
auto new_switch_layer = graph->NewCNode(switch_layer_inputs);
cnode_inputs.emplace_back(new_switch_layer);


+ 1
- 2
mindspore/ccsrc/backend/session/session_basic.h View File

@@ -155,8 +155,7 @@ class SessionBasic : public std::enable_shared_from_this<SessionBasic> {
void GetNewCNodeInputs(const CNodePtr &cnode, KernelGraph *graph, std::vector<AnfNodePtr> *cnode_inputs,
std::unordered_map<AnfNodePtr, AnfNodePtr> *other_graph_cnode);
std::vector<AnfNodePtr> CreateCallSwitchLayerInputs(const CNodePtr &cnode, KernelGraph *graph);
void CreateCallNodeReturnFunction(const CNodePtr &cnode, KernelGraph *graph,
const std::vector<AnfNodePtr> &real_inputs);
void ProcessNodeRetFunc(const CNodePtr &cnode, KernelGraph *graph, const std::vector<AnfNodePtr> &real_inputs);

protected:
friend class Executor;


Loading…
Cancel
Save