diff --git a/mindspore/ccsrc/backend/optimizer/graph_kernel/arithmetic_simplify.cc b/mindspore/ccsrc/backend/optimizer/graph_kernel/arithmetic_simplify.cc index 58e1301c52..ef06307ce4 100644 --- a/mindspore/ccsrc/backend/optimizer/graph_kernel/arithmetic_simplify.cc +++ b/mindspore/ccsrc/backend/optimizer/graph_kernel/arithmetic_simplify.cc @@ -64,11 +64,13 @@ AnfNodePtr SimplifyAdd(const AnfNodePtr &node) { // A + 0 = A MATCH_REPLACE(node, x + zero_num, x); // A*C + B*C = (A + B)*C - MATCH_REPLACE_LAMBDA(node, (x * any_const) + (y * any_const), add_distri_lambda); + MATCH_REPLACE_LAMBDA_IF(node, (x * any_const) + (y * any_const_2), add_distri_lambda, + PIsEqual()(any_const.GetNode(node), any_const_2.GetNode(node))); // (A + C1) + C2 = A + (C1 + C2) MATCH_REPLACE_LAMBDA(node, (x + any_const) + any_const_2, add_union_lambda); // A + (-A) = 0 - MATCH_REPLACE(node, x + PUnaryOperation(prim::kPrimNeg, x), zero_scalar.NewValue()); + MATCH_REPLACE_IF(node, x + PUnaryOperation(prim::kPrimNeg, y), zero_scalar.NewValue(), + PIsEqual()(x.GetNode(node), y.GetNode(node))); return nullptr; }