You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

python_adapter.h 3.1 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485
  1. /**
  2. * Copyright 2019-2022 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef MINDSPORE_CCSRC_INCLUDE_COMMON_UTILS_PYTHON_ADAPTER_H_
  17. #define MINDSPORE_CCSRC_INCLUDE_COMMON_UTILS_PYTHON_ADAPTER_H_
  18. #include <map>
  19. #include <memory>
  20. #include <string>
  21. #include "pybind11/embed.h"
  22. #include "pybind11/pybind11.h"
  23. #include "pybind11/stl.h"
  24. #include "pybind11/numpy.h"
  25. #include "utils/log_adapter.h"
  26. #include "ir/tensor.h"
  27. #include "base/base_ref.h"
  28. #include "include/common/visible.h"
  29. namespace py = pybind11;
  30. namespace mindspore {
  31. // A utility to call python interface
  32. namespace python_adapter {
  33. COMMON_EXPORT py::module GetPyModule(const std::string &module);
  34. COMMON_EXPORT py::object GetPyObjAttr(const py::object &obj, const std::string &attr);
  35. template <class... T>
  36. py::object CallPyObjMethod(const py::object &obj, const std::string &method, T... args) {
  37. if (!method.empty() && !py::isinstance<py::none>(obj)) {
  38. return obj.attr(method.c_str())(args...);
  39. }
  40. return py::none();
  41. }
  42. // call python function of module
  43. template <class... T>
  44. py::object CallPyModFn(const py::module &mod, const std::string &function, T... args) {
  45. if (!function.empty() && !py::isinstance<py::none>(mod)) {
  46. return mod.attr(function.c_str())(args...);
  47. }
  48. return py::none();
  49. }
  50. // turn off the signature when ut use parser to construct a graph.
  51. COMMON_EXPORT void set_use_signature_in_resolve(bool use_signature) noexcept;
  52. COMMON_EXPORT bool UseSignatureInResolve();
  53. COMMON_EXPORT std::shared_ptr<py::scoped_interpreter> set_python_scoped();
  54. COMMON_EXPORT void ResetPythonScope();
  55. COMMON_EXPORT bool IsPythonEnv();
  56. COMMON_EXPORT void SetPythonPath(const std::string &path);
  57. COMMON_EXPORT void set_python_env_flag(bool python_env) noexcept;
  58. COMMON_EXPORT py::object GetPyFn(const std::string &module, const std::string &name);
  59. // Call the python function
  60. template <class... T>
  61. py::object CallPyFn(const std::string &module, const std::string &name, T... args) {
  62. (void)set_python_scoped();
  63. if (!module.empty() && !name.empty()) {
  64. py::module mod = py::module::import(module.c_str());
  65. py::object fn = mod.attr(name.c_str())(args...);
  66. return fn;
  67. }
  68. return py::none();
  69. }
  70. class COMMON_EXPORT PyAdapterCallback {
  71. HANDLER_DEFINE(ValuePtr, PyDataToValue, py::object);
  72. HANDLER_DEFINE(BaseRef, RunPrimitivePyHookFunction, PrimitivePtr, VectorRef);
  73. HANDLER_DEFINE(py::array, TensorToNumpy, tensor::Tensor);
  74. };
  75. } // namespace python_adapter
  76. } // namespace mindspore
  77. #endif // MINDSPORE_CCSRC_INCLUDE_COMMON_UTILS_PYTHON_ADAPTER_H_