You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

py_pass_manager.cc 2.9 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "frontend/optimizer/py_pass_manager.h"
  17. #include <functional>
  18. #include <algorithm>
  19. #include <utility>
  20. #include <initializer_list>
  21. #include "ir/manager.h"
  22. #include "frontend/optimizer/pass_group.h"
  23. namespace mindspore {
  24. namespace opt {
  25. namespace python_pass {
  26. PyPassManagerPtr PyPassManager::global_instance = nullptr;
  27. std::unordered_map<Phase, PassGroupPtr> PyPassManager::phase_to_group_;
  28. PassGroupPtr PyPassManager::GetPassGroup(Phase phase) {
  29. auto pm = phase_to_group_.find(phase);
  30. if (pm == phase_to_group_.end()) {
  31. return nullptr;
  32. }
  33. return pm->second;
  34. }
  35. PyPassManagerPtr PyPassManager::GetInstance() {
  36. if (global_instance == nullptr) {
  37. global_instance = std::shared_ptr<PyPassManager>(new (std::nothrow) PyPassManager());
  38. }
  39. return global_instance;
  40. }
  41. PyPassManager::PyPassManager() {
  42. phase_to_group_[Phase::RESOLVE] = std::make_shared<PassGroup>();
  43. phase_to_group_[Phase::OPT] = std::make_shared<PassGroup>();
  44. }
  45. void PyPassManager::Registe(const std::string &pass_name, const PatternPtr &pattern, const PatternPtr &target,
  46. Phase phase, bool run_only_once, bool multigraph) {
  47. auto cur_pm = GetPassGroup(phase);
  48. MS_EXCEPTION_IF_NULL(cur_pm);
  49. PythonPassPtr new_pass = std::make_shared<PythonPass>(pass_name, pattern, target, run_only_once, multigraph);
  50. cur_pm->AddPass(new_pass);
  51. }
  52. void PyPassManager::Unregiste(const std::string &pass_name, Phase phase) {
  53. auto cur_pm = GetPassGroup(phase);
  54. MS_EXCEPTION_IF_NULL(cur_pm);
  55. if (!cur_pm->DeletePass(pass_name)) {
  56. MS_LOG(WARNING) << "No such pass : " + pass_name + "\n";
  57. }
  58. }
  59. void PyPassManager::ClearRes() {
  60. MS_LOG(INFO) << "Clear PyPassManager resources!";
  61. global_instance = nullptr;
  62. phase_to_group_.clear();
  63. }
  64. REGISTER_PYBIND_DEFINE(
  65. PyPassManager_, ([](const py::module *m) {
  66. (void)py::enum_<Phase>(*m, "phase", py::arithmetic()).value("resolve", Phase::RESOLVE).value("opt", Phase::OPT);
  67. (void)py::class_<PyPassManager, std::shared_ptr<PyPassManager>>(*m, "PyPassManager_")
  68. .def(py::init([]() { return PyPassManager::GetInstance(); }))
  69. .def("registe", &PyPassManager::Registe, "Registe python pass")
  70. .def("unregiste", &PyPassManager::Unregiste, "Delete Python Pass");
  71. }));
  72. } // namespace python_pass
  73. } // namespace opt
  74. } // namespace mindspore