|
|
|
@@ -606,6 +606,10 @@ std::vector<AnfNodePtr> SessionBasic::CreateSwitchOrPartialNode(const CNodePtr & |
|
|
|
} else if (AnfAlgo::CheckPrimitiveType(cnode_input, prim::kPrimSwitch)) { |
|
|
|
auto switch_cnode = cnode_input->cast<CNodePtr>(); |
|
|
|
MS_EXCEPTION_IF_NULL(switch_cnode); |
|
|
|
if (cnode->inputs().size() < 2) { |
|
|
|
cnode_inputs = switch_cnode->inputs(); |
|
|
|
return cnode_inputs; |
|
|
|
} |
|
|
|
std::vector<AnfNodePtr> switch_inputs = {switch_cnode->input(kAnfPrimitiveIndex), |
|
|
|
switch_cnode->input(kFirstDataInputIndex)}; |
|
|
|
for (size_t index = kFirstBranchInSwitch; index < switch_cnode->inputs().size(); index++) { |
|
|
|
@@ -630,7 +634,7 @@ std::vector<AnfNodePtr> SessionBasic::CreateSwitchOrPartialNode(const CNodePtr & |
|
|
|
MS_LOG(EXCEPTION) << "CNode input[0] must be partial or switch."; |
|
|
|
} |
|
|
|
|
|
|
|
CNodePtr SessionBasic::CreateNewCNode(CNodePtr cnode, KernelGraph *graph) { |
|
|
|
std::vector<AnfNodePtr> SessionBasic::CreateValueNode(const CNodePtr &cnode, KernelGraph *graph) { |
|
|
|
MS_EXCEPTION_IF_NULL(cnode); |
|
|
|
MS_EXCEPTION_IF_NULL(graph); |
|
|
|
std::vector<AnfNodePtr> cnode_inputs; |
|
|
|
@@ -641,7 +645,7 @@ CNodePtr SessionBasic::CreateNewCNode(CNodePtr cnode, KernelGraph *graph) { |
|
|
|
MS_EXCEPTION_IF_NULL(fg); |
|
|
|
auto new_fg = BasicClone(fg); |
|
|
|
cnode_inputs.push_back(std::make_shared<ValueNode>(new_fg)); |
|
|
|
} else if (IsValueNode<FuncGraph>(attr_input)) { |
|
|
|
} else { |
|
|
|
// create primitive of cnode:call |
|
|
|
cnode_inputs = {graph->NewValueNode(NewValueNode(std::make_shared<Primitive>(prim::kPrimCall->name())))}; |
|
|
|
// create a ValueNode<KernelGraph> as input of cnode:call |
|
|
|
@@ -653,38 +657,27 @@ CNodePtr SessionBasic::CreateNewCNode(CNodePtr cnode, KernelGraph *graph) { |
|
|
|
cnode_inputs.emplace_back(new_value_node); |
|
|
|
} |
|
|
|
} |
|
|
|
} else if (attr_input->isa<CNode>()) { |
|
|
|
auto cnode_input = graph->GetBackendAnfByFrontAnf(attr_input); |
|
|
|
if (cnode->inputs().size() < 2 && AnfAlgo::CheckPrimitiveType(cnode_input, prim::kPrimSwitch)) { |
|
|
|
auto switch_cnode = cnode_input->cast<CNodePtr>(); |
|
|
|
MS_EXCEPTION_IF_NULL(switch_cnode); |
|
|
|
cnode_inputs = switch_cnode->inputs(); |
|
|
|
} else { |
|
|
|
cnode_inputs = CreateSwitchOrPartialNode(cnode, graph); |
|
|
|
} |
|
|
|
} else if (AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimSwitch)) { |
|
|
|
cnode_inputs = {graph->GetBackendAnfByFrontAnf(cnode->input(kAnfPrimitiveIndex)), |
|
|
|
graph->GetBackendAnfByFrontAnf(cnode->input(kFirstDataInputIndex))}; |
|
|
|
} |
|
|
|
return cnode_inputs; |
|
|
|
} |
|
|
|
|
|
|
|
void SessionBasic::CreateCNodeInputs(const CNodePtr &cnode, KernelGraph *graph, std::vector<AnfNodePtr> *cnode_inputs) { |
|
|
|
MS_EXCEPTION_IF_NULL(cnode); |
|
|
|
MS_EXCEPTION_IF_NULL(graph); |
|
|
|
if (AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimSwitch)) { |
|
|
|
cnode_inputs->emplace_back(graph->GetBackendAnfByFrontAnf(cnode->input(kFirstDataInputIndex))); |
|
|
|
for (size_t index = kFirstBranchInSwitch; index < cnode->inputs().size(); index++) { |
|
|
|
auto node_input = cnode->input(index); |
|
|
|
auto switch_input = CreateSwitchInput(node_input, graph); |
|
|
|
cnode_inputs.emplace_back(switch_input); |
|
|
|
cnode_inputs->emplace_back(switch_input); |
|
|
|
} |
|
|
|
} else { |
|
|
|
// get primitive of old node |
|
|
|
auto prim = AnfAlgo::GetCNodePrimitive(cnode); |
|
|
|
MS_EXCEPTION_IF_NULL(prim); |
|
|
|
// push attr to inputs[0] of new cnode |
|
|
|
cnode_inputs = {graph->NewValueNode(NewValueNode(std::make_shared<Primitive>(*prim)))}; |
|
|
|
} |
|
|
|
|
|
|
|
if (!AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimSwitch)) { |
|
|
|
for (size_t input_idx = kFirstDataInputIndex; input_idx < cnode->inputs().size(); input_idx++) { |
|
|
|
auto anf = cnode->input(input_idx); |
|
|
|
MS_EXCEPTION_IF_NULL(anf); |
|
|
|
// anf has been created before |
|
|
|
if (graph->GetBackendAnfByFrontAnf(anf) != nullptr) { |
|
|
|
cnode_inputs.emplace_back(graph->GetBackendAnfByFrontAnf(anf)); |
|
|
|
cnode_inputs->emplace_back(graph->GetBackendAnfByFrontAnf(anf)); |
|
|
|
continue; |
|
|
|
} else if (IsValueNode<None>(anf)) { |
|
|
|
continue; |
|
|
|
@@ -692,6 +685,32 @@ CNodePtr SessionBasic::CreateNewCNode(CNodePtr cnode, KernelGraph *graph) { |
|
|
|
MS_LOG(EXCEPTION) << "Unexpected input[" << anf->DebugString() << "]"; |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
CNodePtr SessionBasic::CreateNewCNode(CNodePtr cnode, KernelGraph *graph) { |
|
|
|
MS_EXCEPTION_IF_NULL(cnode); |
|
|
|
MS_EXCEPTION_IF_NULL(graph); |
|
|
|
std::vector<AnfNodePtr> cnode_inputs; |
|
|
|
auto attr_input = cnode->input(kAnfPrimitiveIndex); |
|
|
|
MS_EXCEPTION_IF_NULL(attr_input); |
|
|
|
if (IsValueNode<FuncGraph>(attr_input)) { |
|
|
|
// cnode is a graph or a call |
|
|
|
cnode_inputs = CreateValueNode(cnode, graph); |
|
|
|
} else if (attr_input->isa<CNode>()) { |
|
|
|
// cnode ia a call (partial/switch/switch_layer) |
|
|
|
// 1. take the args of call to the partial node, as the real_args to call switch's or switch_layer's child graph |
|
|
|
// 2. the call in frontend is map to the partial/switch/switch_layer in backend and haven't been created |
|
|
|
cnode_inputs = CreateSwitchOrPartialNode(cnode, graph); |
|
|
|
} else { |
|
|
|
// get primitive of old node |
|
|
|
auto prim = AnfAlgo::GetCNodePrimitive(cnode); |
|
|
|
MS_EXCEPTION_IF_NULL(prim); |
|
|
|
// push attr to inputs[0] of new cnode |
|
|
|
cnode_inputs = {graph->NewValueNode(NewValueNode(std::make_shared<Primitive>(*prim)))}; |
|
|
|
} |
|
|
|
// handle inputs of cnode except primitive |
|
|
|
CreateCNodeInputs(cnode, graph, &cnode_inputs); |
|
|
|
|
|
|
|
TraceManager::DebugTrace(std::make_shared<TraceCopy>(cnode->debug_info())); |
|
|
|
auto new_cnode = graph->NewCNode(cnode_inputs); |
|
|
|
TraceManager::EndTrace(); |
|
|
|
|