You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

py_pass_manager.cc 4.2 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "frontend/optimizer/py_pass_manager.h"
  17. #include <functional>
  18. #include <utility>
  19. #include "ir/manager.h"
  20. #include "frontend/optimizer/pass_group.h"
  21. namespace mindspore {
  22. namespace opt {
  23. namespace python_pass {
  24. PyPassManagerPtr PyPassManager::global_instance = nullptr;
  25. std::unordered_map<Phase, PassGroupPtr> PyPassManager::phase_to_group_;
  26. PassGroupPtr PyPassManager::GetPassGroup(Phase phase) {
  27. auto pm = phase_to_group_.find(phase);
  28. if (pm == phase_to_group_.end()) {
  29. return nullptr;
  30. }
  31. return pm->second;
  32. }
  33. PyPassManagerPtr PyPassManager::GetInstance() {
  34. if (global_instance == nullptr) {
  35. global_instance = std::shared_ptr<PyPassManager>(new (std::nothrow) PyPassManager());
  36. }
  37. return global_instance;
  38. }
  39. PyPassManager::PyPassManager() {
  40. phase_to_group_[Phase::PREAD] = std::make_shared<PassGroup>("Pre_AD_PassGroup");
  41. phase_to_group_[Phase::OPT] = std::make_shared<PassGroup>("After_OPT_PassGroup");
  42. res_ = std::make_shared<MatchResult>();
  43. }
  44. void PyPassManager::Registe(const std::string &pass_name, const PatternPtr &pattern, const PatternPtr &target,
  45. bool requires_grad, bool run_only_once) {
  46. PassGroupPtr cur_pg;
  47. if (requires_grad) {
  48. cur_pg = GetPassGroup(Phase::PREAD);
  49. } else {
  50. cur_pg = GetPassGroup(Phase::OPT);
  51. }
  52. MS_EXCEPTION_IF_NULL(cur_pg);
  53. cur_pg->SetRunOnlyOnce(run_only_once);
  54. MS_EXCEPTION_IF_NULL(pattern);
  55. MS_EXCEPTION_IF_NULL(target);
  56. MS_EXCEPTION_IF_NULL(cur_pg);
  57. PythonPassPtr new_pass = std::make_shared<PythonPass>(pass_name, pattern, target, run_only_once);
  58. cur_pg->AddPass(new_pass);
  59. }
  60. void PyPassManager::Unregiste(const std::string &pass_name) {
  61. auto opt_pm = GetPassGroup(Phase::OPT);
  62. if (!opt_pm->DeletePass(pass_name)) {
  63. MS_LOG(WARNING) << "Opt has no such pass : " + pass_name + "\n";
  64. }
  65. auto pre_ad_pm = GetPassGroup(Phase::PREAD);
  66. if (!pre_ad_pm->DeletePass(pass_name)) {
  67. MS_LOG(WARNING) << "Pre_AD has no such pass : " + pass_name + "\n";
  68. }
  69. }
  70. void PyPassManager::GenNewParameter(const PatternPtr &parameter) {
  71. MS_EXCEPTION_IF_NULL(parameter);
  72. // NOTE: Add NewParameter at early stage will cause CSE problems
  73. auto cur_pg = GetPassGroup(Phase::OPT);
  74. MS_EXCEPTION_IF_NULL(cur_pg);
  75. cur_pg->SetRunOnlyOnce(true);
  76. auto new_para_pattern = parameter->cast<NewParameterPtr>();
  77. MS_EXCEPTION_IF_NULL(new_para_pattern);
  78. auto pass_name = new_para_pattern->para_name();
  79. new_para_pattern->set_last(true);
  80. auto new_pass = std::make_shared<PythonPass>(pass_name, nullptr, parameter, true);
  81. cur_pg->AddPass(new_pass);
  82. }
  83. void PyPassManager::ClearRes() {
  84. MS_LOG(INFO) << "Clear PyPassManager resources!";
  85. global_instance = nullptr;
  86. phase_to_group_.clear();
  87. }
  88. REGISTER_PYBIND_DEFINE(
  89. PyPassManager_, ([](const py::module *m) {
  90. (void)py::enum_<Phase>(*m, "phase", py::arithmetic()).value("pre_ad", Phase::PREAD).value("opt", Phase::OPT);
  91. (void)py::class_<PyPassManager, std::shared_ptr<PyPassManager>>(*m, "PyPassManager_")
  92. .def(py::init([]() { return PyPassManager::GetInstance(); }))
  93. .def("registe", &PyPassManager::Registe, "Registe python pass")
  94. .def("unregiste", &PyPassManager::Unregiste, "Delete Python Pass")
  95. .def("gen_new_parameter", &PyPassManager::GenNewParameter, "Generate new parameter")
  96. .def("set_renorm", &PyPassManager::SetRenorm, "Set whether or not to do renorm after modified graph")
  97. .def("set_reopt", &PyPassManager::SetReOpt, "Set whether or not to do optimization after modified graph");
  98. }));
  99. } // namespace python_pass
  100. } // namespace opt
  101. } // namespace mindspore