You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

helper.h 13 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439
  1. #pragma once
  2. #include "megbrain/common.h"
  3. #include "megbrain/imperative/op_def.h"
  4. #include "megbrain/utils/persistent_cache.h"
  5. #include <Python.h>
  6. #include <iterator>
  7. #include <string>
  8. #if __cplusplus > 201703L
  9. #include <ranges>
  10. #endif
  11. #include <pybind11/functional.h>
  12. #include <pybind11/numpy.h>
  13. #include <pybind11/pybind11.h>
  14. #include <pybind11/stl.h>
  15. #include "./numpy_dtypes.h"
  16. pybind11::module submodule(
  17. pybind11::module parent, const char* name, const char* doc = nullptr);
  18. pybind11::module rel_import(pybind11::str name, pybind11::module m, int level);
  19. #if __cplusplus > 201703L
  20. using std::ranges::range_value_t;
  21. #else
  22. template <typename T>
  23. using range_value_t =
  24. std::remove_cv_t<std::remove_reference_t<decltype(*std::declval<T>().begin())>>;
  25. #endif
  26. template <typename T>
  27. auto to_list(const T& x) {
  28. using elem_t = range_value_t<T>;
  29. std::vector<elem_t> ret(x.begin(), x.end());
  30. return pybind11::cast(ret);
  31. }
  32. template <typename T>
  33. auto to_tuple(
  34. const T& x, pybind11::return_value_policy policy =
  35. pybind11::return_value_policy::automatic) {
  36. auto ret = pybind11::tuple(x.size());
  37. for (size_t i = 0; i < x.size(); ++i) {
  38. ret[i] = pybind11::cast(x[i], policy);
  39. }
  40. return ret;
  41. }
  42. template <typename T>
  43. auto to_tuple(
  44. T begin, T end,
  45. pybind11::return_value_policy policy =
  46. pybind11::return_value_policy::automatic) {
  47. auto ret = pybind11::tuple(end - begin);
  48. for (size_t i = 0; begin < end; ++begin, ++i) {
  49. ret[i] = pybind11::cast(*begin, policy);
  50. }
  51. return ret;
  52. }
  53. class PyTaskDipatcher {
  54. struct Queue : mgb::AsyncQueueSC<std::function<void(void)>, Queue> {
  55. using Task = std::function<void(void)>;
  56. // set max_spin=0 to prevent Queue fetch task in busy wait manner.
  57. // this won't affect throughput when python interpreter is sending enough task,
  58. // but will significantly save CPU time when waiting for task, e.g. wait for
  59. // data input
  60. Queue() : mgb::AsyncQueueSC<std::function<void(void)>, Queue>(0) {}
  61. void process_one_task(Task& f) {
  62. if (!Py_IsInitialized())
  63. return;
  64. pybind11::gil_scoped_acquire _;
  65. f();
  66. }
  67. void on_async_queue_worker_thread_start() override {
  68. mgb::sys::set_thread_name("py_task_worker");
  69. }
  70. };
  71. Queue queue;
  72. bool finalized = false;
  73. public:
  74. template <typename T>
  75. void add_task(T&& task) {
  76. // CPython never dlclose an extension so
  77. // finalized means the interpreter has been shutdown
  78. if (!finalized) {
  79. queue.add_task(std::forward<T>(task));
  80. }
  81. }
  82. void wait_all_task_finish() { queue.wait_all_task_finish(); }
  83. ~PyTaskDipatcher() {
  84. finalized = true;
  85. queue.wait_all_task_finish();
  86. }
  87. };
  88. extern PyTaskDipatcher py_task_q;
  89. class GILManager {
  90. PyGILState_STATE gstate;
  91. public:
  92. GILManager() : gstate(PyGILState_Ensure()) {}
  93. ~GILManager() { PyGILState_Release(gstate); }
  94. };
  95. #define PYTHON_GIL GILManager __gil_manager
  96. //! wraps a shared_ptr and decr PyObject ref when destructed
  97. class PyObjRefKeeper {
  98. std::shared_ptr<PyObject> m_ptr;
  99. public:
  100. static void deleter(PyObject* p) {
  101. if (p) {
  102. py_task_q.add_task([p]() { Py_DECREF(p); });
  103. }
  104. }
  105. PyObjRefKeeper() = default;
  106. PyObjRefKeeper(PyObject* p) : m_ptr{p, deleter} {}
  107. PyObject* get() const { return m_ptr.get(); }
  108. //! create a shared_ptr as an alias of the underlying ptr
  109. template <typename T>
  110. std::shared_ptr<T> make_shared(T* ptr) const {
  111. return {m_ptr, ptr};
  112. }
  113. };
  114. //! exception to be thrown when python callback fails
  115. class PyExceptionForward : public std::exception {
  116. PyObject *m_type, *m_value, *m_traceback;
  117. std::string m_msg;
  118. PyExceptionForward(
  119. PyObject* type, PyObject* value, PyObject* traceback,
  120. const std::string& msg)
  121. : m_type{type}, m_value{value}, m_traceback{traceback}, m_msg{msg} {}
  122. public:
  123. PyExceptionForward(const PyExceptionForward&) = delete;
  124. PyExceptionForward& operator=(const PyExceptionForward&) = delete;
  125. ~PyExceptionForward();
  126. PyExceptionForward(PyExceptionForward&& rhs)
  127. : m_type{rhs.m_type},
  128. m_value{rhs.m_value},
  129. m_traceback{rhs.m_traceback},
  130. m_msg{std::move(rhs.m_msg)} {
  131. rhs.m_type = rhs.m_value = rhs.m_traceback = nullptr;
  132. }
  133. //! throw PyExceptionForward from current python error state
  134. static void throw_() __attribute__((noreturn));
  135. //! restore python error
  136. void restore();
  137. const char* what() const noexcept override { return m_msg.c_str(); }
  138. };
  139. //! numpy utils
  140. namespace npy {
  141. //! convert tensor shape to raw vector
  142. static inline std::vector<size_t> shape2vec(const mgb::TensorShape& shape) {
  143. return {shape.shape, shape.shape + shape.ndim};
  144. }
  145. //! change numpy dtype to megbrain supported dtype
  146. PyObject* to_mgb_supported_dtype(PyObject* dtype);
  147. //! convert raw vector to tensor shape
  148. mgb::TensorShape vec2shape(const std::vector<size_t>& vec);
  149. struct PyArrayDescrDeleter {
  150. void operator()(PyArray_Descr* obj) { Py_XDECREF(obj); }
  151. };
  152. //! Convert MegBrain DType to NumPy DType descriptor, the caller receives a new
  153. //! reference to the descriptor.
  154. std::unique_ptr<PyArray_Descr, PyArrayDescrDeleter> dtype_mgb2np_descr(
  155. mgb::DType dtype);
  156. mgb::DType dtype_np2mgb_descr(PyArray_Descr* descr);
  157. //! convert megbrain dtype to numpy dtype object; return new reference
  158. PyObject* dtype_mgb2np(mgb::DType dtype);
  159. //! convert numpy dtype object or string to megbrain dtype
  160. mgb::DType dtype_np2mgb(PyObject* obj);
  161. //! buffer sharing type
  162. enum class ShareType {
  163. MUST_SHARE, //!< must be shared
  164. MUST_UNSHARE, //!< must not be shared
  165. TRY_SHARE //!< share if possible
  166. };
  167. //! get ndarray from HostTensorND
  168. PyObject* ndarray_from_tensor(const mgb::HostTensorND& val, ShareType share_type);
  169. //! specify how to convert numpy array to tensor
  170. struct Meth {
  171. bool must_borrow_ = false;
  172. mgb::HostTensorND* dest_tensor_ = nullptr;
  173. mgb::CompNode dest_cn_;
  174. //! make a Meth that allows borrowing numpy array memory
  175. static Meth borrow(mgb::CompNode dest_cn = mgb::CompNode::default_cpu()) {
  176. return {false, nullptr, dest_cn};
  177. }
  178. //! make a Meth that requires the numpy array to be borrowed
  179. static Meth must_borrow(mgb::CompNode dest_cn = mgb::CompNode::default_cpu()) {
  180. return {true, nullptr, dest_cn};
  181. }
  182. //! make a Meth that requires copying the value into another
  183. //! tensor
  184. static Meth copy_into(mgb::HostTensorND* tensor) {
  185. return {false, tensor, tensor->comp_node()};
  186. }
  187. };
  188. /*!
  189. * \brief convert an object to megbrain tensor
  190. * \param meth specifies how the conversion should take place
  191. * \param dtype desired dtype; it can be set as invalid to allow arbitrary
  192. * dtype
  193. */
  194. mgb::HostTensorND np2tensor(PyObject* obj, const Meth& meth, mgb::DType dtype);
  195. } // namespace npy
  196. // Note: following macro was copied from pybind11/detail/common.h
  197. // Robust support for some features and loading modules compiled against different
  198. // pybind versions requires forcing hidden visibility on pybind code, so we enforce this
  199. // by setting the attribute on the main `pybind11` namespace.
  200. #if !defined(PYBIND11_NAMESPACE)
  201. #ifdef __GNUG__
  202. #define PYBIND11_NAMESPACE pybind11 __attribute__((visibility("hidden")))
  203. #else
  204. #define PYBIND11_NAMESPACE pybind11
  205. #endif
  206. #endif
  207. namespace PYBIND11_NAMESPACE {
  208. namespace detail {
  209. template <typename T, unsigned N>
  210. struct type_caster<megdnn::SmallVector<T, N>>
  211. : list_caster<megdnn::SmallVector<T, N>, T> {};
  212. template <>
  213. struct type_caster<mgb::DType> {
  214. PYBIND11_TYPE_CASTER(mgb::DType, _("DType"));
  215. public:
  216. bool load(handle src, bool convert) {
  217. auto obj = reinterpret_borrow<object>(src);
  218. if (!convert && !isinstance<dtype>(obj)) {
  219. return false;
  220. }
  221. if (obj.is_none()) {
  222. return true;
  223. }
  224. try {
  225. obj = pybind11::dtype::from_args(obj);
  226. } catch (pybind11::error_already_set&) {
  227. return false;
  228. }
  229. try {
  230. value = npy::dtype_np2mgb(obj.ptr());
  231. } catch (...) {
  232. return false;
  233. }
  234. return true;
  235. }
  236. static handle cast(
  237. mgb::DType dt, return_value_policy /* policy */, handle /* parent */) {
  238. // ignore policy and parent because we always return a pure python object
  239. return npy::dtype_mgb2np(std::move(dt));
  240. }
  241. };
  242. template <>
  243. struct type_caster<mgb::TensorShape> {
  244. PYBIND11_TYPE_CASTER(mgb::TensorShape, _("TensorShape"));
  245. public:
  246. bool load(handle src, bool convert) {
  247. auto obj = reinterpret_borrow<object>(src);
  248. if (!convert && !isinstance<tuple>(obj)) {
  249. return false;
  250. }
  251. if (obj.is_none()) {
  252. return true;
  253. }
  254. value.ndim = len(obj);
  255. mgb_assert(value.ndim <= mgb::TensorShape::MAX_NDIM);
  256. size_t i = 0;
  257. for (auto v : obj) {
  258. mgb_assert(i < value.ndim);
  259. value.shape[i] = reinterpret_borrow<object>(v).cast<size_t>();
  260. ++i;
  261. }
  262. return true;
  263. }
  264. static handle cast(
  265. mgb::TensorShape shape, return_value_policy /* policy */,
  266. handle /* parent */) {
  267. // ignore policy and parent because we always return a pure python object
  268. return to_tuple(shape.shape, shape.shape + shape.ndim).release();
  269. }
  270. };
  271. // hack to make custom object implicitly convertible from None
  272. template <typename T>
  273. struct from_none_caster : public type_caster_base<T> {
  274. using base = type_caster_base<T>;
  275. bool load(handle src, bool convert) {
  276. if (!convert || !src.is_none()) {
  277. return base::load(src, convert);
  278. }
  279. // adapted from pybind11::implicitly_convertible
  280. auto temp = reinterpret_steal<object>(
  281. PyObject_Call((PyObject*)this->typeinfo->type, tuple().ptr(), nullptr));
  282. if (!temp) {
  283. PyErr_Clear();
  284. return false;
  285. }
  286. // adapted from pybind11::detail::type_caster_generic
  287. if (base::load(temp, false)) {
  288. loader_life_support::add_patient(temp);
  289. return true;
  290. }
  291. return false;
  292. }
  293. };
  294. template <>
  295. struct type_caster<mgb::CompNode> : public from_none_caster<mgb::CompNode> {};
  296. template <>
  297. struct type_caster<mgb::PersistentCache::Blob> {
  298. PYBIND11_TYPE_CASTER(mgb::PersistentCache::Blob, _("Blob"));
  299. public:
  300. bool load(handle src, bool convert) {
  301. if (!isinstance<bytes>(src)) {
  302. return false;
  303. }
  304. value.ptr = PYBIND11_BYTES_AS_STRING(src.ptr());
  305. value.size = PYBIND11_BYTES_SIZE(src.ptr());
  306. return true;
  307. }
  308. static handle cast(
  309. mgb::PersistentCache::Blob blob, return_value_policy /* policy */,
  310. handle /* parent */) {
  311. return bytes((const char*)blob.ptr, blob.size).release();
  312. }
  313. };
  314. template <typename T>
  315. struct type_caster<mgb::Maybe<T>> {
  316. using value_conv = make_caster<T>;
  317. PYBIND11_TYPE_CASTER(mgb::Maybe<T>, _("Optional[") + value_conv::name + _("]"));
  318. public:
  319. bool load(handle src, bool convert) {
  320. if (!src) {
  321. return false;
  322. }
  323. if (src.is_none()) {
  324. return true;
  325. }
  326. value_conv inner_caster;
  327. if (!inner_caster.load(src, convert)) {
  328. return false;
  329. }
  330. value.emplace(cast_op<T&&>(std::move(inner_caster)));
  331. return true;
  332. }
  333. static handle cast(mgb::Maybe<T> src, return_value_policy policy, handle parent) {
  334. if (!src.valid()) {
  335. return none().inc_ref();
  336. }
  337. return pybind11::cast(src.val(), policy, parent);
  338. }
  339. };
  340. template <>
  341. struct type_caster<mgb::imperative::OpDef> {
  342. protected:
  343. std::shared_ptr<mgb::imperative::OpDef> value;
  344. public:
  345. static constexpr auto name = _("OpDef");
  346. operator mgb::imperative::OpDef&() { return *value; }
  347. operator const mgb::imperative::OpDef&() { return *value; }
  348. operator std::shared_ptr<mgb::imperative::OpDef>&() { return value; }
  349. operator std::shared_ptr<mgb::imperative::OpDef>&&() && { return std::move(value); }
  350. template <typename T>
  351. using cast_op_type = T;
  352. bool load(handle src, bool convert);
  353. static handle cast(
  354. const mgb::imperative::OpDef& op, return_value_policy /* policy */,
  355. handle /* parent */);
  356. static handle cast(
  357. std::shared_ptr<mgb::imperative::OpDef> op, return_value_policy policy,
  358. handle parent) {
  359. return cast(*op, policy, parent);
  360. }
  361. };
  362. template <>
  363. struct type_caster<std::shared_ptr<mgb::imperative::OpDef>>
  364. : public type_caster<mgb::imperative::OpDef> {
  365. template <typename T>
  366. using cast_op_type = pybind11::detail::movable_cast_op_type<T>;
  367. };
  368. } // namespace detail
  369. } // namespace PYBIND11_NAMESPACE
  370. // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}