You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

python_adapter.h 2.6 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef MINDSPORE_CCSRC_PIPELINE_JIT_PARSE_PYTHON_ADAPTER_H_
  17. #define MINDSPORE_CCSRC_PIPELINE_JIT_PARSE_PYTHON_ADAPTER_H_
  18. #include <map>
  19. #include <memory>
  20. #include <string>
  21. #include "pybind11/embed.h"
  22. #include "pybind11/pybind11.h"
  23. #include "pybind11/stl.h"
  24. #include "pipeline/jit/parse/parse_base.h"
  25. #include "utils/log_adapter.h"
  26. namespace mindspore {
  27. namespace parse {
  28. // A utility to call python interface
  29. namespace python_adapter {
  30. py::module GetPyModule(const std::string &module);
  31. py::object GetPyObjAttr(const py::object &obj, const std::string &attr);
  32. template <class... T>
  33. py::object CallPyObjMethod(const py::object &obj, const std::string &method, T... args) {
  34. if (!method.empty() && !py::isinstance<py::none>(obj)) {
  35. return obj.attr(method.c_str())(args...);
  36. }
  37. return py::none();
  38. }
  39. // call python function of module
  40. template <class... T>
  41. py::object CallPyModFn(const py::module &mod, const std::string &function, T... args) {
  42. if (!function.empty() && !py::isinstance<py::none>(mod)) {
  43. return mod.attr(function.c_str())(args...);
  44. }
  45. return py::none();
  46. }
  47. // turn off the signature when ut use parser to construct a graph.
  48. void set_use_signature_in_resolve(bool use_signature) noexcept;
  49. bool UseSignatureInResolve();
  50. std::shared_ptr<py::scoped_interpreter> set_python_scoped();
  51. void ResetPythonScope();
  52. bool IsPythonEnv();
  53. void SetPythonPath(const std::string &path);
  54. void set_python_env_flag(bool python_env) noexcept;
  55. py::object GetPyFn(const std::string &module, const std::string &name);
  56. // Call the python function
  57. template <class... T>
  58. py::object CallPyFn(const std::string &module, const std::string &name, T... args) {
  59. (void)set_python_scoped();
  60. if (!module.empty() && !name.empty()) {
  61. py::module mod = py::module::import(module.c_str());
  62. py::object fn = mod.attr(name.c_str())(args...);
  63. return fn;
  64. }
  65. return py::none();
  66. }
  67. } // namespace python_adapter
  68. } // namespace parse
  69. } // namespace mindspore
  70. #endif // MINDSPORE_CCSRC_PIPELINE_JIT_PARSE_PYTHON_ADAPTER_H_