Browse Source

!7750 Add a simplification pattern to GraphKernel's ArithSimplify.

Merge pull request !7750 from DeshiChen/1026_simplify_mul
tags/v1.1.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
ac3a82006c
1 changed files with 7 additions and 0 deletions
  1. +7
    -0
      mindspore/ccsrc/backend/optimizer/graph_kernel/arithmetic_simplify.cc

+ 7
- 0
mindspore/ccsrc/backend/optimizer/graph_kernel/arithmetic_simplify.cc View File

@@ -258,6 +258,11 @@ AnfNodePtr SimplifyMul(const AnfNodePtr &node) {
auto new_cnode = NewCNodeWithInfo({NewValueNode(prim::kPrimMul), new_lhs, new_rhs}, node);
return new_cnode;
};
auto const_dup_lambda2 = [&node, &x, &const_1, &const_2]() -> AnfNodePtr {
auto new_rhs = const_1.MulByPatternConst(const_2, x.GetNode(node));
auto new_cnode = NewCNodeWithInfo({NewValueNode(prim::kPrimMul), x.GetNode(node), new_rhs}, node);
return new_cnode;
};
auto exp_merge_lambda = [&node, &x, &y]() -> AnfNodePtr {
auto node_tmp = NewCNodeWithInfo({NewValueNode(prim::kPrimTensorAdd), x.GetNode(node), y.GetNode(node)}, node);
auto new_cnode = NewCNodeWithInfo({NewValueNode(prim::kPrimExp), node_tmp}, node);
@@ -283,6 +288,8 @@ AnfNodePtr SimplifyMul(const AnfNodePtr &node) {
};
// (x*C1)*(y*C2) ==> (x*y)*(C1*C2)
MATCH_REPLACE_LAMBDA(node, (const_1 * x) * (const_2 * y), const_dup_lambda);
// (x*C1)*C2 ==> x*(C1*C2)
MATCH_REPLACE_LAMBDA(node, (const_1 * x) * const_2, const_dup_lambda2);
// exp(x)*exp(y) ==> exp(x+y)
MATCH_REPLACE_LAMBDA(node, PUnaryOperation(prim::kPrimExp, x) * PUnaryOperation(prim::kPrimExp, y), exp_merge_lambda);
// sqrt(x)*sqrt(x) ==> x


Loading…
Cancel
Save