|
|
|
@@ -64,11 +64,13 @@ AnfNodePtr SimplifyAdd(const AnfNodePtr &node) { |
|
|
|
// A + 0 = A |
|
|
|
MATCH_REPLACE(node, x + zero_num, x); |
|
|
|
// A*C + B*C = (A + B)*C |
|
|
|
MATCH_REPLACE_LAMBDA(node, (x * any_const) + (y * any_const), add_distri_lambda); |
|
|
|
MATCH_REPLACE_LAMBDA_IF(node, (x * any_const) + (y * any_const_2), add_distri_lambda, |
|
|
|
PIsEqual<AnfNodePtr>()(any_const.GetNode(node), any_const_2.GetNode(node))); |
|
|
|
// (A + C1) + C2 = A + (C1 + C2) |
|
|
|
MATCH_REPLACE_LAMBDA(node, (x + any_const) + any_const_2, add_union_lambda); |
|
|
|
// A + (-A) = 0 |
|
|
|
MATCH_REPLACE(node, x + PUnaryOperation(prim::kPrimNeg, x), zero_scalar.NewValue()); |
|
|
|
MATCH_REPLACE_IF(node, x + PUnaryOperation(prim::kPrimNeg, y), zero_scalar.NewValue(), |
|
|
|
PIsEqual<AnfNodePtr>()(x.GetNode(node), y.GetNode(node))); |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
|
|
|
|
|