You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

opt.cc 12 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357
  1. /**
  2. * Copyright 2019-2021 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "frontend/optimizer/opt.h"
  17. #include <deque>
  18. #include <memory>
  19. #include <algorithm>
  20. #include "utils/hash_map.h"
  21. #include "ir/anf.h"
  22. #include "ir/manager.h"
  23. #include "frontend/optimizer/optimizer.h"
  24. #include "utils/log_adapter.h"
  25. namespace mindspore {
  26. /* namespace to support opt */
  27. namespace opt {
  28. SubstitutionPtr MakeSubstitution(const OptimizerCallerPtr &transform, const std::string &name, const PrimitivePtr &prim,
  29. const RenormAction &renorm_action, bool has_priority_pattern) {
  30. auto fn = [prim](const AnfNodePtr &node) -> bool { return IsPrimitiveCNode(node, prim); };
  31. return std::make_shared<Substitution>(transform, name, fn, renorm_action, has_priority_pattern);
  32. }
  33. SubstitutionPtr MakeSubstitution(const OptimizerCallerPtr &transform, const std::string &name,
  34. const std::vector<PrimitivePtr> &prims, const RenormAction &renorm_action,
  35. bool has_priority_pattern) {
  36. auto fn = [prims](const AnfNodePtr &node) -> bool {
  37. if (!node->isa<CNode>()) {
  38. return false;
  39. }
  40. auto cnode = node->cast<CNodePtr>();
  41. auto inp0 = cnode->input(0);
  42. auto prim0 = GetValueNode<PrimitivePtr>(inp0);
  43. if (prim0 == nullptr) {
  44. return false;
  45. }
  46. auto hash = prim0->Hash();
  47. auto const &name = prim0->name();
  48. for (auto &prim : prims) {
  49. if (hash == prim->Hash() && name == prim->name()) {
  50. return true;
  51. }
  52. }
  53. return false;
  54. };
  55. return std::make_shared<Substitution>(transform, name, fn, renorm_action, has_priority_pattern);
  56. }
  57. SubstitutionPtr MakeSubstitution(const OptimizerCallerPtr &transform, const std::string &name,
  58. const PredicateFuncType &predicate, const RenormAction &renorm_action,
  59. bool has_priority_pattern) {
  60. return std::make_shared<Substitution>(transform, name, predicate, renorm_action, has_priority_pattern);
  61. }
  62. AnfNodePtr Substitution::operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) {
  63. #ifdef ENABLE_PROFILE
  64. double t = GetTime();
  65. #endif
  66. AnfNodePtr result = (*transform_)(optimizer, node);
  67. #ifdef ENABLE_PROFILE
  68. if (optimizer != nullptr) {
  69. auto time = GetTime();
  70. MsProfile::StatTime("substitution." + name_, time - t);
  71. if (result != nullptr) {
  72. MsProfile::StatTime("match." + name_, time - t);
  73. }
  74. }
  75. #endif
  76. if (optimizer != nullptr && optimizer->is_watch_renormalize() && result != nullptr) {
  77. if ((renorm_action_ == FORCE_RENORM) || (result->abstract() == nullptr)) {
  78. optimizer->set_is_untyped_generated();
  79. }
  80. }
  81. return result;
  82. }
  83. static bool isTraversable(const AnfNodePtr &node) {
  84. if (node == nullptr) {
  85. return false;
  86. }
  87. if (node->isa<CNode>() || node->isa<Parameter>()) {
  88. return true;
  89. }
  90. if (IsValueNode<FuncGraph>(node) || IsValueNode<RefKey>(node)) {
  91. return true;
  92. }
  93. return false;
  94. }
  95. static AnfNodePtr DoTransform(const OptimizerPtr &optimizer, const AnfNodePtr &node,
  96. const SubstitutionPtr &substitution) {
  97. auto manager = optimizer->manager();
  98. bool is_match = substitution->predicate_(node);
  99. if (is_match) {
  100. TraceGuard trace_guard(std::make_shared<TraceOpt>(node->debug_info()));
  101. ScopeGuard scope_guard(node->scope());
  102. auto res = (*substitution)(optimizer, node);
  103. if (res != nullptr && res != node) {
  104. #ifdef ENABLE_PROFILE
  105. double t = GetTime();
  106. #endif
  107. MS_LOG(DEBUG) << "Replace " << node->DebugString() << " with " << res->DebugString() << ", by "
  108. << substitution->name_;
  109. (void)manager->Replace(node, res);
  110. #ifdef ENABLE_PROFILE
  111. MsProfile::StatTime("replace." + substitution->name_, GetTime() - t);
  112. #endif
  113. return res;
  114. }
  115. }
  116. return nullptr;
  117. }
  118. static void UpdateTransformingListForSubstitutions(const AnfNodePtr &node, std::deque<AnfNodePtr> *todo, bool change) {
  119. if (IsValueNode<FuncGraph>(node)) {
  120. (*todo).emplace_back(GetValueNode<FuncGraphPtr>(node)->output());
  121. }
  122. if (change) {
  123. (*todo).emplace_back(node);
  124. } else {
  125. if (node->isa<CNode>()) {
  126. auto &inputs = node->cast<CNodePtr>()->inputs();
  127. (void)std::copy(inputs.begin(), inputs.end(), std::back_inserter(*todo));
  128. }
  129. }
  130. }
  131. static void UpdateTransformingListForIR(const AnfNodePtr &node, std::deque<AnfNodePtr> *todo, bool change,
  132. const SubstitutionPtr &substitution) {
  133. if (IsValueNode<FuncGraph>(node)) {
  134. (*todo).emplace_back(GetValueNode<FuncGraphPtr>(node)->output());
  135. }
  136. // If there is a priority pattern in substitution, don't transform the new node,
  137. // otherwise some nodes may match the wrong patterns.
  138. if (change && substitution != nullptr && !substitution->has_priority_pattern_) {
  139. (*todo).emplace_back(node);
  140. } else {
  141. if (node->isa<CNode>()) {
  142. auto &inputs = node->cast<CNodePtr>()->inputs();
  143. (void)std::copy(inputs.begin(), inputs.end(), std::back_inserter(*todo));
  144. }
  145. }
  146. }
  147. static void UpdateTransformingListWithUserNodes(const OptimizerPtr &optimizer, const AnfNodePtr &node,
  148. std::deque<AnfNodePtr> *todo, bool change, size_t seen) {
  149. if (!change) {
  150. return;
  151. }
  152. auto manager = optimizer->manager();
  153. auto &node_users = manager->node_users();
  154. auto users_iterator = node_users.find(node);
  155. if (users_iterator == node_users.end()) {
  156. return;
  157. }
  158. auto users = users_iterator->second;
  159. for (auto &use : users) {
  160. auto use_node = use.first;
  161. if (use_node == nullptr) {
  162. continue;
  163. }
  164. (*todo).emplace_back(use_node);
  165. if (use_node->seen_ == seen) {
  166. use_node->seen_--;
  167. }
  168. }
  169. }
  170. bool SubstitutionList::ApplyIRToSubstitutions(const OptimizerPtr &optimizer, const FuncGraphPtr &func_graph) const {
  171. #ifdef ENABLE_PROFILE
  172. double start = GetTime();
  173. #endif
  174. FuncGraphManagerPtr manager = optimizer->manager();
  175. auto seen = NewSeenGeneration();
  176. std::deque<AnfNodePtr> todo;
  177. todo.emplace_back(func_graph->output());
  178. bool changes = false;
  179. auto &all_nodes = manager->all_nodes();
  180. while (!todo.empty()) {
  181. AnfNodePtr node = todo.front();
  182. todo.pop_front();
  183. if (node == nullptr || node->seen_ == seen || !isTraversable(node) || !all_nodes.contains(node)) {
  184. continue;
  185. }
  186. node->seen_ = seen;
  187. bool change = false;
  188. for (auto &substitution : list_) {
  189. auto res = DoTransform(optimizer, node, substitution);
  190. if (res != nullptr) {
  191. change = true;
  192. changes = true;
  193. node = res;
  194. break;
  195. }
  196. }
  197. UpdateTransformingListForSubstitutions(node, &todo, change);
  198. UpdateTransformingListWithUserNodes(optimizer, node, &todo, change, seen);
  199. }
  200. #ifdef ENABLE_PROFILE
  201. MsProfile::StatTime("opt.transforms." + optimizer->name(), GetTime() - start);
  202. #endif
  203. return changes;
  204. }
  205. bool SubstitutionList::ApplySubstitutionToIR(const OptimizerPtr &optimizer, const FuncGraphPtr &func_graph,
  206. const SubstitutionPtr &substitution) const {
  207. #ifdef ENABLE_PROFILE
  208. double start = GetTime();
  209. #endif
  210. FuncGraphManagerPtr manager = optimizer->manager();
  211. auto seen = NewSeenGeneration();
  212. std::deque<AnfNodePtr> todo;
  213. todo.emplace_back(func_graph->output());
  214. bool changes = false;
  215. auto &all_nodes = manager->all_nodes();
  216. while (!todo.empty()) {
  217. AnfNodePtr node = todo.front();
  218. todo.pop_front();
  219. if (node == nullptr || node->seen_ == seen || !isTraversable(node) || !all_nodes.contains(node)) {
  220. continue;
  221. }
  222. node->seen_ = seen;
  223. bool change = false;
  224. auto res = DoTransform(optimizer, node, substitution);
  225. if (res != nullptr) {
  226. change = true;
  227. changes = true;
  228. node = res;
  229. }
  230. UpdateTransformingListForIR(node, &todo, change, substitution);
  231. UpdateTransformingListWithUserNodes(optimizer, node, &todo, change, seen);
  232. }
  233. #ifdef ENABLE_PROFILE
  234. MsProfile::StatTime("opt.transform." + optimizer->name(), GetTime() - start);
  235. #endif
  236. return changes;
  237. }
  238. void SubstitutionList::DisplayStatusOfSubstitution(const mindspore::HashMap<std::string, std::vector<bool>> &status,
  239. const OptimizerPtr &optimizer, size_t space) const {
  240. constexpr int pad_width = 4;
  241. std::stringstream ss;
  242. ss << std::endl
  243. << "Pass: " << optimizer->name() << "(" << optimizer->CurPass_.counter << ")_" << optimizer->CurPass_.name
  244. << std::endl;
  245. for (size_t i = 0; i < list_.size(); i++) {
  246. auto name = list_[i]->name_;
  247. ss << std::left << std::setw(SizeToInt(space) + pad_width) << name << "\t";
  248. for (auto change : status.at(name + std::to_string(i))) {
  249. ss << change << " ";
  250. }
  251. ss << std::endl;
  252. }
  253. MS_LOG(DEBUG) << ss.str();
  254. }
  255. bool SubstitutionList::ApplySubstitutionsToIR(const OptimizerPtr &optimizer, const FuncGraphPtr &func_graph) const {
  256. // Add for substitution status counting
  257. size_t space = 0;
  258. mindspore::HashMap<std::string, std::vector<bool>> status;
  259. if (optimizer->is_on_debug_) {
  260. for (size_t i = 0; i < list_.size(); i++) {
  261. status[list_[i]->name_ + std::to_string(i)] = {};
  262. }
  263. }
  264. bool changes = false;
  265. bool loop = true;
  266. while (loop) {
  267. loop = false;
  268. for (size_t i = 0; i < list_.size(); i++) {
  269. const auto &substitution = list_[i];
  270. bool change = ApplySubstitutionToIR(optimizer, func_graph, substitution);
  271. changes = changes || change;
  272. loop = loop || change;
  273. #ifdef ENABLE_DUMP_IR
  274. static const auto enable_dump_pass_ir = (common::GetEnv("ENV_DUMP_PASS_IR") == "1");
  275. if (enable_dump_pass_ir && MsContext::GetInstance()->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG)) {
  276. auto fg_name = optimizer->name() + "_r" + std::to_string(optimizer->CurPass_.counter) + "_" +
  277. optimizer->CurPass_.name + "_" + substitution->name_;
  278. DumpIR(fg_name + ".ir", func_graph);
  279. if (MsContext::GetInstance()->get_param<int>(MS_CTX_EXECUTION_MODE) != kPynativeMode) {
  280. func_graph->DumpFuncGraph(fg_name);
  281. ExportIR(fg_name + ".dat", func_graph);
  282. }
  283. }
  284. #endif
  285. // Record the status of each substitution
  286. if (optimizer->is_on_debug_) {
  287. status[substitution->name_ + std::to_string(i)].push_back(change);
  288. space = std::max(substitution->name_.size(), space);
  289. }
  290. }
  291. if (is_once_) {
  292. break;
  293. }
  294. }
  295. // Display the status of each substitution
  296. if (optimizer->is_on_debug_) {
  297. DisplayStatusOfSubstitution(status, optimizer, space);
  298. }
  299. return changes;
  300. }
  301. bool SubstitutionList::operator()(const FuncGraphPtr &func_graph, const OptimizerPtr &optimizer) const {
  302. MS_EXCEPTION_IF_NULL(optimizer);
  303. MS_EXCEPTION_IF_NULL(func_graph);
  304. FuncGraphManagerPtr manager = optimizer->manager();
  305. manager->AddFuncGraph(func_graph);
  306. bool changes = false;
  307. static const auto traverse_mode =
  308. (common::GetEnv("ENV_TRAVERSE_SUBSTITUTIONS_MODE") != "1" ? kOptTraverseFromIRToSubstitutions
  309. : kOptTraverseFromSubstitutionsToIR);
  310. if (traverse_mode == kOptTraverseFromIRToSubstitutions &&
  311. MsContext::GetInstance()->get_param<int>(MS_CTX_EXECUTION_MODE) != kPynativeMode &&
  312. optimizer->traverse_nodes_first() && !is_once_ && !global_sensitive_) {
  313. MS_LOG(DEBUG) << "IR >> SUB, " << optimizer->name() << "(r" << optimizer->CurPass_.counter << ")_"
  314. << optimizer->CurPass_.name;
  315. changes = ApplyIRToSubstitutions(optimizer, func_graph);
  316. } else {
  317. MS_LOG(DEBUG) << "SUB >> IR, " << optimizer->name() << "(r" << optimizer->CurPass_.counter << ")_"
  318. << optimizer->CurPass_.name;
  319. changes = ApplySubstitutionsToIR(optimizer, func_graph);
  320. }
  321. return changes;
  322. }
  323. } // namespace opt
  324. } // namespace mindspore