|
|
|
@@ -934,46 +934,60 @@ std::vector<AnfNodePtr> SessionBasic::CreateCallSwitchInputs(const CNodePtr &cno |
|
|
|
return cnode_inputs; |
|
|
|
} |
|
|
|
|
|
|
|
void SessionBasic::CreateCallNodeReturnFunction(const CNodePtr &cnode, const std::vector<AnfNodePtr> &real_inputs) { |
|
|
|
void SessionBasic::CreateCallNodeReturnFunction(const CNodePtr &cnode, KernelGraph *graph, |
|
|
|
const std::vector<AnfNodePtr> &real_inputs) { |
|
|
|
MS_EXCEPTION_IF_NULL(cnode); |
|
|
|
if (!(AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimPartial))) { |
|
|
|
MS_LOG(EXCEPTION) << "Node: " << cnode->DebugString() << "is not a partial node."; |
|
|
|
// func1 =switch(branch1, branch2) |
|
|
|
// func2 = func1(param1) |
|
|
|
// out = func2(param2) |
|
|
|
// process the last cnode(func2), not func1 which abstract is AbstractFunction |
|
|
|
if (cnode->abstract()->isa<abstract::AbstractFunction>()) { |
|
|
|
return; |
|
|
|
} |
|
|
|
auto partial_input = cnode->input(kFirstDataInputIndex); |
|
|
|
KernelGraphPtr partial_kernel_graph = GetValueNode<KernelGraphPtr>(partial_input); |
|
|
|
MS_EXCEPTION_IF_NULL(partial_kernel_graph); |
|
|
|
auto ret = partial_kernel_graph->get_return(); |
|
|
|
MS_EXCEPTION_IF_NULL(graph); |
|
|
|
auto ret = graph->get_return(); |
|
|
|
MS_EXCEPTION_IF_NULL(ret); |
|
|
|
auto return_input = ret->input(kFirstDataInputIndex); |
|
|
|
// return node is a function |
|
|
|
std::vector<AnfNodePtr> call_inputs = { |
|
|
|
partial_kernel_graph->NewValueNode(NewValueNode(std::make_shared<Primitive>(prim::kPrimCall->name())))}; |
|
|
|
AnfNodePtr real_kernel_graph; |
|
|
|
graph->NewValueNode(NewValueNode(std::make_shared<Primitive>(prim::kPrimCall->name())))}; |
|
|
|
if (AnfAlgo::CheckPrimitiveType(return_input, prim::kPrimPartial)) { |
|
|
|
auto return_input_cnode = return_input->cast<CNodePtr>(); |
|
|
|
auto partial_inputs = return_input_cnode->inputs(); |
|
|
|
call_inputs.insert(call_inputs.end(), partial_inputs.begin() + kFirstDataInputIndex, partial_inputs.end()); |
|
|
|
real_kernel_graph = partial_inputs[kFirstDataInputIndex]; |
|
|
|
} else { // return node is kernel graph |
|
|
|
} else if (IsValueNode<KernelGraph>(return_input)) { // return node is kernel graph |
|
|
|
call_inputs.emplace_back(return_input); |
|
|
|
real_kernel_graph = return_input; |
|
|
|
} else { // return node is value node |
|
|
|
KernelGraphPtr kernel_graph = NewKernelGraph(); |
|
|
|
auto valid_inputs = kernel_graph->MutableValidInputs(); |
|
|
|
MS_EXCEPTION_IF_NULL(valid_inputs); |
|
|
|
auto graph_inputs = kernel_graph->MutableInputs(); |
|
|
|
MS_EXCEPTION_IF_NULL(graph_inputs); |
|
|
|
std::vector<AnfNodePtr> cnode_inputs = {return_input}; |
|
|
|
for (auto real_input : real_inputs) { |
|
|
|
auto new_parameter = kernel_graph->NewParameter(real_input->abstract()); |
|
|
|
valid_inputs->push_back(true); |
|
|
|
graph_inputs->push_back(new_parameter); |
|
|
|
cnode_inputs.push_back(new_parameter); |
|
|
|
} |
|
|
|
auto new_cnode = kernel_graph->NewCNode(cnode_inputs); |
|
|
|
new_cnode->set_abstract(cnode->abstract()); |
|
|
|
std::vector<AnfNodePtr> return_inputs = { |
|
|
|
kernel_graph->NewValueNode(NewValueNode(std::make_shared<Primitive>(prim::kPrimReturn->name()))), new_cnode}; |
|
|
|
auto return_node = kernel_graph->NewCNode(return_inputs); |
|
|
|
return_node->set_abstract(cnode->abstract()); |
|
|
|
kernel_graph->set_return(return_node); |
|
|
|
call_inputs.push_back(std::make_shared<ValueNode>(kernel_graph)); |
|
|
|
} |
|
|
|
|
|
|
|
// new call node inputs |
|
|
|
for (auto real_input : real_inputs) { |
|
|
|
auto parameter_for_input = CreateNewParameterFromCNode(real_input, partial_kernel_graph.get()); |
|
|
|
auto parameter_for_input = CreateNewParameterFromCNode(real_input, graph); |
|
|
|
call_inputs.emplace_back(parameter_for_input); |
|
|
|
} |
|
|
|
|
|
|
|
auto call_node = partial_kernel_graph->NewCNode(call_inputs); |
|
|
|
// update abstract |
|
|
|
MS_EXCEPTION_IF_NULL(real_kernel_graph); |
|
|
|
if (real_kernel_graph->isa<ValueNode>() && IsValueNode<FuncGraph>(real_kernel_graph)) { |
|
|
|
KernelGraphPtr sub_partial_kernel_graph = GetValueNode<KernelGraphPtr>(real_kernel_graph); |
|
|
|
MS_EXCEPTION_IF_NULL(sub_partial_kernel_graph); |
|
|
|
auto ret_partial = sub_partial_kernel_graph->get_return(); |
|
|
|
call_node->set_abstract(ret_partial->abstract()); |
|
|
|
} |
|
|
|
auto call_node = graph->NewCNode(call_inputs); |
|
|
|
call_node->set_abstract(cnode->abstract()); |
|
|
|
// update return input |
|
|
|
ret->set_input(kFirstDataInputIndex, call_node); |
|
|
|
} |
|
|
|
@@ -1005,36 +1019,33 @@ std::vector<AnfNodePtr> SessionBasic::CreateCallSwitchLayerInputs(const CNodePtr |
|
|
|
for (size_t idx = kFirstDataInputIndex; idx < make_tuple_inputs.size(); idx++) { |
|
|
|
auto partial_idx = make_tuple_inputs[idx]; |
|
|
|
MS_EXCEPTION_IF_NULL(cnode->abstract()); |
|
|
|
std::vector<AnfNodePtr> new_partial_inputs; |
|
|
|
KernelGraphPtr partial_kernel_graph; |
|
|
|
// switch_layer node input is partial cnode |
|
|
|
if (AnfAlgo::CheckPrimitiveType(partial_idx, prim::kPrimPartial)) { |
|
|
|
auto partial_node = partial_idx->cast<CNodePtr>(); |
|
|
|
MS_EXCEPTION_IF_NULL(partial_node); |
|
|
|
// update kernel graph when switch_layer node return function |
|
|
|
auto partial_input = partial_node->input(kFirstDataInputIndex); |
|
|
|
KernelGraphPtr partial_kernel_graph = GetValueNode<KernelGraphPtr>(partial_input); |
|
|
|
MS_EXCEPTION_IF_NULL(partial_kernel_graph); |
|
|
|
auto ret = partial_kernel_graph->get_return(); |
|
|
|
MS_EXCEPTION_IF_NULL(ret); |
|
|
|
auto return_input = ret->input(kFirstDataInputIndex); |
|
|
|
if (AnfAlgo::CheckPrimitiveType(return_input, prim::kPrimPartial) || IsValueNode<KernelGraph>(return_input)) { |
|
|
|
CreateCallNodeReturnFunction(partial_node, real_inputs); |
|
|
|
} |
|
|
|
|
|
|
|
std::vector<AnfNodePtr> new_partial_inputs = partial_node->inputs(); |
|
|
|
new_partial_inputs.insert(new_partial_inputs.end(), real_inputs.begin(), real_inputs.end()); |
|
|
|
auto new_partial = graph->NewCNode(new_partial_inputs); |
|
|
|
new_make_tuple_inputs.emplace_back(new_partial); |
|
|
|
} |
|
|
|
// switch_layer node input is kernel graph value node |
|
|
|
if (IsValueNode<KernelGraph>(partial_idx)) { |
|
|
|
// make_tuple inputs is KernelGraph |
|
|
|
std::vector<AnfNodePtr> new_partial_inputs; |
|
|
|
partial_kernel_graph = GetValueNode<KernelGraphPtr>(partial_input); |
|
|
|
new_partial_inputs = partial_node->inputs(); |
|
|
|
} else if (IsValueNode<KernelGraph>(partial_idx)) { // switch_layer node input is kernel graph value node |
|
|
|
new_partial_inputs.emplace_back(NewValueNode(std::make_shared<Primitive>(prim::kPrimPartial->name()))); |
|
|
|
new_partial_inputs.emplace_back(partial_idx); |
|
|
|
new_partial_inputs.insert(new_partial_inputs.end(), real_inputs.begin(), real_inputs.end()); |
|
|
|
auto new_partial = graph->NewCNode(new_partial_inputs); |
|
|
|
new_make_tuple_inputs.emplace_back(new_partial); |
|
|
|
} |
|
|
|
partial_kernel_graph = GetValueNode<KernelGraphPtr>(partial_idx); |
|
|
|
} |
|
|
|
// when branch in swich_layer return function |
|
|
|
MS_EXCEPTION_IF_NULL(partial_kernel_graph); |
|
|
|
auto ret = partial_kernel_graph->get_return(); |
|
|
|
MS_EXCEPTION_IF_NULL(ret); |
|
|
|
auto return_input = ret->input(kFirstDataInputIndex); |
|
|
|
if (AnfAlgo::CheckPrimitiveType(return_input, prim::kPrimPartial) || return_input->isa<ValueNode>()) { |
|
|
|
CreateCallNodeReturnFunction(cnode, partial_kernel_graph.get(), real_inputs); |
|
|
|
} |
|
|
|
// partial node add input args |
|
|
|
new_partial_inputs.insert(new_partial_inputs.end(), real_inputs.begin(), real_inputs.end()); |
|
|
|
// create new partial node |
|
|
|
auto new_partial = graph->NewCNode(new_partial_inputs); |
|
|
|
new_make_tuple_inputs.emplace_back(new_partial); |
|
|
|
} |
|
|
|
auto new_make_tuple = graph->NewCNode(new_make_tuple_inputs); |
|
|
|
switch_layer_inputs.emplace_back(new_make_tuple); |
|
|
|
|