You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

optimizer.h 7.6 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef MINDSPORE_CCSRC_OPTIMIZER_OPTIMIZER_H_
  17. #define MINDSPORE_CCSRC_OPTIMIZER_OPTIMIZER_H_
  18. #include <functional>
  19. #include <memory>
  20. #include <string>
  21. #include <vector>
  22. #include <map>
  23. #include <utility>
  24. #include <initializer_list>
  25. #ifdef DEBUG
  26. #include "debug/draw.h"
  27. #include "debug/anf_ir_dump.h"
  28. #endif
  29. #include "optimizer/opt.h"
  30. #include "pipeline/resource.h"
  31. #include "pipeline/action.h"
  32. #include "debug/trace.h"
  33. namespace mindspore {
  34. namespace opt {
  35. using OptimizeGraphFunc = std::function<bool(const FuncGraphPtr &func_graph, const OptimizerPtr &optimizer)>;
  36. class OptPassConfig {
  37. public:
  38. explicit OptPassConfig(const OptimizeGraphFunc &func) : func_(func) {}
  39. explicit OptPassConfig(const std::vector<SubstitutionPtr> &list, bool is_once = false)
  40. : list_(list), is_once_(is_once) {}
  41. OptPassConfig(const std::initializer_list<SubstitutionPtr> &list, bool is_once = false)
  42. : list_(list), is_once_(is_once) {}
  43. ~OptPassConfig() = default;
  44. const std::vector<SubstitutionPtr> &list() const { return list_; }
  45. const OptimizeGraphFunc &func() const { return func_; }
  46. static OptPassConfig Renormalize() { return OptPassConfig(); }
  47. const bool is_renormalize() const { return is_renormalize_; }
  48. const bool is_once() const { return is_once_; }
  49. private:
  50. OptPassConfig() : is_renormalize_(true) {}
  51. OptimizeGraphFunc func_;
  52. std::vector<SubstitutionPtr> list_;
  53. bool is_renormalize_{false};
  54. bool is_once_{false};
  55. };
  56. class OptPass {
  57. public:
  58. explicit OptPass(const OptimizeGraphFunc &func) : pass_func_(func) {}
  59. ~OptPass() = default;
  60. bool operator()(const FuncGraphPtr &func_graph, const OptimizerPtr &optimizer) const {
  61. return pass_func_(func_graph, optimizer);
  62. }
  63. static OptPass Renormalize() { return OptPass(); }
  64. const bool is_renormalize() const { return is_renormalize_; }
  65. private:
  66. OptPass() : is_renormalize_(true) {}
  67. OptimizeGraphFunc pass_func_;
  68. bool is_renormalize_{false};
  69. };
  70. using OptPassGroupMap = std::vector<std::pair<std::string, OptPassConfig>>;
  71. class Optimizer : public std::enable_shared_from_this<Optimizer> {
  72. public:
  73. Optimizer(const std::string &name, const pipeline::ResourceBasePtr &resource_ptr)
  74. : name_(name), resource_(resource_ptr), run_only_once_(false), is_watch_renormalize_(false) {}
  75. virtual ~Optimizer() = default;
  76. void Init(const OptPassGroupMap &passes, bool run_only_once) {
  77. run_only_once_ = run_only_once;
  78. is_watch_renormalize_ = false;
  79. for (auto &iter : passes) {
  80. const std::string &name = iter.first;
  81. pass_names_.push_back(name);
  82. const OptPassConfig &config = iter.second;
  83. if (config.is_renormalize()) {
  84. passes_.push_back(OptPass::Renormalize());
  85. continue;
  86. }
  87. if (config.list().size() > 0) {
  88. OptimizeGraphFunc func = SubstitutionList(config.list(), config.is_once());
  89. passes_.push_back(OptPass(func));
  90. continue;
  91. }
  92. passes_.push_back(OptPass(config.func()));
  93. }
  94. if (passes_.size() == 1) {
  95. run_only_once_ = true;
  96. }
  97. }
  98. static std::shared_ptr<Optimizer> MakeOptimizer(const std::string &name, const pipeline::ResourceBasePtr resource_ptr,
  99. const OptPassGroupMap &passes, bool run_only_once = false,
  100. bool watch_renormalize = false) {
  101. OptimizerPtr optimizer = std::make_shared<Optimizer>(name, resource_ptr);
  102. optimizer->Init(passes, run_only_once);
  103. if (watch_renormalize) {
  104. optimizer->enable_watch_renormalize();
  105. }
  106. return optimizer;
  107. }
  108. FuncGraphPtr step(FuncGraphPtr func_graph, const abstract::AbstractBasePtrList &args_spec, bool use_profile = true) {
  109. // Optimizer step counter;
  110. int counter = 1;
  111. bool changes = true;
  112. while (changes) {
  113. changes = false;
  114. auto run_runc = [&counter, &func_graph, &args_spec, &changes, use_profile, this]() {
  115. for (size_t i = 0; i < passes_.size(); ++i) {
  116. const OptPass &opt = passes_[i];
  117. auto opt_func = [&func_graph, &args_spec, &changes, &opt, this]() {
  118. if (opt.is_renormalize()) {
  119. auto resource_ptr = std::dynamic_pointer_cast<pipeline::Resource>(resource_);
  120. if (resource_ptr != nullptr) {
  121. if (is_watch_renormalize_) {
  122. if (untyped_nodes_.size() > 0) {
  123. func_graph = pipeline::Renormalize(resource_ptr, func_graph, args_spec);
  124. clear_untyped_nodes();
  125. } else {
  126. MS_LOG(INFO) << "Optimizer::step: Skipping Renormalize because untyped_nodes_ is empty.";
  127. }
  128. } else {
  129. func_graph = pipeline::Renormalize(resource_ptr, func_graph, args_spec);
  130. }
  131. }
  132. } else if (opt(func_graph, shared_from_this())) {
  133. changes = true;
  134. }
  135. };
  136. use_profile ? (WITH(MsProfile::GetProfile()->Step(pass_names_[i])) opt_func) : opt_func();
  137. #ifdef DEBUG
  138. MS_LOG(DEBUG) << "" << name_ << " round " << counter << " OptPass " << pass_names_[i] << " end.";
  139. auto fg_name = name_ + "_r" + std::to_string(counter) + "_" + std::to_string(i) + "_" + pass_names_[i];
  140. func_graph->DumpFuncGraph(fg_name);
  141. DumpIR(fg_name + ".ir", func_graph);
  142. MS_LOG(DEBUG) << "Dump " << pass_names_[i] << " func graph.";
  143. #endif
  144. }
  145. };
  146. use_profile ? (WITH(MsProfile::GetProfile()->Lap(counter++)) run_runc) : run_runc();
  147. if (run_only_once_) {
  148. break;
  149. }
  150. }
  151. auto keep_root = [&func_graph, this]() {
  152. std::vector<FuncGraphPtr> func_graphs;
  153. func_graphs.push_back(func_graph);
  154. resource_->manager()->KeepRoots(func_graphs);
  155. };
  156. use_profile ? WITH(MsProfile::GetProfile()->Step("keep_roots")) keep_root : keep_root();
  157. return func_graph;
  158. }
  159. pipeline::ResourceBasePtr resource() const { return resource_; }
  160. FuncGraphManagerPtr manager() const {
  161. if (resource_ != nullptr) {
  162. return resource_->manager();
  163. }
  164. MS_LOG(EXCEPTION) << "No ResourceBase exists.";
  165. }
  166. const std::string name() const { return name_; }
  167. void add_node_to_renormalize(AnfNodePtr anode) {
  168. if (std::find(untyped_nodes_.begin(), untyped_nodes_.end(), anode) == untyped_nodes_.end()) {
  169. untyped_nodes_.push_back(anode);
  170. }
  171. }
  172. void clear_untyped_nodes() { untyped_nodes_.clear(); }
  173. void enable_watch_renormalize() { is_watch_renormalize_ = true; }
  174. void disable_watch_renormalize() { is_watch_renormalize_ = false; }
  175. bool is_watch_renormalize() { return is_watch_renormalize_; }
  176. private:
  177. const std::string name_;
  178. pipeline::ResourceBasePtr resource_;
  179. std::vector<OptPass> passes_;
  180. std::vector<std::string> pass_names_;
  181. bool run_only_once_;
  182. std::vector<AnfNodePtr> untyped_nodes_;
  183. bool is_watch_renormalize_;
  184. };
  185. } // namespace opt
  186. } // namespace mindspore
  187. #endif // MINDSPORE_CCSRC_OPTIMIZER_OPTIMIZER_H_