|
|
|
@@ -49,6 +49,7 @@ using PatternNodeMap = std::unordered_map<PatternPtr, AnfNodePtr, PatternHasher, |
|
|
|
class Pattern : public Base { |
|
|
|
public: |
|
|
|
Pattern() : unique_name_(std::to_string(g_id_++)) {} |
|
|
|
~Pattern() = default; |
|
|
|
virtual MatchResultPtr match(const AnfNodePtr &node) { return nullptr; } |
|
|
|
virtual bool operator==(const Pattern &other) const { return unique_name_ == other.unique_name_; } |
|
|
|
string unique_name() const { return unique_name_; } |
|
|
|
@@ -82,6 +83,7 @@ struct PatternHasher { |
|
|
|
class IsPrimTypeOf : public Pattern { |
|
|
|
public: |
|
|
|
IsPrimTypeOf() { unique_name_ = std::to_string(g_id_++); } |
|
|
|
~IsPrimTypeOf() = default; |
|
|
|
IsPrimTypeOf(vector<PrimitivePyPtr> prims, string name, bool should_replace) |
|
|
|
: primitives_(prims), name_(name), matched_prim_(nullptr) { |
|
|
|
unique_name_ = std::to_string(g_id_++) + "_" + name; |
|
|
|
@@ -120,6 +122,7 @@ class IsPrimTypeOf : public Pattern { |
|
|
|
class CallWith : public Pattern { |
|
|
|
public: |
|
|
|
CallWith() { unique_name_ = std::to_string(g_id_++); } |
|
|
|
~CallWith() = default; |
|
|
|
CallWith(PatternPtr prim_pattern, vector<PatternPtr> inputs, bool should_replace) { |
|
|
|
// NOTE: should_replace is ignored in this case, since each sub-pattern has its own setting |
|
|
|
prim_pattern_ = prim_pattern; |
|
|
|
@@ -154,6 +157,7 @@ class CallWith : public Pattern { |
|
|
|
class IsIn : public Pattern { |
|
|
|
public: |
|
|
|
IsIn() { unique_name_ = std::to_string(g_id_++); } |
|
|
|
~IsIn() = default; |
|
|
|
explicit IsIn(vector<PatternPtr> patterns) : patterns_(patterns) { |
|
|
|
unique_name_ = std::to_string(g_id_++); |
|
|
|
for (auto &iter : patterns) { |
|
|
|
@@ -170,6 +174,7 @@ class IsIn : public Pattern { |
|
|
|
class IsNot : public Pattern { |
|
|
|
public: |
|
|
|
IsNot() { unique_name_ = std::to_string(g_id_++); } |
|
|
|
~IsNot() = default; |
|
|
|
explicit IsNot(vector<PatternPtr> patterns) : patterns_(patterns) { |
|
|
|
unique_name_ = std::to_string(g_id_++); |
|
|
|
for (auto &iter : patterns) { |
|
|
|
@@ -186,6 +191,7 @@ class IsNot : public Pattern { |
|
|
|
class AnyPattern : public Pattern { |
|
|
|
public: |
|
|
|
AnyPattern() { unique_name_ = std::to_string(g_id_++) + "_AnyPattern"; } |
|
|
|
~AnyPattern() = default; |
|
|
|
MS_DECLARE_PARENT(AnyPattern, Pattern); |
|
|
|
MatchResultPtr match(const AnfNodePtr &node) override; |
|
|
|
}; |
|
|
|
@@ -193,6 +199,7 @@ class AnyPattern : public Pattern { |
|
|
|
class NewTensor : public Pattern { |
|
|
|
public: |
|
|
|
NewTensor() { unique_name_ = std::to_string(g_id_++); } |
|
|
|
~NewTensor() = default; |
|
|
|
explicit NewTensor(tensor::TensorPtr input_tensor) : input_tensor_(input_tensor) { should_replace_ = false; } |
|
|
|
MS_DECLARE_PARENT(NewTensor, Pattern); |
|
|
|
MatchResultPtr match(const AnfNodePtr &node) override { |
|
|
|
@@ -207,6 +214,7 @@ class NewTensor : public Pattern { |
|
|
|
class MatchResult { |
|
|
|
public: |
|
|
|
MatchResult() {} |
|
|
|
~MatchResult() = default; |
|
|
|
void add_entry(PatternPtr pattern, AnfNodePtr node) { match_result_[pattern] = node; } |
|
|
|
PatternNodeMap _result() { return match_result_; } |
|
|
|
AnfNodePtr get_node(const PatternPtr &pattern); |
|
|
|
|