You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

optimizer.cc 8.4 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "pre_activate/common/optimizer.h"
  17. #include <functional>
  18. #include <memory>
  19. #include <string>
  20. #include <vector>
  21. #include <algorithm>
  22. #include <utility>
  23. #include <initializer_list>
  24. #include "pre_activate/common/pass_manager.h"
  25. #include "session/anf_runtime_algorithm.h"
  26. #include "ir/manager.h"
  27. namespace mindspore {
  28. namespace opt {
  29. namespace {
  30. AnfNodePtr HandleSexpVector(const BaseRef &sexp, const BaseRef &graph, PrimitiveVarMap *primitive_vars,
  31. bool multigraph);
  32. ValueNodePtr CreateValueNodeWithSexp(const BaseRef &sexp) {
  33. if (utils::isa<int>(sexp)) {
  34. return NewValueNode(utils::cast<int>(sexp));
  35. }
  36. if (utils::isa<float>(sexp)) {
  37. return NewValueNode(utils::cast<float>(sexp));
  38. }
  39. if (utils::isa<bool>(sexp)) {
  40. return NewValueNode(utils::cast<bool>(sexp));
  41. }
  42. if (utils::isa<ValuePtr>(sexp)) {
  43. return NewValueNode(utils::cast<ValuePtr>(sexp));
  44. }
  45. return nullptr;
  46. }
  47. CNodePtr CreateCNodeWithGraph(const std::vector<AnfNodePtr> &input_nodes, const BaseRef &graph) {
  48. if (utils::isa<FuncGraphPtr>(graph)) {
  49. return std::make_shared<CNode>(input_nodes, utils::cast<FuncGraphPtr>(graph));
  50. }
  51. if (utils::isa<VarPtr>(graph)) {
  52. return std::make_shared<CNode>(input_nodes, utils::cast<VarPtr>(graph));
  53. }
  54. return nullptr;
  55. }
  56. VarNodePtr CreateVarNodeWithSexp(const BaseRef &sexp, const BaseRef &graph) {
  57. if (utils::isa<VarPtr>(graph)) {
  58. MS_LOG(DEBUG) << "make VarPtr " + graph.ToString();
  59. return std::make_shared<VarNode>(utils::cast<VarPtr>(sexp), nullptr);
  60. }
  61. if (utils::isa<FuncGraphPtr>(graph)) {
  62. MS_LOG(DEBUG) << "VarNode, should input a Var in graph. It's GraphPtr: " + graph.ToString();
  63. return std::make_shared<VarNode>(utils::cast<VarPtr>(sexp), utils::cast<FuncGraphPtr>(graph));
  64. }
  65. MS_LOG(ERROR) << "VarNode, should input a Var in graph. It's " + graph.ToString();
  66. return nullptr;
  67. }
  68. AnfNodePtr SexpToNode(const BaseRef &sexp, const BaseRef &graph, PrimitiveVarMap *primitive_vars,
  69. bool multigraph = false) {
  70. MS_LOG(DEBUG) << "SexpToNode sexp: " + sexp.ToString() + ", graph " + graph.ToString();
  71. MS_EXCEPTION_IF_NULL(primitive_vars);
  72. if (utils::isa<VectorRef>(sexp)) {
  73. return HandleSexpVector(sexp, graph, primitive_vars, multigraph);
  74. }
  75. if (utils::isa<VarPtr>(sexp)) {
  76. auto var_ptr = utils::cast<VarPtr>(sexp);
  77. MS_EXCEPTION_IF_NULL(var_ptr);
  78. if (var_ptr->primitive()) {
  79. (*primitive_vars)[var_ptr->primitive()] = var_ptr;
  80. return NewValueNode(var_ptr->primitive());
  81. }
  82. return CreateVarNodeWithSexp(sexp, graph);
  83. }
  84. if (utils::isa<AnfNodePtr>(sexp)) {
  85. return utils::cast<AnfNodePtr>(sexp);
  86. }
  87. auto value_node = CreateValueNodeWithSexp(sexp);
  88. if (value_node == nullptr) {
  89. MS_LOG(EXCEPTION) << "sexp cannot converted. sexp: " + sexp.ToString();
  90. }
  91. return value_node;
  92. }
  93. AnfNodePtr HandleSexpVector(const BaseRef &sexp, const BaseRef &graph, PrimitiveVarMap *primitive_vars,
  94. bool multigraph) {
  95. MS_LOG(DEBUG) << "HandleSexpVector sexp: " + sexp.ToString() + ", graph " + graph.ToString();
  96. std::vector<AnfNodePtr> input_nodes;
  97. const auto &tuple = utils::cast<VectorRef>(sexp);
  98. if (multigraph && utils::isa<VarPtr>(graph)) {
  99. for (auto &x : tuple) {
  100. AnfNodePtr node = SexpToNode(x, std::make_shared<Var>("G"), primitive_vars, true);
  101. input_nodes.push_back(node);
  102. }
  103. VarPtr var_ptr = utils::cast<VarPtr>(graph);
  104. return std::make_shared<CNode>(input_nodes, var_ptr);
  105. }
  106. for (auto &x : tuple) {
  107. AnfNodePtr node = SexpToNode(x, graph, primitive_vars, multigraph);
  108. input_nodes.push_back(node);
  109. }
  110. return CreateCNodeWithGraph(input_nodes, graph);
  111. }
  112. } // namespace
  113. static bool AnfEqual(const BaseRef &a, const BaseRef &b) {
  114. if (utils::isa<AnfNodePtr>(a) && utils::isa<AnfNodePtr>(b)) {
  115. auto a_node = utils::cast<AnfNodePtr>(a);
  116. auto b_node = utils::cast<AnfNodePtr>(b);
  117. if (IsValueNode<Primitive>(a_node) && IsValueNode<Primitive>(b_node)) {
  118. auto a_value_node = a_node->cast<ValueNodePtr>();
  119. auto a_value = a_value_node->value();
  120. auto a_prim = a_value->cast<PrimitivePtr>();
  121. auto b_value_node = b_node->cast<ValueNodePtr>();
  122. auto b_value = b_value_node->value();
  123. auto b_prim = b_value->cast<PrimitivePtr>();
  124. return a_prim->name() == b_prim->name();
  125. } else if (a_node->isa<ValueNode>() && b_node->isa<ValueNode>()) {
  126. auto a_value_node_ptr = a_node->cast<ValueNodePtr>();
  127. if (a_value_node_ptr == nullptr) {
  128. MS_LOG(EXCEPTION) << "cast value node ptr fail";
  129. }
  130. auto a_value_ptr = a_value_node_ptr->value();
  131. if (a_value_ptr == nullptr) {
  132. MS_LOG(EXCEPTION) << "value ptr is nullptr";
  133. }
  134. auto b_value_node_ptr = b_node->cast<ValueNodePtr>();
  135. if (b_value_node_ptr == nullptr) {
  136. MS_LOG(EXCEPTION) << "cast value node ptr fail";
  137. }
  138. auto b_value_ptr = b_value_node_ptr->value();
  139. if (b_value_ptr == nullptr) {
  140. MS_LOG(EXCEPTION) << "value ptr is nullptr";
  141. }
  142. return (*a_value_ptr) == (*b_value_ptr);
  143. }
  144. MS_LOG(DEBUG) << "check AnfNodePtr equal";
  145. }
  146. if (utils::isa<FuncGraphPtr>(a) && utils::isa<FuncGraphPtr>(b)) {
  147. MS_LOG(DEBUG) << "check GraphPtr equal";
  148. }
  149. return a == b;
  150. }
  151. static bool CNodeTypeEqual(const BaseRef &a, const BaseRef &b) {
  152. // To matchCNode and Kernel's type
  153. if (utils::isa<CNode>(a) && utils::isa<CNode>(b)) {
  154. return true;
  155. }
  156. return a.type() == b.type();
  157. }
  158. PatternProcessPass::PatternProcessPass(const std::string &name, bool multigraph)
  159. : NodePass(name),
  160. multigraph_(multigraph),
  161. pattern_engine_(PatternEngine(std::make_shared<DefaultVisitor>(),
  162. std::function<bool(const BaseRef &, const BaseRef &)>(AnfEqual),
  163. std::function<bool(const BaseRef &, const BaseRef &)>(CNodeTypeEqual))),
  164. primitive_vars_(std::make_shared<PrimitiveVarMap>()) {}
  165. const BaseRef PatternProcessPass::DefinePattern() const {
  166. VarPtr X = std::make_shared<Var>();
  167. return BaseRef({X});
  168. }
  169. void PatternProcessPass::Build() {
  170. VarPtr fg = std::make_shared<Var>("RootG");
  171. BaseRef pattern = std::move(DefinePattern());
  172. pattern_ = SexpToNode(pattern, fg, primitive_vars_.get(), multigraph_);
  173. }
  174. AnfNodePtr PatternProcessPass::Run(const FuncGraphPtr &func_graph, const AnfNodePtr &node) {
  175. if (pattern_ == nullptr) {
  176. Build();
  177. }
  178. auto empty_equiv = std::make_shared<Equiv>();
  179. MS_EXCEPTION_IF_NULL(primitive_vars_);
  180. EquivPtr equiv = pattern_engine_.Match(pattern_, node, *primitive_vars_, empty_equiv);
  181. if (equiv != nullptr && !equiv->empty()) {
  182. return Process(func_graph, node, equiv);
  183. }
  184. return nullptr;
  185. }
  186. void GraphOptimizer::AddPassManager(const PassManagerPtr &pass_manager) {
  187. if (pass_manager != nullptr) {
  188. pass_managers_.push_back(pass_manager);
  189. }
  190. }
  191. FuncGraphPtr GraphOptimizer::Optimize(const FuncGraphPtr &func_graph, bool run_only_once) {
  192. MS_EXCEPTION_IF_NULL(func_graph);
  193. run_only_once_ = (pass_managers_.size() == 1) ? true : run_only_once;
  194. auto manager = func_graph->manager();
  195. if (manager == nullptr) {
  196. manager = Manage(func_graph, false);
  197. func_graph->set_manager(manager);
  198. }
  199. bool changed = true;
  200. while (changed) {
  201. changed = false;
  202. for (size_t i = 0; i < pass_managers_.size(); ++i) {
  203. const PassManagerPtr &pm = pass_managers_[i];
  204. if (pm != nullptr && pm->Run(func_graph)) {
  205. changed = true;
  206. }
  207. }
  208. if (run_only_once_) {
  209. break;
  210. }
  211. }
  212. std::vector<FuncGraphPtr> func_graphs;
  213. func_graphs.push_back(func_graph);
  214. manager->KeepRoots(func_graphs);
  215. (void)TopoSort(func_graph->get_return());
  216. return func_graph;
  217. }
  218. } // namespace opt
  219. } // namespace mindspore