You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

py_pass.cc 7.7 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "optimizer/py_pass.h"
  17. #include <unordered_set>
  18. #include <deque>
  19. #include <algorithm>
  20. #include <utility>
  21. #include <vector>
  22. #include "ir/func_graph.h"
  23. #include "ir/manager.h"
  24. #include "pipeline/parse/parse_base.h"
  25. #include "pipeline/resource.h"
  26. namespace mindspore {
  27. namespace opt {
  28. namespace python_pass {
  29. namespace internal {
  30. std::string GetNodeRepr(AnfNodePtr node) {
  31. if (node != nullptr) {
  32. if (node->isa<CNode>()) {
  33. std::string repr = "(";
  34. auto const &inputs = node->cast<CNodePtr>()->inputs();
  35. for (auto &input : inputs) {
  36. repr += " ";
  37. repr += GetNodeRepr(input);
  38. repr += " ";
  39. }
  40. repr += ")";
  41. return repr;
  42. }
  43. if (node->isa<ValueNode>()) {
  44. return GetValueNode(node)->ToString();
  45. }
  46. return node->ToString();
  47. }
  48. return "";
  49. }
  50. void ResolveFuncGraph_(const FuncGraphPtr &fg) {
  51. auto manager = Manage(fg, false);
  52. parse::python_adapter::set_use_signature_in_resolve(false);
  53. parse::ResolveAll(manager);
  54. }
  55. bool Match(const AnfNodePtr &pattern, const AnfNodePtr &node, const NodeEquivPtr &equiv_ptr) {
  56. if (node == nullptr) {
  57. return false;
  58. }
  59. MS_EXCEPTION_IF_NULL(pattern);
  60. if (pattern->isa<ValueNode>()) {
  61. if (!node->isa<ValueNode>()) {
  62. return false;
  63. }
  64. if (GetNodeRepr(pattern) == GetNodeRepr(node)) {
  65. // add to equiv_ptr
  66. equiv_ptr->insert(std::make_pair(GetValueNode(pattern)->ToString(), node));
  67. return true;
  68. }
  69. return false;
  70. } else if (pattern->isa<Parameter>()) {
  71. MS_LOG(DEBUG) << pattern->ToString() + "\n";
  72. // add to equiv_ptr
  73. equiv_ptr->insert(std::make_pair(pattern->ToString(), node));
  74. return true;
  75. } else if (pattern->isa<CNode>()) {
  76. // match every single sub ANode
  77. if (!node->isa<CNode>()) {
  78. return false;
  79. }
  80. auto pattern_inputs = pattern->cast<CNodePtr>()->inputs();
  81. auto node_inputs = node->cast<CNodePtr>()->inputs();
  82. if (pattern_inputs.size() != node_inputs.size()) {
  83. return false;
  84. }
  85. for (auto p_item = pattern_inputs.begin(), node_item = node_inputs.begin(); p_item != pattern_inputs.end();
  86. p_item++, node_item++) {
  87. auto res = Match(*p_item, *node_item, equiv_ptr);
  88. if (!res) {
  89. return false;
  90. }
  91. }
  92. return true;
  93. }
  94. MS_LOG(EXCEPTION) << "Unexpected condition, (" + pattern->ToString() + " , " + node->ToString() + ")\n";
  95. }
  96. AnfNodePtr BuildTarget(const FuncGraphPtr &func_graph, const AnfNodePtr cur_raw_dst_node_,
  97. const NodeEquivPtr &equiv_ptr) {
  98. if (cur_raw_dst_node_->isa<Parameter>()) {
  99. auto sub_pair = equiv_ptr->find(cur_raw_dst_node_->ToString());
  100. if (sub_pair != equiv_ptr->end()) {
  101. return sub_pair->second;
  102. }
  103. MS_LOG(EXCEPTION) << "cur_raw_dst_node_ : " + internal::GetNodeRepr(cur_raw_dst_node_) + "\n";
  104. } else if (cur_raw_dst_node_->isa<ValueNode>()) {
  105. // check primitive ValueNode
  106. auto sub_pair = equiv_ptr->find(cur_raw_dst_node_->cast<ValueNodePtr>()->value()->ToString());
  107. if (sub_pair != equiv_ptr->end()) {
  108. return sub_pair->second;
  109. }
  110. return cur_raw_dst_node_;
  111. } else if (cur_raw_dst_node_->isa<CNode>()) {
  112. std::vector<AnfNodePtr> new_inputs;
  113. auto inputs = cur_raw_dst_node_->cast<CNodePtr>()->inputs();
  114. for (auto sub_node = inputs.begin(); sub_node != inputs.end(); sub_node++) {
  115. auto subed = internal::BuildTarget(func_graph, *sub_node, equiv_ptr);
  116. new_inputs.push_back(subed);
  117. }
  118. return func_graph->NewCNode(new_inputs);
  119. }
  120. MS_LOG(EXCEPTION) << "Unexpected node type, got : " + internal::GetNodeRepr(cur_raw_dst_node_);
  121. }
  122. bool isTraversable(const AnfNodePtr &node) {
  123. if (node == nullptr) {
  124. return false;
  125. }
  126. if (node->isa<CNode>() || node->isa<Parameter>()) {
  127. return true;
  128. }
  129. if (IsValueNode<FuncGraph>(node) || IsValueNode<RefKey>(node)) {
  130. return true;
  131. }
  132. return false;
  133. }
  134. } // namespace internal
  135. void PythonPass::Build(const py::function &src, const py::function &dst) {
  136. // 1. get FuncGraph from py::function
  137. auto src_fg_ = parse::ParsePythonCode(src);
  138. auto dst_fg_ = parse::ParsePythonCode(dst);
  139. if (src_fg_ == nullptr || dst_fg_ == nullptr) {
  140. MS_LOG(EXCEPTION) << "Failed to parse python code.\n";
  141. }
  142. // 2. Resolve
  143. internal::ResolveFuncGraph_(src_fg_);
  144. internal::ResolveFuncGraph_(dst_fg_);
  145. // 3. from FuncGraphPtr to ValueNode
  146. src_node_ = src_fg_->output();
  147. dst_node_ = dst_fg_->output();
  148. }
  149. PythonPass::PythonPass(const std::string &name, const py::function &src, const py::function &dst, bool run_only_once,
  150. bool multigraph)
  151. : name_(name), run_only_once_(run_only_once), multigraph_(multigraph) {
  152. Build(src, dst);
  153. }
  154. AnfNodePtr PythonPass::Run(const FuncGraphPtr &func_graph, const AnfNodePtr &node) {
  155. auto equiv_ptr = std::make_shared<NodeEquiv>();
  156. bool is_a_match = internal::Match(src_node_, node, equiv_ptr);
  157. if (is_a_match) {
  158. auto new_node = internal::BuildTarget(func_graph, dst_node_, equiv_ptr);
  159. MS_LOG(DEBUG) << "To be replaced node: " + internal::GetNodeRepr(new_node) + "\n";
  160. return new_node;
  161. }
  162. return nullptr;
  163. }
  164. bool PythonPass::Run(const FuncGraphPtr &func_graph) {
  165. MS_EXCEPTION_IF_NULL(func_graph);
  166. FuncGraphManagerPtr manager = func_graph->manager();
  167. MS_EXCEPTION_IF_NULL(manager);
  168. manager->AddFuncGraph(func_graph);
  169. auto seen = NewSeenGeneration();
  170. // 1024 is for the initial capacity of deque
  171. std::deque<AnfNodePtr> todo(1024);
  172. todo.push_back(func_graph->output());
  173. bool changes = false;
  174. auto &all_nodes = manager->all_nodes();
  175. while (!todo.empty()) {
  176. AnfNodePtr node = todo.front();
  177. todo.pop_front();
  178. // check whether this node has been matched.
  179. if (node == nullptr || node->seen_ == seen || !internal::isTraversable(node) || !all_nodes.contains(node)) {
  180. continue;
  181. }
  182. node->seen_ = seen;
  183. // select nodes that this transform can be applied.
  184. AnfNodePtr new_node = Run(func_graph, node);
  185. bool change = (new_node != nullptr);
  186. if (new_node != nullptr && new_node != node) {
  187. (void)manager->Replace(node, new_node);
  188. } else if (new_node == nullptr) {
  189. new_node = node;
  190. }
  191. if (run_only_once_) {
  192. return change;
  193. }
  194. // find success, and add them to todo list
  195. if (IsValueNode<FuncGraph>(node)) {
  196. todo.push_back(GetValueNode<FuncGraphPtr>(node)->output());
  197. }
  198. if (node->isa<CNode>()) {
  199. auto &inputs = node->cast<CNodePtr>()->inputs();
  200. (void)std::copy(inputs.begin(), inputs.end(), std::back_inserter(todo));
  201. }
  202. auto &node_users = manager->node_users();
  203. if (change && node_users.find(node) != node_users.end()) {
  204. for (auto &use : node_users[node]) {
  205. auto use_node = use.first;
  206. if (use_node == nullptr) {
  207. continue;
  208. }
  209. todo.push_back(use_node);
  210. if (use_node->seen_ == seen) {
  211. use_node->seen_--;
  212. }
  213. }
  214. }
  215. }
  216. return changes;
  217. }
  218. } // namespace python_pass
  219. } // namespace opt
  220. } // namespace mindspore