You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

pattern.h 9.8 kB

5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284
  1. /**
  2. * Copyright 2020-2021 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_PATTERN_H_
  17. #define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_PATTERN_H_
  18. #include <string>
  19. #include <memory>
  20. #include <vector>
  21. #include <unordered_map>
  22. #include "base/base.h"
  23. #include "ir/anf.h"
  24. #include "ir/tensor.h"
  25. #include "pybind_api/ir/primitive_py.h"
  26. #include "pybind_api/ir/tensor_py.h"
  27. namespace mindspore {
  28. namespace opt {
  29. namespace python_pass {
  30. using std::string;
  31. using std::vector;
  32. class MatchResult;
  33. using MatchResultPtr = std::shared_ptr<MatchResult>;
  34. class Pattern;
  35. using PatternPtr = std::shared_ptr<Pattern>;
  36. class Prim;
  37. using PrimPtr = std::shared_ptr<Prim>;
  38. class Call;
  39. using CallPtr = std::shared_ptr<Call>;
  40. class NewTensor;
  41. using NewTensorPtr = std::shared_ptr<NewTensor>;
  42. class NewParameter;
  43. using NewParameterPtr = std::shared_ptr<NewParameter>;
  44. class Imm;
  45. using ImmPtr = std::shared_ptr<Imm>;
  46. struct PatternHasher;
  47. struct PatternEqual;
  48. using PatternNodeMap = std::unordered_map<PatternPtr, AnfNodePtr, PatternHasher, PatternEqual>;
  49. class Pattern : public Base {
  50. public:
  51. Pattern() : unique_name_(std::to_string(g_id_++)) {}
  52. ~Pattern() = default;
  53. MS_DECLARE_PARENT(Pattern, Base);
  54. virtual MatchResultPtr match(const AnfNodePtr &node) { return nullptr; }
  55. virtual bool operator==(const Pattern &other) const { return unique_name_ == other.unique_name_; }
  56. string unique_name() const { return unique_name_; }
  57. vector<PatternPtr> inputs() { return inputs_; }
  58. virtual void reset() {}
  59. static void reset_gid() { g_id_ = 0; }
  60. protected:
  61. static int64_t g_id_;
  62. // NOTE: To ensure uniqueness of the name, raise g_id_ by 1 every time a pattern got constructed
  63. string unique_name_;
  64. vector<PatternPtr> inputs_;
  65. };
  66. struct PatternEqual {
  67. bool operator()(PatternPtr const &p1, PatternPtr const &p2) const {
  68. MS_EXCEPTION_IF_NULL(p1);
  69. MS_EXCEPTION_IF_NULL(p2);
  70. return p1->unique_name() == p2->unique_name();
  71. }
  72. };
  73. struct PatternHasher {
  74. std::size_t operator()(PatternPtr const &p) const {
  75. MS_EXCEPTION_IF_NULL(p);
  76. return std::hash<string>()(p->unique_name());
  77. }
  78. };
  79. class Prim final : public Pattern {
  80. public:
  81. Prim() { unique_name_ = std::to_string(g_id_++); }
  82. ~Prim() = default;
  83. Prim(vector<py::object> prim_objs, string name) : name_(name) {
  84. unique_name_ = std::to_string(g_id_++) + "Prim_" + name;
  85. for (auto &prim_obj : prim_objs) {
  86. if (py::isinstance<PrimitivePyAdapter>(prim_obj)) {
  87. auto prim_adapter = prim_obj.cast<PrimitivePyAdapterPtr>();
  88. primitives_.push_back(std::make_shared<PrimitivePy>(prim_obj, prim_adapter));
  89. } else if (py::isinstance<py::str>(prim_obj)) {
  90. std::string prim_name = prim_obj.cast<py::str>();
  91. primitives_.push_back(std::make_shared<PrimitivePy>(prim_name));
  92. } else {
  93. MS_LOG(EXCEPTION) << "Parameter of Prim::__init__ must be Primitive_ type or Prim name, please check input.";
  94. }
  95. }
  96. // Default using the first prim to build target
  97. matched_prim_ = primitives_[0];
  98. }
  99. MS_DECLARE_PARENT(Prim, Pattern);
  100. MatchResultPtr match(const AnfNodePtr &node) override;
  101. PrimitivePyPtr matched_primitive() { return matched_prim_; }
  102. void reset() override {
  103. // Init before reset
  104. MS_EXCEPTION_IF_NULL(matched_prim_);
  105. matched_prim_ = primitives_[0];
  106. }
  107. private:
  108. vector<PrimitivePyPtr> primitives_;
  109. string name_;
  110. PrimitivePyPtr matched_prim_{nullptr};
  111. };
  112. class Call final : public Pattern {
  113. public:
  114. Call() { unique_name_ = std::to_string(g_id_++); }
  115. ~Call() = default;
  116. Call(PatternPtr prim_pattern, vector<PatternPtr> inputs) {
  117. // NOTE: should_replace is ignored in this case, since each sub-pattern has its own setting
  118. prim_pattern_ = prim_pattern;
  119. unique_name_ = std::to_string(g_id_++) + "Call_" + prim_pattern->unique_name();
  120. inputs_ = inputs;
  121. }
  122. Call(py::object prim_obj, vector<PatternPtr> inputs) {
  123. if (py::isinstance<PrimitivePyAdapter>(prim_obj)) {
  124. auto prim_adapter = prim_obj.cast<PrimitivePyAdapterPtr>();
  125. prim_ = std::make_shared<PrimitivePy>(prim_obj, prim_adapter);
  126. } else if (py::isinstance<py::str>(prim_obj)) {
  127. std::string prim_name = prim_obj.cast<py::str>();
  128. prim_ = std::make_shared<PrimitivePy>(prim_name);
  129. } else {
  130. MS_LOG(EXCEPTION) << "Parameter of Call::__init__ must be Primitive_ type or Prim name, please check input.";
  131. }
  132. unique_name_ = std::to_string(g_id_++) + "Call_" + prim_->ToString();
  133. inputs_ = inputs;
  134. }
  135. MS_DECLARE_PARENT(Call, Pattern);
  136. MatchResultPtr match(const AnfNodePtr &node) override;
  137. PrimitivePtr prim_value() { return prim_; }
  138. PatternPtr prim_pattern() { return prim_pattern_; }
  139. private:
  140. PatternPtr prim_pattern_ = nullptr;
  141. PrimitivePtr prim_ = nullptr;
  142. vector<string> types_;
  143. string name_;
  144. };
  145. class OneOf final : public Pattern {
  146. public:
  147. OneOf() { unique_name_ = std::to_string(g_id_++); }
  148. ~OneOf() = default;
  149. explicit OneOf(vector<PatternPtr> patterns) : patterns_(patterns) {
  150. unique_name_ = std::to_string(g_id_++) + "OneOf";
  151. for (auto &iter : patterns) {
  152. unique_name_ = unique_name_ + "_" + iter->unique_name();
  153. }
  154. }
  155. MS_DECLARE_PARENT(OneOf, Pattern);
  156. MatchResultPtr match(const AnfNodePtr &node) override;
  157. private:
  158. vector<PatternPtr> patterns_;
  159. };
  160. class NoneOf final : public Pattern {
  161. public:
  162. NoneOf() { unique_name_ = std::to_string(g_id_++); }
  163. ~NoneOf() = default;
  164. explicit NoneOf(vector<PatternPtr> patterns) : patterns_(patterns) {
  165. unique_name_ = std::to_string(g_id_++) + "NoneOf";
  166. for (auto &iter : patterns) {
  167. unique_name_ = unique_name_ + "_" + iter->unique_name();
  168. }
  169. }
  170. MS_DECLARE_PARENT(NoneOf, Pattern);
  171. MatchResultPtr match(const AnfNodePtr &node) override;
  172. private:
  173. vector<PatternPtr> patterns_;
  174. };
  175. class Any final : public Pattern {
  176. public:
  177. Any() { unique_name_ = std::to_string(g_id_++) + "_Any"; }
  178. ~Any() = default;
  179. MS_DECLARE_PARENT(Any, Pattern);
  180. MatchResultPtr match(const AnfNodePtr &node) override;
  181. };
  182. class NewTensor final : public Pattern {
  183. public:
  184. NewTensor() { unique_name_ = std::to_string(g_id_++); }
  185. ~NewTensor() = default;
  186. explicit NewTensor(tensor::TensorPtr input_tensor) : input_tensor_(input_tensor) {
  187. unique_name_ = std::to_string(g_id_++) + "NewTensor";
  188. }
  189. MS_DECLARE_PARENT(NewTensor, Pattern);
  190. MatchResultPtr match(const AnfNodePtr &node) override {
  191. MS_LOG(EXCEPTION) << "Find NewTensor in pattern, NewTensor should only appear in the target.\n";
  192. }
  193. tensor::TensorPtr input_tensor() { return input_tensor_; }
  194. private:
  195. tensor::TensorPtr input_tensor_;
  196. };
  197. class NewParameter final : public Pattern {
  198. public:
  199. NewParameter() { unique_name_ = std::to_string(g_id_++); }
  200. explicit NewParameter(string para_name, tensor::TensorPtr default_tensor, bool requires_grad, bool layerwise_parallel)
  201. : para_name_(para_name), requires_grad_(requires_grad), layerwise_parallel_(layerwise_parallel) {
  202. unique_name_ = std::to_string(g_id_++) + "NewParameter_" + para_name;
  203. default_tensor_ = std::make_shared<tensor::Tensor>(*default_tensor.get());
  204. built_ = false;
  205. }
  206. ~NewParameter() = default;
  207. MS_DECLARE_PARENT(NewParameter, Pattern);
  208. MatchResultPtr match(const AnfNodePtr &node) override {
  209. MS_LOG(EXCEPTION) << "Find NewParameter in pattern, NewParameter should only appear in the target.\n";
  210. }
  211. string para_name() { return para_name_; }
  212. tensor::TensorPtr default_tensor() { return default_tensor_; }
  213. bool requires_grad() { return requires_grad_; }
  214. bool layerwise_parallel() { return layerwise_parallel_; }
  215. bool built() { return built_; }
  216. void set_built(bool built) { built_ = built; }
  217. void reset() override { built_ = false; }
  218. bool should_last() { return last_across_passes_; }
  219. void set_last(bool last) { last_across_passes_ = last; }
  220. private:
  221. string para_name_;
  222. bool requires_grad_;
  223. bool layerwise_parallel_;
  224. bool last_across_passes_{false};
  225. bool built_;
  226. tensor::TensorPtr default_tensor_;
  227. };
  228. class Imm final : public Pattern {
  229. public:
  230. Imm() { unique_name_ = std::to_string(g_id_++); }
  231. explicit Imm(int value) : value_(value) { unique_name_ = std::to_string(g_id_++) + "Imm_" + std::to_string(value); }
  232. ~Imm() = default;
  233. MS_DECLARE_PARENT(Imm, Pattern);
  234. MatchResultPtr match(const AnfNodePtr &node) override;
  235. int value() { return value_; }
  236. private:
  237. int64_t value_;
  238. };
  239. class MatchResult {
  240. public:
  241. MatchResult() {}
  242. ~MatchResult() = default;
  243. void add_entry(PatternPtr pattern, AnfNodePtr node) { match_result_[pattern] = node; }
  244. const PatternNodeMap &result() { return match_result_; }
  245. AnfNodePtr get_node(const PatternPtr &pattern);
  246. void merge(const MatchResultPtr &other_result);
  247. void clear() { match_result_.clear(); }
  248. void dump() {
  249. MS_LOG(DEBUG) << "match_result_.size: " + std::to_string(match_result_.size()) + "\n";
  250. for (auto &iter : match_result_) {
  251. MS_LOG(DEBUG) << "Pattern : " + iter.first->unique_name() + " , node : " + iter.second->ToString() + "\n";
  252. }
  253. }
  254. private:
  255. PatternNodeMap match_result_;
  256. };
  257. } // namespace python_pass
  258. } // namespace opt
  259. } // namespace mindspore
  260. #endif // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_PATTERN_H_