|
|
|
@@ -98,12 +98,19 @@ class ConvertSwitchReplacement : public OptimizerCaller { |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
|
|
|
|
auto cnode_ = node->cast<CNodePtr>(); |
|
|
|
if (cnode_->size() < 1) { |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
|
|
|
|
auto node_ = cnode_->input(0); |
|
|
|
|
|
|
|
PatternNode<AnfNodePtr> cond, true_br, false_br; |
|
|
|
|
|
|
|
auto ConvertSwitchLambda = [&node, &cond, &true_br, &false_br]() -> AnfNodePtr { |
|
|
|
auto g1_ = GetValueNode<FuncGraphPtr>(true_br.GetNode(node)); |
|
|
|
auto g2_ = GetValueNode<FuncGraphPtr>(false_br.GetNode(node)); |
|
|
|
auto x_ = cond.GetNode(node); |
|
|
|
auto ConvertSwitchLambda = [&node_, &cond, &true_br, &false_br]() -> AnfNodePtr { |
|
|
|
auto g1_ = GetValueNode<FuncGraphPtr>(true_br.GetNode(node_)); |
|
|
|
auto g2_ = GetValueNode<FuncGraphPtr>(false_br.GetNode(node_)); |
|
|
|
auto x_ = cond.GetNode(node_); |
|
|
|
|
|
|
|
// for switch replace method, only graphs without graph inside can be replaced |
|
|
|
for (auto &item : g1_->value_nodes()) { |
|
|
|
@@ -126,7 +133,7 @@ class ConvertSwitchReplacement : public OptimizerCaller { |
|
|
|
auto trans_g2 = internal::TransformGraphCondFalseBranchNodes(g2_, x_); |
|
|
|
|
|
|
|
std::vector<AnfNodePtr> params; |
|
|
|
auto fg = node->func_graph(); |
|
|
|
auto fg = node_->func_graph(); |
|
|
|
auto cloned_g1 = InlineClone(trans_g1, fg, params); |
|
|
|
auto cloned_g2 = InlineClone(trans_g2, fg, params); |
|
|
|
auto nnode = internal::TransformMergeBranches(cloned_g1, cloned_g2, true_output, false_output, x_, fg); |
|
|
|
@@ -135,8 +142,8 @@ class ConvertSwitchReplacement : public OptimizerCaller { |
|
|
|
}; |
|
|
|
|
|
|
|
MATCH_REPLACE_LAMBDA_IF( |
|
|
|
node, PCNode(PPrimitive(prim::kPrimSwitch, cond, true_br, false_br)).MinExtraNodes(0), ConvertSwitchLambda, |
|
|
|
true_br.CheckFunc(IsValueNode<FuncGraph>, node) && false_br.CheckFunc(IsValueNode<FuncGraph>, node)); |
|
|
|
node_, PPrimitive(prim::kPrimSwitch, cond, true_br, false_br), ConvertSwitchLambda, |
|
|
|
true_br.CheckFunc(IsValueNode<FuncGraph>, node_) && false_br.CheckFunc(IsValueNode<FuncGraph>, node_)); |
|
|
|
|
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
|