Merge pull request !4248 from Giancarlo/pm_update_sparsetensortags/v0.7.0-beta
| @@ -20,10 +20,10 @@ | |||
| #include <vector> | |||
| #include <algorithm> | |||
| #include "frontend/operator/ops.h" | |||
| #include "frontend/optimizer/anf_visitor.h" | |||
| #include "frontend/optimizer/irpass.h" | |||
| #include "frontend/optimizer/optimizer.h" | |||
| #include "frontend/optimizer/anf_visitor.h" | |||
| #include "frontend/operator/ops.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| @@ -31,43 +31,16 @@ namespace irpass { | |||
| // {prim::kPrimRowTensorGetIndices, {prim::kPrimMakeRowTensor, Xs}} | |||
| // {prim::kPrimRowTensorGetValues, {prim::kPrimMakeRowTensor, Xs}} | |||
| // {prim::kPrimRowTensorGetDenseShape, {prim::kPrimMakeRowTensor, Xs}} | |||
| class RowTensorEliminater : public AnfVisitor { | |||
| class RowTensorEliminater : public OptimizerCaller { | |||
| public: | |||
| AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { | |||
| Reset(); | |||
| AnfVisitor::Match(prim::kPrimRowTensorGetIndices, {IsCNode})(node); | |||
| if (is_match_) { | |||
| return tuple_->input(1); | |||
| } | |||
| AnfVisitor::Match(prim::kPrimRowTensorGetValues, {IsCNode})(node); | |||
| if (is_match_) { | |||
| return tuple_->input(2); | |||
| } | |||
| AnfVisitor::Match(prim::kPrimRowTensorGetDenseShape, {IsCNode})(node); | |||
| if (is_match_) { | |||
| return tuple_->input(3); | |||
| } | |||
| PatternNode x, y, z; | |||
| auto slices = PPrimitive(prim::kPrimMakeRowTensor, x, y, z).MinExtraNodes(0); | |||
| MATCH_REPLACE(node, PPrimitive(prim::kPrimRowTensorGetIndices, slices), x); | |||
| MATCH_REPLACE(node, PPrimitive(prim::kPrimRowTensorGetValues, slices), y); | |||
| MATCH_REPLACE(node, PPrimitive(prim::kPrimRowTensorGetDenseShape, slices), z); | |||
| return nullptr; | |||
| } | |||
| void Visit(const CNodePtr &cnode) override { | |||
| if (IsPrimitiveCNode(cnode, prim::kPrimMakeRowTensor)) { | |||
| tuple_ = cnode; | |||
| is_match_ = true; | |||
| } | |||
| } | |||
| void Reset() { | |||
| tuple_ = nullptr; | |||
| is_match_ = false; | |||
| } | |||
| private: | |||
| bool is_match_{false}; | |||
| CNodePtr tuple_{nullptr}; | |||
| }; | |||
| } // namespace irpass | |||
| } // namespace opt | |||
| @@ -20,10 +20,10 @@ | |||
| #include <vector> | |||
| #include <algorithm> | |||
| #include "frontend/operator/ops.h" | |||
| #include "frontend/optimizer/irpass.h" | |||
| #include "frontend/optimizer/optimizer.h" | |||
| #include "ir/visitor.h" | |||
| #include "frontend/operator/ops.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| @@ -31,43 +31,16 @@ namespace irpass { | |||
| // {prim::kPrimSparseTensorGetIndices, {prim::kPrimMakeSparseTensor, Xs}} | |||
| // {prim::kPrimSparseTensorGetValues, {prim::kPrimMakeSparseTensor, Xs}} | |||
| // {prim::kPrimSparseTensorGetDenseShape, {prim::kPrimMakeSparseTensor, Xs}} | |||
| class SparseTensorEliminater : public AnfVisitor { | |||
| class SparseTensorEliminater : public OptimizerCaller { | |||
| public: | |||
| AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { | |||
| Reset(); | |||
| AnfVisitor::Match(prim::kPrimSparseTensorGetIndices, {IsCNode})(node); | |||
| if (is_match_) { | |||
| return tuple_->input(1); | |||
| } | |||
| AnfVisitor::Match(prim::kPrimSparseTensorGetValues, {IsCNode})(node); | |||
| if (is_match_) { | |||
| return tuple_->input(2); | |||
| } | |||
| AnfVisitor::Match(prim::kPrimSparseTensorGetDenseShape, {IsCNode})(node); | |||
| if (is_match_) { | |||
| return tuple_->input(3); | |||
| } | |||
| PatternNode x, y, z; | |||
| auto sparse = PPrimitive(prim::kPrimMakeSparseTensor, x, y, z).MinExtraNodes(0); | |||
| MATCH_REPLACE(node, PPrimitive(prim::kPrimSparseTensorGetIndices, sparse), x); | |||
| MATCH_REPLACE(node, PPrimitive(prim::kPrimSparseTensorGetValues, sparse), y); | |||
| MATCH_REPLACE(node, PPrimitive(prim::kPrimSparseTensorGetDenseShape, sparse), z); | |||
| return nullptr; | |||
| } | |||
| void Visit(const CNodePtr &cnode) override { | |||
| if (IsPrimitiveCNode(cnode, prim::kPrimMakeSparseTensor)) { | |||
| tuple_ = cnode; | |||
| is_match_ = true; | |||
| } | |||
| } | |||
| void Reset() { | |||
| tuple_ = nullptr; | |||
| is_match_ = false; | |||
| } | |||
| private: | |||
| bool is_match_{false}; | |||
| CNodePtr tuple_{nullptr}; | |||
| }; | |||
| } // namespace irpass | |||
| } // namespace opt | |||
| @@ -372,6 +372,8 @@ class PPrimitive : public PBase<PPrimitive<TArgs...> > { | |||
| return *this; | |||
| } | |||
| const AnfNodePtrList &GetCapturedExtraNodes() const { return extra_nodes_; } | |||
| /// Returns the FuncGraph of the original node captured by this Primitive Pattern. | |||
| /// Throws exception if a node was not captured before. | |||
| FuncGraphPtr GetFuncGraph() const { | |||