| @@ -70,6 +70,7 @@ OptimizeIRPassLib::OptimizeIRPassLib() { | |||||
| same_eliminate_ = MakeSubstitution(SameEliminater(), "same_eliminate", prim::kPrimSameTypeShape); | same_eliminate_ = MakeSubstitution(SameEliminater(), "same_eliminate", prim::kPrimSameTypeShape); | ||||
| check_bprop_eliminate_ = MakeSubstitution(CheckBpropEliminater(), "check_bprop_eliminate", prim::kPrimCheckBprop); | check_bprop_eliminate_ = MakeSubstitution(CheckBpropEliminater(), "check_bprop_eliminate", prim::kPrimCheckBprop); | ||||
| reset_defer_inline_ = MakeSubstitution(ResetDeferInline(), "reset_defer_inline", IsValueNode<FuncGraph>); | reset_defer_inline_ = MakeSubstitution(ResetDeferInline(), "reset_defer_inline", IsValueNode<FuncGraph>); | ||||
| depend_value_elim_ = MakeSubstitution(DependValueElim(), "depend_value_elim", prim::kPrimDepend); | |||||
| // Env Item Eliminate | // Env Item Eliminate | ||||
| env_get_item_eliminate_ = MakeSubstitution(EnvGetItemEliminater(), "env_get_item_eliminate", prim::kPrimEnvGetItem); | env_get_item_eliminate_ = MakeSubstitution(EnvGetItemEliminater(), "env_get_item_eliminate", prim::kPrimEnvGetItem); | ||||
| @@ -48,6 +48,7 @@ class OptimizeIRPassLib { | |||||
| SubstitutionPtr same_eliminate_; | SubstitutionPtr same_eliminate_; | ||||
| SubstitutionPtr check_bprop_eliminate_; | SubstitutionPtr check_bprop_eliminate_; | ||||
| SubstitutionPtr reset_defer_inline_; | SubstitutionPtr reset_defer_inline_; | ||||
| SubstitutionPtr depend_value_elim_; | |||||
| // Env Item Eliminate | // Env Item Eliminate | ||||
| SubstitutionPtr env_get_item_eliminate_; | SubstitutionPtr env_get_item_eliminate_; | ||||
| @@ -24,9 +24,11 @@ | |||||
| #include "optimizer/optimizer.h" | #include "optimizer/optimizer.h" | ||||
| #include "optimizer/irpass.h" | #include "optimizer/irpass.h" | ||||
| #include "ir/optimizer_caller.h" | |||||
| #include "optimizer/irpass/prim_eliminate.h" | #include "optimizer/irpass/prim_eliminate.h" | ||||
| #include "ir/visitor.h" | #include "ir/visitor.h" | ||||
| #include "operator/ops.h" | #include "operator/ops.h" | ||||
| #include "ir/pattern_matcher.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace opt { | namespace opt { | ||||
| @@ -191,6 +193,17 @@ class ZeroLikeFillZero : public AnfVisitor { | |||||
| AnfNodePtr y_{nullptr}; | AnfNodePtr y_{nullptr}; | ||||
| PrimitivePtr PrimFill_, PrimShape_, PrimDType_; | PrimitivePtr PrimFill_, PrimShape_, PrimDType_; | ||||
| }; | }; | ||||
| // {prim::kPrimDepend, X, ValueCond}->X | |||||
| class DependValueElim : public OptimizerCaller { | |||||
| public: | |||||
| AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { | |||||
| PatternNode<AnfNodePtr> x, cond; | |||||
| MATCH_REPLACE_IF(node, PPrimitive(prim::kPrimDepend, x, cond), x, IsVNode(cond.GetNode(node))); | |||||
| return nullptr; | |||||
| } | |||||
| }; | |||||
| } // namespace irpass | } // namespace irpass | ||||
| } // namespace opt | } // namespace opt | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -108,6 +108,7 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) { | |||||
| irpass.incorporate_env_getitem_, | irpass.incorporate_env_getitem_, | ||||
| irpass.incorporate_env_getitem_switch_, | irpass.incorporate_env_getitem_switch_, | ||||
| irpass.new_env_get_item_, | irpass.new_env_get_item_, | ||||
| irpass.depend_value_elim_, | |||||
| }); | }); | ||||
| opt::OptPassConfig a_3 = opt::OptPassConfig({ | opt::OptPassConfig a_3 = opt::OptPassConfig({ | ||||
| irpass.same_eliminate_, | irpass.same_eliminate_, | ||||
| @@ -257,6 +257,14 @@ TEST_F(TestOptLib, test_elim_transpose) { | |||||
| ASSERT_TRUE(CheckOpt(before, after, patterns)); | ASSERT_TRUE(CheckOpt(before, after, patterns)); | ||||
| } | } | ||||
| TEST_F(TestOptLib, test_elim_depend_value) { | |||||
| FuncGraphPtr before = getPyFun.CallAndParseRet("test_elim_depend_value", "before"); | |||||
| FuncGraphPtr after = getPyFun.CallAndParseRet("test_elim_depend_value", "after"); | |||||
| auto patterns = std::vector<SubstitutionPtr>({irpass.depend_value_elim_}); | |||||
| ASSERT_TRUE(CheckOpt(before, after, patterns)); | |||||
| } | |||||
| TEST_F(TestOptLib, test_elim_tile_multiply_one) { | TEST_F(TestOptLib, test_elim_tile_multiply_one) { | ||||
| FuncGraphPtr before = getPyFun.CallAndParseRet("test_elim_tile_multiply_one", "before"); | FuncGraphPtr before = getPyFun.CallAndParseRet("test_elim_tile_multiply_one", "before"); | ||||
| FuncGraphPtr after = getPyFun.CallAndParseRet("test_elim_tile_multiply_one", "after"); | FuncGraphPtr after = getPyFun.CallAndParseRet("test_elim_tile_multiply_one", "after"); | ||||
| @@ -494,6 +494,21 @@ def test_elim_transpose(tag): | |||||
| return fns[tag] | return fns[tag] | ||||
| def test_elim_depend_value(tag): | |||||
| """ test_elim_depend_value """ | |||||
| fns = FnDict() | |||||
| depend = P.Depend() | |||||
| @fns | |||||
| def before(x): | |||||
| return depend(x, None) | |||||
| @fns | |||||
| def after(x): | |||||
| return x | |||||
| return fns[tag] | |||||
| def test_elim_tile_multiply_one(tag): | def test_elim_tile_multiply_one(tag): | ||||
| """ test_elim_tile_multiply_one """ | """ test_elim_tile_multiply_one """ | ||||