diff --git a/mindspore/ccsrc/backend/optimizer/graph_kernel/arithmetic_simplify.cc b/mindspore/ccsrc/backend/optimizer/graph_kernel/arithmetic_simplify.cc index 511bdbccaf..83df87c285 100644 --- a/mindspore/ccsrc/backend/optimizer/graph_kernel/arithmetic_simplify.cc +++ b/mindspore/ccsrc/backend/optimizer/graph_kernel/arithmetic_simplify.cc @@ -258,6 +258,11 @@ AnfNodePtr SimplifyMul(const AnfNodePtr &node) { auto new_cnode = NewCNodeWithInfo({NewValueNode(prim::kPrimMul), new_lhs, new_rhs}, node); return new_cnode; }; + auto const_dup_lambda2 = [&node, &x, &const_1, &const_2]() -> AnfNodePtr { + auto new_rhs = const_1.MulByPatternConst(const_2, x.GetNode(node)); + auto new_cnode = NewCNodeWithInfo({NewValueNode(prim::kPrimMul), x.GetNode(node), new_rhs}, node); + return new_cnode; + }; auto exp_merge_lambda = [&node, &x, &y]() -> AnfNodePtr { auto node_tmp = NewCNodeWithInfo({NewValueNode(prim::kPrimTensorAdd), x.GetNode(node), y.GetNode(node)}, node); auto new_cnode = NewCNodeWithInfo({NewValueNode(prim::kPrimExp), node_tmp}, node); @@ -283,6 +288,8 @@ AnfNodePtr SimplifyMul(const AnfNodePtr &node) { }; // (x*C1)*(y*C2) ==> (x*y)*(C1*C2) MATCH_REPLACE_LAMBDA(node, (const_1 * x) * (const_2 * y), const_dup_lambda); + // (x*C1)*C2 ==> x*(C1*C2) + MATCH_REPLACE_LAMBDA(node, (const_1 * x) * const_2, const_dup_lambda2); // exp(x)*exp(y) ==> exp(x+y) MATCH_REPLACE_LAMBDA(node, PUnaryOperation(prim::kPrimExp, x) * PUnaryOperation(prim::kPrimExp, y), exp_merge_lambda); // sqrt(x)*sqrt(x) ==> x