You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

pattern_engine.cc 12 kB

4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366
  1. /**
  2. * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
  3. *
  4. * Copyright 2019-2020 Huawei Technologies Co., Ltd
  5. *
  6. * Licensed under the Apache License, Version 2.0 (the "License");
  7. * you may not use this file except in compliance with the License.
  8. * You may obtain a copy of the License at
  9. *
  10. * http://www.apache.org/licenses/LICENSE-2.0
  11. *
  12. * Unless required by applicable law or agreed to in writing, software
  13. * distributed under the License is distributed on an "AS IS" BASIS,
  14. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. * See the License for the specific language governing permissions and
  16. * limitations under the License.
  17. */
  18. #include "backend/optimizer/common/pattern_engine.h"
  19. #include <exception>
  20. #include <iostream>
  21. #include <functional>
  22. #include "frontend/optimizer/opt.h"
  23. #include "ir/anf.h"
  24. #include "utils/convert_utils_base.h"
  25. #include "utils/overload.h"
  26. #include "backend/optimizer/common/helper.h"
  27. namespace mindspore {
  28. static int GetNextTag() {
  29. static int kID = 0;
  30. return kID++;
  31. }
  32. void Var::EnsureTag() {
  33. if (tag_.length() == 0) {
  34. std::ostringstream buffer;
  35. buffer << "_" << GetNextTag();
  36. tag_ = buffer.str();
  37. }
  38. }
  39. bool operator==(const VarPtr &lhs, const VarPtr &rhs) {
  40. if (lhs->isa<CondVar>() && rhs->isa<CondVar>()) {
  41. CondVarPtr v1 = dyn_cast<CondVar>(lhs);
  42. CondVarPtr v2 = dyn_cast<CondVar>(rhs);
  43. return *v1 == *v2;
  44. }
  45. if (lhs->isa<SeqVar>() && rhs->isa<SeqVar>()) {
  46. SVarPtr v1 = dyn_cast<SeqVar>(lhs);
  47. SVarPtr v2 = dyn_cast<SeqVar>(rhs);
  48. return *v1 == *v2;
  49. }
  50. return (*lhs == *rhs);
  51. }
  52. std::string SeqVar::ToString() const {
  53. std::ostringstream buffer;
  54. buffer << "SeqVar(" << tag() << ", " << subvar_->ToString() << ")";
  55. return buffer.str();
  56. }
  57. std::ostream &operator<<(std::ostream &os, const VarPtr &var) {
  58. if (var == nullptr) {
  59. os << "";
  60. } else {
  61. os << var->ToString();
  62. }
  63. return os;
  64. }
  65. template <>
  66. std::ostream &operator<<<VarPtr, BaseRef>(std::ostream &os, const Equiv &equiv) {
  67. os << "[Equiv]"
  68. << "\n";
  69. for (auto &equiv_item : equiv) {
  70. auto k = equiv_item.first;
  71. os << k << ":";
  72. BaseRef x = equiv_item.second;
  73. if (utils::isa<AnfNodePtr>(x)) {
  74. auto node = utils::cast<AnfNodePtr>(x);
  75. os << "TypeString[" << node->type_name() << "]";
  76. if (IsValueNode<FuncGraph>(node)) {
  77. os << "IsValueNodeGraph ";
  78. }
  79. os << "type " << node->type_name();
  80. if (node->isa<ValueNode>()) {
  81. os << " value " << GetValueNode(node);
  82. }
  83. os << " addr: " << node;
  84. } else if (utils::isa<Named>(x)) {
  85. os << "Named " << x.ToString().c_str();
  86. } else if (utils::isa<VarPtr>(x)) {
  87. os << "TypeString[Var]";
  88. os << (utils::cast<VarPtr>(x));
  89. } else if (utils::isa<FuncGraphPtr>(x)) {
  90. os << "TypeString[Graph]";
  91. }
  92. os << "\n";
  93. }
  94. return os;
  95. }
  96. static BaseRef GetVar(const BaseRef &x) {
  97. MS_LOG(DEBUG) << "getVar start :%s" + x.ToString();
  98. if (utils::isa<AnfNodePtr>(x)) {
  99. auto node = utils::cast<AnfNodePtr>(x);
  100. MS_LOG(DEBUG) << "TypeString [" + node->type_name() + "]";
  101. if (node->isa<VarNode>()) {
  102. MS_LOG(DEBUG) << "IsVarNode " + node->cast<VarNodePtr>()->var_->ToString();
  103. return node->cast<VarNodePtr>()->var_;
  104. }
  105. if (node->isa<ValueNode>()) {
  106. MS_LOG(DEBUG) << "value " + GetValueNode(node)->ToString() + " addr: " + node->ToString();
  107. } else {
  108. MS_LOG(DEBUG) << "type " + node->type_name();
  109. }
  110. } else if (utils::isa<Named>(x)) {
  111. MS_LOG(DEBUG) << "Named " + x.ToString();
  112. } else if (utils::isa<VectorRef>(x)) {
  113. MS_LOG(DEBUG) << "VectorRef";
  114. } else if (utils::isa<VarPtr>(x)) {
  115. MS_LOG(DEBUG) << "TypeString[Var] " + x.ToString();
  116. }
  117. MS_LOG(DEBUG) << "GetVar end: " + x.ToString();
  118. return x;
  119. }
  120. EquivPtr MatchOnVar(const BaseRef &pattern, const BaseRef &expr, EquivPtr equiv) {
  121. MS_LOG(DEBUG) << "MatchOnVar pattern " + pattern.ToString() + " expr: " + expr.ToString();
  122. MS_EXCEPTION_IF_NULL(equiv);
  123. if (utils::isa<VarPtr>(pattern)) {
  124. VarPtr var = utils::cast<VarPtr>(pattern);
  125. if (var->matches(expr)) {
  126. (*equiv)[var] = expr;
  127. MS_LOG(DEBUG) << "pattern is var match: " + pattern.ToString() + ", " + expr.ToString();
  128. return equiv;
  129. }
  130. }
  131. return nullptr;
  132. }
  133. bool PatternEngine::ToVector(const VectorRef &pattern_ref, const VectorRef &expr_ref, VectorRef *const values_pattern,
  134. VectorRef *const values_expr) const {
  135. MS_EXCEPTION_IF_NULL(values_expr);
  136. if (utils::isa<SeqPtr>(pattern_ref)) {
  137. *values_pattern = pattern_ref;
  138. *values_expr = expr_ref;
  139. return true;
  140. }
  141. return false;
  142. }
  143. bool PatternEngine::ToVector(const BaseRef &pattern_ref, const BaseRef &expr_ref, VectorRef *const values_pattern,
  144. VectorRef *const values_expr) const {
  145. MS_EXCEPTION_IF_NULL(values_expr);
  146. MS_LOG(DEBUG) << "visit pattern_ref";
  147. bool success = visitor_->Visit(pattern_ref, values_pattern, nullptr);
  148. if (!success) {
  149. return false;
  150. }
  151. MS_LOG(DEBUG) << "visit expr_ref";
  152. return visitor_->Visit(expr_ref, values_expr, nullptr);
  153. }
  154. static int GetSVarStartIndex(const VectorRef &values) {
  155. int index = -1;
  156. int count = 0;
  157. for (auto &value : values) {
  158. if (utils::isa<VarPtr>(value) && utils::cast<VarPtr>(value)->isa<SeqVar>()) {
  159. if (index != -1) {
  160. MS_LOG(DEBUG) << "Multiple SVars in sequence";
  161. return kInvalidVarIndex;
  162. }
  163. index = count;
  164. }
  165. count++;
  166. }
  167. return index;
  168. }
  169. void UpdateEquivMap(const VectorRef &values_pattern, const BaseRef &expr_ref, const PrimitiveVarMap &primitive_vars,
  170. const EquivPtr &equiv) {
  171. if (equiv == nullptr || values_pattern.empty() || !utils::isa<AnfNodePtr>(values_pattern[0]) ||
  172. !utils::isa<AnfNodePtr>(expr_ref)) {
  173. return;
  174. }
  175. auto real_node = utils::cast<AnfNodePtr>(expr_ref);
  176. MS_EXCEPTION_IF_NULL(real_node);
  177. if (!real_node->isa<CNode>()) {
  178. return;
  179. }
  180. auto prim_node = utils::cast<AnfNodePtr>(values_pattern[0]);
  181. MS_EXCEPTION_IF_NULL(prim_node);
  182. if (!IsValueNode<Primitive>(prim_node)) {
  183. return;
  184. }
  185. ValuePtr value = GetValueNode(prim_node);
  186. MS_EXCEPTION_IF_NULL(value);
  187. auto prim = value->cast<PrimitivePtr>();
  188. MS_EXCEPTION_IF_NULL(prim);
  189. auto iter = primitive_vars.find(prim);
  190. if (iter == primitive_vars.end()) {
  191. return;
  192. }
  193. (*equiv)[iter->second] = real_node;
  194. }
  195. EquivPtr PatternEngine::AlignSVar(const VectorRef &values_pattern, const VectorRef &values_expr,
  196. const PrimitiveVarMap &primitive_vars, EquivPtr equiv) const {
  197. int svar_index = GetSVarStartIndex(values_pattern);
  198. if (svar_index == kInvalidVarIndex) {
  199. return nullptr;
  200. }
  201. size_t values_pattern_len = values_pattern.size();
  202. size_t values_expr_len = values_expr.size();
  203. if (svar_index == -1) {
  204. if (values_pattern_len != values_expr_len) {
  205. MS_LOG(DEBUG) << "Structures of differing size: pattern len " << values_pattern_len << ", expr len "
  206. << values_expr_len;
  207. return nullptr;
  208. }
  209. }
  210. if (values_expr_len < values_pattern_len - 1) {
  211. MS_LOG(DEBUG) << "invalid size: pattern len " << values_pattern_len << ", expr len " << values_expr_len;
  212. return nullptr;
  213. }
  214. size_t diff = values_expr_len - values_pattern_len + 1;
  215. for (size_t i = 0; i < values_pattern_len; i++) {
  216. size_t expr_i = i;
  217. if (svar_index != -1 && i == IntToSize(svar_index)) {
  218. auto seq =
  219. std::vector<BaseRef>(values_expr.begin() + svar_index, values_expr.begin() + svar_index + SizeToInt(diff));
  220. equiv = Match(values_pattern[svar_index], seq, primitive_vars, equiv);
  221. } else {
  222. if (svar_index != -1 && i > IntToSize(svar_index)) {
  223. expr_i = i + diff - 1;
  224. }
  225. equiv = Match(values_pattern[i], values_expr[expr_i], primitive_vars, equiv);
  226. }
  227. if (equiv == nullptr) {
  228. return nullptr;
  229. }
  230. }
  231. return equiv;
  232. }
  233. EquivPtr PatternEngine::Match(const BaseRef &pattern, const BaseRef &expr, const PrimitiveVarMap &primitive_vars,
  234. EquivPtr equiv) const {
  235. MS_LOG(DEBUG) << "-----[in Match]";
  236. MS_LOG(DEBUG) << "GetVar w";
  237. BaseRef pattern_ref = GetVar(pattern);
  238. MS_LOG(DEBUG) << "GetVar v";
  239. BaseRef expr_ref = expr;
  240. if (equiv == nullptr) {
  241. MS_LOG(EXCEPTION) << "Equiv pointer is null";
  242. }
  243. MS_LOG(DEBUG) << "Pattern ref " + pattern_ref.ToString() + ", expr ref" + expr_ref.ToString();
  244. // 1. if pattern_ref is var and already in equiv, replace it.
  245. if (utils::isa<VarPtr>(pattern_ref)) {
  246. VarPtr var = utils::cast<VarPtr>(pattern_ref);
  247. auto iter = equiv->find(var);
  248. if (iter != equiv->end()) {
  249. pattern_ref = iter->second;
  250. }
  251. }
  252. // 2. check equal
  253. if (PatternEngine::AnfNodeEqual(pattern_ref, expr_ref)) {
  254. return equiv;
  255. }
  256. // 3. match var
  257. EquivPtr ret_equiv = MatchOnVar(pattern_ref, expr_ref, equiv);
  258. if (ret_equiv) {
  259. return ret_equiv;
  260. }
  261. // 4. here the type can be std:vector, std:list,
  262. // or cnode.
  263. if (!PatternEngine::CNodeTypeEqual(pattern_ref, expr_ref)) {
  264. MS_LOG(DEBUG) << "Type mismatch";
  265. return nullptr;
  266. }
  267. // 5. transfer the Containers by visitor to std::vector
  268. VectorRef values_pattern;
  269. VectorRef values_expr;
  270. if (!ToVector(pattern_ref, expr_ref, &values_pattern, &values_expr)) {
  271. return nullptr;
  272. }
  273. // 6. if any svar in both side, find the SeqVar index,
  274. // try to pack the Var s in std::vector to a Seq and match elements one by one.
  275. // check svar
  276. equiv = AlignSVar(values_pattern, values_expr, primitive_vars, equiv);
  277. UpdateEquivMap(values_pattern, expr_ref, primitive_vars, equiv);
  278. return equiv;
  279. }
  280. bool PatternEngine::AnfNodeEqual(const BaseRef &a, const BaseRef &b) {
  281. if (utils::isa<AnfNodePtr>(a) && utils::isa<AnfNodePtr>(b)) {
  282. auto a_node = utils::cast<AnfNodePtr>(a);
  283. auto b_node = utils::cast<AnfNodePtr>(b);
  284. MS_EXCEPTION_IF_NULL(a_node);
  285. MS_EXCEPTION_IF_NULL(b_node);
  286. if (IsValueNode<Primitive>(a_node) && IsValueNode<Primitive>(b_node)) {
  287. auto a_value_node = a_node->cast<ValueNodePtr>();
  288. MS_EXCEPTION_IF_NULL(a_value_node);
  289. auto a_value = a_value_node->value();
  290. MS_EXCEPTION_IF_NULL(a_value);
  291. auto a_prim = a_value->cast<PrimitivePtr>();
  292. MS_EXCEPTION_IF_NULL(a_prim);
  293. auto b_value_node = b_node->cast<ValueNodePtr>();
  294. MS_EXCEPTION_IF_NULL(b_value_node);
  295. auto b_value = b_value_node->value();
  296. MS_EXCEPTION_IF_NULL(b_value);
  297. auto b_prim = b_value->cast<PrimitivePtr>();
  298. MS_EXCEPTION_IF_NULL(b_prim);
  299. return a_prim->name() == b_prim->name();
  300. } else if (a_node->isa<ValueNode>() && b_node->isa<ValueNode>()) {
  301. auto a_value_node_ptr = a_node->cast<ValueNodePtr>();
  302. if (a_value_node_ptr == nullptr) {
  303. MS_LOG(EXCEPTION) << "cast value node ptr fail";
  304. }
  305. auto a_value_ptr = a_value_node_ptr->value();
  306. if (a_value_ptr == nullptr) {
  307. MS_LOG(EXCEPTION) << "value ptr is nullptr";
  308. }
  309. auto b_value_node_ptr = b_node->cast<ValueNodePtr>();
  310. if (b_value_node_ptr == nullptr) {
  311. MS_LOG(EXCEPTION) << "cast value node ptr fail";
  312. }
  313. auto b_value_ptr = b_value_node_ptr->value();
  314. if (b_value_ptr == nullptr) {
  315. MS_LOG(EXCEPTION) << "value ptr is nullptr";
  316. }
  317. return (*a_value_ptr) == (*b_value_ptr);
  318. }
  319. MS_LOG(DEBUG) << "check AnfNodePtr equal";
  320. }
  321. if (utils::isa<FuncGraphPtr>(a) && utils::isa<FuncGraphPtr>(b)) {
  322. MS_LOG(DEBUG) << "check GraphPtr equal";
  323. }
  324. return a == b;
  325. }
  326. bool PatternEngine::CNodeTypeEqual(const BaseRef &a, const BaseRef &b) {
  327. // To matchCNode and Kernel's type
  328. if (utils::isa<CNode>(a) && utils::isa<CNode>(b)) {
  329. return true;
  330. }
  331. return a.type() == b.type();
  332. }
  333. } // namespace mindspore