You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

python_helper.h 6.6 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229
  1. /**
  2. * \file python_module/src/cpp/python_helper.h
  3. *
  4. * This file is part of MegBrain, a deep learning framework developed by Megvii.
  5. *
  6. * \brief helper utilities for python integration
  7. *
  8. * \copyright Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  9. *
  10. */
  11. #pragma once
  12. #include "megbrain/graph.h"
  13. #include <Python.h>
  14. #include <string>
  15. class GILManager {
  16. PyGILState_STATE gstate;
  17. public:
  18. GILManager():
  19. gstate(PyGILState_Ensure())
  20. {
  21. }
  22. ~GILManager() {
  23. PyGILState_Release(gstate);
  24. }
  25. };
  26. #define PYTHON_GIL GILManager __gil_manager
  27. //! wraps a shared_ptr and decr PyObject ref when destructed
  28. class PyObjRefKeeper {
  29. std::shared_ptr<PyObject> m_ptr;
  30. public:
  31. static void deleter(PyObject* p) {
  32. if (p) {
  33. PYTHON_GIL;
  34. Py_DECREF(p);
  35. }
  36. }
  37. PyObjRefKeeper() = default;
  38. PyObjRefKeeper(PyObject* p) : m_ptr{p, deleter} {}
  39. PyObject* get() const { return m_ptr.get(); }
  40. //! create a shared_ptr as an alias of the underlying ptr
  41. template <typename T>
  42. std::shared_ptr<T> make_shared(T* ptr) const {
  43. return {m_ptr, ptr};
  44. }
  45. };
  46. class PyStackExtracter {
  47. static PyStackExtracter *ins;
  48. public:
  49. virtual ~PyStackExtracter() = default;
  50. virtual std::string extract() = 0;
  51. static void reg(PyStackExtracter *p) {
  52. ins = p;
  53. }
  54. static std::string run() {
  55. return ins->extract();
  56. }
  57. };
  58. //! exception to be thrown when python callback fails
  59. class PyExceptionForward : public std::exception {
  60. PyObject *m_type, *m_value, *m_traceback;
  61. std::string m_msg;
  62. PyExceptionForward(PyObject* type, PyObject* value, PyObject* traceback,
  63. const std::string& msg)
  64. : m_type{type},
  65. m_value{value},
  66. m_traceback{traceback},
  67. m_msg{msg} {}
  68. public:
  69. PyExceptionForward(const PyExceptionForward&) = delete;
  70. PyExceptionForward& operator=(const PyExceptionForward&) = delete;
  71. ~PyExceptionForward();
  72. PyExceptionForward(PyExceptionForward&& rhs)
  73. : m_type{rhs.m_type},
  74. m_value{rhs.m_value},
  75. m_traceback{rhs.m_traceback},
  76. m_msg{std::move(rhs.m_msg)} {
  77. rhs.m_type = rhs.m_value = rhs.m_traceback = nullptr;
  78. }
  79. //! throw PyExceptionForward from current python error state
  80. static void throw_() __attribute__((noreturn));
  81. //! restore python error
  82. void restore();
  83. const char* what() const noexcept override { return m_msg.c_str(); }
  84. };
  85. /*!
  86. * \brief make python exception
  87. */
  88. class PyMGBExceptionMaker {
  89. static PyObject *py_exc_class;
  90. friend std::string blame(mgb::cg::OperatorNodeBase* opr);
  91. public:
  92. static void setup_py_exception(std::exception &exc);
  93. static void _reg_exception_class(PyObject *cls) {
  94. py_exc_class = cls;
  95. }
  96. };
  97. //! associate a python object with an operator
  98. class OprPyTracker final : public mgb::NonCopyableObj {
  99. class TrackerStorage;
  100. OprPyTracker() = delete;
  101. public:
  102. /*!
  103. * \brief set current tracker; all operators created later would share
  104. * this tracker
  105. *
  106. * Note that a py reference would be kept
  107. */
  108. static void begin_set_tracker(mgb::cg::ComputingGraph& graph,
  109. PyObject* obj);
  110. static void end_set_tracker(mgb::cg::ComputingGraph& graph);
  111. struct TrackerResult {
  112. mgb::cg::OperatorNodeBase
  113. //! operator that directly causes the exception
  114. *exc_opr = nullptr,
  115. //! operator constructed by user (non-optimized exc_opr)
  116. *unopt_opr = nullptr,
  117. //! the grad source if opr is constructed by taking grad
  118. *opr_grad_src = nullptr;
  119. PyObject *tracker = nullptr, *tracker_grad_src = nullptr;
  120. //! format as python tuple
  121. PyObject* as_tuple(const char* leading_msg = nullptr) const;
  122. };
  123. //! get tracker from exception
  124. static TrackerResult get_tracker(mgb::MegBrainError& exc);
  125. //! get tracker from operator
  126. static TrackerResult get_tracker(mgb::cg::OperatorNodeBase* opr);
  127. };
  128. std::string blame(mgb::cg::OperatorNodeBase* opr);
  129. //! numpy utils
  130. namespace npy {
  131. //! convert tensor shape to raw vector
  132. static inline std::vector<size_t> shape2vec(const mgb::TensorShape &shape) {
  133. return {shape.shape, shape.shape + shape.ndim};
  134. }
  135. //! change numpy dtype to megbrain supported dtype
  136. PyObject* to_mgb_supported_dtype(PyObject *dtype);
  137. //! convert raw vector to tensor shape
  138. mgb::TensorShape vec2shape(const std::vector<size_t> &vec);
  139. //! convert megbrain dtype to numpy dtype object; return new reference
  140. PyObject* dtype_mgb2np(mgb::DType dtype);
  141. //! convert numpy dtype object or string to megbrain dtype
  142. mgb::DType dtype_np2mgb(PyObject *obj);
  143. //! buffer sharing type
  144. enum class ShareType {
  145. MUST_SHARE, //!< must be shared
  146. MUST_UNSHARE, //!< must not be shared
  147. TRY_SHARE //!< share if possible
  148. };
  149. //! get ndarray from HostTensorND
  150. PyObject* ndarray_from_tensor(const mgb::HostTensorND &val,
  151. ShareType share_type);
  152. //! specify how to convert numpy array to tensor
  153. struct Meth {
  154. bool must_borrow_ = false;
  155. mgb::HostTensorND *dest_tensor_ = nullptr;
  156. mgb::CompNode dest_cn_;
  157. //! make a Meth that allows borrowing numpy array memory
  158. static Meth borrow(
  159. mgb::CompNode dest_cn = mgb::CompNode::default_cpu()) {
  160. return {false, nullptr, dest_cn};
  161. }
  162. //! make a Meth that requires the numpy array to be borrowed
  163. static Meth must_borrow(
  164. mgb::CompNode dest_cn = mgb::CompNode::default_cpu()) {
  165. return {true, nullptr, dest_cn};
  166. }
  167. //! make a Meth that requires copying the value into another
  168. //! tensor
  169. static Meth copy_into(mgb::HostTensorND *tensor) {
  170. return {false, tensor, tensor->comp_node()};
  171. }
  172. };
  173. /*!
  174. * \brief convert an object to megbrain tensor
  175. * \param meth specifies how the conversion should take place
  176. * \param dtype desired dtype; it can be set as invalid to allow arbitrary
  177. * dtype
  178. */
  179. mgb::HostTensorND np2tensor(PyObject *obj, const Meth &meth,
  180. mgb::DType dtype);
  181. }
  182. // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台