Browse Source

!4248 Update RowTensorEliminater and IndexedSliceEliminate to Pattern Matcher

Merge pull request !4248 from Giancarlo/pm_update_sparsetensor
tags/v0.7.0-beta
mindspore-ci-bot Gitee 5 years ago
parent
commit
64923214f1
3 changed files with 17 additions and 69 deletions
  1. +8
    -35
      mindspore/ccsrc/frontend/optimizer/irpass/row_tensor_eliminate.h
  2. +7
    -34
      mindspore/ccsrc/frontend/optimizer/irpass/sparse_tensor_eliminate.h
  3. +2
    -0
      mindspore/core/ir/pattern_matcher.h

+ 8
- 35
mindspore/ccsrc/frontend/optimizer/irpass/row_tensor_eliminate.h View File

@@ -20,10 +20,10 @@
#include <vector>
#include <algorithm>

#include "frontend/operator/ops.h"
#include "frontend/optimizer/anf_visitor.h"
#include "frontend/optimizer/irpass.h"
#include "frontend/optimizer/optimizer.h"
#include "frontend/optimizer/anf_visitor.h"
#include "frontend/operator/ops.h"

namespace mindspore {
namespace opt {
@@ -31,43 +31,16 @@ namespace irpass {
// {prim::kPrimRowTensorGetIndices, {prim::kPrimMakeRowTensor, Xs}}
// {prim::kPrimRowTensorGetValues, {prim::kPrimMakeRowTensor, Xs}}
// {prim::kPrimRowTensorGetDenseShape, {prim::kPrimMakeRowTensor, Xs}}
class RowTensorEliminater : public AnfVisitor {
class RowTensorEliminater : public OptimizerCaller {
public:
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
Reset();
AnfVisitor::Match(prim::kPrimRowTensorGetIndices, {IsCNode})(node);

if (is_match_) {
return tuple_->input(1);
}
AnfVisitor::Match(prim::kPrimRowTensorGetValues, {IsCNode})(node);

if (is_match_) {
return tuple_->input(2);
}
AnfVisitor::Match(prim::kPrimRowTensorGetDenseShape, {IsCNode})(node);

if (is_match_) {
return tuple_->input(3);
}
PatternNode x, y, z;
auto slices = PPrimitive(prim::kPrimMakeRowTensor, x, y, z).MinExtraNodes(0);
MATCH_REPLACE(node, PPrimitive(prim::kPrimRowTensorGetIndices, slices), x);
MATCH_REPLACE(node, PPrimitive(prim::kPrimRowTensorGetValues, slices), y);
MATCH_REPLACE(node, PPrimitive(prim::kPrimRowTensorGetDenseShape, slices), z);
return nullptr;
}

void Visit(const CNodePtr &cnode) override {
if (IsPrimitiveCNode(cnode, prim::kPrimMakeRowTensor)) {
tuple_ = cnode;
is_match_ = true;
}
}

void Reset() {
tuple_ = nullptr;
is_match_ = false;
}

private:
bool is_match_{false};
CNodePtr tuple_{nullptr};
};
} // namespace irpass
} // namespace opt


+ 7
- 34
mindspore/ccsrc/frontend/optimizer/irpass/sparse_tensor_eliminate.h View File

@@ -20,10 +20,10 @@
#include <vector>
#include <algorithm>

#include "frontend/operator/ops.h"
#include "frontend/optimizer/irpass.h"
#include "frontend/optimizer/optimizer.h"
#include "ir/visitor.h"
#include "frontend/operator/ops.h"

namespace mindspore {
namespace opt {
@@ -31,43 +31,16 @@ namespace irpass {
// {prim::kPrimSparseTensorGetIndices, {prim::kPrimMakeSparseTensor, Xs}}
// {prim::kPrimSparseTensorGetValues, {prim::kPrimMakeSparseTensor, Xs}}
// {prim::kPrimSparseTensorGetDenseShape, {prim::kPrimMakeSparseTensor, Xs}}
class SparseTensorEliminater : public AnfVisitor {
class SparseTensorEliminater : public OptimizerCaller {
public:
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
Reset();
AnfVisitor::Match(prim::kPrimSparseTensorGetIndices, {IsCNode})(node);

if (is_match_) {
return tuple_->input(1);
}
AnfVisitor::Match(prim::kPrimSparseTensorGetValues, {IsCNode})(node);

if (is_match_) {
return tuple_->input(2);
}
AnfVisitor::Match(prim::kPrimSparseTensorGetDenseShape, {IsCNode})(node);

if (is_match_) {
return tuple_->input(3);
}
PatternNode x, y, z;
auto sparse = PPrimitive(prim::kPrimMakeSparseTensor, x, y, z).MinExtraNodes(0);
MATCH_REPLACE(node, PPrimitive(prim::kPrimSparseTensorGetIndices, sparse), x);
MATCH_REPLACE(node, PPrimitive(prim::kPrimSparseTensorGetValues, sparse), y);
MATCH_REPLACE(node, PPrimitive(prim::kPrimSparseTensorGetDenseShape, sparse), z);
return nullptr;
}

void Visit(const CNodePtr &cnode) override {
if (IsPrimitiveCNode(cnode, prim::kPrimMakeSparseTensor)) {
tuple_ = cnode;
is_match_ = true;
}
}

void Reset() {
tuple_ = nullptr;
is_match_ = false;
}

private:
bool is_match_{false};
CNodePtr tuple_{nullptr};
};
} // namespace irpass
} // namespace opt


+ 2
- 0
mindspore/core/ir/pattern_matcher.h View File

@@ -372,6 +372,8 @@ class PPrimitive : public PBase<PPrimitive<TArgs...> > {
return *this;
}

const AnfNodePtrList &GetCapturedExtraNodes() const { return extra_nodes_; }

/// Returns the FuncGraph of the original node captured by this Primitive Pattern.
/// Throws exception if a node was not captured before.
FuncGraphPtr GetFuncGraph() const {


Loading…
Cancel
Save