You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

pyext17.h 11 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330
  1. /**
  2. * \file imperative/python/src/pyext17.h
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #pragma once
  12. #include <stdexcept>
  13. #include <vector>
  14. #include <utility>
  15. #include <Python.h>
  16. namespace pyext17 {
  17. #ifdef METH_FASTCALL
  18. constexpr bool has_fastcall = true;
  19. #else
  20. constexpr bool has_fastcall = false;
  21. #endif
  22. #ifdef _Py_TPFLAGS_HAVE_VECTORCALL
  23. constexpr bool has_vectorcall = true;
  24. #else
  25. constexpr bool has_vectorcall = false;
  26. #endif
  27. template<typename... Args>
  28. struct invocable_with {
  29. template<typename T>
  30. constexpr bool operator()(T&& lmb) {
  31. return std::is_invocable_v<T, Args...>;
  32. }
  33. };
  34. #define HAS_MEMBER_TYPE(T, U) invocable_with<T>{}([](auto&& x) -> typename std::decay_t<decltype(x)>::U {})
  35. #define HAS_MEMBER(T, m) invocable_with<T>{}([](auto&& x) -> decltype(&std::decay_t<decltype(x)>::m) {})
  36. inline PyObject* cvt_retval(PyObject* rv) {
  37. return rv;
  38. }
  39. #define CVT_RET_PYOBJ(...) \
  40. if constexpr (std::is_same_v<decltype(__VA_ARGS__), void>) { \
  41. __VA_ARGS__; \
  42. Py_RETURN_NONE; \
  43. } else { \
  44. return cvt_retval(__VA_ARGS__); \
  45. }
  46. template <typename T>
  47. struct wrap {
  48. private:
  49. typedef wrap<T> wrap_t;
  50. public:
  51. PyObject_HEAD
  52. std::aligned_storage_t<sizeof(T), alignof(T)> storage;
  53. #ifdef _Py_TPFLAGS_HAVE_VECTORCALL
  54. PyObject* (*vectorcall_slot)(PyObject*, PyObject*const*, size_t, PyObject*);
  55. #endif
  56. inline T* inst() {
  57. return reinterpret_cast<T*>(&storage);
  58. }
  59. inline static PyObject* pycast(T* ptr) {
  60. return (PyObject*)((char*)ptr - offsetof(wrap_t, storage));
  61. }
  62. private:
  63. // method wrapper
  64. enum struct meth_type {
  65. noarg,
  66. varkw,
  67. fastcall,
  68. singarg
  69. };
  70. template<auto f>
  71. struct detect_meth_type {
  72. static constexpr meth_type value = []() {
  73. using F = decltype(f);
  74. static_assert(std::is_member_function_pointer_v<F>);
  75. if constexpr (std::is_invocable_v<F, T>) {
  76. return meth_type::noarg;
  77. } else if constexpr (std::is_invocable_v<F, T, PyObject*, PyObject*>) {
  78. return meth_type::varkw;
  79. } else if constexpr (std::is_invocable_v<F, T, PyObject*const*, Py_ssize_t>) {
  80. return meth_type::fastcall;
  81. } else if constexpr (std::is_invocable_v<F, T, PyObject*>) {
  82. return meth_type::singarg;
  83. } else {
  84. static_assert(!std::is_same_v<F, F>);
  85. }
  86. }();
  87. };
  88. template<meth_type, auto f>
  89. struct meth {};
  90. template<auto f>
  91. struct meth<meth_type::noarg, f> {
  92. static constexpr int flags = METH_NOARGS;
  93. static PyObject* impl(PyObject* self, PyObject*) {
  94. auto* inst = reinterpret_cast<wrap_t*>(self)->inst();
  95. CVT_RET_PYOBJ((inst->*f)());
  96. }
  97. };
  98. template<auto f>
  99. struct meth<meth_type::varkw, f> {
  100. static constexpr int flags = METH_VARARGS | METH_KEYWORDS;
  101. static PyObject* impl(PyObject* self, PyObject* args, PyObject* kwargs) {
  102. auto* inst = reinterpret_cast<wrap_t*>(self)->inst();
  103. CVT_RET_PYOBJ((inst->*f)(args, kwargs));
  104. }
  105. };
  106. template<auto f>
  107. struct meth<meth_type::fastcall, f> {
  108. #ifdef METH_FASTCALL
  109. static constexpr int flags = METH_FASTCALL;
  110. static PyObject* impl(PyObject* self, PyObject*const* args, Py_ssize_t nargs) {
  111. auto* inst = reinterpret_cast<wrap_t*>(self)->inst();
  112. CVT_RET_PYOBJ((inst->*f)(args, nargs));
  113. }
  114. #else
  115. static constexpr int flags = METH_VARARGS;
  116. static PyObject* impl(PyObject* self, PyObject* args) {
  117. auto* inst = reinterpret_cast<wrap_t*>(self)->inst();
  118. auto* arr = &PyTuple_GET_ITEM(args, 0);
  119. auto size = PyTuple_GET_SIZE(args);
  120. CVT_RET_PYOBJ((inst->*f)(arr, size));
  121. }
  122. #endif
  123. };
  124. template<auto f>
  125. struct meth<meth_type::singarg, f> {
  126. static constexpr int flags = METH_O;
  127. static PyObject* impl(PyObject* self, PyObject* obj) {
  128. auto* inst = reinterpret_cast<wrap_t*>(self)->inst();
  129. CVT_RET_PYOBJ((inst->*f)(obj));
  130. }
  131. };
  132. template<auto f>
  133. static constexpr PyMethodDef make_meth_def(const char* name, const char* doc = nullptr) {
  134. using M = meth<detect_meth_type<f>::value, f>;
  135. return {name, (PyCFunction)M::impl, M::flags, doc};
  136. }
  137. // polyfills
  138. struct tp_vectorcall {
  139. static constexpr bool valid = HAS_MEMBER(T, tp_vectorcall);
  140. static constexpr bool haskw = [](){if constexpr (valid)
  141. if constexpr (std::is_invocable_v<decltype(&T::tp_vectorcall), T, PyObject*const*, size_t, PyObject*>)
  142. return true;
  143. return false;}();
  144. template<typename = void>
  145. static PyObject* impl(PyObject* self, PyObject*const* args, size_t nargsf, PyObject *kwnames) {
  146. auto* inst = reinterpret_cast<wrap_t*>(self)->inst();
  147. if constexpr (haskw) {
  148. CVT_RET_PYOBJ(inst->tp_vectorcall(args, nargsf, kwnames));
  149. } else {
  150. if (kwnames && PyTuple_GET_SIZE(kwnames)) {
  151. PyErr_SetString(PyExc_TypeError, "expect no keyword argument");
  152. return nullptr;
  153. }
  154. CVT_RET_PYOBJ(inst->tp_vectorcall(args, nargsf));
  155. }
  156. }
  157. static constexpr Py_ssize_t offset = []() {if constexpr (valid) return offsetof(wrap_t, vectorcall_slot);
  158. else return 0;}();
  159. };
  160. struct tp_call {
  161. static constexpr bool provided = HAS_MEMBER(T, tp_call);
  162. static constexpr bool static_form = invocable_with<T, PyObject*, PyObject*, PyObject*>{}(
  163. [](auto&& t, auto... args) -> decltype(std::decay_t<decltype(t)>::tp_call(args...)) {});
  164. static constexpr bool valid = provided || tp_vectorcall::valid;
  165. template<typename = void>
  166. static PyObject* impl(PyObject* self, PyObject* args, PyObject* kwargs) {
  167. auto* inst = reinterpret_cast<wrap_t*>(self)->inst();
  168. CVT_RET_PYOBJ(inst->tp_call(args, kwargs));
  169. }
  170. static constexpr ternaryfunc value = []() {if constexpr (static_form) return T::tp_call;
  171. else if constexpr (provided) return impl<>;
  172. #ifdef _Py_TPFLAGS_HAVE_VECTORCALL
  173. else if constexpr (valid) return PyVectorcall_Call;
  174. #endif
  175. else return nullptr;}();
  176. };
  177. struct tp_new {
  178. static constexpr bool provided = HAS_MEMBER(T, tp_new);
  179. static constexpr bool varkw = std::is_constructible_v<T, PyObject*, PyObject*>;
  180. static constexpr bool noarg = std::is_default_constructible_v<T>;
  181. template<typename = void>
  182. static PyObject* impl(PyTypeObject* type, PyObject* args, PyObject* kwargs) {
  183. auto* self = type->tp_alloc(type, 0);
  184. auto* inst = reinterpret_cast<wrap_t*>(self)->inst();
  185. if constexpr (has_vectorcall && tp_vectorcall::valid) {
  186. reinterpret_cast<wrap_t*>(self)->vectorcall_slot = &tp_vectorcall::template impl<>;
  187. }
  188. if constexpr (varkw) {
  189. new(inst) T(args, kwargs);
  190. } else {
  191. new(inst) T();
  192. }
  193. return self;
  194. }
  195. static constexpr newfunc value = []() {if constexpr (provided) return T::tp_new;
  196. else if constexpr (varkw || noarg) return impl<>;
  197. else return nullptr;}();
  198. };
  199. struct tp_dealloc {
  200. static constexpr bool provided = HAS_MEMBER(T, tp_dealloc);
  201. template<typename = void>
  202. static void impl(PyObject* self) {
  203. reinterpret_cast<wrap_t*>(self)->inst()->~T();
  204. Py_TYPE(self)->tp_free(self);
  205. }
  206. static constexpr destructor value = []() {if constexpr (provided) return T::tp_dealloc;
  207. else return impl<>;}();
  208. };
  209. public:
  210. class TypeBuilder {
  211. std::vector<PyMethodDef> m_methods;
  212. PyTypeObject m_type;
  213. bool m_finalized = false;
  214. bool m_ready = false;
  215. void check_finalized() {
  216. if (m_finalized) {
  217. throw std::runtime_error("type is already finalized");
  218. }
  219. }
  220. public:
  221. TypeBuilder(const TypeBuilder&) = delete;
  222. TypeBuilder& operator=(const TypeBuilder&) = delete;
  223. TypeBuilder() : m_type{PyVarObject_HEAD_INIT(nullptr, 0)} {
  224. constexpr auto has_tp_name = HAS_MEMBER(T, tp_name);
  225. if constexpr (has_tp_name) {
  226. m_type.tp_name = T::tp_name;
  227. }
  228. m_type.tp_dealloc = tp_dealloc::value;
  229. #ifdef _Py_TPFLAGS_HAVE_VECTORCALL
  230. m_type.tp_vectorcall_offset = tp_vectorcall::offset;
  231. #endif
  232. m_type.tp_call = tp_call::value;
  233. m_type.tp_basicsize = sizeof(wrap_t);
  234. m_type.tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE;
  235. #ifdef _Py_TPFLAGS_HAVE_VECTORCALL
  236. if constexpr (tp_vectorcall::valid) {
  237. m_type.tp_flags |= _Py_TPFLAGS_HAVE_VECTORCALL;
  238. }
  239. #endif
  240. m_type.tp_new = tp_new::value;
  241. }
  242. PyTypeObject* operator->() {
  243. return &m_type;
  244. }
  245. bool ready() const {
  246. return m_ready;
  247. }
  248. PyObject* finalize() {
  249. if (!m_finalized) {
  250. if (m_methods.size()) {
  251. m_methods.push_back({0});
  252. if (m_type.tp_methods) {
  253. PyErr_SetString(PyExc_SystemError, "tp_method is already set");
  254. return nullptr;
  255. }
  256. m_type.tp_methods = &m_methods[0];
  257. }
  258. if (PyType_Ready(&m_type)) {
  259. return nullptr;
  260. }
  261. m_ready = true;
  262. }
  263. return (PyObject*)&m_type;
  264. }
  265. template<auto f>
  266. TypeBuilder& def(const char* name, const char* doc = nullptr) {
  267. check_finalized();
  268. m_methods.push_back(make_meth_def<f>(name, doc));
  269. return *this;
  270. }
  271. };
  272. static TypeBuilder& type() {
  273. static TypeBuilder type_helper;
  274. return type_helper;
  275. }
  276. };
  277. } // namespace pyext17
  278. #undef HAS_MEMBER_TYPE
  279. #undef HAS_MEMBER
  280. #undef CVT_RET_PYOBJ

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台