|
|
|
@@ -928,9 +928,8 @@ std::vector<AnfNodePtr> SessionBasic::CreateCallSwitchInputs(const CNodePtr &cno |
|
|
|
return cnode_inputs; |
|
|
|
} |
|
|
|
|
|
|
|
void SessionBasic::CreateCallNodeReturnFunction(const CNodePtr &cnode, const AnfNodePtr &real_input) { |
|
|
|
void SessionBasic::CreateCallNodeReturnFunction(const CNodePtr &cnode, const std::vector<AnfNodePtr> &real_inputs) { |
|
|
|
MS_EXCEPTION_IF_NULL(cnode); |
|
|
|
MS_EXCEPTION_IF_NULL(real_input); |
|
|
|
if (!(AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimPartial))) { |
|
|
|
MS_LOG(EXCEPTION) << "Node: " << cnode->DebugString() << "is not a partial node."; |
|
|
|
} |
|
|
|
@@ -940,24 +939,37 @@ void SessionBasic::CreateCallNodeReturnFunction(const CNodePtr &cnode, const Anf |
|
|
|
auto ret = partial_kernel_graph->get_return(); |
|
|
|
MS_EXCEPTION_IF_NULL(ret); |
|
|
|
auto return_input = ret->input(kFirstDataInputIndex); |
|
|
|
// if kernel graph return node is a function |
|
|
|
// return node is a function |
|
|
|
std::vector<AnfNodePtr> call_inputs = { |
|
|
|
partial_kernel_graph->NewValueNode(NewValueNode(std::make_shared<Primitive>(prim::kPrimCall->name())))}; |
|
|
|
AnfNodePtr real_kernel_graph; |
|
|
|
if (AnfAlgo::CheckPrimitiveType(return_input, prim::kPrimPartial)) { |
|
|
|
std::vector<AnfNodePtr> call_inputs = { |
|
|
|
partial_kernel_graph->NewValueNode(NewValueNode(std::make_shared<Primitive>(prim::kPrimCall->name())))}; |
|
|
|
auto return_input_cnode = return_input->cast<CNodePtr>(); |
|
|
|
|
|
|
|
auto partial_inputs = return_input_cnode->inputs(); |
|
|
|
call_inputs.insert(call_inputs.end(), partial_inputs.begin() + kFirstDataInputIndex, partial_inputs.end()); |
|
|
|
real_kernel_graph = partial_inputs[kFirstDataInputIndex]; |
|
|
|
} else { // return node is kernel graph |
|
|
|
call_inputs.emplace_back(return_input); |
|
|
|
real_kernel_graph = return_input; |
|
|
|
} |
|
|
|
|
|
|
|
// new call node inputs |
|
|
|
for (auto real_input : real_inputs) { |
|
|
|
auto parameter_for_input = CreateNewParameterFromCNode(real_input, partial_kernel_graph.get()); |
|
|
|
call_inputs.emplace_back(parameter_for_input); |
|
|
|
auto call_node = partial_kernel_graph->NewCNode(call_inputs); |
|
|
|
// update abstract |
|
|
|
KernelGraphPtr sub_partial_kernel_graph = GetValueNode<KernelGraphPtr>(partial_inputs[kFirstDataInputIndex]); |
|
|
|
} |
|
|
|
|
|
|
|
auto call_node = partial_kernel_graph->NewCNode(call_inputs); |
|
|
|
// update abstract |
|
|
|
MS_EXCEPTION_IF_NULL(real_kernel_graph); |
|
|
|
if (real_kernel_graph->isa<ValueNode>() && IsValueNode<FuncGraph>(real_kernel_graph)) { |
|
|
|
KernelGraphPtr sub_partial_kernel_graph = GetValueNode<KernelGraphPtr>(real_kernel_graph); |
|
|
|
MS_EXCEPTION_IF_NULL(sub_partial_kernel_graph); |
|
|
|
auto ret_partial = sub_partial_kernel_graph->get_return(); |
|
|
|
call_node->set_abstract(ret_partial->abstract()); |
|
|
|
// update return input |
|
|
|
ret->set_input(kFirstDataInputIndex, call_node); |
|
|
|
} |
|
|
|
// update return input |
|
|
|
ret->set_input(kFirstDataInputIndex, call_node); |
|
|
|
} |
|
|
|
|
|
|
|
std::vector<AnfNodePtr> SessionBasic::CreateCallSwitchLayerInputs(const CNodePtr &cnode, KernelGraph *graph) { |
|
|
|
@@ -977,9 +989,11 @@ std::vector<AnfNodePtr> SessionBasic::CreateCallSwitchLayerInputs(const CNodePtr |
|
|
|
auto node = make_tuple_node->cast<CNodePtr>(); |
|
|
|
MS_EXCEPTION_IF_NULL(node); |
|
|
|
auto make_tuple_inputs = node->inputs(); |
|
|
|
// there is real input in call, should put it to make_tuple in switch_layer |
|
|
|
auto real_input = cnode->input(kFirstDataInputIndex); |
|
|
|
auto real_input_back = graph->GetBackendAnfByFrontAnf(real_input); |
|
|
|
// there are real inputs in call, should put it to make_tuple in switch_layer |
|
|
|
std::vector<AnfNodePtr> real_inputs; |
|
|
|
for (size_t idx = kFirstDataInputIndex; idx < cnode->inputs().size(); ++idx) { |
|
|
|
real_inputs.emplace_back(graph->GetBackendAnfByFrontAnf(cnode->input(idx))); |
|
|
|
} |
|
|
|
std::vector<AnfNodePtr> new_make_tuple_inputs = { |
|
|
|
graph->NewValueNode(NewValueNode(std::make_shared<Primitive>(prim::kPrimMakeTuple->name())))}; |
|
|
|
for (size_t idx = kFirstDataInputIndex; idx < make_tuple_inputs.size(); idx++) { |
|
|
|
@@ -990,10 +1004,18 @@ std::vector<AnfNodePtr> SessionBasic::CreateCallSwitchLayerInputs(const CNodePtr |
|
|
|
auto partial_node = partial_idx->cast<CNodePtr>(); |
|
|
|
MS_EXCEPTION_IF_NULL(partial_node); |
|
|
|
// update kernel graph when switch_layer node return function |
|
|
|
CreateCallNodeReturnFunction(partial_node, real_input_back); |
|
|
|
auto partial_input = partial_node->input(kFirstDataInputIndex); |
|
|
|
KernelGraphPtr partial_kernel_graph = GetValueNode<KernelGraphPtr>(partial_input); |
|
|
|
MS_EXCEPTION_IF_NULL(partial_kernel_graph); |
|
|
|
auto ret = partial_kernel_graph->get_return(); |
|
|
|
MS_EXCEPTION_IF_NULL(ret); |
|
|
|
auto return_input = ret->input(kFirstDataInputIndex); |
|
|
|
if (AnfAlgo::CheckPrimitiveType(return_input, prim::kPrimPartial) || IsValueNode<KernelGraph>(return_input)) { |
|
|
|
CreateCallNodeReturnFunction(partial_node, real_inputs); |
|
|
|
} |
|
|
|
|
|
|
|
std::vector<AnfNodePtr> new_partial_inputs = partial_node->inputs(); |
|
|
|
new_partial_inputs.emplace_back(real_input_back); |
|
|
|
new_partial_inputs.insert(new_partial_inputs.end(), real_inputs.begin(), real_inputs.end()); |
|
|
|
auto new_partial = graph->NewCNode(new_partial_inputs); |
|
|
|
new_make_tuple_inputs.emplace_back(new_partial); |
|
|
|
} |
|
|
|
@@ -1003,7 +1025,7 @@ std::vector<AnfNodePtr> SessionBasic::CreateCallSwitchLayerInputs(const CNodePtr |
|
|
|
std::vector<AnfNodePtr> new_partial_inputs; |
|
|
|
new_partial_inputs.emplace_back(NewValueNode(std::make_shared<Primitive>(prim::kPrimPartial->name()))); |
|
|
|
new_partial_inputs.emplace_back(partial_idx); |
|
|
|
new_partial_inputs.emplace_back(real_input_back); |
|
|
|
new_partial_inputs.insert(new_partial_inputs.end(), real_inputs.begin(), real_inputs.end()); |
|
|
|
auto new_partial = graph->NewCNode(new_partial_inputs); |
|
|
|
new_make_tuple_inputs.emplace_back(new_partial); |
|
|
|
} |
|
|
|
|