You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

pattern_engine.cc 11 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357
  1. /**
  2. * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
  3. *
  4. * Copyright 2019-2020 Huawei Technologies Co., Ltd
  5. *
  6. * Licensed under the Apache License, Version 2.0 (the "License");
  7. * you may not use this file except in compliance with the License.
  8. * You may obtain a copy of the License at
  9. *
  10. * http://www.apache.org/licenses/LICENSE-2.0
  11. *
  12. * Unless required by applicable law or agreed to in writing, software
  13. * distributed under the License is distributed on an "AS IS" BASIS,
  14. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. * See the License for the specific language governing permissions and
  16. * limitations under the License.
  17. */
  18. #include "backend/optimizer/common/pattern_engine.h"
  19. #include <exception>
  20. #include <iostream>
  21. #include <functional>
  22. #include "frontend/optimizer/opt.h"
  23. #include "ir/anf.h"
  24. #include "utils/convert_utils_base.h"
  25. #include "utils/overload.h"
  26. namespace mindspore {
  27. static int GetNextTag() {
  28. static int kID = 0;
  29. return kID++;
  30. }
  31. void Var::EnsureTag() {
  32. if (tag_.length() == 0) {
  33. std::ostringstream buffer;
  34. buffer << "_" << GetNextTag();
  35. tag_ = buffer.str();
  36. }
  37. }
  38. bool operator==(const VarPtr &lhs, const VarPtr &rhs) {
  39. if (lhs->isa<CondVar>() && rhs->isa<CondVar>()) {
  40. CondVarPtr v1 = dyn_cast<CondVar>(lhs);
  41. CondVarPtr v2 = dyn_cast<CondVar>(rhs);
  42. return *v1 == *v2;
  43. }
  44. if (lhs->isa<SeqVar>() && rhs->isa<SeqVar>()) {
  45. SVarPtr v1 = dyn_cast<SeqVar>(lhs);
  46. SVarPtr v2 = dyn_cast<SeqVar>(rhs);
  47. return *v1 == *v2;
  48. }
  49. return (*lhs == *rhs);
  50. }
  51. std::string SeqVar::ToString() const {
  52. std::ostringstream buffer;
  53. buffer << "SeqVar(" << tag() << ", " << subvar_->ToString() << ")";
  54. return buffer.str();
  55. }
  56. std::ostream &operator<<(std::ostream &os, const VarPtr &var) {
  57. if (var == nullptr) {
  58. os << "";
  59. } else {
  60. os << var->ToString();
  61. }
  62. return os;
  63. }
  64. template <>
  65. std::ostream &operator<<<VarPtr, BaseRef>(std::ostream &os, const Equiv &equiv) {
  66. os << "[Equiv]"
  67. << "\n";
  68. for (auto &equiv_item : equiv) {
  69. auto k = equiv_item.first;
  70. os << k << ":";
  71. BaseRef x = equiv_item.second;
  72. if (utils::isa<AnfNodePtr>(x)) {
  73. auto node = utils::cast<AnfNodePtr>(x);
  74. os << "TypeString[" << node->type_name() << "]";
  75. if (IsValueNode<FuncGraph>(node)) {
  76. os << "IsValueNodeGraph ";
  77. }
  78. os << "type " << node->type_name();
  79. if (node->isa<ValueNode>()) {
  80. os << " value " << GetValueNode(node);
  81. }
  82. os << " addr: " << node;
  83. } else if (utils::isa<Named>(x)) {
  84. os << "Named " << x.ToString().c_str();
  85. } else if (utils::isa<VarPtr>(x)) {
  86. os << "TypeString[Var]";
  87. os << utils::cast<VarPtr>(x);
  88. } else if (utils::isa<FuncGraphPtr>(x)) {
  89. os << "TypeString[Graph]";
  90. }
  91. os << "\n";
  92. }
  93. return os;
  94. }
  95. static BaseRef GetVar(const BaseRef &x) {
  96. MS_LOG(DEBUG) << "getVar start :%s" + x.ToString();
  97. if (utils::isa<AnfNodePtr>(x)) {
  98. auto node = utils::cast<AnfNodePtr>(x);
  99. MS_LOG(DEBUG) << "TypeString [" + node->type_name() + "]";
  100. if (node->isa<VarNode>()) {
  101. MS_LOG(DEBUG) << "IsVarNode " + node->cast<VarNodePtr>()->var_->ToString();
  102. return node->cast<VarNodePtr>()->var_;
  103. }
  104. if (node->isa<ValueNode>()) {
  105. MS_LOG(DEBUG) << "value " + GetValueNode(node)->ToString() + " addr: " + node->ToString();
  106. } else {
  107. MS_LOG(DEBUG) << "type " + node->type_name();
  108. }
  109. } else if (utils::isa<Named>(x)) {
  110. MS_LOG(DEBUG) << "Named " + x.ToString();
  111. } else if (utils::isa<VectorRef>(x)) {
  112. MS_LOG(DEBUG) << "VectorRef";
  113. } else if (utils::isa<VarPtr>(x)) {
  114. MS_LOG(DEBUG) << "TypeString[Var] " + x.ToString();
  115. }
  116. MS_LOG(DEBUG) << "GetVar end: " + x.ToString();
  117. return x;
  118. }
  119. EquivPtr MatchOnVar(const BaseRef &pattern, const BaseRef &expr, EquivPtr equiv) {
  120. MS_LOG(DEBUG) << "MatchOnVar pattern " + pattern.ToString() + " expr: " + expr.ToString();
  121. MS_EXCEPTION_IF_NULL(equiv);
  122. if (utils::isa<VarPtr>(pattern)) {
  123. VarPtr var = utils::cast<VarPtr>(pattern);
  124. if (var->matches(expr)) {
  125. (*equiv)[var] = expr;
  126. MS_LOG(DEBUG) << "pattern is var match: " + pattern.ToString() + ", " + expr.ToString();
  127. return equiv;
  128. }
  129. }
  130. return nullptr;
  131. }
  132. bool PatternEngine::ToVector(const VectorRef &pattern_ref, const VectorRef &expr_ref, VectorRef *const values_pattern,
  133. VectorRef *const values_expr) const {
  134. MS_EXCEPTION_IF_NULL(values_expr);
  135. if (utils::isa<SeqPtr>(pattern_ref)) {
  136. *values_pattern = pattern_ref;
  137. *values_expr = expr_ref;
  138. return true;
  139. }
  140. return false;
  141. }
  142. bool PatternEngine::ToVector(const BaseRef &pattern_ref, const BaseRef &expr_ref, VectorRef *const values_pattern,
  143. VectorRef *const values_expr) const {
  144. MS_EXCEPTION_IF_NULL(values_expr);
  145. // visitor to visite the list
  146. auto appender_pattern = [](VectorRef &values) {
  147. std::function<BaseRef(const BaseRef &)> fn = [&](const BaseRef &u) {
  148. values.push_back(GetVar(u));
  149. return u;
  150. };
  151. return fn;
  152. };
  153. visitor_->SetFn(appender_pattern(*values_pattern));
  154. MS_LOG(DEBUG) << "visit pattern_ref";
  155. bool success = visitor_->Visit(pattern_ref, nullptr);
  156. if (!success) {
  157. return false;
  158. }
  159. auto appender_expr = [](VectorRef &values) {
  160. std::function<BaseRef(const BaseRef &)> fn = [&](const BaseRef &u) {
  161. values.push_back(u);
  162. return u;
  163. };
  164. return fn;
  165. };
  166. visitor_->SetFn(appender_expr(*values_expr));
  167. MS_LOG(DEBUG) << "visit expr_ref";
  168. return visitor_->Visit(expr_ref, nullptr);
  169. }
  170. static int GetSVarStartIndex(const VectorRef &values) {
  171. int index = -1;
  172. int count = 0;
  173. for (auto &value : values) {
  174. if (utils::isa<VarPtr>(value) && utils::cast<VarPtr>(value)->isa<SeqVar>()) {
  175. if (index != -1) {
  176. MS_LOG(DEBUG) << "Multiple SVars in sequence";
  177. return kInvalidVarIndex;
  178. }
  179. index = count;
  180. }
  181. count++;
  182. }
  183. return index;
  184. }
  185. void UpdateEquivMap(const VectorRef &values_pattern, const BaseRef &expr_ref, const PrimitiveVarMap &primitive_vars,
  186. EquivPtr equiv) {
  187. if (equiv == nullptr || values_pattern.empty() || !utils::isa<AnfNodePtr>(values_pattern[0]) ||
  188. !utils::isa<AnfNodePtr>(expr_ref)) {
  189. return;
  190. }
  191. auto real_node = utils::cast<AnfNodePtr>(expr_ref);
  192. MS_EXCEPTION_IF_NULL(real_node);
  193. if (!real_node->isa<CNode>()) {
  194. return;
  195. }
  196. auto prim_node = utils::cast<AnfNodePtr>(values_pattern[0]);
  197. MS_EXCEPTION_IF_NULL(prim_node);
  198. if (!IsValueNode<Primitive>(prim_node)) {
  199. return;
  200. }
  201. ValuePtr value = GetValueNode(prim_node);
  202. MS_EXCEPTION_IF_NULL(value);
  203. auto prim = value->cast<PrimitivePtr>();
  204. MS_EXCEPTION_IF_NULL(prim);
  205. auto iter = primitive_vars.find(prim);
  206. if (iter == primitive_vars.end()) {
  207. return;
  208. }
  209. (*equiv)[iter->second] = real_node;
  210. }
  211. EquivPtr PatternEngine::AlignSVar(const VectorRef &values_pattern, const VectorRef &values_expr,
  212. const PrimitiveVarMap &primitive_vars, EquivPtr equiv) const {
  213. int svar_index = GetSVarStartIndex(values_pattern);
  214. if (svar_index == kInvalidVarIndex) {
  215. return nullptr;
  216. }
  217. size_t values_pattern_len = values_pattern.size();
  218. size_t values_expr_len = values_expr.size();
  219. if (svar_index == -1) {
  220. if (values_pattern_len != values_expr_len) {
  221. MS_LOG(DEBUG) << "Structures of differing size: pattern len " << values_pattern_len << ", expr len "
  222. << values_expr_len;
  223. return nullptr;
  224. }
  225. }
  226. if (values_expr_len < values_pattern_len - 1) {
  227. MS_LOG(DEBUG) << "invalid size: pattern len " << values_pattern_len << ", expr len " << values_expr_len;
  228. return nullptr;
  229. }
  230. size_t diff = values_expr_len - values_pattern_len + 1;
  231. for (size_t i = 0; i < values_pattern_len; i++) {
  232. size_t expr_i = i;
  233. if (svar_index != -1 && i == IntToSize(svar_index)) {
  234. auto seq =
  235. std::vector<BaseRef>(values_expr.begin() + svar_index, values_expr.begin() + svar_index + SizeToInt(diff));
  236. equiv = Match(values_pattern[svar_index], seq, primitive_vars, equiv);
  237. } else {
  238. if (svar_index != -1 && i > IntToSize(svar_index)) {
  239. expr_i = i + diff - 1;
  240. }
  241. equiv = Match(values_pattern[i], values_expr[expr_i], primitive_vars, equiv);
  242. }
  243. if (equiv == nullptr) {
  244. return nullptr;
  245. }
  246. }
  247. return equiv;
  248. }
  249. EquivPtr PatternEngine::Match(const BaseRef &pattern, const BaseRef &expr, const PrimitiveVarMap &primitive_vars,
  250. EquivPtr equiv) const {
  251. MS_LOG(DEBUG) << "-----[in Match]";
  252. MS_LOG(DEBUG) << "GetVar w";
  253. BaseRef pattern_ref = GetVar(pattern);
  254. MS_LOG(DEBUG) << "GetVar v";
  255. BaseRef expr_ref = expr;
  256. if (equiv == nullptr) {
  257. MS_LOG(EXCEPTION) << "Equiv pointer is null";
  258. }
  259. MS_LOG(DEBUG) << "Pattern ref " + pattern_ref.ToString() + ", expr ref" + expr_ref.ToString();
  260. // 1. if pattern_ref is var and already in equiv, replace it.
  261. if (utils::isa<VarPtr>(pattern_ref)) {
  262. VarPtr var = utils::cast<VarPtr>(pattern_ref);
  263. auto iter = equiv->find(var);
  264. if (iter != equiv->end()) {
  265. pattern_ref = iter->second;
  266. }
  267. }
  268. // 2. check equal
  269. if (eq_(pattern_ref, expr_ref)) {
  270. return equiv;
  271. }
  272. // 3. match var
  273. EquivPtr ret_equiv = MatchOnVar(pattern_ref, expr_ref, equiv);
  274. if (ret_equiv) {
  275. return ret_equiv;
  276. }
  277. // 4. here the type can be std:vector, std:list,
  278. // or cnode.
  279. if (!type_eq_(pattern_ref, expr_ref)) {
  280. MS_LOG(DEBUG) << "Type mismatch";
  281. return nullptr;
  282. }
  283. // 5. transfer the Containers by visitor to std::vector
  284. VectorRef values_pattern;
  285. VectorRef values_expr;
  286. if (!ToVector(pattern_ref, expr_ref, &values_pattern, &values_expr)) {
  287. return nullptr;
  288. }
  289. // 6. if any svar in both side, find the SeqVar index,
  290. // try to pack the Var s in std::vector to a Seq and match elements one by one.
  291. // check svar
  292. equiv = AlignSVar(values_pattern, values_expr, primitive_vars, equiv);
  293. UpdateEquivMap(values_pattern, expr_ref, primitive_vars, equiv);
  294. return equiv;
  295. }
  296. BaseRef PatternEngine::Replace(const BaseRef &pattern, const EquivPtr &equiv) const {
  297. MS_EXCEPTION_IF_NULL(equiv);
  298. MS_LOG(DEBUG) << "-----[in Replace]";
  299. BaseRef ref = GetVar(pattern);
  300. BaseRef out;
  301. bool is_match = false;
  302. // w is var
  303. if (utils::isa<VarPtr>(ref)) {
  304. const VarPtr &var = utils::cast<VarPtr>(ref);
  305. auto iter = equiv->find(var);
  306. if (iter != equiv->end()) {
  307. out = iter->second;
  308. is_match = true;
  309. }
  310. }
  311. if (is_match) {
  312. return out;
  313. }
  314. // visitor to visit the list
  315. std::function<BaseRef(BaseRef)> fn = [&, this, equiv](const BaseRef &u) { return Replace(u, equiv); };
  316. visitor_->SetFn(fn);
  317. BaseRef visit_out;
  318. if (!visitor_->Visit(pattern, &visit_out)) {
  319. return pattern;
  320. }
  321. return visit_out;
  322. }
  323. } // namespace mindspore