You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

optimizer.h 10 kB

5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_OPTIMIZER_H_
  17. #define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_OPTIMIZER_H_
  18. #include <algorithm>
  19. #include <functional>
  20. #include <iterator>
  21. #include <memory>
  22. #include <string>
  23. #include <vector>
  24. #include <map>
  25. #include <utility>
  26. #include <initializer_list>
  27. #include "debug/draw.h"
  28. #include "debug/anf_ir_dump.h"
  29. #include "debug/anf_ir_utils.h"
  30. #include "debug/trace.h"
  31. #include "frontend/optimizer/opt.h"
  32. #include "pipeline/jit/resource.h"
  33. #include "pipeline/jit/action.h"
  34. #include "utils/ms_context.h"
  35. namespace mindspore {
  36. namespace opt {
  37. using OptimizeGraphFunc = std::function<bool(const FuncGraphPtr &func_graph, const OptimizerPtr &optimizer)>;
  38. class OptPassConfig {
  39. public:
  40. explicit OptPassConfig(const OptimizeGraphFunc &func) : func_(func) {}
  41. explicit OptPassConfig(const std::vector<SubstitutionPtr> &list, bool is_once = false, bool global_sensitive = false)
  42. : list_(list), is_once_(is_once), global_sensitive_(global_sensitive) {}
  43. OptPassConfig(const std::initializer_list<SubstitutionPtr> &list, bool is_once = false, bool global_sensitive = false)
  44. : list_(list), is_once_(is_once), global_sensitive_(global_sensitive) {}
  45. ~OptPassConfig() = default;
  46. const std::vector<SubstitutionPtr> &list() const { return list_; }
  47. const OptimizeGraphFunc &func() const { return func_; }
  48. static OptPassConfig Renormalize() { return OptPassConfig(); }
  49. const bool is_renormalize() const { return is_renormalize_; }
  50. const bool is_once() const { return is_once_; }
  51. const bool global_sensitive() const { return global_sensitive_; }
  52. private:
  53. OptPassConfig() : is_renormalize_(true) {}
  54. OptimizeGraphFunc func_;
  55. std::vector<SubstitutionPtr> list_;
  56. bool is_renormalize_{false};
  57. bool is_once_{false};
  58. bool global_sensitive_{false};
  59. };
  60. class OptPass {
  61. public:
  62. explicit OptPass(const OptimizeGraphFunc &func) : pass_func_(func) {}
  63. ~OptPass() = default;
  64. bool operator()(const FuncGraphPtr &func_graph, const OptimizerPtr &optimizer) const {
  65. return pass_func_(func_graph, optimizer);
  66. }
  67. static OptPass Renormalize() { return OptPass(); }
  68. const bool is_renormalize() const { return is_renormalize_; }
  69. private:
  70. OptPass() : is_renormalize_(true) {}
  71. OptimizeGraphFunc pass_func_;
  72. bool is_renormalize_{false};
  73. };
  74. using OptPassGroupMap = std::vector<std::pair<std::string, OptPassConfig>>;
  75. class Optimizer : public std::enable_shared_from_this<Optimizer> {
  76. public:
  77. Optimizer(const std::string &name, const pipeline::ResourceBasePtr &resource_ptr, bool traverse_nodes_first = true)
  78. : name_(name),
  79. resource_(resource_ptr),
  80. run_only_once_(false),
  81. is_watch_renormalize_(false),
  82. is_enable_(true),
  83. is_untyped_generated_(false),
  84. traverse_nodes_first_(traverse_nodes_first) {}
  85. virtual ~Optimizer() = default;
  86. void Init(const OptPassGroupMap &passes, bool run_only_once) {
  87. run_only_once_ = run_only_once;
  88. is_watch_renormalize_ = false;
  89. is_untyped_generated_ = false;
  90. is_on_debug_ = IS_OUTPUT_ON(mindspore::DEBUG);
  91. for (auto &iter : passes) {
  92. const std::string &name = iter.first;
  93. pass_names_.push_back(name);
  94. const OptPassConfig &config = iter.second;
  95. if (config.is_renormalize()) {
  96. passes_.push_back(OptPass::Renormalize());
  97. continue;
  98. }
  99. if (config.list().size() > 0) {
  100. OptimizeGraphFunc func = SubstitutionList(config.list(), config.is_once(), config.global_sensitive());
  101. passes_.push_back(OptPass(func));
  102. continue;
  103. }
  104. passes_.push_back(OptPass(config.func()));
  105. }
  106. if (passes_.size() == 1) {
  107. run_only_once_ = true;
  108. }
  109. }
  110. static std::shared_ptr<Optimizer> MakeOptimizer(const std::string &name, const pipeline::ResourceBasePtr resource_ptr,
  111. const OptPassGroupMap &passes, bool run_only_once = false,
  112. bool watch_renormalize = false, bool traverse_nodes_first = true) {
  113. OptimizerPtr optimizer = std::make_shared<Optimizer>(name, resource_ptr, traverse_nodes_first);
  114. optimizer->Init(passes, run_only_once);
  115. if (watch_renormalize) {
  116. optimizer->enable_watch_renormalize();
  117. }
  118. return optimizer;
  119. }
  120. FuncGraphPtr step(FuncGraphPtr func_graph, bool use_profile = true) {
  121. if (!is_enable_) {
  122. return func_graph;
  123. }
  124. // Optimizer step counter;
  125. int64_t counter = 1;
  126. bool changes = true;
  127. // If no changes since last renormalization, then no need to do the renormalization again.
  128. // Set the initial value to true, so the renormalization can be executed once if it's the
  129. // only pass.
  130. bool changes_since_last_renorm = true;
  131. while (changes) {
  132. changes = false;
  133. auto run_runc = [&counter, &func_graph, &changes, &changes_since_last_renorm, use_profile, this]() {
  134. for (size_t i = 0; i < passes_.size(); ++i) {
  135. const OptPass &opt = passes_[i];
  136. CurPass_ = {counter, pass_names_[i]};
  137. auto opt_func = [&func_graph, &changes, &opt, &changes_since_last_renorm, this]() {
  138. if (opt.is_renormalize()) {
  139. if (!changes_since_last_renorm) {
  140. return;
  141. }
  142. auto resource_ptr = std::dynamic_pointer_cast<pipeline::Resource>(resource_);
  143. if (resource_ptr != nullptr) {
  144. // StepParallel may replace the AbstractValue of the parameters of func_graph,
  145. // So generate the args_spec from parameters.
  146. abstract::AbstractBasePtrList maybe_new_args_spec;
  147. if (is_watch_renormalize_) {
  148. if (is_untyped_generated_) {
  149. std::transform(func_graph->parameters().begin(), func_graph->parameters().end(),
  150. std::back_inserter(maybe_new_args_spec),
  151. [](AnfNodePtr param) -> AbstractBasePtr { return param->abstract(); });
  152. func_graph = pipeline::Renormalize(resource_ptr, func_graph, maybe_new_args_spec);
  153. clear_is_untyped_generated();
  154. } else {
  155. MS_LOG(INFO) << "Optimizer::step: Skipping Renormalize because is_untyped_generated_ is False.";
  156. }
  157. } else {
  158. std::transform(func_graph->parameters().begin(), func_graph->parameters().end(),
  159. std::back_inserter(maybe_new_args_spec),
  160. [](AnfNodePtr param) -> AbstractBasePtr { return param->abstract(); });
  161. func_graph = pipeline::Renormalize(resource_ptr, func_graph, maybe_new_args_spec);
  162. }
  163. }
  164. changes_since_last_renorm = false;
  165. } else if (opt(func_graph, shared_from_this())) {
  166. changes = true;
  167. changes_since_last_renorm = true;
  168. }
  169. };
  170. use_profile ? (WITH(MsProfile::GetProfile()->Step(pass_names_[i])) opt_func) : opt_func();
  171. #ifdef ENABLE_DUMP_IR
  172. static const auto enable_dump_pass_ir = GetDumpConfig().enable_dump_pass_ir;
  173. if (enable_dump_pass_ir && MsContext::GetInstance()->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG)) {
  174. auto fg_name =
  175. "opt_substep_" + name_ + "_r" + std::to_string(counter) + "_" + std::to_string(i) + "_" + pass_names_[i];
  176. MS_LOG(DEBUG) << "The opt " << name_ << " round " << counter << " OptPass " << pass_names_[i] << " end.";
  177. DumpIR(fg_name + ".ir", func_graph);
  178. if (MsContext::GetInstance()->get_param<int>(MS_CTX_EXECUTION_MODE) != kPynativeMode) {
  179. func_graph->DumpFuncGraph(fg_name);
  180. ExportIR(fg_name + ".dat", func_graph);
  181. }
  182. MS_LOG(DEBUG) << "Dump " << pass_names_[i] << " func graph.";
  183. }
  184. #endif
  185. }
  186. };
  187. use_profile ? (WITH(MsProfile::GetProfile()->Lap(counter)) run_runc) : run_runc();
  188. counter++;
  189. if (run_only_once_) {
  190. break;
  191. }
  192. }
  193. return func_graph;
  194. }
  195. pipeline::ResourceBasePtr resource() const { return resource_; }
  196. FuncGraphManagerPtr manager() const {
  197. if (resource_ != nullptr) {
  198. return resource_->manager();
  199. }
  200. MS_LOG(EXCEPTION) << "No ResourceBase exists.";
  201. }
  202. const std::string name() const { return name_; }
  203. void set_is_untyped_generated() { is_untyped_generated_ = true; }
  204. void clear_is_untyped_generated() { is_untyped_generated_ = false; }
  205. void enable_watch_renormalize() { is_watch_renormalize_ = true; }
  206. void disable_watch_renormalize() { is_watch_renormalize_ = false; }
  207. bool is_watch_renormalize() { return is_watch_renormalize_; }
  208. void set_enable(bool enable) { is_enable_ = enable; }
  209. bool traverse_nodes_first() { return traverse_nodes_first_; }
  210. struct {
  211. int64_t counter;
  212. std::string name;
  213. } CurPass_;
  214. bool is_on_debug_{false};
  215. private:
  216. const std::string name_;
  217. pipeline::ResourceBasePtr resource_;
  218. std::vector<OptPass> passes_;
  219. std::vector<std::string> pass_names_;
  220. bool run_only_once_;
  221. bool is_watch_renormalize_;
  222. bool is_enable_;
  223. bool is_untyped_generated_;
  224. bool traverse_nodes_first_;
  225. };
  226. } // namespace opt
  227. } // namespace mindspore
  228. #endif // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_OPTIMIZER_H_