You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

python_adapter.h 2.8 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576
  1. /**
  2. * Copyright 2019-2022 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef MINDSPORE_CCSRC_INCLUDE_COMMON_UTILS_PYTHON_ADAPTER_H_
  17. #define MINDSPORE_CCSRC_INCLUDE_COMMON_UTILS_PYTHON_ADAPTER_H_
  18. #include <map>
  19. #include <memory>
  20. #include <string>
  21. #include "pybind11/embed.h"
  22. #include "pybind11/pybind11.h"
  23. #include "pybind11/stl.h"
  24. #include "utils/log_adapter.h"
  25. #include "include/common/visible.h"
  26. namespace py = pybind11;
  27. namespace mindspore {
  28. // A utility to call python interface
  29. namespace python_adapter {
  30. COMMON_EXPORT py::module GetPyModule(const std::string &module);
  31. COMMON_EXPORT py::object GetPyObjAttr(const py::object &obj, const std::string &attr);
  32. template <class... T>
  33. py::object CallPyObjMethod(const py::object &obj, const std::string &method, T... args) {
  34. if (!method.empty() && !py::isinstance<py::none>(obj)) {
  35. return obj.attr(method.c_str())(args...);
  36. }
  37. return py::none();
  38. }
  39. // call python function of module
  40. template <class... T>
  41. py::object CallPyModFn(const py::module &mod, const std::string &function, T... args) {
  42. if (!function.empty() && !py::isinstance<py::none>(mod)) {
  43. return mod.attr(function.c_str())(args...);
  44. }
  45. return py::none();
  46. }
  47. // turn off the signature when ut use parser to construct a graph.
  48. COMMON_EXPORT void set_use_signature_in_resolve(bool use_signature) noexcept;
  49. COMMON_EXPORT bool UseSignatureInResolve();
  50. COMMON_EXPORT std::shared_ptr<py::scoped_interpreter> set_python_scoped();
  51. COMMON_EXPORT void ResetPythonScope();
  52. COMMON_EXPORT bool IsPythonEnv();
  53. COMMON_EXPORT void SetPythonPath(const std::string &path);
  54. COMMON_EXPORT void set_python_env_flag(bool python_env) noexcept;
  55. COMMON_EXPORT py::object GetPyFn(const std::string &module, const std::string &name);
  56. // Call the python function
  57. template <class... T>
  58. py::object CallPyFn(const std::string &module, const std::string &name, T... args) {
  59. (void)set_python_scoped();
  60. if (!module.empty() && !name.empty()) {
  61. py::module mod = py::module::import(module.c_str());
  62. py::object fn = mod.attr(name.c_str())(args...);
  63. return fn;
  64. }
  65. return py::none();
  66. }
  67. } // namespace python_adapter
  68. } // namespace mindspore
  69. #endif // MINDSPORE_CCSRC_INCLUDE_COMMON_UTILS_PYTHON_ADAPTER_H_