|
|
@@ -497,7 +497,50 @@ CNodePtr SessionBasic::CreateNewCNode(const CNodePtr &cnode, bool valid_input, K |
|
|
return new_cnode; |
|
|
return new_cnode; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
static std::vector<AnfNodePtr> CreateSwitchOrPartialNode(const CNodePtr &cnode, KernelGraph *graph) { |
|
|
|
|
|
|
|
|
CNodePtr SessionBasic::CreateSwitchInput(const AnfNodePtr &node_input, KernelGraph *graph) { |
|
|
|
|
|
MS_EXCEPTION_IF_NULL(node_input); |
|
|
|
|
|
MS_EXCEPTION_IF_NULL(graph); |
|
|
|
|
|
// switch input generalizes partial |
|
|
|
|
|
if (AnfAlgo::CheckPrimitiveType(node_input, prim::kPrimPartial) || |
|
|
|
|
|
AnfAlgo::CheckPrimitiveType(node_input, prim::kPrimCall)) { |
|
|
|
|
|
return node_input->cast<CNodePtr>(); |
|
|
|
|
|
} |
|
|
|
|
|
if (node_input->isa<CNode>()) { |
|
|
|
|
|
MS_LOG(EXCEPTION) << "If switch input is " << node_input->DebugString() << ", it mast be partial or call."; |
|
|
|
|
|
} |
|
|
|
|
|
std::vector<AnfNodePtr> partial_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimPartial->name()))}; |
|
|
|
|
|
if (node_input->isa<ValueNode>() && IsValueNode<FuncGraph>(node_input)) { |
|
|
|
|
|
partial_inputs.emplace_back(node_input); |
|
|
|
|
|
auto partial_node = graph->NewCNode(partial_inputs); |
|
|
|
|
|
return partial_node; |
|
|
|
|
|
} |
|
|
|
|
|
KernelGraphPtr kernel_graph = NewKernelGraph(); |
|
|
|
|
|
MS_EXCEPTION_IF_NULL(kernel_graph); |
|
|
|
|
|
kernel_graph->set_output(graph->GetBackendAnfByFrontAnf(node_input)); |
|
|
|
|
|
partial_inputs.emplace_back(std::make_shared<ValueNode>(kernel_graph)); |
|
|
|
|
|
auto partial_node = graph->NewCNode(partial_inputs); |
|
|
|
|
|
return partial_node; |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
CNodePtr SessionBasic::HandleSwitchInputs(const AnfNodePtr &anf_node, KernelGraph *graph) { |
|
|
|
|
|
MS_EXCEPTION_IF_NULL(anf_node); |
|
|
|
|
|
MS_EXCEPTION_IF_NULL(graph); |
|
|
|
|
|
auto node = anf_node->cast<CNodePtr>(); |
|
|
|
|
|
MS_EXCEPTION_IF_NULL(node); |
|
|
|
|
|
if (node->inputs().size() < kSwitchInputSize) { |
|
|
|
|
|
MS_LOG(EXCEPTION) << "Switch input size less than " << kSwitchInputSize; |
|
|
|
|
|
} |
|
|
|
|
|
auto primitive = NewValueNode(std::make_shared<Primitive>(prim::kPrimSwitch->name())); |
|
|
|
|
|
std::vector<AnfNodePtr> switch_inputs = {primitive, node->input(1)}; |
|
|
|
|
|
for (size_t index = 2; index < node->inputs().size(); index++) { |
|
|
|
|
|
auto input = CreateSwitchInput(node->input(index), graph); |
|
|
|
|
|
switch_inputs.emplace_back(input); |
|
|
|
|
|
} |
|
|
|
|
|
auto switch_node = graph->NewCNode(switch_inputs); |
|
|
|
|
|
return switch_node; |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
std::vector<AnfNodePtr> SessionBasic::CreateSwitchOrPartialNode(const CNodePtr &cnode, KernelGraph *graph) { |
|
|
MS_EXCEPTION_IF_NULL(cnode); |
|
|
MS_EXCEPTION_IF_NULL(cnode); |
|
|
MS_EXCEPTION_IF_NULL(graph); |
|
|
MS_EXCEPTION_IF_NULL(graph); |
|
|
// create primitive of cnode:call(partial or switch) |
|
|
// create primitive of cnode:call(partial or switch) |
|
|
@@ -522,7 +565,8 @@ static std::vector<AnfNodePtr> CreateSwitchOrPartialNode(const CNodePtr &cnode, |
|
|
}); |
|
|
}); |
|
|
return cnode_inputs; |
|
|
return cnode_inputs; |
|
|
} else if (AnfAlgo::CheckPrimitiveType(cnode_input, prim::kPrimSwitch)) { |
|
|
} else if (AnfAlgo::CheckPrimitiveType(cnode_input, prim::kPrimSwitch)) { |
|
|
cnode_inputs.emplace_back(cnode_input); |
|
|
|
|
|
|
|
|
auto switch_node = HandleSwitchInputs(cnode_input, graph); |
|
|
|
|
|
cnode_inputs.emplace_back(switch_node); |
|
|
return cnode_inputs; |
|
|
return cnode_inputs; |
|
|
} |
|
|
} |
|
|
MS_LOG(EXCEPTION) << "CNode input[0] must be partial or switch."; |
|
|
MS_LOG(EXCEPTION) << "CNode input[0] must be partial or switch."; |
|
|
|