You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

pattern_engine.cc 9.9 kB

4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311
  1. /**
  2. * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
  3. *
  4. * Copyright 2019-2020 Huawei Technologies Co., Ltd
  5. *
  6. * Licensed under the Apache License, Version 2.0 (the "License");
  7. * you may not use this file except in compliance with the License.
  8. * You may obtain a copy of the License at
  9. *
  10. * http://www.apache.org/licenses/LICENSE-2.0
  11. *
  12. * Unless required by applicable law or agreed to in writing, software
  13. * distributed under the License is distributed on an "AS IS" BASIS,
  14. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. * See the License for the specific language governing permissions and
  16. * limitations under the License.
  17. */
  18. #include "backend/optimizer/common/pattern_engine.h"
  19. #include "frontend/optimizer/opt.h"
  20. #include "ir/anf.h"
  21. #include "utils/convert_utils_base.h"
  22. #include "utils/overload.h"
  23. #include "backend/optimizer/common/helper.h"
  24. namespace mindspore {
  25. static int GetNextTag() {
  26. static int kID = 0;
  27. return kID++;
  28. }
  29. void Var::EnsureTag() {
  30. if (tag_.length() == 0) {
  31. std::ostringstream buffer;
  32. buffer << "_" << GetNextTag();
  33. tag_ = buffer.str();
  34. }
  35. }
  36. bool operator==(const VarPtr &lhs, const VarPtr &rhs) {
  37. if (lhs->isa<CondVar>() && rhs->isa<CondVar>()) {
  38. CondVarPtr v1 = dyn_cast<CondVar>(lhs);
  39. CondVarPtr v2 = dyn_cast<CondVar>(rhs);
  40. return *v1 == *v2;
  41. }
  42. if (lhs->isa<SeqVar>() && rhs->isa<SeqVar>()) {
  43. SVarPtr v1 = dyn_cast<SeqVar>(lhs);
  44. SVarPtr v2 = dyn_cast<SeqVar>(rhs);
  45. return *v1 == *v2;
  46. }
  47. return (*lhs == *rhs);
  48. }
  49. std::string SeqVar::ToString() const {
  50. std::ostringstream buffer;
  51. buffer << "SeqVar(" << tag() << ", " << subvar_->ToString() << ")";
  52. return buffer.str();
  53. }
  54. std::ostream &operator<<(std::ostream &os, const VarPtr &var) {
  55. if (var == nullptr) {
  56. os << "";
  57. } else {
  58. os << var->ToString();
  59. }
  60. return os;
  61. }
  62. template <>
  63. std::ostream &operator<<<VarPtr, BaseRef>(std::ostream &os, const Equiv &equiv) {
  64. os << "[Equiv]"
  65. << "\n";
  66. for (auto &equiv_item : equiv) {
  67. auto k = equiv_item.first;
  68. os << k << ":";
  69. BaseRef x = equiv_item.second;
  70. if (utils::isa<AnfNodePtr>(x)) {
  71. auto node = utils::cast<AnfNodePtr>(x);
  72. os << "TypeString[" << node->type_name() << "]";
  73. if (IsValueNode<FuncGraph>(node)) {
  74. os << "IsValueNodeGraph ";
  75. }
  76. os << "type " << node->type_name();
  77. if (node->isa<ValueNode>()) {
  78. os << " value " << GetValueNode(node);
  79. }
  80. os << " addr: " << node;
  81. } else if (utils::isa<Named>(x)) {
  82. os << "Named " << x.ToString().c_str();
  83. } else if (utils::isa<VarPtr>(x)) {
  84. os << "TypeString[Var]";
  85. os << (utils::cast<VarPtr>(x));
  86. } else if (utils::isa<FuncGraphPtr>(x)) {
  87. os << "TypeString[Graph]";
  88. }
  89. os << "\n";
  90. }
  91. return os;
  92. }
  93. static BaseRef GetVar(const BaseRef &x) {
  94. MS_LOG(DEBUG) << "getVar start :%s" + x.ToString();
  95. if (utils::isa<AnfNodePtr>(x)) {
  96. auto node = utils::cast<AnfNodePtr>(x);
  97. MS_LOG(DEBUG) << "TypeString [" + node->type_name() + "]";
  98. if (node->isa<VarNode>()) {
  99. MS_LOG(DEBUG) << "IsVarNode " + node->cast<VarNodePtr>()->var_->ToString();
  100. return node->cast<VarNodePtr>()->var_;
  101. }
  102. if (node->isa<ValueNode>()) {
  103. MS_LOG(DEBUG) << "value " + GetValueNode(node)->ToString() + " addr: " + node->ToString();
  104. } else {
  105. MS_LOG(DEBUG) << "type " + node->type_name();
  106. }
  107. } else if (utils::isa<Named>(x)) {
  108. MS_LOG(DEBUG) << "Named " + x.ToString();
  109. } else if (utils::isa<VectorRef>(x)) {
  110. MS_LOG(DEBUG) << "VectorRef";
  111. } else if (utils::isa<VarPtr>(x)) {
  112. MS_LOG(DEBUG) << "TypeString[Var] " + x.ToString();
  113. }
  114. MS_LOG(DEBUG) << "GetVar end: " + x.ToString();
  115. return x;
  116. }
  117. EquivPtr MatchOnVar(const BaseRef &pattern, const BaseRef &expr, EquivPtr equiv) {
  118. MS_LOG(DEBUG) << "MatchOnVar pattern " + pattern.ToString() + " expr: " + expr.ToString();
  119. MS_EXCEPTION_IF_NULL(equiv);
  120. if (utils::isa<VarPtr>(pattern)) {
  121. VarPtr var = utils::cast<VarPtr>(pattern);
  122. if (var->matches(expr)) {
  123. (*equiv)[var] = expr;
  124. MS_LOG(DEBUG) << "pattern is var match: " + pattern.ToString() + ", " + expr.ToString();
  125. return equiv;
  126. }
  127. }
  128. return nullptr;
  129. }
  130. bool PatternEngine::ToVector(const VectorRef &pattern_ref, const VectorRef &expr_ref, VectorRef *const values_pattern,
  131. VectorRef *const values_expr) const {
  132. MS_EXCEPTION_IF_NULL(values_expr);
  133. if (utils::isa<SeqPtr>(pattern_ref)) {
  134. *values_pattern = pattern_ref;
  135. *values_expr = expr_ref;
  136. return true;
  137. }
  138. return false;
  139. }
  140. bool PatternEngine::ToVector(const BaseRef &pattern_ref, const BaseRef &expr_ref, VectorRef *const values_pattern,
  141. VectorRef *const values_expr) const {
  142. MS_EXCEPTION_IF_NULL(values_expr);
  143. MS_LOG(DEBUG) << "visit pattern_ref";
  144. bool success = visitor_->Visit(pattern_ref, values_pattern, nullptr);
  145. if (!success) {
  146. return false;
  147. }
  148. MS_LOG(DEBUG) << "visit expr_ref";
  149. return visitor_->Visit(expr_ref, values_expr, nullptr);
  150. }
  151. static int GetSVarStartIndex(const VectorRef &values) {
  152. int index = -1;
  153. int count = 0;
  154. for (auto &value : values) {
  155. if (utils::isa<VarPtr>(value) && utils::cast<VarPtr>(value)->isa<SeqVar>()) {
  156. if (index != -1) {
  157. MS_LOG(DEBUG) << "Multiple SVars in sequence";
  158. return kInvalidVarIndex;
  159. }
  160. index = count;
  161. }
  162. count++;
  163. }
  164. return index;
  165. }
  166. void UpdateEquivMap(const VectorRef &values_pattern, const BaseRef &expr_ref, const PrimitiveVarMap &primitive_vars,
  167. const EquivPtr &equiv) {
  168. if (equiv == nullptr || values_pattern.empty() || !utils::isa<AnfNodePtr>(values_pattern[0]) ||
  169. !utils::isa<AnfNodePtr>(expr_ref)) {
  170. return;
  171. }
  172. auto real_node = utils::cast<AnfNodePtr>(expr_ref);
  173. MS_EXCEPTION_IF_NULL(real_node);
  174. if (!real_node->isa<CNode>()) {
  175. return;
  176. }
  177. auto prim_node = utils::cast<AnfNodePtr>(values_pattern[0]);
  178. MS_EXCEPTION_IF_NULL(prim_node);
  179. if (!IsValueNode<Primitive>(prim_node)) {
  180. return;
  181. }
  182. ValuePtr value = GetValueNode(prim_node);
  183. MS_EXCEPTION_IF_NULL(value);
  184. auto prim = value->cast<PrimitivePtr>();
  185. MS_EXCEPTION_IF_NULL(prim);
  186. auto iter = primitive_vars.find(prim);
  187. if (iter == primitive_vars.end()) {
  188. return;
  189. }
  190. (*equiv)[iter->second] = real_node;
  191. }
  192. EquivPtr PatternEngine::AlignSVar(const VectorRef &values_pattern, const VectorRef &values_expr,
  193. const PrimitiveVarMap &primitive_vars, EquivPtr equiv) const {
  194. int svar_index = GetSVarStartIndex(values_pattern);
  195. if (svar_index == kInvalidVarIndex) {
  196. return nullptr;
  197. }
  198. size_t values_pattern_len = values_pattern.size();
  199. size_t values_expr_len = values_expr.size();
  200. if (svar_index == -1) {
  201. if (values_pattern_len != values_expr_len) {
  202. MS_LOG(DEBUG) << "Structures of differing size: pattern len " << values_pattern_len << ", expr len "
  203. << values_expr_len;
  204. return nullptr;
  205. }
  206. }
  207. if (values_expr_len < values_pattern_len - 1) {
  208. MS_LOG(DEBUG) << "invalid size: pattern len " << values_pattern_len << ", expr len " << values_expr_len;
  209. return nullptr;
  210. }
  211. size_t diff = values_expr_len - values_pattern_len + 1;
  212. for (size_t i = 0; i < values_pattern_len; i++) {
  213. size_t expr_i = i;
  214. if (svar_index != -1 && i == IntToSize(svar_index)) {
  215. auto seq =
  216. std::vector<BaseRef>(values_expr.begin() + svar_index, values_expr.begin() + svar_index + SizeToInt(diff));
  217. equiv = Match(values_pattern[svar_index], seq, primitive_vars, equiv);
  218. } else {
  219. if (svar_index != -1 && i > IntToSize(svar_index)) {
  220. expr_i = i + diff - 1;
  221. }
  222. equiv = Match(values_pattern[i], values_expr[expr_i], primitive_vars, equiv);
  223. }
  224. if (equiv == nullptr) {
  225. return nullptr;
  226. }
  227. }
  228. return equiv;
  229. }
  230. EquivPtr PatternEngine::Match(const BaseRef &pattern, const BaseRef &expr, const PrimitiveVarMap &primitive_vars,
  231. EquivPtr equiv) const {
  232. MS_LOG(DEBUG) << "-----[in Match]";
  233. MS_LOG(DEBUG) << "GetVar w";
  234. BaseRef pattern_ref = GetVar(pattern);
  235. MS_LOG(DEBUG) << "GetVar v";
  236. BaseRef expr_ref = expr;
  237. if (equiv == nullptr) {
  238. MS_LOG(EXCEPTION) << "Equiv pointer is null";
  239. }
  240. MS_LOG(DEBUG) << "Pattern ref " + pattern_ref.ToString() + ", expr ref" + expr_ref.ToString();
  241. // 1. if pattern_ref is var and already in equiv, replace it.
  242. if (utils::isa<VarPtr>(pattern_ref)) {
  243. VarPtr var = utils::cast<VarPtr>(pattern_ref);
  244. auto iter = equiv->find(var);
  245. if (iter != equiv->end()) {
  246. pattern_ref = iter->second;
  247. }
  248. }
  249. // 2. check equal
  250. if (opt::AnfEqual(pattern_ref, expr_ref)) {
  251. return equiv;
  252. }
  253. // 3. match var
  254. EquivPtr ret_equiv = MatchOnVar(pattern_ref, expr_ref, equiv);
  255. if (ret_equiv) {
  256. return ret_equiv;
  257. }
  258. // 4. here the type can be std:vector, std:list,
  259. // or cnode.
  260. if (!PatternEngine::CNodeTypeEqual(pattern_ref, expr_ref)) {
  261. MS_LOG(DEBUG) << "Type mismatch";
  262. return nullptr;
  263. }
  264. // 5. transfer the Containers by visitor to std::vector
  265. VectorRef values_pattern;
  266. VectorRef values_expr;
  267. if (!ToVector(pattern_ref, expr_ref, &values_pattern, &values_expr)) {
  268. return nullptr;
  269. }
  270. // 6. if any svar in both side, find the SeqVar index,
  271. // try to pack the Var s in std::vector to a Seq and match elements one by one.
  272. // check svar
  273. equiv = AlignSVar(values_pattern, values_expr, primitive_vars, equiv);
  274. UpdateEquivMap(values_pattern, expr_ref, primitive_vars, equiv);
  275. return equiv;
  276. }
  277. bool PatternEngine::CNodeTypeEqual(const BaseRef &a, const BaseRef &b) {
  278. // To matchCNode and Kernel's type
  279. if (utils::isa<CNode>(a) && utils::isa<CNode>(b)) {
  280. return true;
  281. }
  282. return a.type() == b.type();
  283. }
  284. } // namespace mindspore