Browse Source

add pass to eliminate depend value

tags/v0.5.0-beta
BowenK 5 years ago
parent
commit
8f29e7242f
6 changed files with 39 additions and 0 deletions
  1. +1
    -0
      mindspore/ccsrc/optimizer/irpass.cc
  2. +1
    -0
      mindspore/ccsrc/optimizer/irpass.h
  3. +13
    -0
      mindspore/ccsrc/optimizer/irpass/special_op_eliminate.h
  4. +1
    -0
      mindspore/ccsrc/pipeline/pass.cc
  5. +8
    -0
      tests/ut/cpp/optimizer/lib_test.cc
  6. +15
    -0
      tests/ut/cpp/python_input/gtest_input/optimizer/opt_test.py

+ 1
- 0
mindspore/ccsrc/optimizer/irpass.cc View File

@@ -70,6 +70,7 @@ OptimizeIRPassLib::OptimizeIRPassLib() {
same_eliminate_ = MakeSubstitution(SameEliminater(), "same_eliminate", prim::kPrimSameTypeShape);
check_bprop_eliminate_ = MakeSubstitution(CheckBpropEliminater(), "check_bprop_eliminate", prim::kPrimCheckBprop);
reset_defer_inline_ = MakeSubstitution(ResetDeferInline(), "reset_defer_inline", IsValueNode<FuncGraph>);
depend_value_elim_ = MakeSubstitution(DependValueElim(), "depend_value_elim", prim::kPrimDepend);

// Env Item Eliminate
env_get_item_eliminate_ = MakeSubstitution(EnvGetItemEliminater(), "env_get_item_eliminate", prim::kPrimEnvGetItem);


+ 1
- 0
mindspore/ccsrc/optimizer/irpass.h View File

@@ -48,6 +48,7 @@ class OptimizeIRPassLib {
SubstitutionPtr same_eliminate_;
SubstitutionPtr check_bprop_eliminate_;
SubstitutionPtr reset_defer_inline_;
SubstitutionPtr depend_value_elim_;

// Env Item Eliminate
SubstitutionPtr env_get_item_eliminate_;


+ 13
- 0
mindspore/ccsrc/optimizer/irpass/special_op_eliminate.h View File

@@ -24,9 +24,11 @@

#include "optimizer/optimizer.h"
#include "optimizer/irpass.h"
#include "ir/optimizer_caller.h"
#include "optimizer/irpass/prim_eliminate.h"
#include "ir/visitor.h"
#include "operator/ops.h"
#include "ir/pattern_matcher.h"

namespace mindspore {
namespace opt {
@@ -191,6 +193,17 @@ class ZeroLikeFillZero : public AnfVisitor {
AnfNodePtr y_{nullptr};
PrimitivePtr PrimFill_, PrimShape_, PrimDType_;
};

// {prim::kPrimDepend, X, ValueCond}->X
class DependValueElim : public OptimizerCaller {
public:
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
PatternNode<AnfNodePtr> x, cond;
MATCH_REPLACE_IF(node, PPrimitive(prim::kPrimDepend, x, cond), x, IsVNode(cond.GetNode(node)));
return nullptr;
}
};

} // namespace irpass
} // namespace opt
} // namespace mindspore


+ 1
- 0
mindspore/ccsrc/pipeline/pass.cc View File

@@ -108,6 +108,7 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) {
irpass.incorporate_env_getitem_,
irpass.incorporate_env_getitem_switch_,
irpass.new_env_get_item_,
irpass.depend_value_elim_,
});
opt::OptPassConfig a_3 = opt::OptPassConfig({
irpass.same_eliminate_,


+ 8
- 0
tests/ut/cpp/optimizer/lib_test.cc View File

@@ -257,6 +257,14 @@ TEST_F(TestOptLib, test_elim_transpose) {
ASSERT_TRUE(CheckOpt(before, after, patterns));
}

TEST_F(TestOptLib, test_elim_depend_value) {
FuncGraphPtr before = getPyFun.CallAndParseRet("test_elim_depend_value", "before");
FuncGraphPtr after = getPyFun.CallAndParseRet("test_elim_depend_value", "after");

auto patterns = std::vector<SubstitutionPtr>({irpass.depend_value_elim_});
ASSERT_TRUE(CheckOpt(before, after, patterns));
}

TEST_F(TestOptLib, test_elim_tile_multiply_one) {
FuncGraphPtr before = getPyFun.CallAndParseRet("test_elim_tile_multiply_one", "before");
FuncGraphPtr after = getPyFun.CallAndParseRet("test_elim_tile_multiply_one", "after");


+ 15
- 0
tests/ut/cpp/python_input/gtest_input/optimizer/opt_test.py View File

@@ -494,6 +494,21 @@ def test_elim_transpose(tag):

return fns[tag]

def test_elim_depend_value(tag):
""" test_elim_depend_value """
fns = FnDict()
depend = P.Depend()

@fns
def before(x):
return depend(x, None)

@fns
def after(x):
return x

return fns[tag]


def test_elim_tile_multiply_one(tag):
""" test_elim_tile_multiply_one """


Loading…
Cancel
Save