| @@ -98,12 +98,19 @@ class ConvertSwitchReplacement : public OptimizerCaller { | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| auto cnode_ = node->cast<CNodePtr>(); | |||||
| if (cnode_->size() < 1) { | |||||
| return nullptr; | |||||
| } | |||||
| auto node_ = cnode_->input(0); | |||||
| PatternNode<AnfNodePtr> cond, true_br, false_br; | PatternNode<AnfNodePtr> cond, true_br, false_br; | ||||
| auto ConvertSwitchLambda = [&node, &cond, &true_br, &false_br]() -> AnfNodePtr { | |||||
| auto g1_ = GetValueNode<FuncGraphPtr>(true_br.GetNode(node)); | |||||
| auto g2_ = GetValueNode<FuncGraphPtr>(false_br.GetNode(node)); | |||||
| auto x_ = cond.GetNode(node); | |||||
| auto ConvertSwitchLambda = [&node_, &cond, &true_br, &false_br]() -> AnfNodePtr { | |||||
| auto g1_ = GetValueNode<FuncGraphPtr>(true_br.GetNode(node_)); | |||||
| auto g2_ = GetValueNode<FuncGraphPtr>(false_br.GetNode(node_)); | |||||
| auto x_ = cond.GetNode(node_); | |||||
| // for switch replace method, only graphs without graph inside can be replaced | // for switch replace method, only graphs without graph inside can be replaced | ||||
| for (auto &item : g1_->value_nodes()) { | for (auto &item : g1_->value_nodes()) { | ||||
| @@ -126,7 +133,7 @@ class ConvertSwitchReplacement : public OptimizerCaller { | |||||
| auto trans_g2 = internal::TransformGraphCondFalseBranchNodes(g2_, x_); | auto trans_g2 = internal::TransformGraphCondFalseBranchNodes(g2_, x_); | ||||
| std::vector<AnfNodePtr> params; | std::vector<AnfNodePtr> params; | ||||
| auto fg = node->func_graph(); | |||||
| auto fg = node_->func_graph(); | |||||
| auto cloned_g1 = InlineClone(trans_g1, fg, params); | auto cloned_g1 = InlineClone(trans_g1, fg, params); | ||||
| auto cloned_g2 = InlineClone(trans_g2, fg, params); | auto cloned_g2 = InlineClone(trans_g2, fg, params); | ||||
| auto nnode = internal::TransformMergeBranches(cloned_g1, cloned_g2, true_output, false_output, x_, fg); | auto nnode = internal::TransformMergeBranches(cloned_g1, cloned_g2, true_output, false_output, x_, fg); | ||||
| @@ -135,8 +142,8 @@ class ConvertSwitchReplacement : public OptimizerCaller { | |||||
| }; | }; | ||||
| MATCH_REPLACE_LAMBDA_IF( | MATCH_REPLACE_LAMBDA_IF( | ||||
| node, PCNode(PPrimitive(prim::kPrimSwitch, cond, true_br, false_br)).MinExtraNodes(0), ConvertSwitchLambda, | |||||
| true_br.CheckFunc(IsValueNode<FuncGraph>, node) && false_br.CheckFunc(IsValueNode<FuncGraph>, node)); | |||||
| node_, PPrimitive(prim::kPrimSwitch, cond, true_br, false_br), ConvertSwitchLambda, | |||||
| true_br.CheckFunc(IsValueNode<FuncGraph>, node_) && false_br.CheckFunc(IsValueNode<FuncGraph>, node_)); | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||