You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

pattern_engine.h 6.7 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204
  1. /**
  2. * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
  3. *
  4. * Copyright 2019 Huawei Technologies Co., Ltd
  5. *
  6. * Licensed under the Apache License, Version 2.0 (the "License");
  7. * you may not use this file except in compliance with the License.
  8. * You may obtain a copy of the License at
  9. *
  10. * http://www.apache.org/licenses/LICENSE-2.0
  11. *
  12. * Unless required by applicable law or agreed to in writing, software
  13. * distributed under the License is distributed on an "AS IS" BASIS,
  14. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. * See the License for the specific language governing permissions and
  16. * limitations under the License.
  17. */
  18. #ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_COMMON_PATTERN_ENGINE_H_
  19. #define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_COMMON_PATTERN_ENGINE_H_
  20. #include <string>
  21. #include <sstream>
  22. #include <memory>
  23. #include <vector>
  24. #include <unordered_set>
  25. #include <unordered_map>
  26. #include <initializer_list>
  27. #include <iostream>
  28. #include <algorithm>
  29. #include <map>
  30. #include <stdexcept>
  31. #include <list>
  32. #include <utility>
  33. #include "backend/optimizer/common/visit.h"
  34. #include "base/base.h"
  35. #include "utils/log_adapter.h"
  36. #include "base/base_ref.h"
  37. namespace mindspore {
  38. class CondVar;
  39. class SeqVar;
  40. using CondVarPtr = std::shared_ptr<CondVar>;
  41. using SVarPtr = std::shared_ptr<SeqVar>;
  42. const int kInvalidVarIndex = -2;
  43. using ConditionFunc = std::function<bool(const BaseRef &)>;
  44. // Base wildcard variable which could match any anf node.
  45. class Var : public Base {
  46. friend class VarHasher;
  47. public:
  48. explicit Var(std::string tag = "") : tag_(std::move(tag)), primitive_(nullptr) { EnsureTag(); }
  49. explicit Var(const PrimitivePtr &primitive, std::string tag = "") : tag_(std::move(tag)), primitive_(primitive) {
  50. EnsureTag();
  51. }
  52. Var(const Var &other) : Base(other), tag_(other.tag_) {}
  53. virtual Var &operator=(const Var &other) {
  54. if (&other == this) {
  55. return *this;
  56. }
  57. this->tag_ = other.tag_;
  58. return *this;
  59. }
  60. ~Var() override = default;
  61. MS_DECLARE_PARENT(Var, Base);
  62. virtual bool matches(const BaseRef &) { return true; }
  63. virtual bool operator==(const Var &other) const { return tag_ == other.tag_; }
  64. bool operator!=(const Var &other) const { return !(&other == this); }
  65. std::string tag() const { return tag_; }
  66. PrimitivePtr primitive() const { return primitive_; }
  67. std::string ToString() const override {
  68. std::ostringstream buffer;
  69. buffer << "Var(" << tag_ << ")";
  70. return buffer.str();
  71. }
  72. std::size_t hash() const override { return std::hash<std::string>()(tag_); }
  73. protected:
  74. void EnsureTag();
  75. std::string tag_;
  76. PrimitivePtr primitive_;
  77. };
  78. // VarNode means variable node, a subclass of AnfNode
  79. class VarNode : public AnfNode {
  80. public:
  81. VarNode(const VarPtr &value, const FuncGraphPtr &func_graph) : AnfNode(func_graph), var_(value) {}
  82. ~VarNode() override = default;
  83. MS_DECLARE_PARENT(VarNode, AnfNode);
  84. const VarPtr var_;
  85. };
  86. using VarNodePtr = std::shared_ptr<VarNode>;
  87. class VarHasher {
  88. public:
  89. std::size_t operator()(const Var &var) const { return var.hash(); }
  90. };
  91. // Condition Var, match an anf node when condition function return true.
  92. class CondVar : public Var {
  93. public:
  94. explicit CondVar(const ConditionFunc &cond) : cond_fn_(cond) {}
  95. ~CondVar() override = default;
  96. MS_DECLARE_PARENT(CondVar, Var);
  97. bool matches(const BaseRef &value) override {
  98. MS_LOG(DEBUG) << "CondVarPtr match: " + value.ToString();
  99. if (utils::isa<Var>(value)) {
  100. return false;
  101. }
  102. return cond_fn_(value);
  103. }
  104. ConditionFunc cond_fn_;
  105. };
  106. using Seq = VectorRef;
  107. using SeqPtr = std::shared_ptr<Seq>;
  108. // Sequence Var which could match multiple consecutive input nodes of a CNode.
  109. class SeqVar : public Var {
  110. public:
  111. SeqVar() { subvar_ = std::make_shared<Var>(); }
  112. ~SeqVar() override = default;
  113. MS_DECLARE_PARENT(SeqVar, Var);
  114. explicit SeqVar(const VarPtr subvar) : subvar_(nullptr) { subvar_ = subvar; }
  115. bool matches(const BaseRef &value) override {
  116. // match Seq.
  117. if (utils::isa<Seq>(value)) {
  118. const Seq &seq = utils::cast<Seq>(value);
  119. return std::all_of(seq.begin(), seq.end(), [this](const BaseRef &v) {
  120. auto eq = subvar_->matches(v);
  121. return eq;
  122. });
  123. }
  124. return false;
  125. }
  126. bool operator==(const SeqVar &other) const { return *subvar_ == *other.subvar_; }
  127. std::string ToString() const override;
  128. private:
  129. VarPtr subvar_;
  130. };
  131. bool operator==(const VarPtr &lhs, const VarPtr &rhs);
  132. inline bool operator!=(const VarPtr &lhs, const VarPtr &rhs) { return !(lhs == rhs); }
  133. std::ostream &operator<<(std::ostream &os, const VarPtr &var);
  134. using Equiv = std::map<VarPtr, BaseRef>;
  135. using EquivPtr = std::shared_ptr<Equiv>;
  136. using PrimitiveVarMap = std::unordered_map<PrimitivePtr, VarPtr>;
  137. using PrimitiveVarMapPtr = std::shared_ptr<PrimitiveVarMap>;
  138. inline bool DefaultTypeEq(const BaseRef &x, const BaseRef &y) { return x.type() == y.type(); }
  139. class PatternEngine {
  140. public:
  141. PatternEngine(const std::shared_ptr<Visitor> &visitor,
  142. const std::function<bool(const BaseRef &, const BaseRef &)> &eq,
  143. const std::function<bool(const BaseRef &, const BaseRef &)> &type_eq = DefaultTypeEq)
  144. : visitor_(visitor), eq_(eq), type_eq_(type_eq) {}
  145. ~PatternEngine() = default;
  146. EquivPtr Match(const BaseRef &pattern, const BaseRef &expr, const PrimitiveVarMap &primitive_vars,
  147. EquivPtr equiv) const;
  148. // Replace pattern with equivalent
  149. BaseRef Replace(const BaseRef &pattern, const EquivPtr &equiv) const;
  150. private:
  151. EquivPtr AlignSVar(const VectorRef &values_pattern, const VectorRef &values_expr,
  152. const PrimitiveVarMap &primitive_vars, EquivPtr equiv) const;
  153. bool ToVector(const BaseRef &pattern, const BaseRef &expr, VectorRef *const values_pattern,
  154. VectorRef *const values_expr) const;
  155. bool ToVector(const VectorRef &pattern_ref, const VectorRef &expr_ref, VectorRef *const values_pattern,
  156. VectorRef *const values_expr) const;
  157. std::shared_ptr<Visitor> visitor_;
  158. std::function<bool(const BaseRef &, const BaseRef &)> eq_;
  159. std::function<bool(const BaseRef &, const BaseRef &)> type_eq_;
  160. };
  161. } // namespace mindspore
  162. namespace std {
  163. using mindspore::ERROR;
  164. using mindspore::LogStream;
  165. using mindspore::NoExceptionType;
  166. template <>
  167. struct hash<mindspore::VarPtr> {
  168. std::size_t operator()(const mindspore::VarPtr var) const {
  169. if (var == nullptr) {
  170. MS_LOG(ERROR) << "Invalid var ptr";
  171. return 0;
  172. }
  173. return std::hash<std::string>{}(var->tag());
  174. }
  175. };
  176. } // namespace std
  177. #endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_COMMON_PATTERN_ENGINE_H_