You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

helper.h 10 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320
  1. #pragma once
  2. #include "megbrain/graph.h"
  3. #include <Python.h>
  4. #include <string>
  5. #include <iterator>
  6. #if __cplusplus > 201703L
  7. #include <ranges>
  8. #endif
  9. #include <pybind11/pybind11.h>
  10. #include <pybind11/stl.h>
  11. #include <pybind11/numpy.h>
  12. #include <pybind11/functional.h>
  13. pybind11::module submodule(pybind11::module parent, const char* name, const char* doc = nullptr);
  14. pybind11::module rel_import(pybind11::str name, pybind11::module m, int level);
  15. #if __cplusplus > 201703L
  16. using std::ranges::range_value_t;
  17. #else
  18. template<typename T>
  19. using range_value_t = std::remove_cv_t<std::remove_reference_t<decltype(*std::declval<T>().begin())>>;
  20. #endif
  21. template<typename T>
  22. auto to_list(const T& x) {
  23. using elem_t = range_value_t<T>;
  24. std::vector<elem_t> ret(x.begin(), x.end());
  25. return pybind11::cast(ret);
  26. }
  27. template<typename T>
  28. auto to_tuple(const T& x, pybind11::return_value_policy policy = pybind11::return_value_policy::automatic) {
  29. auto ret = pybind11::tuple(x.size());
  30. for (size_t i = 0; i < x.size(); ++i) {
  31. ret[i] = pybind11::cast(x[i], policy);
  32. }
  33. return ret;
  34. }
  35. template<typename T>
  36. auto to_tuple(T begin, T end, pybind11::return_value_policy policy = pybind11::return_value_policy::automatic) {
  37. auto ret = pybind11::tuple(end - begin);
  38. for (size_t i = 0; begin < end; ++begin, ++i) {
  39. ret[i] = pybind11::cast(*begin, policy);
  40. }
  41. return ret;
  42. }
  43. class PyTaskDipatcher {
  44. struct Queue : mgb::AsyncQueueSC<std::function<void(void)>, Queue> {
  45. using Task = std::function<void(void)>;
  46. void process_one_task(Task& f) {
  47. if (!Py_IsInitialized()) return;
  48. pybind11::gil_scoped_acquire _;
  49. f();
  50. }
  51. };
  52. Queue queue;
  53. bool finalized = false;
  54. public:
  55. template<typename T>
  56. void add_task(T&& task) {
  57. // CPython never dlclose an extension so
  58. // finalized means the interpreter has been shutdown
  59. if (!finalized) {
  60. queue.add_task(std::forward<T>(task));
  61. }
  62. }
  63. void wait_all_task_finish() {
  64. queue.wait_all_task_finish();
  65. }
  66. ~PyTaskDipatcher() {
  67. finalized = true;
  68. queue.wait_all_task_finish();
  69. }
  70. };
  71. extern PyTaskDipatcher py_task_q;
  72. class GILManager {
  73. PyGILState_STATE gstate;
  74. public:
  75. GILManager():
  76. gstate(PyGILState_Ensure())
  77. {
  78. }
  79. ~GILManager() {
  80. PyGILState_Release(gstate);
  81. }
  82. };
  83. #define PYTHON_GIL GILManager __gil_manager
  84. //! wraps a shared_ptr and decr PyObject ref when destructed
  85. class PyObjRefKeeper {
  86. std::shared_ptr<PyObject> m_ptr;
  87. public:
  88. static void deleter(PyObject* p) {
  89. if (p) {
  90. py_task_q.add_task([p](){Py_DECREF(p);});
  91. }
  92. }
  93. PyObjRefKeeper() = default;
  94. PyObjRefKeeper(PyObject* p) : m_ptr{p, deleter} {}
  95. PyObject* get() const { return m_ptr.get(); }
  96. //! create a shared_ptr as an alias of the underlying ptr
  97. template <typename T>
  98. std::shared_ptr<T> make_shared(T* ptr) const {
  99. return {m_ptr, ptr};
  100. }
  101. };
  102. //! exception to be thrown when python callback fails
  103. class PyExceptionForward : public std::exception {
  104. PyObject *m_type, *m_value, *m_traceback;
  105. std::string m_msg;
  106. PyExceptionForward(PyObject* type, PyObject* value, PyObject* traceback,
  107. const std::string& msg)
  108. : m_type{type},
  109. m_value{value},
  110. m_traceback{traceback},
  111. m_msg{msg} {}
  112. public:
  113. PyExceptionForward(const PyExceptionForward&) = delete;
  114. PyExceptionForward& operator=(const PyExceptionForward&) = delete;
  115. ~PyExceptionForward();
  116. PyExceptionForward(PyExceptionForward&& rhs)
  117. : m_type{rhs.m_type},
  118. m_value{rhs.m_value},
  119. m_traceback{rhs.m_traceback},
  120. m_msg{std::move(rhs.m_msg)} {
  121. rhs.m_type = rhs.m_value = rhs.m_traceback = nullptr;
  122. }
  123. //! throw PyExceptionForward from current python error state
  124. static void throw_() __attribute__((noreturn));
  125. //! restore python error
  126. void restore();
  127. const char* what() const noexcept override { return m_msg.c_str(); }
  128. };
  129. //! numpy utils
  130. namespace npy {
  131. //! convert tensor shape to raw vector
  132. static inline std::vector<size_t> shape2vec(const mgb::TensorShape &shape) {
  133. return {shape.shape, shape.shape + shape.ndim};
  134. }
  135. //! change numpy dtype to megbrain supported dtype
  136. PyObject* to_mgb_supported_dtype(PyObject *dtype);
  137. //! convert raw vector to tensor shape
  138. mgb::TensorShape vec2shape(const std::vector<size_t> &vec);
  139. //! convert megbrain dtype to numpy dtype object; return new reference
  140. PyObject* dtype_mgb2np(mgb::DType dtype);
  141. //! convert numpy dtype object or string to megbrain dtype
  142. mgb::DType dtype_np2mgb(PyObject *obj);
  143. //! buffer sharing type
  144. enum class ShareType {
  145. MUST_SHARE, //!< must be shared
  146. MUST_UNSHARE, //!< must not be shared
  147. TRY_SHARE //!< share if possible
  148. };
  149. //! get ndarray from HostTensorND
  150. PyObject* ndarray_from_tensor(const mgb::HostTensorND &val,
  151. ShareType share_type);
  152. //! specify how to convert numpy array to tensor
  153. struct Meth {
  154. bool must_borrow_ = false;
  155. mgb::HostTensorND *dest_tensor_ = nullptr;
  156. mgb::CompNode dest_cn_;
  157. //! make a Meth that allows borrowing numpy array memory
  158. static Meth borrow(
  159. mgb::CompNode dest_cn = mgb::CompNode::default_cpu()) {
  160. return {false, nullptr, dest_cn};
  161. }
  162. //! make a Meth that requires the numpy array to be borrowed
  163. static Meth must_borrow(
  164. mgb::CompNode dest_cn = mgb::CompNode::default_cpu()) {
  165. return {true, nullptr, dest_cn};
  166. }
  167. //! make a Meth that requires copying the value into another
  168. //! tensor
  169. static Meth copy_into(mgb::HostTensorND *tensor) {
  170. return {false, tensor, tensor->comp_node()};
  171. }
  172. };
  173. /*!
  174. * \brief convert an object to megbrain tensor
  175. * \param meth specifies how the conversion should take place
  176. * \param dtype desired dtype; it can be set as invalid to allow arbitrary
  177. * dtype
  178. */
  179. mgb::HostTensorND np2tensor(PyObject *obj, const Meth &meth,
  180. mgb::DType dtype);
  181. }
  182. // Note: following macro was copied from pybind11/detail/common.h
  183. // Robust support for some features and loading modules compiled against different pybind versions
  184. // requires forcing hidden visibility on pybind code, so we enforce this by setting the attribute on
  185. // the main `pybind11` namespace.
  186. #if !defined(PYBIND11_NAMESPACE)
  187. # ifdef __GNUG__
  188. # define PYBIND11_NAMESPACE pybind11 __attribute__((visibility("hidden")))
  189. # else
  190. # define PYBIND11_NAMESPACE pybind11
  191. # endif
  192. #endif
  193. namespace PYBIND11_NAMESPACE {
  194. namespace detail {
  195. template<typename T, unsigned N> struct type_caster<megdnn::SmallVector<T, N>>
  196. : list_caster<megdnn::SmallVector<T, N>, T> {};
  197. template <> struct type_caster<mgb::DType> {
  198. PYBIND11_TYPE_CASTER(mgb::DType, _("DType"));
  199. public:
  200. bool load(handle src, bool convert) {
  201. auto obj = reinterpret_borrow<object>(src);
  202. if (!convert && !isinstance<dtype>(obj)) {
  203. return false;
  204. }
  205. if (obj.is_none()) {
  206. return true;
  207. }
  208. try {
  209. obj = pybind11::dtype::from_args(obj);
  210. } catch (pybind11::error_already_set&) {
  211. return false;
  212. }
  213. try {
  214. value = npy::dtype_np2mgb(obj.ptr());
  215. } catch (...) {
  216. return false;
  217. }
  218. return true;
  219. }
  220. static handle cast(mgb::DType dt, return_value_policy /* policy */, handle /* parent */) {
  221. // ignore policy and parent because we always return a pure python object
  222. return npy::dtype_mgb2np(std::move(dt));
  223. }
  224. };
  225. template <> struct type_caster<mgb::TensorShape> {
  226. PYBIND11_TYPE_CASTER(mgb::TensorShape, _("TensorShape"));
  227. public:
  228. bool load(handle src, bool convert) {
  229. auto obj = reinterpret_steal<object>(src);
  230. if (!isinstance<tuple>(obj)) {
  231. return false;
  232. }
  233. value.ndim = len(obj);
  234. mgb_assert(value.ndim <= mgb::TensorShape::MAX_NDIM);
  235. size_t i = 0;
  236. for (auto v : obj) {
  237. mgb_assert(i < value.ndim);
  238. value.shape[i] = reinterpret_borrow<object>(v).cast<size_t>();
  239. ++i;
  240. }
  241. return true;
  242. }
  243. static handle cast(mgb::TensorShape shape, return_value_policy /* policy */, handle /* parent */) {
  244. // ignore policy and parent because we always return a pure python object
  245. return to_tuple(shape.shape, shape.shape + shape.ndim).release();
  246. }
  247. };
  248. // hack to make custom object implicitly convertible from None
  249. template <typename T> struct from_none_caster : public type_caster_base<T> {
  250. using base = type_caster_base<T>;
  251. bool load(handle src, bool convert) {
  252. if (!convert || !src.is_none()) {
  253. return base::load(src, convert);
  254. }
  255. // adapted from pybind11::implicitly_convertible
  256. auto temp = reinterpret_steal<object>(PyObject_Call(
  257. (PyObject*) this->typeinfo->type, tuple().ptr(), nullptr));
  258. if (!temp) {
  259. PyErr_Clear();
  260. return false;
  261. }
  262. // adapted from pybind11::detail::type_caster_generic
  263. if (base::load(temp, false)) {
  264. loader_life_support::add_patient(temp);
  265. return true;
  266. }
  267. return false;
  268. }
  269. };
  270. template<> struct type_caster<mgb::CompNode> : public from_none_caster<mgb::CompNode> {};
  271. } // detail
  272. } // PYBIND11_NAMESPACE
  273. // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台