Browse Source

!2333 remove addnclass in opt

Merge pull request !2333 from xychow/remove-addnclass-in-opt
tags/v0.6.0-beta
mindspore-ci-bot Gitee 5 years ago
parent
commit
2711a628fb
2 changed files with 8 additions and 17 deletions
  1. +6
    -13
      mindspore/ccsrc/optimizer/irpass/merge_addn.h
  2. +2
    -4
      tests/ut/cpp/python_input/gtest_input/optimizer/opt_test.py

+ 6
- 13
mindspore/ccsrc/optimizer/irpass/merge_addn.h View File

@@ -35,9 +35,6 @@ namespace irpass {
// {{PrimAddNClass}, {prim::kPrimMakeTuple, Ys, Xs}}
class MergeAddN : public AnfVisitor {
public:
MergeAddN() : PrimAddN_(prim::GetPythonOps("AddN", "mindspore.ops.operations")) {}
~MergeAddN() override = default;

AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override {
Reset();
optimizer_ = optimizer;
@@ -47,15 +44,15 @@ class MergeAddN : public AnfVisitor {
return nullptr;
}

auto fg = node->func_graph();
// {PrimAddNClass}
auto addn_node = fg->NewCNode({NewValueNode(PrimAddN_)});
auto cnode = node->cast<CNodePtr>();
auto addn = NewValueNode(GetValueNode(cnode->input(0)));

// {prim::kPrimMakeTuple, Xs, Ys}, {prim::kPrimMakeTuple, Ys, Xs}
(void)args_.insert(args_.begin(), NewValueNode(prim::kPrimMakeTuple));
auto fg = node->func_graph();
auto make_node = fg->NewCNode(args_);

return fg->NewCNode({addn_node, make_node});
return fg->NewCNode({addn, make_node});
}

void Visit(const CNodePtr &cnode) override {
@@ -127,7 +124,6 @@ class MergeAddN : public AnfVisitor {
}

private:
ValuePtr PrimAddN_;
OptimizerPtr optimizer_{nullptr};
std::vector<AnfNodePtr> Xs_{}, Ys_{}, args_{};
bool is_inner_{false}, is_outer_{false}, is_match_{false};
@@ -136,9 +132,6 @@ class MergeAddN : public AnfVisitor {
// {PrimAddN, {kPrimMakeTuple, Xs}}
class AddNZeroFilter : public AnfVisitor {
public:
AddNZeroFilter() : PrimAddN_(prim::GetPythonOps("AddN", "mindspore.ops.operations")) {}
~AddNZeroFilter() override = default;

AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
Reset();
AnfVisitor::Match(prim::kPrimAddN, {IsCNode})(node);
@@ -161,8 +154,9 @@ class AddNZeroFilter : public AnfVisitor {
return nullptr;
}

auto cnode = node->cast<CNodePtr>();
auto addn = NewValueNode(GetValueNode(cnode->input(0)));
auto fg = node->func_graph();
auto addn = fg->NewCNode({NewValueNode(PrimAddN_)});
auto make_tuple = fg->NewCNode(filtered_Xs_);
return fg->NewCNode({addn, make_tuple});
}
@@ -193,7 +187,6 @@ class AddNZeroFilter : public AnfVisitor {
}

private:
ValuePtr PrimAddN_;
std::vector<AnfNodePtr> filtered_Xs_{}, Xs_{};
bool has_zero_like_{false};
};


+ 2
- 4
tests/ut/cpp/python_input/gtest_input/optimizer/opt_test.py View File

@@ -875,7 +875,6 @@ def test_merge_addn(tag):
""" test_merge_addn """
fns = FnDict()
addn = P.AddN()
AddN = P.AddN

@fns
def before(x, y, z, a):
@@ -883,7 +882,7 @@ def test_merge_addn(tag):

@fns
def after(x, y, z, a):
return AddN()((a, x, y, z))
return addn((a, x, y, z))

return fns[tag]

@@ -892,7 +891,6 @@ def test_addn_zero(tag):
""" test_addn_zero """
fns = FnDict()
addn = P.AddN()
AddN = P.AddN
zero_tensor = Primitive('ZerosLike')

@fns
@@ -901,7 +899,7 @@ def test_addn_zero(tag):

@fns
def after(x, y, z, a):
return AddN()((a, z))
return addn((a, z))

@fns
def before_2(x, y, z, a):


Loading…
Cancel
Save