| @@ -20,9 +20,6 @@ namespace mindspore { | |||||
| namespace opt { | namespace opt { | ||||
| namespace irpass { | namespace irpass { | ||||
| AnfNodePtr ArithmeticSimplify::operator()(const OptimizerPtr &, const AnfNodePtr &node) { | AnfNodePtr ArithmeticSimplify::operator()(const OptimizerPtr &, const AnfNodePtr &node) { | ||||
| if (MsContext::GetInstance()->execution_mode() == kPynativeMode) { | |||||
| return nullptr; | |||||
| } | |||||
| PatternNode x, y, z, xs; | PatternNode x, y, z, xs; | ||||
| PConstant one_(node, false, 1); | PConstant one_(node, false, 1); | ||||
| PConstant one_scalar_(node, false, 1, true); | PConstant one_scalar_(node, false, 1, true); | ||||
| @@ -32,14 +29,16 @@ AnfNodePtr ArithmeticSimplify::operator()(const OptimizerPtr &, const AnfNodePtr | |||||
| PConstant const_2(node); | PConstant const_2(node); | ||||
| PConstant any_const(node); | PConstant any_const(node); | ||||
| MATCH_REPLACE(node, x + zero_, x); // Add by zero | |||||
| MATCH_REPLACE(node, x + zero_scalar_, x); // Add by zero | |||||
| MATCH_REPLACE(node, PBinOperation(prim::kPrimScalarAdd, x, zero_scalar_, true), x); // Scalar Add by zero | |||||
| MATCH_REPLACE_IF(node, x * one_, any_const.WithValueOf(x), !one_.CheckFunc(IsParam, node)); // Multiply by one | |||||
| MATCH_REPLACE(node, PBinOperation(prim::kPrimScalarMul, x, one_scalar_, true), x); // Scalar Mul by one | |||||
| if (MsContext::GetInstance()->execution_mode() != kPynativeMode) { | |||||
| MATCH_REPLACE(node, x + zero_, x); // Add by zero | |||||
| MATCH_REPLACE(node, x + zero_scalar_, x); // Add by zero | |||||
| MATCH_REPLACE(node, PBinOperation(prim::kPrimScalarAdd, x, zero_scalar_, true), x); // Scalar Add by zero | |||||
| MATCH_REPLACE_IF(node, x * one_, any_const.WithValueOf(x), !one_.CheckFunc(IsParam, node)); // Multiply by one | |||||
| MATCH_REPLACE(node, PBinOperation(prim::kPrimScalarMul, x, one_scalar_, true), x); // Scalar Mul by one | |||||
| // Scalar Mul by zero | |||||
| MATCH_REPLACE(node, PBinOperation(prim::kPrimScalarMul, x, zero_scalar_, true), zero_scalar_.NewValue()); | |||||
| // Scalar Mul by zero | |||||
| MATCH_REPLACE(node, PBinOperation(prim::kPrimScalarMul, x, zero_scalar_, true), zero_scalar_.NewValue()); | |||||
| } | |||||
| // Prim Eliminate (identity) | // Prim Eliminate (identity) | ||||
| MATCH_REPLACE(node, PPrimitive(prim::kPrimIdentity, x), x); | MATCH_REPLACE(node, PPrimitive(prim::kPrimIdentity, x), x); | ||||