You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

pattern.h 7.7 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_PATTERN_H_
  17. #define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_PATTERN_H_
  18. #include <string>
  19. #include <memory>
  20. #include <vector>
  21. #include <unordered_map>
  22. #include "base/base.h"
  23. #include "ir/anf.h"
  24. #include "ir/tensor.h"
  25. #include "pybind_api/ir/primitive_py.h"
  26. #include "pybind_api/ir/tensor_py.h"
  27. namespace mindspore {
  28. namespace opt {
  29. namespace python_pass {
  30. using std::string;
  31. using std::vector;
  32. class MatchResult;
  33. using MatchResultPtr = std::shared_ptr<MatchResult>;
  34. class Pattern;
  35. using PatternPtr = std::shared_ptr<Pattern>;
  36. class IsPrimTypeOf;
  37. using IsPrimTypeOfPtr = std::shared_ptr<IsPrimTypeOf>;
  38. class CallWith;
  39. using CallWithPtr = std::shared_ptr<CallWith>;
  40. class NewTensor;
  41. using NewTensorPtr = std::shared_ptr<NewTensor>;
  42. struct PatternHasher;
  43. struct PatternEqual;
  44. using PatternNodeMap = std::unordered_map<PatternPtr, AnfNodePtr, PatternHasher, PatternEqual>;
  45. class Pattern : public Base {
  46. public:
  47. Pattern() : unique_name_(std::to_string(g_id_++)) {}
  48. ~Pattern() = default;
  49. virtual MatchResultPtr match(const AnfNodePtr &node) { return nullptr; }
  50. virtual bool operator==(const Pattern &other) const { return unique_name_ == other.unique_name_; }
  51. string unique_name() const { return unique_name_; }
  52. vector<PatternPtr> inputs() { return inputs_; }
  53. bool should_replace() { return should_replace_; }
  54. virtual void reset() {}
  55. protected:
  56. static int g_id_;
  57. // NOTE: To ensure uniqueness of the name, raise g_id_ by 1 every time a pattern got constructed
  58. string unique_name_;
  59. vector<PatternPtr> inputs_;
  60. bool should_replace_ = true;
  61. };
  62. struct PatternEqual {
  63. bool operator()(PatternPtr const &p1, PatternPtr const &p2) const {
  64. MS_EXCEPTION_IF_NULL(p1);
  65. MS_EXCEPTION_IF_NULL(p2);
  66. return p1->unique_name() == p2->unique_name();
  67. }
  68. };
  69. struct PatternHasher {
  70. std::size_t operator()(PatternPtr const &p) const {
  71. MS_EXCEPTION_IF_NULL(p);
  72. return std::hash<string>()(p->unique_name());
  73. }
  74. };
  75. class IsPrimTypeOf : public Pattern {
  76. public:
  77. IsPrimTypeOf() { unique_name_ = std::to_string(g_id_++); }
  78. ~IsPrimTypeOf() = default;
  79. IsPrimTypeOf(vector<PrimitivePyPtr> prims, string name, bool should_replace)
  80. : primitives_(prims), name_(name), matched_prim_(nullptr) {
  81. unique_name_ = std::to_string(g_id_++) + "_" + name;
  82. should_replace_ = should_replace;
  83. if (!should_replace) {
  84. matched_prim_ = prims[0];
  85. }
  86. }
  87. IsPrimTypeOf(vector<string> types, string name, bool should_replace) : types_(types), name_(name) {
  88. unique_name_ = std::to_string(g_id_++) + "_" + name;
  89. // Make primitives_
  90. for (auto &iter : types) {
  91. primitives_.push_back(std::make_shared<PrimitivePy>(iter, py::cast(nullptr)));
  92. }
  93. should_replace_ = should_replace;
  94. if (!should_replace) {
  95. matched_prim_ = primitives_[0];
  96. }
  97. }
  98. MS_DECLARE_PARENT(IsPrimTypeOf, Pattern);
  99. MatchResultPtr match(const AnfNodePtr &node) override;
  100. PrimitivePyPtr matched_primitive() { return matched_prim_; }
  101. void reset() override {
  102. if (should_replace_) {
  103. matched_prim_ = nullptr;
  104. }
  105. }
  106. private:
  107. vector<string> types_;
  108. vector<PrimitivePyPtr> primitives_;
  109. string name_;
  110. PrimitivePyPtr matched_prim_;
  111. };
  112. class CallWith : public Pattern {
  113. public:
  114. CallWith() { unique_name_ = std::to_string(g_id_++); }
  115. ~CallWith() = default;
  116. CallWith(PatternPtr prim_pattern, vector<PatternPtr> inputs, bool should_replace) {
  117. // NOTE: should_replace is ignored in this case, since each sub-pattern has its own setting
  118. prim_pattern_ = prim_pattern;
  119. unique_name_ = std::to_string(g_id_++) + prim_pattern->unique_name();
  120. inputs_ = inputs;
  121. should_replace_ = should_replace;
  122. }
  123. CallWith(PrimitivePyPtr prim, vector<PatternPtr> inputs, bool should_replace) {
  124. prim_ = prim;
  125. unique_name_ = std::to_string(g_id_++) + prim_->ToString();
  126. inputs_ = inputs;
  127. should_replace_ = should_replace;
  128. }
  129. CallWith(string prim_str, vector<PatternPtr> inputs, bool should_replace) {
  130. prim_ = std::make_shared<PrimitivePy>(prim_str, py::cast(nullptr));
  131. unique_name_ = std::to_string(g_id_++) + prim_->ToString();
  132. inputs_ = inputs;
  133. should_replace_ = should_replace;
  134. }
  135. MS_DECLARE_PARENT(CallWith, Pattern);
  136. MatchResultPtr match(const AnfNodePtr &node) override;
  137. PrimitivePtr prim_value() { return prim_; }
  138. PatternPtr prim_pattern() { return prim_pattern_; }
  139. private:
  140. PatternPtr prim_pattern_ = nullptr;
  141. PrimitivePtr prim_ = nullptr;
  142. vector<string> types_;
  143. string name_;
  144. };
  145. class IsIn : public Pattern {
  146. public:
  147. IsIn() { unique_name_ = std::to_string(g_id_++); }
  148. ~IsIn() = default;
  149. explicit IsIn(vector<PatternPtr> patterns) : patterns_(patterns) {
  150. unique_name_ = std::to_string(g_id_++);
  151. for (auto &iter : patterns) {
  152. unique_name_ = unique_name_ + "_" + iter->unique_name();
  153. }
  154. }
  155. MS_DECLARE_PARENT(IsIn, Pattern);
  156. MatchResultPtr match(const AnfNodePtr &node) override;
  157. private:
  158. vector<PatternPtr> patterns_;
  159. };
  160. class IsNot : public Pattern {
  161. public:
  162. IsNot() { unique_name_ = std::to_string(g_id_++); }
  163. ~IsNot() = default;
  164. explicit IsNot(vector<PatternPtr> patterns) : patterns_(patterns) {
  165. unique_name_ = std::to_string(g_id_++);
  166. for (auto &iter : patterns) {
  167. unique_name_ = "IsNot_" + unique_name_ + "_" + iter->unique_name();
  168. }
  169. }
  170. MS_DECLARE_PARENT(IsNot, Pattern);
  171. MatchResultPtr match(const AnfNodePtr &node) override;
  172. private:
  173. vector<PatternPtr> patterns_;
  174. };
  175. class AnyPattern : public Pattern {
  176. public:
  177. AnyPattern() { unique_name_ = std::to_string(g_id_++) + "_AnyPattern"; }
  178. ~AnyPattern() = default;
  179. MS_DECLARE_PARENT(AnyPattern, Pattern);
  180. MatchResultPtr match(const AnfNodePtr &node) override;
  181. };
  182. class NewTensor : public Pattern {
  183. public:
  184. NewTensor() { unique_name_ = std::to_string(g_id_++); }
  185. ~NewTensor() = default;
  186. explicit NewTensor(tensor::TensorPtr input_tensor) : input_tensor_(input_tensor) { should_replace_ = false; }
  187. MS_DECLARE_PARENT(NewTensor, Pattern);
  188. MatchResultPtr match(const AnfNodePtr &node) override {
  189. MS_LOG(EXCEPTION) << "Find NewTensor in pattern, NewTensor should only appear in the target.\n";
  190. }
  191. tensor::TensorPtr input_tensor() { return input_tensor_; }
  192. private:
  193. tensor::TensorPtr input_tensor_;
  194. };
  195. class MatchResult {
  196. public:
  197. MatchResult() {}
  198. ~MatchResult() = default;
  199. void add_entry(PatternPtr pattern, AnfNodePtr node) { match_result_[pattern] = node; }
  200. PatternNodeMap _result() { return match_result_; }
  201. AnfNodePtr get_node(const PatternPtr &pattern);
  202. void merge(const MatchResultPtr &other_result);
  203. void clear() { match_result_.clear(); }
  204. void dump() {
  205. MS_LOG(DEBUG) << "match_result_.size: " + std::to_string(match_result_.size()) + "\n";
  206. for (auto &iter : match_result_) {
  207. MS_LOG(DEBUG) << "Pattern : " + iter.first->unique_name() + " , node : " + iter.second->ToString() + "\n";
  208. }
  209. }
  210. private:
  211. PatternNodeMap match_result_;
  212. };
  213. } // namespace python_pass
  214. } // namespace opt
  215. } // namespace mindspore
  216. #endif // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_PATTERN_H_