You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

pyext17.h 18 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529
  1. /**
  2. * \file imperative/python/src/pyext17.h
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #pragma once
  12. #include <exception>
  13. #include <stdexcept>
  14. #include <vector>
  15. #include <utility>
  16. #include <Python.h>
  17. #include <pybind11/pybind11.h>
  18. namespace pyext17 {
  19. #ifdef METH_FASTCALL
  20. constexpr bool has_fastcall = true;
  21. #else
  22. constexpr bool has_fastcall = false;
  23. #endif
  24. #ifdef _Py_TPFLAGS_HAVE_VECTORCALL
  25. constexpr bool has_vectorcall = true;
  26. #else
  27. constexpr bool has_vectorcall = false;
  28. #endif
  29. template<typename... Args>
  30. struct invocable_with {
  31. template<typename T>
  32. constexpr bool operator()(T&& lmb) {
  33. return std::is_invocable_v<T, Args...>;
  34. }
  35. };
  36. #define HAS_MEMBER_TYPE(T, U) invocable_with<T>{}([](auto&& x) -> typename std::decay_t<decltype(x)>::U {})
  37. #define HAS_MEMBER(T, m) invocable_with<T>{}([](auto&& x) -> decltype(&std::decay_t<decltype(x)>::m) {})
  38. inline PyObject* cvt_retval(PyObject* rv) {
  39. return rv;
  40. }
  41. #define CVT_RET_PYOBJ(...) \
  42. if constexpr (std::is_same_v<decltype(__VA_ARGS__), void>) { \
  43. __VA_ARGS__; \
  44. Py_RETURN_NONE; \
  45. } else { \
  46. return cvt_retval(__VA_ARGS__); \
  47. }
  48. inline int cvt_retint(int ret) {
  49. return ret;
  50. }
  51. #define CVT_RET_INT(...) \
  52. if constexpr (std::is_same_v<decltype(__VA_ARGS__), void>) { \
  53. __VA_ARGS__; \
  54. return 0; \
  55. } else { \
  56. return cvt_retint(__VA_ARGS__); \
  57. }
  58. struct py_err_set : std::exception {};
  59. // refer to pybind11 for the following exception handling helper
  60. inline void pybind11_translate_exception(std::exception_ptr last_exception) {
  61. auto &registered_exception_translators = pybind11::detail::get_internals().registered_exception_translators;
  62. for (auto& translator : registered_exception_translators) {
  63. try {
  64. translator(last_exception);
  65. } catch (...) {
  66. last_exception = std::current_exception();
  67. continue;
  68. }
  69. return;
  70. }
  71. PyErr_SetString(PyExc_SystemError, "Exception escaped from default exception translator!");
  72. }
  73. inline void pybind11_translate_exception() {
  74. pybind11_translate_exception(std::current_exception());
  75. }
  76. #if defined(__GNUG__) && !defined(__clang__)
  77. #define PYEXT17_TRANSLATE_EXC_CATCH_FORCED_UNWIND catch (::abi::__forced_unwind&) {throw;}
  78. #else
  79. #define PYEXT17_TRANSLATE_EXC_CATCH_FORCED_UNWIND
  80. #endif
  81. #define PYEXT17_TRANSLATE_EXC \
  82. catch(::pyext17::py_err_set&) {} \
  83. catch(::pybind11::error_already_set& e) {e.restore();} \
  84. PYEXT17_TRANSLATE_EXC_CATCH_FORCED_UNWIND \
  85. catch(...) {::pyext17::pybind11_translate_exception();}
  86. #define PYEXT17_TRANSLATE_EXC_RET(RET) \
  87. catch(::pyext17::py_err_set&) {return RET;} \
  88. catch(::pybind11::error_already_set& e) {e.restore(); return RET;} \
  89. PYEXT17_TRANSLATE_EXC_CATCH_FORCED_UNWIND \
  90. catch(...) {::pyext17::pybind11_translate_exception(); return RET;};
  91. template <typename T>
  92. struct wrap {
  93. private:
  94. typedef wrap<T> wrap_t;
  95. public:
  96. PyObject_HEAD
  97. std::aligned_storage_t<sizeof(T), alignof(T)> storage;
  98. #ifdef _Py_TPFLAGS_HAVE_VECTORCALL
  99. PyObject* (*vectorcall_slot)(PyObject*, PyObject*const*, size_t, PyObject*);
  100. #endif
  101. inline T* inst() {
  102. return reinterpret_cast<T*>(&storage);
  103. }
  104. inline static PyObject* pycast(T* ptr) {
  105. return (PyObject*)((char*)ptr - offsetof(wrap_t, storage));
  106. }
  107. private:
  108. // method wrapper
  109. enum struct meth_type {
  110. noarg,
  111. varkw,
  112. fastcall,
  113. singarg
  114. };
  115. template<auto f>
  116. struct detect_meth_type {
  117. static constexpr meth_type value = []() {
  118. using F = decltype(f);
  119. static_assert(std::is_member_function_pointer_v<F>);
  120. if constexpr (std::is_invocable_v<F, T>) {
  121. return meth_type::noarg;
  122. } else if constexpr (std::is_invocable_v<F, T, PyObject*, PyObject*>) {
  123. return meth_type::varkw;
  124. } else if constexpr (std::is_invocable_v<F, T, PyObject*const*, Py_ssize_t>) {
  125. return meth_type::fastcall;
  126. } else if constexpr (std::is_invocable_v<F, T, PyObject*>) {
  127. return meth_type::singarg;
  128. } else {
  129. static_assert(!std::is_same_v<F, F>);
  130. }
  131. }();
  132. };
  133. template<meth_type, auto f>
  134. struct meth {};
  135. template<auto f>
  136. struct meth<meth_type::noarg, f> {
  137. static constexpr int flags = METH_NOARGS;
  138. static PyObject* impl(PyObject* self, PyObject*) {
  139. auto* inst = reinterpret_cast<wrap_t*>(self)->inst();
  140. try {
  141. CVT_RET_PYOBJ((inst->*f)());
  142. } PYEXT17_TRANSLATE_EXC_RET(nullptr)
  143. }
  144. };
  145. template<auto f>
  146. struct meth<meth_type::varkw, f> {
  147. static constexpr int flags = METH_VARARGS | METH_KEYWORDS;
  148. static PyObject* impl(PyObject* self, PyObject* args, PyObject* kwargs) {
  149. auto* inst = reinterpret_cast<wrap_t*>(self)->inst();
  150. try {
  151. CVT_RET_PYOBJ((inst->*f)(args, kwargs));
  152. } PYEXT17_TRANSLATE_EXC_RET(nullptr)
  153. }
  154. };
  155. template<auto f>
  156. struct meth<meth_type::fastcall, f> {
  157. #ifdef METH_FASTCALL
  158. static constexpr int flags = METH_FASTCALL;
  159. static PyObject* impl(PyObject* self, PyObject*const* args, Py_ssize_t nargs) {
  160. auto* inst = reinterpret_cast<wrap_t*>(self)->inst();
  161. try {
  162. CVT_RET_PYOBJ((inst->*f)(args, nargs));
  163. } PYEXT17_TRANSLATE_EXC_RET(nullptr)
  164. }
  165. #else
  166. static constexpr int flags = METH_VARARGS;
  167. static PyObject* impl(PyObject* self, PyObject* args) {
  168. auto* inst = reinterpret_cast<wrap_t*>(self)->inst();
  169. auto* arr = &PyTuple_GET_ITEM(args, 0);
  170. auto size = PyTuple_GET_SIZE(args);
  171. try {
  172. CVT_RET_PYOBJ((inst->*f)(arr, size));
  173. } PYEXT17_TRANSLATE_EXC_RET(nullptr)
  174. }
  175. #endif
  176. };
  177. template<auto f>
  178. struct meth<meth_type::singarg, f> {
  179. static constexpr int flags = METH_O;
  180. static PyObject* impl(PyObject* self, PyObject* obj) {
  181. auto* inst = reinterpret_cast<wrap_t*>(self)->inst();
  182. try {
  183. CVT_RET_PYOBJ((inst->*f)(obj));
  184. } PYEXT17_TRANSLATE_EXC_RET(nullptr)
  185. }
  186. };
  187. template<auto f>
  188. static constexpr PyMethodDef make_meth_def(const char* name, const char* doc = nullptr) {
  189. using M = meth<detect_meth_type<f>::value, f>;
  190. return {name, (PyCFunction)M::impl, M::flags, doc};
  191. }
  192. template<auto f>
  193. struct getter {
  194. using F = decltype(f);
  195. static PyObject* impl(PyObject* self, void* closure) {
  196. auto* inst = reinterpret_cast<wrap_t*>(self)->inst();
  197. try {
  198. if constexpr (std::is_invocable_v<F, PyObject*, void*>) {
  199. CVT_RET_PYOBJ(f(self, closure));
  200. } else if constexpr (std::is_invocable_v<F, T, void*>) {
  201. CVT_RET_PYOBJ((inst->*f)(closure));
  202. } else if constexpr (std::is_invocable_v<F, T>) {
  203. CVT_RET_PYOBJ((inst->*f)());
  204. } else {
  205. static_assert(!std::is_same_v<F, F>);
  206. }
  207. } PYEXT17_TRANSLATE_EXC_RET(nullptr)
  208. }
  209. };
  210. template<auto f>
  211. struct setter {
  212. using F = decltype(f);
  213. template<typename = void>
  214. static int impl_(PyObject* self, PyObject* val, void* closure) {
  215. auto* inst = reinterpret_cast<wrap_t*>(self)->inst();
  216. try {
  217. if constexpr (std::is_invocable_v<F, PyObject*, PyObject*, void*>) {
  218. CVT_RET_INT(f(self, val, closure));
  219. } else if constexpr (std::is_invocable_v<F, T, PyObject*, void*>) {
  220. CVT_RET_INT((inst->*f)(val, closure));
  221. } else if constexpr (std::is_invocable_v<F, T, PyObject*>) {
  222. CVT_RET_INT((inst->*f)(val));
  223. } else {
  224. static_assert(!std::is_same_v<F, F>);
  225. }
  226. } PYEXT17_TRANSLATE_EXC_RET(-1)
  227. }
  228. static constexpr auto impl = []() {if constexpr (std::is_same_v<F, std::nullptr_t>) return nullptr;
  229. else return impl_<>;}();
  230. };
  231. template<auto get, auto set = nullptr>
  232. static constexpr PyGetSetDef make_getset_def(const char* name, const char* doc = nullptr, void* closure = nullptr) {
  233. return {const_cast<char *>(name), getter<get>::impl, setter<set>::impl, const_cast<char *>(doc), closure};
  234. }
  235. // polyfills
  236. struct tp_vectorcall {
  237. static constexpr bool valid = HAS_MEMBER(T, tp_vectorcall);
  238. static constexpr bool haskw = [](){if constexpr (valid)
  239. if constexpr (std::is_invocable_v<decltype(&T::tp_vectorcall), T, PyObject*const*, size_t, PyObject*>)
  240. return true;
  241. return false;}();
  242. template<typename = void>
  243. static PyObject* impl(PyObject* self, PyObject*const* args, size_t nargsf, PyObject *kwnames) {
  244. auto* inst = reinterpret_cast<wrap_t*>(self)->inst();
  245. if constexpr (haskw) {
  246. CVT_RET_PYOBJ(inst->tp_vectorcall(args, nargsf, kwnames));
  247. } else {
  248. if (kwnames && PyTuple_GET_SIZE(kwnames)) {
  249. PyErr_SetString(PyExc_TypeError, "expect no keyword argument");
  250. return nullptr;
  251. }
  252. CVT_RET_PYOBJ(inst->tp_vectorcall(args, nargsf));
  253. }
  254. }
  255. static constexpr Py_ssize_t offset = []() {if constexpr (valid) return offsetof(wrap_t, vectorcall_slot);
  256. else return 0;}();
  257. };
  258. struct tp_call {
  259. static constexpr bool provided = HAS_MEMBER(T, tp_call);
  260. static constexpr bool static_form = invocable_with<T, PyObject*, PyObject*, PyObject*>{}(
  261. [](auto&& t, auto... args) -> decltype(std::decay_t<decltype(t)>::tp_call(args...)) {});
  262. static constexpr bool valid = provided || tp_vectorcall::valid;
  263. template<typename = void>
  264. static PyObject* impl(PyObject* self, PyObject* args, PyObject* kwargs) {
  265. auto* inst = reinterpret_cast<wrap_t*>(self)->inst();
  266. CVT_RET_PYOBJ(inst->tp_call(args, kwargs));
  267. }
  268. static constexpr ternaryfunc value = []() {if constexpr (static_form) return T::tp_call;
  269. else if constexpr (provided) return impl<>;
  270. #ifdef _Py_TPFLAGS_HAVE_VECTORCALL
  271. else if constexpr (valid) return PyVectorcall_Call;
  272. #endif
  273. else return nullptr;}();
  274. };
  275. struct tp_new {
  276. static constexpr bool provided = HAS_MEMBER(T, tp_new);
  277. static constexpr bool varkw = std::is_constructible_v<T, PyObject*, PyObject*>;
  278. static constexpr bool noarg = std::is_default_constructible_v<T>;
  279. template<typename = void>
  280. static PyObject* impl(PyTypeObject* type, PyObject* args, PyObject* kwargs) {
  281. struct FreeGuard {
  282. PyObject* self;
  283. PyTypeObject* type;
  284. ~FreeGuard() {if (self) type->tp_free(self);}
  285. };
  286. auto* self = type->tp_alloc(type, 0);
  287. FreeGuard free_guard{self, type};
  288. auto* inst = reinterpret_cast<wrap_t*>(self)->inst();
  289. if constexpr (has_vectorcall && tp_vectorcall::valid) {
  290. reinterpret_cast<wrap_t*>(self)->vectorcall_slot = &tp_vectorcall::template impl<>;
  291. }
  292. try {
  293. if constexpr (varkw) {
  294. new(inst) T(args, kwargs);
  295. } else {
  296. new(inst) T();
  297. }
  298. } PYEXT17_TRANSLATE_EXC_RET(nullptr)
  299. free_guard.self = nullptr;
  300. return self;
  301. }
  302. static constexpr newfunc value = []() {if constexpr (provided) return T::tp_new;
  303. else if constexpr (varkw || noarg) return impl<>;
  304. else return nullptr;}();
  305. };
  306. struct tp_dealloc {
  307. static constexpr bool provided = HAS_MEMBER(T, tp_dealloc);
  308. template<typename = void>
  309. static void impl(PyObject* self) {
  310. reinterpret_cast<wrap_t*>(self)->inst()->~T();
  311. Py_TYPE(self)->tp_free(self);
  312. }
  313. static constexpr destructor value = []() {if constexpr (provided) return T::tp_dealloc;
  314. else return impl<>;}();
  315. };
  316. public:
  317. class TypeBuilder {
  318. std::vector<PyMethodDef> m_methods;
  319. std::vector<PyGetSetDef> m_getsets;
  320. PyTypeObject m_type;
  321. bool m_finalized = false;
  322. bool m_ready = false;
  323. void check_finalized() {
  324. if (m_finalized) {
  325. throw std::runtime_error("type is already finalized");
  326. }
  327. }
  328. static const char* to_c_str(const char* s) {return s;}
  329. template <size_t N, typename... Ts>
  330. static const char* to_c_str(const pybind11::detail::descr<N, Ts...>& desc) {
  331. return desc.text;
  332. }
  333. public:
  334. TypeBuilder(const TypeBuilder&) = delete;
  335. TypeBuilder& operator=(const TypeBuilder&) = delete;
  336. TypeBuilder() : m_type{PyVarObject_HEAD_INIT(nullptr, 0)} {
  337. constexpr auto has_tp_name = HAS_MEMBER(T, tp_name);
  338. if constexpr (has_tp_name) {
  339. m_type.tp_name = to_c_str(T::tp_name);
  340. }
  341. m_type.tp_dealloc = tp_dealloc::value;
  342. #ifdef _Py_TPFLAGS_HAVE_VECTORCALL
  343. m_type.tp_vectorcall_offset = tp_vectorcall::offset;
  344. #endif
  345. m_type.tp_call = tp_call::value;
  346. m_type.tp_basicsize = sizeof(wrap_t);
  347. m_type.tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE;
  348. #ifdef _Py_TPFLAGS_HAVE_VECTORCALL
  349. if constexpr (tp_vectorcall::valid) {
  350. m_type.tp_flags |= _Py_TPFLAGS_HAVE_VECTORCALL;
  351. }
  352. #endif
  353. m_type.tp_new = tp_new::value;
  354. }
  355. PyTypeObject* operator->() {
  356. return &m_type;
  357. }
  358. bool ready() const {
  359. return m_ready;
  360. }
  361. bool isinstance(PyObject* op) {
  362. return PyObject_TypeCheck(op, &m_type);
  363. }
  364. bool isexact(PyObject* op) {
  365. return Py_TYPE(op) == &m_type;
  366. }
  367. bool same_pytype(PyTypeObject *pt) {
  368. return pt == &m_type;
  369. }
  370. PyObject* finalize() {
  371. if (!m_finalized) {
  372. m_finalized = true;
  373. if (m_methods.size()) {
  374. m_methods.push_back({0});
  375. if (m_type.tp_methods) {
  376. PyErr_SetString(PyExc_SystemError, "tp_method is already set");
  377. return nullptr;
  378. }
  379. m_type.tp_methods = &m_methods[0];
  380. }
  381. if (m_getsets.size()) {
  382. m_getsets.push_back({0});
  383. if (m_type.tp_getset) {
  384. PyErr_SetString(PyExc_SystemError, "tp_getset is already set");
  385. return nullptr;
  386. }
  387. m_type.tp_getset = &m_getsets[0];
  388. }
  389. if (PyType_Ready(&m_type)) {
  390. return nullptr;
  391. }
  392. m_ready = true;
  393. }
  394. return (PyObject*)&m_type;
  395. }
  396. template<auto f>
  397. TypeBuilder& def(const char* name, const char* doc = nullptr) {
  398. check_finalized();
  399. m_methods.push_back(make_meth_def<f>(name, doc));
  400. return *this;
  401. }
  402. template<auto get, auto set = nullptr>
  403. TypeBuilder& def_getset(const char* name, const char* doc = nullptr, void* closure = nullptr) {
  404. check_finalized();
  405. m_getsets.push_back(make_getset_def<get, set>(name, doc, closure));
  406. return *this;
  407. }
  408. };
  409. static TypeBuilder& type() {
  410. static TypeBuilder type_helper;
  411. return type_helper;
  412. }
  413. template<typename... Args>
  414. static PyObject* cnew(Args&&... args) {
  415. auto* pytype = type().operator->();
  416. return cnew_with_type(pytype, std::forward<Args>(args)...);
  417. }
  418. template<typename... Args>
  419. static PyObject* cnew_with_type(PyTypeObject* pytype, Args&&... args) {
  420. auto* self = pytype->tp_alloc(pytype, 0);
  421. auto* inst = reinterpret_cast<wrap_t*>(self)->inst();
  422. if constexpr (has_vectorcall && tp_vectorcall::valid) {
  423. reinterpret_cast<wrap_t*>(self)->vectorcall_slot = &tp_vectorcall::template impl<>;
  424. }
  425. new(inst) T(std::forward<Args>(args)...);
  426. return self;
  427. }
  428. struct caster {
  429. static constexpr auto name = T::tp_name;
  430. T* value;
  431. bool load(pybind11::handle src, bool convert) {
  432. if (wrap_t::type().isinstance(src.ptr())) {
  433. value = reinterpret_cast<wrap_t*>(src.ptr())->inst();
  434. return true;
  435. }
  436. return false;
  437. }
  438. template <typename U> using cast_op_type = pybind11::detail::cast_op_type<U>;
  439. operator T*() { return value; }
  440. operator T&() { return *value; }
  441. };
  442. };
  443. } // namespace pyext17
  444. #undef HAS_MEMBER_TYPE
  445. #undef HAS_MEMBER
  446. #undef CVT_RET_PYOBJ
  447. #undef CVT_RET_INT

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台