|
|
|
@@ -35,9 +35,6 @@ namespace irpass { |
|
|
|
// {{PrimAddNClass}, {prim::kPrimMakeTuple, Ys, Xs}} |
|
|
|
class MergeAddN : public AnfVisitor { |
|
|
|
public: |
|
|
|
MergeAddN() : PrimAddN_(prim::GetPythonOps("AddN", "mindspore.ops.operations")) {} |
|
|
|
~MergeAddN() override = default; |
|
|
|
|
|
|
|
AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override { |
|
|
|
Reset(); |
|
|
|
optimizer_ = optimizer; |
|
|
|
@@ -47,15 +44,15 @@ class MergeAddN : public AnfVisitor { |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
|
|
|
|
auto fg = node->func_graph(); |
|
|
|
// {PrimAddNClass} |
|
|
|
auto addn_node = fg->NewCNode({NewValueNode(PrimAddN_)}); |
|
|
|
auto cnode = node->cast<CNodePtr>(); |
|
|
|
auto addn = NewValueNode(GetValueNode(cnode->input(0))); |
|
|
|
|
|
|
|
// {prim::kPrimMakeTuple, Xs, Ys}, {prim::kPrimMakeTuple, Ys, Xs} |
|
|
|
(void)args_.insert(args_.begin(), NewValueNode(prim::kPrimMakeTuple)); |
|
|
|
auto fg = node->func_graph(); |
|
|
|
auto make_node = fg->NewCNode(args_); |
|
|
|
|
|
|
|
return fg->NewCNode({addn_node, make_node}); |
|
|
|
return fg->NewCNode({addn, make_node}); |
|
|
|
} |
|
|
|
|
|
|
|
void Visit(const CNodePtr &cnode) override { |
|
|
|
@@ -127,7 +124,6 @@ class MergeAddN : public AnfVisitor { |
|
|
|
} |
|
|
|
|
|
|
|
private: |
|
|
|
ValuePtr PrimAddN_; |
|
|
|
OptimizerPtr optimizer_{nullptr}; |
|
|
|
std::vector<AnfNodePtr> Xs_{}, Ys_{}, args_{}; |
|
|
|
bool is_inner_{false}, is_outer_{false}, is_match_{false}; |
|
|
|
@@ -136,9 +132,6 @@ class MergeAddN : public AnfVisitor { |
|
|
|
// {PrimAddN, {kPrimMakeTuple, Xs}} |
|
|
|
class AddNZeroFilter : public AnfVisitor { |
|
|
|
public: |
|
|
|
AddNZeroFilter() : PrimAddN_(prim::GetPythonOps("AddN", "mindspore.ops.operations")) {} |
|
|
|
~AddNZeroFilter() override = default; |
|
|
|
|
|
|
|
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { |
|
|
|
Reset(); |
|
|
|
AnfVisitor::Match(prim::kPrimAddN, {IsCNode})(node); |
|
|
|
@@ -161,8 +154,9 @@ class AddNZeroFilter : public AnfVisitor { |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
|
|
|
|
auto cnode = node->cast<CNodePtr>(); |
|
|
|
auto addn = NewValueNode(GetValueNode(cnode->input(0))); |
|
|
|
auto fg = node->func_graph(); |
|
|
|
auto addn = fg->NewCNode({NewValueNode(PrimAddN_)}); |
|
|
|
auto make_tuple = fg->NewCNode(filtered_Xs_); |
|
|
|
return fg->NewCNode({addn, make_tuple}); |
|
|
|
} |
|
|
|
@@ -193,7 +187,6 @@ class AddNZeroFilter : public AnfVisitor { |
|
|
|
} |
|
|
|
|
|
|
|
private: |
|
|
|
ValuePtr PrimAddN_; |
|
|
|
std::vector<AnfNodePtr> filtered_Xs_{}, Xs_{}; |
|
|
|
bool has_zero_like_{false}; |
|
|
|
}; |
|
|
|
|