You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

py_pass_manager.h 2.8 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_PY_PASS_MANAGER_H_
  17. #define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_PY_PASS_MANAGER_H_
  18. #include <memory>
  19. #include <string>
  20. #include <vector>
  21. #include <unordered_map>
  22. #include "ir/anf.h"
  23. #include "ir/func_graph.h"
  24. #include "pybind_api/ir/primitive_py.h"
  25. #include "ir/graph_utils.h"
  26. #include "utils/ms_utils.h"
  27. #include "pipeline/jit/resource.h"
  28. #include "frontend/optimizer/pattern.h"
  29. #include "frontend/optimizer/py_pass.h"
  30. #include "frontend/optimizer/pass_group.h"
  31. namespace mindspore {
  32. namespace opt {
  33. namespace python_pass {
  34. class PyPassManager;
  35. using PyPassManagerPtr = std::shared_ptr<PyPassManager>;
  36. enum Phase { PREAD, OPT };
  37. class PyPassManager {
  38. protected:
  39. PyPassManager();
  40. static PyPassManagerPtr global_instance;
  41. public:
  42. // Singletons should not be cloneable and assignable
  43. PyPassManager(const PyPassManager &other) = delete;
  44. void operator=(const PyPassManager &) = delete;
  45. // Access the only global instance
  46. static PyPassManagerPtr GetInstance();
  47. virtual ~PyPassManager() = default;
  48. void Registe(const std::string &pass_name, const PatternPtr &pattern, const PatternPtr &target, bool requires_grad,
  49. bool run_only_once);
  50. void Unregiste(const std::string &pass_name);
  51. void GenNewParameter(const PatternPtr &parameter);
  52. PassGroupPtr GetPassGroup(Phase phase);
  53. MatchResultPtr GetMatchResult() { return res_; }
  54. void SetRenorm(bool should_renorm) { should_renorm_ = should_renorm; }
  55. bool ShouldRenorm() { return should_renorm_; }
  56. void SetReOpt(bool should_reopt) { should_reopt_ = should_reopt; }
  57. bool ShouldReOpt() { return should_reopt_; }
  58. void SetResource(pipeline::ResourcePtr resource) { resource_ = resource; }
  59. pipeline::ResourcePtr GetResource() { return resource_; }
  60. void ClearRes();
  61. void ClearPipelineRes() {
  62. resource_ = nullptr;
  63. Pattern::reset_gid();
  64. }
  65. private:
  66. bool should_renorm_ = true;
  67. bool should_reopt_ = true;
  68. MatchResultPtr res_;
  69. pipeline::ResourcePtr resource_;
  70. static std::unordered_map<Phase, PassGroupPtr> phase_to_group_;
  71. };
  72. } // namespace python_pass
  73. } // namespace opt
  74. } // namespace mindspore
  75. #endif // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_PY_PASS_MANAGER_H_