You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

pyext17.h 20 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579
  1. /**
  2. * \file imperative/python/src/pyext17.h
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #pragma once
  12. #include <Python.h>
  13. #include <pybind11/pybind11.h>
  14. #include <exception>
  15. #include <stdexcept>
  16. #include <utility>
  17. #include <vector>
  18. namespace pyext17 {
  19. #ifdef METH_FASTCALL
  20. constexpr bool has_fastcall = true;
  21. #else
  22. constexpr bool has_fastcall = false;
  23. #endif
  24. #ifdef _Py_TPFLAGS_HAVE_VECTORCALL
  25. constexpr bool has_vectorcall = true;
  26. #else
  27. constexpr bool has_vectorcall = false;
  28. #endif
  29. template <typename... Args>
  30. struct invocable_with {
  31. template <typename T>
  32. constexpr bool operator()(T&& lmb) {
  33. return std::is_invocable_v<T, Args...>;
  34. }
  35. };
  36. #define HAS_MEMBER_TYPE(T, U) \
  37. invocable_with<T>{}([](auto&& x) -> typename std::decay_t<decltype(x)>::U {})
  38. #define HAS_MEMBER(T, m) \
  39. invocable_with<T>{}([](auto&& x) -> decltype(&std::decay_t<decltype(x)>::m) {})
  40. inline PyObject* cvt_retval(PyObject* rv) {
  41. return rv;
  42. }
  43. #define CVT_RET_PYOBJ(...) \
  44. if constexpr (std::is_same_v<decltype(__VA_ARGS__), void>) { \
  45. __VA_ARGS__; \
  46. Py_RETURN_NONE; \
  47. } else { \
  48. return cvt_retval(__VA_ARGS__); \
  49. }
  50. inline int cvt_retint(int ret) {
  51. return ret;
  52. }
  53. #define CVT_RET_INT(...) \
  54. if constexpr (std::is_same_v<decltype(__VA_ARGS__), void>) { \
  55. __VA_ARGS__; \
  56. return 0; \
  57. } else { \
  58. return cvt_retint(__VA_ARGS__); \
  59. }
  60. struct py_err_set : std::exception {};
  61. // refer to pybind11 for the following exception handling helper
  62. inline void pybind11_translate_exception(std::exception_ptr last_exception) {
  63. auto& registered_exception_translators =
  64. pybind11::detail::get_internals().registered_exception_translators;
  65. for (auto& translator : registered_exception_translators) {
  66. try {
  67. translator(last_exception);
  68. } catch (...) {
  69. last_exception = std::current_exception();
  70. continue;
  71. }
  72. return;
  73. }
  74. PyErr_SetString(
  75. PyExc_SystemError, "Exception escaped from default exception translator!");
  76. }
  77. inline void pybind11_translate_exception() {
  78. pybind11_translate_exception(std::current_exception());
  79. }
  80. #if defined(__GNUG__) && !defined(__clang__)
  81. #define PYEXT17_TRANSLATE_EXC_CATCH_FORCED_UNWIND \
  82. catch (::abi::__forced_unwind&) { \
  83. throw; \
  84. }
  85. #else
  86. #define PYEXT17_TRANSLATE_EXC_CATCH_FORCED_UNWIND
  87. #endif
  88. #define PYEXT17_TRANSLATE_EXC \
  89. catch (::pyext17::py_err_set&) { \
  90. } \
  91. catch (::pybind11::error_already_set & e) { \
  92. e.restore(); \
  93. } \
  94. PYEXT17_TRANSLATE_EXC_CATCH_FORCED_UNWIND \
  95. catch (...) { \
  96. ::pyext17::pybind11_translate_exception(); \
  97. }
  98. #define PYEXT17_TRANSLATE_EXC_RET(RET) \
  99. catch (::pyext17::py_err_set&) { \
  100. return RET; \
  101. } \
  102. catch (::pybind11::error_already_set & e) { \
  103. e.restore(); \
  104. return RET; \
  105. } \
  106. PYEXT17_TRANSLATE_EXC_CATCH_FORCED_UNWIND \
  107. catch (...) { \
  108. ::pyext17::pybind11_translate_exception(); \
  109. return RET; \
  110. };
  111. template <typename T>
  112. struct wrap {
  113. private:
  114. typedef wrap<T> wrap_t;
  115. public:
  116. PyObject_HEAD std::aligned_storage_t<sizeof(T), alignof(T)> storage;
  117. #ifdef _Py_TPFLAGS_HAVE_VECTORCALL
  118. PyObject* (*vectorcall_slot)(PyObject*, PyObject* const*, size_t, PyObject*);
  119. #endif
  120. inline T* inst() { return reinterpret_cast<T*>(&storage); }
  121. inline static PyObject* pycast(T* ptr) {
  122. return (PyObject*)((char*)ptr - offsetof(wrap_t, storage));
  123. }
  124. private:
  125. // method wrapper
  126. enum struct meth_type { noarg, varkw, fastcall, singarg };
  127. template <auto f>
  128. struct detect_meth_type {
  129. static constexpr meth_type value = []() {
  130. using F = decltype(f);
  131. static_assert(std::is_member_function_pointer_v<F>);
  132. if constexpr (std::is_invocable_v<F, T>) {
  133. return meth_type::noarg;
  134. } else if constexpr (std::is_invocable_v<F, T, PyObject*, PyObject*>) {
  135. return meth_type::varkw;
  136. } else if constexpr (std::is_invocable_v<
  137. F, T, PyObject* const*, Py_ssize_t>) {
  138. return meth_type::fastcall;
  139. } else if constexpr (std::is_invocable_v<F, T, PyObject*>) {
  140. return meth_type::singarg;
  141. } else {
  142. static_assert(!std::is_same_v<F, F>);
  143. }
  144. }();
  145. };
  146. template <meth_type, auto f>
  147. struct meth {};
  148. template <auto f>
  149. struct meth<meth_type::noarg, f> {
  150. static constexpr int flags = METH_NOARGS;
  151. static PyObject* impl(PyObject* self, PyObject*) {
  152. auto* inst = reinterpret_cast<wrap_t*>(self)->inst();
  153. try {
  154. CVT_RET_PYOBJ((inst->*f)());
  155. }
  156. PYEXT17_TRANSLATE_EXC_RET(nullptr)
  157. }
  158. };
  159. template <auto f>
  160. struct meth<meth_type::varkw, f> {
  161. static constexpr int flags = METH_VARARGS | METH_KEYWORDS;
  162. static PyObject* impl(PyObject* self, PyObject* args, PyObject* kwargs) {
  163. auto* inst = reinterpret_cast<wrap_t*>(self)->inst();
  164. try {
  165. CVT_RET_PYOBJ((inst->*f)(args, kwargs));
  166. }
  167. PYEXT17_TRANSLATE_EXC_RET(nullptr)
  168. }
  169. };
  170. template <auto f>
  171. struct meth<meth_type::fastcall, f> {
  172. #ifdef METH_FASTCALL
  173. static constexpr int flags = METH_FASTCALL;
  174. static PyObject* impl(PyObject* self, PyObject* const* args, Py_ssize_t nargs) {
  175. auto* inst = reinterpret_cast<wrap_t*>(self)->inst();
  176. try {
  177. CVT_RET_PYOBJ((inst->*f)(args, nargs));
  178. }
  179. PYEXT17_TRANSLATE_EXC_RET(nullptr)
  180. }
  181. #else
  182. static constexpr int flags = METH_VARARGS;
  183. static PyObject* impl(PyObject* self, PyObject* args) {
  184. auto* inst = reinterpret_cast<wrap_t*>(self)->inst();
  185. auto* arr = &PyTuple_GET_ITEM(args, 0);
  186. auto size = PyTuple_GET_SIZE(args);
  187. try {
  188. CVT_RET_PYOBJ((inst->*f)(arr, size));
  189. }
  190. PYEXT17_TRANSLATE_EXC_RET(nullptr)
  191. }
  192. #endif
  193. };
  194. template <auto f>
  195. struct meth<meth_type::singarg, f> {
  196. static constexpr int flags = METH_O;
  197. static PyObject* impl(PyObject* self, PyObject* obj) {
  198. auto* inst = reinterpret_cast<wrap_t*>(self)->inst();
  199. try {
  200. CVT_RET_PYOBJ((inst->*f)(obj));
  201. }
  202. PYEXT17_TRANSLATE_EXC_RET(nullptr)
  203. }
  204. };
  205. template <auto f>
  206. static constexpr PyMethodDef make_meth_def(
  207. const char* name, const char* doc = nullptr) {
  208. using M = meth<detect_meth_type<f>::value, f>;
  209. return {name, (PyCFunction)M::impl, M::flags, doc};
  210. }
  211. template <auto f>
  212. struct getter {
  213. using F = decltype(f);
  214. static PyObject* impl(PyObject* self, void* closure) {
  215. auto* inst = reinterpret_cast<wrap_t*>(self)->inst();
  216. try {
  217. if constexpr (std::is_invocable_v<F, PyObject*, void*>) {
  218. CVT_RET_PYOBJ(f(self, closure));
  219. } else if constexpr (std::is_invocable_v<F, T, void*>) {
  220. CVT_RET_PYOBJ((inst->*f)(closure));
  221. } else if constexpr (std::is_invocable_v<F, T>) {
  222. CVT_RET_PYOBJ((inst->*f)());
  223. } else {
  224. static_assert(!std::is_same_v<F, F>);
  225. }
  226. }
  227. PYEXT17_TRANSLATE_EXC_RET(nullptr)
  228. }
  229. };
  230. template <auto f>
  231. struct setter {
  232. using F = decltype(f);
  233. template <typename = void>
  234. static int impl_(PyObject* self, PyObject* val, void* closure) {
  235. auto* inst = reinterpret_cast<wrap_t*>(self)->inst();
  236. try {
  237. if constexpr (std::is_invocable_v<F, PyObject*, PyObject*, void*>) {
  238. CVT_RET_INT(f(self, val, closure));
  239. } else if constexpr (std::is_invocable_v<F, T, PyObject*, void*>) {
  240. CVT_RET_INT((inst->*f)(val, closure));
  241. } else if constexpr (std::is_invocable_v<F, T, PyObject*>) {
  242. CVT_RET_INT((inst->*f)(val));
  243. } else {
  244. static_assert(!std::is_same_v<F, F>);
  245. }
  246. }
  247. PYEXT17_TRANSLATE_EXC_RET(-1)
  248. }
  249. static constexpr auto impl = []() {
  250. if constexpr (std::is_same_v<F, std::nullptr_t>)
  251. return nullptr;
  252. else
  253. return impl_<>;
  254. }();
  255. };
  256. template <auto get, auto set = nullptr>
  257. static constexpr PyGetSetDef make_getset_def(
  258. const char* name, const char* doc = nullptr, void* closure = nullptr) {
  259. return {const_cast<char*>(name), getter<get>::impl, setter<set>::impl,
  260. const_cast<char*>(doc), closure};
  261. }
  262. // polyfills
  263. struct tp_vectorcall {
  264. static constexpr bool valid = HAS_MEMBER(T, tp_vectorcall);
  265. static constexpr bool haskw = []() {
  266. if constexpr (valid)
  267. if constexpr (std::is_invocable_v<
  268. decltype(&T::tp_vectorcall), T, PyObject* const*,
  269. size_t, PyObject*>)
  270. return true;
  271. return false;
  272. }();
  273. template <typename = void>
  274. static PyObject* impl(
  275. PyObject* self, PyObject* const* args, size_t nargsf,
  276. PyObject* kwnames) {
  277. auto* inst = reinterpret_cast<wrap_t*>(self)->inst();
  278. if constexpr (haskw) {
  279. CVT_RET_PYOBJ(inst->tp_vectorcall(args, nargsf, kwnames));
  280. } else {
  281. if (kwnames && PyTuple_GET_SIZE(kwnames)) {
  282. PyErr_SetString(PyExc_TypeError, "expect no keyword argument");
  283. return nullptr;
  284. }
  285. CVT_RET_PYOBJ(inst->tp_vectorcall(args, nargsf));
  286. }
  287. }
  288. static constexpr Py_ssize_t offset = []() {
  289. if constexpr (valid)
  290. return offsetof(wrap_t, vectorcall_slot);
  291. else
  292. return 0;
  293. }();
  294. };
  295. struct tp_call {
  296. static constexpr bool provided = HAS_MEMBER(T, tp_call);
  297. static constexpr bool static_form =
  298. invocable_with<T, PyObject*, PyObject*, PyObject*>{}(
  299. [](auto&& t, auto... args)
  300. -> decltype(std::decay_t<decltype(t)>::tp_call(
  301. args...)) {});
  302. static constexpr bool valid = provided || tp_vectorcall::valid;
  303. template <typename = void>
  304. static PyObject* impl(PyObject* self, PyObject* args, PyObject* kwargs) {
  305. auto* inst = reinterpret_cast<wrap_t*>(self)->inst();
  306. CVT_RET_PYOBJ(inst->tp_call(args, kwargs));
  307. }
  308. static constexpr ternaryfunc value = []() {
  309. if constexpr (static_form)
  310. return T::tp_call;
  311. else if constexpr (provided)
  312. return impl<>;
  313. #ifdef _Py_TPFLAGS_HAVE_VECTORCALL
  314. else if constexpr (valid)
  315. return PyVectorcall_Call;
  316. #endif
  317. else
  318. return nullptr;
  319. }();
  320. };
  321. struct tp_new {
  322. static constexpr bool provided = HAS_MEMBER(T, tp_new);
  323. static constexpr bool varkw = std::is_constructible_v<T, PyObject*, PyObject*>;
  324. static constexpr bool noarg = std::is_default_constructible_v<T>;
  325. template <typename = void>
  326. static PyObject* impl(PyTypeObject* type, PyObject* args, PyObject* kwargs) {
  327. struct FreeGuard {
  328. PyObject* self;
  329. PyTypeObject* type;
  330. ~FreeGuard() {
  331. if (self)
  332. type->tp_free(self);
  333. }
  334. };
  335. auto* self = type->tp_alloc(type, 0);
  336. FreeGuard free_guard{self, type};
  337. auto* inst = reinterpret_cast<wrap_t*>(self)->inst();
  338. if constexpr (has_vectorcall && tp_vectorcall::valid) {
  339. reinterpret_cast<wrap_t*>(self)->vectorcall_slot =
  340. &tp_vectorcall::template impl<>;
  341. }
  342. try {
  343. if constexpr (varkw) {
  344. new (inst) T(args, kwargs);
  345. } else {
  346. new (inst) T();
  347. }
  348. }
  349. PYEXT17_TRANSLATE_EXC_RET(nullptr)
  350. free_guard.self = nullptr;
  351. return self;
  352. }
  353. static constexpr newfunc value = []() {
  354. if constexpr (provided)
  355. return T::tp_new;
  356. else if constexpr (varkw || noarg)
  357. return impl<>;
  358. else
  359. return nullptr;
  360. }();
  361. };
  362. struct tp_dealloc {
  363. static constexpr bool provided = HAS_MEMBER(T, tp_dealloc);
  364. template <typename = void>
  365. static void impl(PyObject* self) {
  366. reinterpret_cast<wrap_t*>(self)->inst()->~T();
  367. Py_TYPE(self)->tp_free(self);
  368. }
  369. static constexpr destructor value = []() {
  370. if constexpr (provided)
  371. return T::tp_dealloc;
  372. else
  373. return impl<>;
  374. }();
  375. };
  376. public:
  377. class TypeBuilder {
  378. std::vector<PyMethodDef> m_methods;
  379. std::vector<PyGetSetDef> m_getsets;
  380. PyTypeObject m_type;
  381. bool m_finalized = false;
  382. bool m_ready = false;
  383. void check_finalized() {
  384. if (m_finalized) {
  385. throw std::runtime_error("type is already finalized");
  386. }
  387. }
  388. static const char* to_c_str(const char* s) { return s; }
  389. template <size_t N, typename... Ts>
  390. static const char* to_c_str(const pybind11::detail::descr<N, Ts...>& desc) {
  391. return desc.text;
  392. }
  393. public:
  394. TypeBuilder(const TypeBuilder&) = delete;
  395. TypeBuilder& operator=(const TypeBuilder&) = delete;
  396. TypeBuilder() : m_type{PyVarObject_HEAD_INIT(nullptr, 0)} {
  397. constexpr auto has_tp_name = HAS_MEMBER(T, tp_name);
  398. if constexpr (has_tp_name) {
  399. m_type.tp_name = to_c_str(T::tp_name);
  400. }
  401. m_type.tp_dealloc = tp_dealloc::value;
  402. #ifdef _Py_TPFLAGS_HAVE_VECTORCALL
  403. m_type.tp_vectorcall_offset = tp_vectorcall::offset;
  404. #endif
  405. m_type.tp_call = tp_call::value;
  406. m_type.tp_basicsize = sizeof(wrap_t);
  407. m_type.tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE;
  408. #ifdef _Py_TPFLAGS_HAVE_VECTORCALL
  409. if constexpr (tp_vectorcall::valid) {
  410. m_type.tp_flags |= _Py_TPFLAGS_HAVE_VECTORCALL;
  411. }
  412. #endif
  413. m_type.tp_new = tp_new::value;
  414. }
  415. PyTypeObject* operator->() { return &m_type; }
  416. bool ready() const { return m_ready; }
  417. bool isinstance(PyObject* op) { return PyObject_TypeCheck(op, &m_type); }
  418. bool isexact(PyObject* op) { return Py_TYPE(op) == &m_type; }
  419. bool same_pytype(PyTypeObject* pt) { return pt == &m_type; }
  420. PyObject* finalize() {
  421. if (!m_finalized) {
  422. m_finalized = true;
  423. if (m_methods.size()) {
  424. m_methods.push_back({0});
  425. if (m_type.tp_methods) {
  426. PyErr_SetString(PyExc_SystemError, "tp_method is already set");
  427. return nullptr;
  428. }
  429. m_type.tp_methods = &m_methods[0];
  430. }
  431. if (m_getsets.size()) {
  432. m_getsets.push_back({0});
  433. if (m_type.tp_getset) {
  434. PyErr_SetString(PyExc_SystemError, "tp_getset is already set");
  435. return nullptr;
  436. }
  437. m_type.tp_getset = &m_getsets[0];
  438. }
  439. if (PyType_Ready(&m_type)) {
  440. return nullptr;
  441. }
  442. m_ready = true;
  443. }
  444. return (PyObject*)&m_type;
  445. }
  446. template <auto f>
  447. TypeBuilder& def(const char* name, const char* doc = nullptr) {
  448. check_finalized();
  449. m_methods.push_back(make_meth_def<f>(name, doc));
  450. return *this;
  451. }
  452. template <auto get, auto set = nullptr>
  453. TypeBuilder& def_getset(
  454. const char* name, const char* doc = nullptr, void* closure = nullptr) {
  455. check_finalized();
  456. m_getsets.push_back(make_getset_def<get, set>(name, doc, closure));
  457. return *this;
  458. }
  459. };
  460. static TypeBuilder& type() {
  461. static TypeBuilder type_helper;
  462. return type_helper;
  463. }
  464. template <typename... Args>
  465. static PyObject* cnew(Args&&... args) {
  466. auto* pytype = type().operator->();
  467. return cnew_with_type(pytype, std::forward<Args>(args)...);
  468. }
  469. template <typename... Args>
  470. static PyObject* cnew_with_type(PyTypeObject* pytype, Args&&... args) {
  471. auto* self = pytype->tp_alloc(pytype, 0);
  472. auto* inst = reinterpret_cast<wrap_t*>(self)->inst();
  473. if constexpr (has_vectorcall && tp_vectorcall::valid) {
  474. reinterpret_cast<wrap_t*>(self)->vectorcall_slot =
  475. &tp_vectorcall::template impl<>;
  476. }
  477. new (inst) T(std::forward<Args>(args)...);
  478. return self;
  479. }
  480. struct caster {
  481. static constexpr auto name = T::tp_name;
  482. T* value;
  483. bool load(pybind11::handle src, bool convert) {
  484. if (wrap_t::type().isinstance(src.ptr())) {
  485. value = reinterpret_cast<wrap_t*>(src.ptr())->inst();
  486. return true;
  487. }
  488. return false;
  489. }
  490. template <typename U>
  491. using cast_op_type = pybind11::detail::cast_op_type<U>;
  492. operator T*() { return value; }
  493. operator T&() { return *value; }
  494. };
  495. };
  496. } // namespace pyext17
  497. #undef HAS_MEMBER_TYPE
  498. #undef HAS_MEMBER
  499. #undef CVT_RET_PYOBJ
  500. #undef CVT_RET_INT

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台