|
|
|
@@ -892,8 +892,8 @@ std::vector<AnfNodePtr> SessionBasic::CreateCallSwitchInputs(const CNodePtr &cno |
|
|
|
return cnode_inputs; |
|
|
|
} |
|
|
|
|
|
|
|
void SessionBasic::CreateCallNodeReturnFunction(const CNodePtr &cnode, KernelGraph *graph, |
|
|
|
const std::vector<AnfNodePtr> &real_inputs) { |
|
|
|
void SessionBasic::ProcessNodeRetFunc(const CNodePtr &cnode, KernelGraph *graph, |
|
|
|
const std::vector<AnfNodePtr> &real_inputs) { |
|
|
|
MS_EXCEPTION_IF_NULL(cnode); |
|
|
|
// func1 =switch(branch1, branch2) |
|
|
|
// func2 = func1(param1) |
|
|
|
@@ -997,7 +997,7 @@ std::vector<AnfNodePtr> SessionBasic::CreateCallSwitchLayerInputs(const CNodePtr |
|
|
|
MS_EXCEPTION_IF_NULL(ret); |
|
|
|
auto return_input = ret->input(kFirstDataInputIndex); |
|
|
|
if (AnfAlgo::CheckPrimitiveType(return_input, prim::kPrimPartial) || return_input->isa<ValueNode>()) { |
|
|
|
CreateCallNodeReturnFunction(cnode, partial_kernel_graph.get(), real_inputs); |
|
|
|
ProcessNodeRetFunc(cnode, partial_kernel_graph.get(), real_inputs); |
|
|
|
} |
|
|
|
// partial node add input args |
|
|
|
new_partial_inputs.insert(new_partial_inputs.end(), real_inputs.begin(), real_inputs.end()); |
|
|
|
@@ -1006,7 +1006,11 @@ std::vector<AnfNodePtr> SessionBasic::CreateCallSwitchLayerInputs(const CNodePtr |
|
|
|
new_make_tuple_inputs.emplace_back(new_partial); |
|
|
|
} |
|
|
|
auto new_make_tuple = graph->NewCNode(new_make_tuple_inputs); |
|
|
|
new_make_tuple->set_abstract(make_tuple_node->abstract()); |
|
|
|
auto abstract = make_tuple_node->abstract(); |
|
|
|
if (abstract == nullptr) { |
|
|
|
abstract = std::make_shared<abstract::AbstractTuple>(AbstractBasePtrList()); |
|
|
|
} |
|
|
|
new_make_tuple->set_abstract(abstract); |
|
|
|
switch_layer_inputs.emplace_back(new_make_tuple); |
|
|
|
auto new_switch_layer = graph->NewCNode(switch_layer_inputs); |
|
|
|
cnode_inputs.emplace_back(new_switch_layer); |
|
|
|
|