From 038bd6cbf14996d16b0db189868c74ade33cdfad Mon Sep 17 00:00:00 2001 From: lvliang Date: Sat, 8 Aug 2020 14:11:27 +0800 Subject: [PATCH] enable-the-identity-op-can-be-eliminated-in-pynative-mode --- .../optimizer/irpass/arithmetic_simplify.cc | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/mindspore/ccsrc/frontend/optimizer/irpass/arithmetic_simplify.cc b/mindspore/ccsrc/frontend/optimizer/irpass/arithmetic_simplify.cc index 7e2c989f49..ecdc44d25a 100644 --- a/mindspore/ccsrc/frontend/optimizer/irpass/arithmetic_simplify.cc +++ b/mindspore/ccsrc/frontend/optimizer/irpass/arithmetic_simplify.cc @@ -20,9 +20,6 @@ namespace mindspore { namespace opt { namespace irpass { AnfNodePtr ArithmeticSimplify::operator()(const OptimizerPtr &, const AnfNodePtr &node) { - if (MsContext::GetInstance()->execution_mode() == kPynativeMode) { - return nullptr; - } PatternNode x, y, z, xs; PConstant one_(node, false, 1); PConstant one_scalar_(node, false, 1, true); @@ -32,14 +29,16 @@ AnfNodePtr ArithmeticSimplify::operator()(const OptimizerPtr &, const AnfNodePtr PConstant const_2(node); PConstant any_const(node); - MATCH_REPLACE(node, x + zero_, x); // Add by zero - MATCH_REPLACE(node, x + zero_scalar_, x); // Add by zero - MATCH_REPLACE(node, PBinOperation(prim::kPrimScalarAdd, x, zero_scalar_, true), x); // Scalar Add by zero - MATCH_REPLACE_IF(node, x * one_, any_const.WithValueOf(x), !one_.CheckFunc(IsParam, node)); // Multiply by one - MATCH_REPLACE(node, PBinOperation(prim::kPrimScalarMul, x, one_scalar_, true), x); // Scalar Mul by one + if (MsContext::GetInstance()->execution_mode() != kPynativeMode) { + MATCH_REPLACE(node, x + zero_, x); // Add by zero + MATCH_REPLACE(node, x + zero_scalar_, x); // Add by zero + MATCH_REPLACE(node, PBinOperation(prim::kPrimScalarAdd, x, zero_scalar_, true), x); // Scalar Add by zero + MATCH_REPLACE_IF(node, x * one_, any_const.WithValueOf(x), !one_.CheckFunc(IsParam, node)); // Multiply by one + MATCH_REPLACE(node, PBinOperation(prim::kPrimScalarMul, x, one_scalar_, true), x); // Scalar Mul by one - // Scalar Mul by zero - MATCH_REPLACE(node, PBinOperation(prim::kPrimScalarMul, x, zero_scalar_, true), zero_scalar_.NewValue()); + // Scalar Mul by zero + MATCH_REPLACE(node, PBinOperation(prim::kPrimScalarMul, x, zero_scalar_, true), zero_scalar_.NewValue()); + } // Prim Eliminate (identity) MATCH_REPLACE(node, PPrimitive(prim::kPrimIdentity, x), x);