Browse Source

enable-the-identity-op-can-be-eliminated-in-pynative-mode

tags/v0.7.0-beta
lvliang 5 years ago
parent
commit
038bd6cbf1
1 changed files with 9 additions and 10 deletions
  1. +9
    -10
      mindspore/ccsrc/frontend/optimizer/irpass/arithmetic_simplify.cc

+ 9
- 10
mindspore/ccsrc/frontend/optimizer/irpass/arithmetic_simplify.cc View File

@@ -20,9 +20,6 @@ namespace mindspore {
namespace opt { namespace opt {
namespace irpass { namespace irpass {
AnfNodePtr ArithmeticSimplify::operator()(const OptimizerPtr &, const AnfNodePtr &node) { AnfNodePtr ArithmeticSimplify::operator()(const OptimizerPtr &, const AnfNodePtr &node) {
if (MsContext::GetInstance()->execution_mode() == kPynativeMode) {
return nullptr;
}
PatternNode x, y, z, xs; PatternNode x, y, z, xs;
PConstant one_(node, false, 1); PConstant one_(node, false, 1);
PConstant one_scalar_(node, false, 1, true); PConstant one_scalar_(node, false, 1, true);
@@ -32,14 +29,16 @@ AnfNodePtr ArithmeticSimplify::operator()(const OptimizerPtr &, const AnfNodePtr
PConstant const_2(node); PConstant const_2(node);
PConstant any_const(node); PConstant any_const(node);


MATCH_REPLACE(node, x + zero_, x); // Add by zero
MATCH_REPLACE(node, x + zero_scalar_, x); // Add by zero
MATCH_REPLACE(node, PBinOperation(prim::kPrimScalarAdd, x, zero_scalar_, true), x); // Scalar Add by zero
MATCH_REPLACE_IF(node, x * one_, any_const.WithValueOf(x), !one_.CheckFunc(IsParam, node)); // Multiply by one
MATCH_REPLACE(node, PBinOperation(prim::kPrimScalarMul, x, one_scalar_, true), x); // Scalar Mul by one
if (MsContext::GetInstance()->execution_mode() != kPynativeMode) {
MATCH_REPLACE(node, x + zero_, x); // Add by zero
MATCH_REPLACE(node, x + zero_scalar_, x); // Add by zero
MATCH_REPLACE(node, PBinOperation(prim::kPrimScalarAdd, x, zero_scalar_, true), x); // Scalar Add by zero
MATCH_REPLACE_IF(node, x * one_, any_const.WithValueOf(x), !one_.CheckFunc(IsParam, node)); // Multiply by one
MATCH_REPLACE(node, PBinOperation(prim::kPrimScalarMul, x, one_scalar_, true), x); // Scalar Mul by one


// Scalar Mul by zero
MATCH_REPLACE(node, PBinOperation(prim::kPrimScalarMul, x, zero_scalar_, true), zero_scalar_.NewValue());
// Scalar Mul by zero
MATCH_REPLACE(node, PBinOperation(prim::kPrimScalarMul, x, zero_scalar_, true), zero_scalar_.NewValue());
}
// Prim Eliminate (identity) // Prim Eliminate (identity)
MATCH_REPLACE(node, PPrimitive(prim::kPrimIdentity, x), x); MATCH_REPLACE(node, PPrimitive(prim::kPrimIdentity, x), x);




Loading…
Cancel
Save