diff --git a/mindspore/ccsrc/optimizer/irpass.cc b/mindspore/ccsrc/optimizer/irpass.cc index cc5ef1bfe1..5daf080492 100644 --- a/mindspore/ccsrc/optimizer/irpass.cc +++ b/mindspore/ccsrc/optimizer/irpass.cc @@ -70,6 +70,7 @@ OptimizeIRPassLib::OptimizeIRPassLib() { same_eliminate_ = MakeSubstitution(SameEliminater(), "same_eliminate", prim::kPrimSameTypeShape); check_bprop_eliminate_ = MakeSubstitution(CheckBpropEliminater(), "check_bprop_eliminate", prim::kPrimCheckBprop); reset_defer_inline_ = MakeSubstitution(ResetDeferInline(), "reset_defer_inline", IsValueNode); + depend_value_elim_ = MakeSubstitution(DependValueElim(), "depend_value_elim", prim::kPrimDepend); // Env Item Eliminate env_get_item_eliminate_ = MakeSubstitution(EnvGetItemEliminater(), "env_get_item_eliminate", prim::kPrimEnvGetItem); diff --git a/mindspore/ccsrc/optimizer/irpass.h b/mindspore/ccsrc/optimizer/irpass.h index a388ccb7c8..ac0c6eda6f 100644 --- a/mindspore/ccsrc/optimizer/irpass.h +++ b/mindspore/ccsrc/optimizer/irpass.h @@ -48,6 +48,7 @@ class OptimizeIRPassLib { SubstitutionPtr same_eliminate_; SubstitutionPtr check_bprop_eliminate_; SubstitutionPtr reset_defer_inline_; + SubstitutionPtr depend_value_elim_; // Env Item Eliminate SubstitutionPtr env_get_item_eliminate_; diff --git a/mindspore/ccsrc/optimizer/irpass/special_op_eliminate.h b/mindspore/ccsrc/optimizer/irpass/special_op_eliminate.h index cfefed40c7..ed4ac24148 100644 --- a/mindspore/ccsrc/optimizer/irpass/special_op_eliminate.h +++ b/mindspore/ccsrc/optimizer/irpass/special_op_eliminate.h @@ -24,9 +24,11 @@ #include "optimizer/optimizer.h" #include "optimizer/irpass.h" +#include "ir/optimizer_caller.h" #include "optimizer/irpass/prim_eliminate.h" #include "ir/visitor.h" #include "operator/ops.h" +#include "ir/pattern_matcher.h" namespace mindspore { namespace opt { @@ -191,6 +193,17 @@ class ZeroLikeFillZero : public AnfVisitor { AnfNodePtr y_{nullptr}; PrimitivePtr PrimFill_, PrimShape_, PrimDType_; }; + +// {prim::kPrimDepend, X, ValueCond}->X +class DependValueElim : public OptimizerCaller { + public: + AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { + PatternNode x, cond; + MATCH_REPLACE_IF(node, PPrimitive(prim::kPrimDepend, x, cond), x, IsVNode(cond.GetNode(node))); + return nullptr; + } +}; + } // namespace irpass } // namespace opt } // namespace mindspore diff --git a/mindspore/ccsrc/pipeline/pass.cc b/mindspore/ccsrc/pipeline/pass.cc index e6c1f95b7d..0ffaebac4c 100644 --- a/mindspore/ccsrc/pipeline/pass.cc +++ b/mindspore/ccsrc/pipeline/pass.cc @@ -108,6 +108,7 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) { irpass.incorporate_env_getitem_, irpass.incorporate_env_getitem_switch_, irpass.new_env_get_item_, + irpass.depend_value_elim_, }); opt::OptPassConfig a_3 = opt::OptPassConfig({ irpass.same_eliminate_, diff --git a/tests/ut/cpp/optimizer/lib_test.cc b/tests/ut/cpp/optimizer/lib_test.cc index 3febf049a6..037bcd75d1 100644 --- a/tests/ut/cpp/optimizer/lib_test.cc +++ b/tests/ut/cpp/optimizer/lib_test.cc @@ -257,6 +257,14 @@ TEST_F(TestOptLib, test_elim_transpose) { ASSERT_TRUE(CheckOpt(before, after, patterns)); } +TEST_F(TestOptLib, test_elim_depend_value) { + FuncGraphPtr before = getPyFun.CallAndParseRet("test_elim_depend_value", "before"); + FuncGraphPtr after = getPyFun.CallAndParseRet("test_elim_depend_value", "after"); + + auto patterns = std::vector({irpass.depend_value_elim_}); + ASSERT_TRUE(CheckOpt(before, after, patterns)); +} + TEST_F(TestOptLib, test_elim_tile_multiply_one) { FuncGraphPtr before = getPyFun.CallAndParseRet("test_elim_tile_multiply_one", "before"); FuncGraphPtr after = getPyFun.CallAndParseRet("test_elim_tile_multiply_one", "after"); diff --git a/tests/ut/cpp/python_input/gtest_input/optimizer/opt_test.py b/tests/ut/cpp/python_input/gtest_input/optimizer/opt_test.py index a0102014d8..af8cab902c 100644 --- a/tests/ut/cpp/python_input/gtest_input/optimizer/opt_test.py +++ b/tests/ut/cpp/python_input/gtest_input/optimizer/opt_test.py @@ -494,6 +494,21 @@ def test_elim_transpose(tag): return fns[tag] +def test_elim_depend_value(tag): + """ test_elim_depend_value """ + fns = FnDict() + depend = P.Depend() + + @fns + def before(x): + return depend(x, None) + + @fns + def after(x): + return x + + return fns[tag] + def test_elim_tile_multiply_one(tag): """ test_elim_tile_multiply_one """