You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

helper.cpp 27 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722
  1. #include "./helper.h"
  2. #include <pybind11/eval.h>
  3. #include "megbrain/graph/cg.h"
  4. #include "megbrain/graph/event.h"
  5. #include "megbrain/graph/exc_extra_info.h"
  6. #include "megbrain/tensor.h"
  7. #include "megbrain/utils/mempool.h"
  8. #include "./numpy_dtypes.h"
  9. namespace py = pybind11;
  10. PyTaskDipatcher py_task_q = {};
  11. py::module submodule(py::module parent, const char* name, const char* doc) {
  12. auto m = parent.def_submodule(name, doc);
  13. m.attr("__package__") = parent.attr("__name__");
  14. m.attr("__builtins__") = py::module::import("builtins");
  15. return m;
  16. }
  17. py::module rel_import(py::str name, py::module m, int level) {
  18. py::object import = py::module::import("builtins").attr("__import__");
  19. return import(name, m.attr("__dict__"), py::arg("level") = level);
  20. }
  21. /*
  22. * demangle typeid, see
  23. * http://stackoverflow.com/questions/281818/unmangling-the-result-of-stdtype-infoname
  24. */
  25. #ifdef __GNUG__
  26. #include <cxxabi.h>
  27. #include <cstdlib>
  28. #include <memory>
  29. namespace {
  30. std::string demangle_typeid(const char* name) {
  31. int status = -4; // some arbitrary value to eliminate the compiler warning
  32. // enable c++11 by passing the flag -std=c++11 to g++
  33. std::unique_ptr<char, void (*)(void*)> res{
  34. abi::__cxa_demangle(name, nullptr, nullptr, &status), std::free};
  35. return (status == 0) ? res.get() : name;
  36. }
  37. } // namespace
  38. #else
  39. namespace {
  40. // does nothing if not g++
  41. std::string demangle_typeid(const char* name) {
  42. return name;
  43. }
  44. } // namespace
  45. #endif
  46. using namespace mgb;
  47. using namespace cg;
  48. namespace {
  49. std::string repr_pyobj(PyObject* obj) {
  50. if (!obj)
  51. return "<null PyObject>";
  52. PYTHON_GIL;
  53. auto str = PyObject_Repr(obj);
  54. if (!str)
  55. return ssprintf("<PyObject at %p (repr failed)>", obj);
  56. std::string ret{PyUnicode_AsUTF8(str)};
  57. Py_DECREF(str);
  58. return ret;
  59. }
  60. template <typename T>
  61. std::string typeid_name(const T& t) {
  62. return demangle_typeid(typeid(t).name());
  63. }
  64. } // anonymous namespace
  65. /* ============== PyExceptionForward ============== */
  66. PyExceptionForward::~PyExceptionForward() {
  67. PYTHON_GIL;
  68. PyObjRefKeeper::deleter(m_type);
  69. PyObjRefKeeper::deleter(m_value);
  70. PyObjRefKeeper::deleter(m_traceback);
  71. }
  72. void PyExceptionForward::restore() {
  73. PyErr_Restore(m_type, m_value, m_traceback);
  74. m_type = m_value = m_traceback = nullptr;
  75. }
  76. void PyExceptionForward::throw_() {
  77. PyObject *etype, *obj, *trace;
  78. PyErr_Fetch(&etype, &obj, &trace);
  79. PyErr_NormalizeException(&etype, &obj, &trace);
  80. std::string msg{"python exception"};
  81. bool succ = false;
  82. if (etype && obj && trace) {
  83. auto run = [&]() {
  84. #define DEF(name, expr) \
  85. PyObjRefKeeper name{expr}; \
  86. if (!name.get()) \
  87. return
  88. DEF(mod, PyImport_ImportModule("traceback"));
  89. DEF(result,
  90. PyObject_CallMethod(
  91. mod.get(), "format_exception", "(OOO)", etype, obj, trace));
  92. if (!PyList_Check(result.get()))
  93. return;
  94. auto size = PyList_Size(result.get());
  95. msg.append(":\n");
  96. for (Py_ssize_t i = 0; i < size; ++i) {
  97. msg.append(" ");
  98. msg.append(PyUnicode_AsUTF8(PyList_GetItem(result.get(), i)));
  99. }
  100. msg.pop_back(); // remove last \n
  101. succ = true;
  102. #undef DEF
  103. };
  104. run();
  105. }
  106. if (!succ) {
  107. PyObject* obj_str_py;
  108. if (obj && (obj_str_py = PyObject_Repr(obj))) {
  109. msg.append(" with message ");
  110. msg.append(PyUnicode_AsUTF8(obj_str_py));
  111. Py_DECREF(obj_str_py);
  112. } else {
  113. msg.append(" with unknown message");
  114. }
  115. }
  116. // throwing exception may cause abort due to unknown reasons; so we first
  117. // log the message
  118. mgb_log_error("caught exception from python callback: %s", msg.c_str());
  119. fflush(stdout);
  120. fflush(stderr);
  121. throw PyExceptionForward{etype, obj, trace, msg};
  122. }
  123. /* ============== namespace npy ============== */
  124. namespace npy {
  125. int to_mgb_supported_dtype_raw(int dtype) {
  126. if (dtype == NPY_INT64)
  127. return NPY_INT32;
  128. if (dtype == NPY_FLOAT64)
  129. return NPY_FLOAT32;
  130. return dtype;
  131. }
  132. #define FOREACH_NPY_DTYPE_PAIR(cb) \
  133. cb(Uint8, NPY_UINT8) cb(Int8, NPY_INT8) cb(Uint16, NPY_UINT16) \
  134. cb(Int16, NPY_INT16) cb(Int32, NPY_INT32) cb(Float16, NPY_FLOAT16) \
  135. cb(Float32, NPY_FLOAT32) cb(Bool, NPY_BOOL)
  136. #define FOREACH_NPY_MGB_DTYPE_PAIR(cb) \
  137. FOREACH_NPY_DTYPE_PAIR(cb) \
  138. FOREACH_MGB_DTYPE_PAIR(cb)
  139. //! convert megbrain dtype to numpy dtype
  140. int dtype_mgb2np_raw(DType dtype) {
  141. mgb_assert(dtype.valid(), "attempt to convert from invalid dtype");
  142. switch (dtype.enumv()) {
  143. #define cb(_m, _n) \
  144. case DTypeEnum::_m: \
  145. return _n;
  146. FOREACH_NPY_MGB_DTYPE_PAIR(cb)
  147. #undef cb
  148. default:
  149. break;
  150. }
  151. throw ConversionError(
  152. ssprintf("can not convert dtype %s to numpy dtype", dtype.name()));
  153. }
  154. //! Convert MegBrain DType to NumPy DType descriptor, the caller receives a new
  155. //! reference to the descriptor.
  156. std::unique_ptr<PyArray_Descr, PyArrayDescrDeleter> dtype_mgb2np_descr(DType dtype) {
  157. PYTHON_GIL;
  158. mgb_assert(dtype.valid(), "attempt to convert from invalid dtype");
  159. auto build_mgb_dtype_dict =
  160. [](const char* name,
  161. const std::vector<std::pair<const char*, PyObject*>>& data) {
  162. PyObject* metadata = PyDict_New();
  163. PyObject* mgb_dtype_metadata = PyDict_New();
  164. PyDict_SetItemString(
  165. mgb_dtype_metadata, "name", PyUnicode_FromString(name));
  166. for (const auto& d : data) {
  167. PyDict_SetItemString(mgb_dtype_metadata, d.first, d.second);
  168. }
  169. PyDict_SetItemString(metadata, "mgb_dtype", mgb_dtype_metadata);
  170. return metadata;
  171. };
  172. if (dtype.has_param()) {
  173. PyArray_Descr* type_descr;
  174. switch (dtype.enumv()) {
  175. case DTypeEnum::QuantizedS1: {
  176. auto& param = dtype.param<dtype::QuantizedS1>();
  177. type_descr = PyArray_DescrNewFromType(NPY_INT8);
  178. type_descr->metadata = build_mgb_dtype_dict(
  179. DTypeTrait<dtype::QuantizedS1>::name,
  180. {{"scale", PyFloat_FromDouble(param.scale)}});
  181. break;
  182. }
  183. case DTypeEnum::Quantized4Asymm: {
  184. auto& param = dtype.param<dtype::Quantized4Asymm>();
  185. type_descr = PyArray_DescrNewFromType(NPY_UINT8);
  186. type_descr->metadata = build_mgb_dtype_dict(
  187. DTypeTrait<dtype::Quantized4Asymm>::name,
  188. {{"scale", PyFloat_FromDouble(param.scale)},
  189. {"zero_point", PyLong_FromLong(param.zero_point)}});
  190. break;
  191. }
  192. case DTypeEnum::QuantizedS4: {
  193. auto& param = dtype.param<dtype::QuantizedS4>();
  194. type_descr = PyArray_DescrNewFromType(NPY_INT8);
  195. type_descr->metadata = build_mgb_dtype_dict(
  196. DTypeTrait<dtype::QuantizedS4>::name,
  197. {{"scale", PyFloat_FromDouble(param.scale)}});
  198. break;
  199. }
  200. case DTypeEnum::Quantized8Asymm: {
  201. auto& param = dtype.param<dtype::Quantized8Asymm>();
  202. type_descr = PyArray_DescrNewFromType(NPY_UINT8);
  203. type_descr->metadata = build_mgb_dtype_dict(
  204. DTypeTrait<dtype::Quantized8Asymm>::name,
  205. {{"scale", PyFloat_FromDouble(param.scale)},
  206. {"zero_point", PyLong_FromLong(param.zero_point)}});
  207. break;
  208. }
  209. case DTypeEnum::QuantizedS8: {
  210. auto& param = dtype.param<dtype::QuantizedS8>();
  211. type_descr = PyArray_DescrNewFromType(NPY_INT8);
  212. type_descr->metadata = build_mgb_dtype_dict(
  213. DTypeTrait<dtype::QuantizedS8>::name,
  214. {{"scale", PyFloat_FromDouble(param.scale)}});
  215. break;
  216. }
  217. case DTypeEnum::QuantizedS32: {
  218. auto& param = dtype.param<dtype::QuantizedS32>();
  219. type_descr = PyArray_DescrNewFromType(NPY_INT32);
  220. type_descr->metadata = build_mgb_dtype_dict(
  221. DTypeTrait<dtype::QuantizedS32>::name,
  222. {{"scale", PyFloat_FromDouble(param.scale)}});
  223. break;
  224. }
  225. default:
  226. mgb_throw(
  227. ConversionError, "unhandled parameterized DType %s",
  228. dtype.name());
  229. }
  230. return std::unique_ptr<PyArray_Descr, PyArrayDescrDeleter>(type_descr);
  231. }
  232. PyArray_Descr* basic_descr = PyArray_DescrFromType(dtype_mgb2np_raw(dtype));
  233. mgb_assert(
  234. basic_descr != nullptr,
  235. "failed to convert expected dtype to numpy type descriptor");
  236. return std::unique_ptr<PyArray_Descr, PyArrayDescrDeleter>(basic_descr);
  237. }
  238. DType dtype_np2mgb_raw(int npt) {
  239. switch (npt) {
  240. #define cb(_m, _n) \
  241. case _n: \
  242. return dtype::_m();
  243. FOREACH_NPY_DTYPE_PAIR(cb)
  244. #undef cb
  245. }
  246. #define cb(_m, _n) \
  247. if (_n == npt) \
  248. return dtype::_m();
  249. FOREACH_MGB_DTYPE_PAIR(cb)
  250. #undef cb
  251. PYTHON_GIL;
  252. std::string msg;
  253. auto py_obj = PyArray_TypeObjectFromType(npt);
  254. if (!py_obj) {
  255. msg = ssprintf("unknown numpy dtype enum %d", npt);
  256. } else {
  257. msg = ssprintf("unsupported numpy dtype %s", repr_pyobj(py_obj).c_str());
  258. }
  259. Py_DECREF(py_obj);
  260. throw ConversionError(msg);
  261. }
  262. DType dtype_np2mgb_descr(PyArray_Descr* descr) {
  263. PYTHON_GIL;
  264. auto handle_parameterized_dtype = [](PyObject* metadata) -> DType {
  265. mgb_assert(
  266. PyDict_Check(metadata),
  267. "Invalid parameterized DType metadata: should be a dict");
  268. PyObject* dtype_name_py = PyDict_GetItemString(metadata, "name");
  269. mgb_assert(
  270. PyUnicode_Check(dtype_name_py),
  271. "Invalid parameterized DType metadata: name should be a str");
  272. std::string dtype_name(PyUnicode_AsUTF8(dtype_name_py));
  273. if (dtype_name == "Quantized8Asymm") {
  274. PyObject* scale_py = PyDict_GetItemString(metadata, "scale");
  275. PyObject* zero_point_py = PyDict_GetItemString(metadata, "zero_point");
  276. mgb_assert(
  277. scale_py && zero_point_py,
  278. "Invalid Quantized8Asymm metadata: missing scale or "
  279. "zero_point.");
  280. mgb_assert(
  281. PyFloat_Check(scale_py),
  282. "Invalid Quantized8Asymm metadata: scale should be float");
  283. mgb_assert(
  284. PyLong_Check(zero_point_py),
  285. "Invalid Quantized8Asymm metadata: zero_point should be "
  286. "integer");
  287. auto zero_point = PyLong_AS_LONG(zero_point_py);
  288. mgb_assert(
  289. zero_point >= 0 && zero_point < 256,
  290. "Invalid Quantized8Asymm metadata: zero_point should be "
  291. "in [0, 256)");
  292. return dtype::Quantized8Asymm(
  293. static_cast<float>(PyFloat_AS_DOUBLE(scale_py)),
  294. static_cast<uint8_t>(zero_point));
  295. }
  296. if (dtype_name == "Quantized4Asymm") {
  297. PyObject* scale_py = PyDict_GetItemString(metadata, "scale");
  298. PyObject* zero_point_py = PyDict_GetItemString(metadata, "zero_point");
  299. mgb_assert(
  300. scale_py && zero_point_py,
  301. "Invalid Quantized4Asymm metadata: missing scale or "
  302. "zero_point.");
  303. mgb_assert(
  304. PyFloat_Check(scale_py),
  305. "Invalid Quantized4Asymm metadata: scale should be float");
  306. mgb_assert(
  307. PyLong_Check(zero_point_py),
  308. "Invalid Quantized4Asymm metadata: zero_point should be "
  309. "integer");
  310. auto zero_point = PyLong_AS_LONG(zero_point_py);
  311. mgb_assert(
  312. zero_point >= 0 && zero_point < 15,
  313. "Invalid Quantized4Asymm metadata: zero_point should be "
  314. "in [0, 15)");
  315. return dtype::Quantized4Asymm(
  316. static_cast<float>(PyFloat_AS_DOUBLE(scale_py)),
  317. static_cast<uint8_t>(zero_point));
  318. }
  319. if (dtype_name == "QuantizedS32" || dtype_name == "QuantizedS8" ||
  320. dtype_name == "QuantizedS4" || dtype_name == "QuantizedS1") {
  321. PyObject* scale_py = PyDict_GetItemString(metadata, "scale");
  322. mgb_assert(scale_py, "Invalid metadata: missing scale");
  323. mgb_assert(
  324. PyFloat_Check(scale_py), "Invalid metadata: scale should be float");
  325. float scale = static_cast<float>(PyFloat_AS_DOUBLE(scale_py));
  326. if (dtype_name == "QuantizedS32") {
  327. return dtype::QuantizedS32(scale);
  328. } else if (dtype_name == "QuantizedS8") {
  329. return dtype::QuantizedS8(scale);
  330. } else if (dtype_name == "QuantizedS4") {
  331. return dtype::QuantizedS4(scale);
  332. } else if (dtype_name == "QuantizedS1") {
  333. return dtype::QuantizedS1(scale);
  334. }
  335. }
  336. throw ConversionError(
  337. ssprintf("Unknown parameterized DType: %s", dtype_name.c_str())
  338. .c_str());
  339. };
  340. PyObject* dtype_metadata;
  341. if (descr->metadata && PyDict_Check(descr->metadata) &&
  342. (dtype_metadata = PyDict_GetItemString(descr->metadata, "mgb_dtype"))) {
  343. return handle_parameterized_dtype(dtype_metadata);
  344. }
  345. return dtype_np2mgb_raw(descr->type_num);
  346. }
  347. HostTensorND lowbit_ndarray_to_host_tensor(
  348. CompNode comp_node, TensorLayout& layout, PyArrayObject* input) {
  349. auto src_ptr = reinterpret_cast<dt_byte*>(PyArray_DATA(input));
  350. if (!layout.ndim) {
  351. // numpy scalar
  352. mgb_assert(src_ptr, "can not convert from null numpy array");
  353. layout.init_contiguous_stride({1});
  354. } else {
  355. mgb_assert(
  356. layout.ndim && layout.ndim <= TensorShape::MAX_NDIM,
  357. "unsupported ndim %zu", layout.ndim);
  358. TensorLayout ly;
  359. ly.ndim = layout.ndim;
  360. for (size_t i = 0; i < layout.ndim; ++i) {
  361. ly.shape[i] = layout.shape[i] = PyArray_SHAPE(input)[i];
  362. ly.stride[i] = PyArray_STRIDE(input, i);
  363. mgb_assert(layout.shape[i], "zero shape not supported");
  364. }
  365. mgb_assert(ly.is_physical_contiguous());
  366. layout.init_contiguous_stride();
  367. }
  368. HostTensorND ret{comp_node, layout};
  369. if (layout.format.is_lowbit_aligned()) {
  370. mgb_assert(layout.is_contiguous());
  371. lowbit_memcpy_byte2aligned(ret.raw_ptr(), src_ptr, layout);
  372. } else {
  373. lowbit_memcpy_byte2compact(
  374. layout.dtype, ret.raw_ptr(), src_ptr, layout.total_nr_elems());
  375. }
  376. return ret;
  377. }
  378. /*!
  379. * \brief convert a python object to tensor and try to borrow memory if the
  380. * original object is a contiguous numpy array
  381. * \param dtype see np2tensor
  382. * \return the megbrain tensor, and whether memory is borrowed
  383. */
  384. std::pair<HostTensorND, bool> np2tensor_try_borrow(
  385. PyObject* obj, const npy::Meth& meth, DType dtype) {
  386. auto dest_cn = meth.dest_cn_;
  387. mgb_assert(dest_cn.valid());
  388. PYTHON_GIL;
  389. PyArray_Descr* expected_descr = nullptr;
  390. if (dtype.valid()) {
  391. // The reference to expected_descr will be stealed later.
  392. expected_descr = dtype_mgb2np_descr(dtype).release();
  393. }
  394. // make result from PyArrayObject; its reference may be stolen
  395. auto make_from_arr = [&](PyArrayObject* input, bool allow_borrow) {
  396. TensorLayout layout{{}, dtype_np2mgb_descr(PyArray_DESCR(input))};
  397. if (dtype.valid())
  398. mgb_assert(dtype == layout.dtype);
  399. layout.ndim = PyArray_NDIM(input);
  400. if (layout.dtype.is_low_bit()) {
  401. auto ret = lowbit_ndarray_to_host_tensor(dest_cn, layout, input);
  402. if (meth.dest_tensor_) {
  403. meth.dest_tensor_->copy_from(ret);
  404. ret = *meth.dest_tensor_;
  405. }
  406. return std::make_pair(ret, false);
  407. }
  408. auto data = reinterpret_cast<dt_byte*>(PyArray_DATA(input));
  409. if (!layout.ndim) {
  410. // numpy scalar
  411. mgb_assert(data, "can not convert from null numpy array");
  412. layout.init_contiguous_stride({1});
  413. } else {
  414. mgb_assert(
  415. layout.ndim && layout.ndim <= TensorShape::MAX_NDIM,
  416. "unsupported ndim %zu", layout.ndim);
  417. auto dsize = layout.dtype.size();
  418. bool is_empty = false;
  419. for (size_t i = 0; i < layout.ndim; ++i) {
  420. layout.shape[i] = PyArray_SHAPE(input)[i];
  421. layout.stride[i] = PyArray_STRIDE(input, i);
  422. if (!layout.shape[i]) {
  423. is_empty = true;
  424. }
  425. mgb_assert(
  426. layout.stride[i] % dsize == 0, "bad stride %zd",
  427. layout.stride[i]);
  428. layout.stride[i] /= dsize;
  429. }
  430. mgb_assert(is_empty || layout.is_contiguous());
  431. }
  432. if (!meth.dest_tensor_ && allow_borrow) {
  433. Py_INCREF(input);
  434. PyObjRefKeeper ref_obj_cvt{reinterpret_cast<PyObject*>(input)};
  435. HostTensorStorage storage;
  436. auto input_ptr = ref_obj_cvt.make_shared(data);
  437. storage.reset(dest_cn, layout.span().high_byte, input_ptr);
  438. HostTensorND ret;
  439. ret.reset(storage, layout);
  440. return std::make_pair(ret, true);
  441. } else {
  442. auto storage = HostTensorStorage(dest_cn);
  443. storage.ensure_size(layout.span().dist_byte());
  444. memcpy(storage.ptr(), data, layout.span().dist_byte());
  445. HostTensorND ret{dest_cn, layout.dtype};
  446. if (meth.dest_tensor_) {
  447. meth.dest_tensor_->reset(storage, layout);
  448. return std::make_pair(*meth.dest_tensor_, false);
  449. } else {
  450. HostTensorND ret;
  451. ret.reset(storage, layout);
  452. return std::make_pair(ret, false);
  453. }
  454. }
  455. };
  456. PyArrayObject* obj_as_arr = nullptr;
  457. do {
  458. // check contiguous and dtype, and borrow mem if ok
  459. if (!PyArray_Check(obj))
  460. break;
  461. obj_as_arr = reinterpret_cast<PyArrayObject*>(obj);
  462. int typenum = PyArray_DTYPE(obj_as_arr)->type_num;
  463. // We have to check dtype.valid() and typenum first to avoid
  464. // accidentally trigger ConversionError on incompatible dtypes which can
  465. // be automatically converted into comptaible ones (e.g. float64).
  466. if (dtype.valid() && (expected_descr->type_num != typenum ||
  467. dtype_np2mgb_descr(PyArray_DTYPE(obj_as_arr)) != dtype))
  468. break;
  469. if (typenum != to_mgb_supported_dtype_raw(typenum)) {
  470. mgb_assert(!dtype.valid() && expected_descr == nullptr);
  471. expected_descr = PyArray_DescrFromType(to_mgb_supported_dtype_raw(typenum));
  472. break;
  473. }
  474. if (PyArray_ISCARRAY_RO(obj_as_arr)) {
  475. return make_from_arr(obj_as_arr, true);
  476. }
  477. } while (0);
  478. constexpr auto NP_FLAGS = NPY_ARRAY_C_CONTIGUOUS | NPY_ARRAY_FORCECAST;
  479. PyObject* obj_cvt;
  480. if (obj_as_arr) {
  481. obj_cvt = PyArray_FromArray(obj_as_arr, expected_descr, NP_FLAGS);
  482. } else {
  483. obj_cvt = PyArray_FromAny(obj, expected_descr, 0, 0, NP_FLAGS, nullptr);
  484. }
  485. if (obj_cvt) {
  486. // convert to mgb supported dtype
  487. auto arr = reinterpret_cast<PyArrayObject*>(obj_cvt);
  488. int dt0 = PyArray_TYPE(arr), dt1 = to_mgb_supported_dtype_raw(dt0);
  489. if (dt0 != dt1) {
  490. mgb_assert(expected_descr == nullptr);
  491. expected_descr = PyArray_DescrFromType(dt1);
  492. mgb_assert(expected_descr);
  493. auto obj_cvt_new =
  494. PyArray_FromAny(obj_cvt, expected_descr, 0, 0, NP_FLAGS, nullptr);
  495. Py_DECREF(obj_cvt);
  496. obj_cvt = obj_cvt_new;
  497. }
  498. }
  499. if (!obj_cvt) {
  500. if (PyErr_Occurred()) {
  501. PyExceptionForward::throw_();
  502. }
  503. throw ConversionError(ssprintf(
  504. "can not convert to numpy array from %s", repr_pyobj(obj).c_str()));
  505. }
  506. auto ret = make_from_arr(reinterpret_cast<PyArrayObject*>(obj_cvt), false);
  507. Py_DECREF(obj_cvt);
  508. return ret;
  509. }
  510. //! hold a reference to HostTensorND
  511. class HostTensorNDRefHolder final : public NonCopyableObj {
  512. HostTensorND m_val;
  513. static MemPool<HostTensorNDRefHolder> sm_mem_pool;
  514. friend class MemPool<HostTensorNDRefHolder>;
  515. HostTensorNDRefHolder(const HostTensorND& v) : m_val{v} {}
  516. public:
  517. static HostTensorNDRefHolder* alloc(const HostTensorND& v) {
  518. return sm_mem_pool.alloc(v);
  519. }
  520. static void free(HostTensorNDRefHolder* p) { return sm_mem_pool.free(p); }
  521. };
  522. MemPool<HostTensorNDRefHolder> HostTensorNDRefHolder::sm_mem_pool;
  523. void ndarray_shared_from_tensor_py_capsule_dtor(PyObject* cap) {
  524. auto ptr = PyCapsule_GetPointer(cap, "HostTensorND");
  525. mgb_assert(ptr, "not a PyCapsule: %s", repr_pyobj(cap).c_str());
  526. HostTensorNDRefHolder::free(static_cast<HostTensorNDRefHolder*>(ptr));
  527. }
  528. PyObject* ndarray_from_tensor(const HostTensorND& val, ShareType share_type) {
  529. if (!val.layout().is_contiguous() && !val.shape().is_empty()) {
  530. mgb_assert(share_type != ShareType::MUST_SHARE);
  531. HostTensorND contig;
  532. contig.copy_from(val);
  533. return ndarray_from_tensor(contig, ShareType::TRY_SHARE);
  534. }
  535. PYTHON_GIL;
  536. npy_intp dims[TensorLayout::MAX_NDIM];
  537. for (size_t i = 0; i < val.layout().ndim; ++i)
  538. dims[i] = val.shape()[i];
  539. PyObject* ret = nullptr;
  540. auto alloc_new_ret = [&]() {
  541. mgb_assert(!ret);
  542. ret = PyArray_NewFromDescr(
  543. &PyArray_Type, dtype_mgb2np_descr(val.dtype()).release(),
  544. val.layout().ndim, dims, nullptr, nullptr, 0, nullptr);
  545. mgb_assert(ret, "failed to allocate array");
  546. mgb_assert(PyArray_Check(ret));
  547. return PyArray_DATA(reinterpret_cast<PyArrayObject*>(ret));
  548. };
  549. if (val.dtype().is_low_bit()) {
  550. mgb_assert(
  551. share_type != ShareType::MUST_SHARE,
  552. "can not share memory for lowbit dtype");
  553. const auto& layout = val.layout();
  554. if (layout.format.is_lowbit_aligned()) {
  555. lowbit_memcpy_aligned2byte(alloc_new_ret(), val.raw_ptr(), val.layout());
  556. } else {
  557. lowbit_memcpy_compact2byte(
  558. val.dtype(), alloc_new_ret(), val.raw_ptr(),
  559. val.layout().total_nr_elems());
  560. }
  561. } else if (share_type == ShareType::MUST_UNSHARE) {
  562. memcpy(alloc_new_ret(), val.raw_ptr(), val.layout().span().dist_byte());
  563. } else {
  564. // share data
  565. ret = PyArray_NewFromDescr(
  566. &PyArray_Type, dtype_mgb2np_descr(val.dtype()).release(),
  567. val.layout().ndim, dims, nullptr, const_cast<dt_byte*>(val.raw_ptr()),
  568. 0, nullptr);
  569. mgb_assert(ret, "failed to alloc ndarray");
  570. auto capsule = PyCapsule_New(
  571. HostTensorNDRefHolder::alloc(val), "HostTensorND",
  572. ndarray_shared_from_tensor_py_capsule_dtor);
  573. mgb_assert(capsule, "failed to create PyCapsule");
  574. auto err =
  575. PyArray_SetBaseObject(reinterpret_cast<PyArrayObject*>(ret), capsule);
  576. mgb_assert(!err);
  577. }
  578. return ret;
  579. }
  580. HostTensorND np2tensor(PyObject* obj, const Meth& meth, DType dtype) {
  581. auto ret_full = np2tensor_try_borrow(obj, meth, dtype);
  582. if (meth.must_borrow_) {
  583. mgb_assert(
  584. ret_full.second,
  585. "can not borrow from numpy array as contig array with dtype "
  586. "%s; src=%s",
  587. dtype.name(), repr_pyobj(obj).c_str());
  588. }
  589. return ret_full.first;
  590. }
  591. PyObject* dtype_mgb2np(mgb::DType dtype) {
  592. PYTHON_GIL;
  593. // According to
  594. // https://docs.scipy.org/doc/numpy/reference/c-api.array.html#c.PyArray_TypeObjectFromType
  595. // the following is equivalent to PyArray_TypeObjectFromType for built-in
  596. // types.
  597. if (!dtype.valid()) {
  598. Py_XINCREF(Py_None);
  599. return Py_None;
  600. }
  601. auto descr = dtype_mgb2np_descr(dtype);
  602. if (descr == nullptr) {
  603. Py_XINCREF(Py_None);
  604. return Py_None;
  605. }
  606. if (dtype.has_param()) {
  607. return reinterpret_cast<PyObject*>(descr.release());
  608. }
  609. PyObject* typeobj = reinterpret_cast<PyObject*>(descr->typeobj);
  610. Py_XINCREF(typeobj);
  611. return typeobj;
  612. }
  613. mgb::DType dtype_np2mgb(PyObject* obj) {
  614. mgb_assert(obj && obj != Py_None, "can not convert null PyObject to numpy dtype");
  615. // see
  616. // http://stackoverflow.com/questions/8477122/numpy-c-api-convert-type-object-to-type-number
  617. PYTHON_GIL;
  618. PyArray_Descr* dtype;
  619. if (!PyArray_DescrConverter(obj, &dtype)) {
  620. throw ConversionError(ssprintf(
  621. "can not convert to np.dtype from %s", repr_pyobj(obj).c_str()));
  622. }
  623. mgb::DType result = dtype_np2mgb_descr(dtype);
  624. Py_DECREF(dtype);
  625. return result;
  626. }
  627. PyObject* to_mgb_supported_dtype(PyObject* dtype) {
  628. PYTHON_GIL;
  629. PyArray_Descr* descr;
  630. if (!PyArray_DescrConverter(dtype, &descr)) {
  631. throw ConversionError(ssprintf(
  632. "can not convert to np.dtype from %s", repr_pyobj(dtype).c_str()));
  633. }
  634. mgb_assert(
  635. !descr->metadata,
  636. "unexpected metadata in dtype: "
  637. "dtype_obj=%s metadata=%s",
  638. repr_pyobj(dtype).c_str(), repr_pyobj(descr->metadata).c_str());
  639. int type_num = to_mgb_supported_dtype_raw(descr->type_num);
  640. return PyArray_TypeObjectFromType(type_num);
  641. }
  642. TensorShape vec2shape(const std::vector<size_t>& vec) {
  643. TensorShape shape;
  644. mgb_assert(
  645. vec.size() <= TensorShape::MAX_NDIM, "dim too large: %zd (max %zd)",
  646. vec.size(), TensorShape::MAX_NDIM);
  647. shape.ndim = vec.size();
  648. for (size_t i = 0; i < vec.size(); i++) {
  649. if (!vec[i]) {
  650. shape.ndim = 0;
  651. break;
  652. }
  653. shape[i] = vec[i];
  654. }
  655. mgb_assert(shape.ndim, "shape should not be empty");
  656. return shape;
  657. }
  658. } // namespace npy