You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

helper.cpp 26 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705
  1. #include "./helper.h"
  2. #include <pybind11/eval.h>
  3. #include "megbrain/graph/exc_extra_info.h"
  4. #include "megbrain/graph/event.h"
  5. #include "megbrain/graph/cg.h"
  6. #include "megbrain/tensor.h"
  7. #include "megbrain/utils/mempool.h"
  8. #include "./numpy_dtypes.h"
  9. /*
  10. * demangle typeid, see
  11. * http://stackoverflow.com/questions/281818/unmangling-the-result-of-stdtype-infoname
  12. */
  13. #ifdef __GNUG__
  14. #include <cstdlib>
  15. #include <memory>
  16. #include <cxxabi.h>
  17. namespace py = pybind11;
  18. PyTaskDipatcher py_task_q = {};
  19. py::module submodule(py::module parent, const char* name, const char* doc) {
  20. auto m = parent.def_submodule(name, doc);
  21. m.attr("__package__") = parent.attr("__name__");
  22. m.attr("__builtins__") = py::module::import("builtins");
  23. return m;
  24. }
  25. py::module rel_import(py::str name, py::module m, int level) {
  26. py::object import = py::module::import("builtins").attr("__import__");
  27. return import(name, m.attr("__dict__"), py::arg("level")=level);
  28. }
  29. namespace {
  30. std::string demangle_typeid(const char* name) {
  31. int status = -4; // some arbitrary value to eliminate the compiler warning
  32. // enable c++11 by passing the flag -std=c++11 to g++
  33. std::unique_ptr<char, void(*)(void*)> res {
  34. abi::__cxa_demangle(name, nullptr, nullptr, &status),
  35. std::free
  36. };
  37. return (status==0) ? res.get() : name ;
  38. }
  39. }
  40. #else
  41. namespace {
  42. // does nothing if not g++
  43. std::string demangle_typeid(const char* name) {
  44. return name;
  45. }
  46. }
  47. #endif
  48. using namespace mgb;
  49. using namespace cg;
  50. namespace {
  51. std::string repr_pyobj(PyObject *obj) {
  52. if (!obj)
  53. return "<null PyObject>";
  54. PYTHON_GIL;
  55. auto str = PyObject_Repr(obj);
  56. if (!str)
  57. return ssprintf("<PyObject at %p (repr failed)>", obj);
  58. std::string ret{PyUnicode_AsUTF8(str)};
  59. Py_DECREF(str);
  60. return ret;
  61. }
  62. template<typename T>
  63. std::string typeid_name(const T &t) {
  64. return demangle_typeid(typeid(t).name());
  65. }
  66. } // anonymous namespace
  67. /* ============== PyExceptionForward ============== */
  68. PyExceptionForward::~PyExceptionForward() {
  69. PYTHON_GIL;
  70. PyObjRefKeeper::deleter(m_type);
  71. PyObjRefKeeper::deleter(m_value);
  72. PyObjRefKeeper::deleter(m_traceback);
  73. }
  74. void PyExceptionForward::restore() {
  75. PyErr_Restore(m_type, m_value, m_traceback);
  76. m_type = m_value = m_traceback = nullptr;
  77. }
  78. void PyExceptionForward::throw_() {
  79. PyObject *etype, *obj, *trace;
  80. PyErr_Fetch(&etype, &obj, &trace);
  81. PyErr_NormalizeException(&etype, &obj, &trace);
  82. std::string msg{"python exception"};
  83. bool succ = false;
  84. if (etype && obj && trace) {
  85. auto run = [&]() {
  86. #define DEF(name, expr) \
  87. PyObjRefKeeper name{expr}; \
  88. if (!name.get()) \
  89. return
  90. DEF(mod, PyImport_ImportModule("traceback"));
  91. DEF(result, PyObject_CallMethod(mod.get(), "format_exception",
  92. "(OOO)", etype, obj, trace));
  93. if (!PyList_Check(result.get()))
  94. return;
  95. auto size = PyList_Size(result.get());
  96. msg.append(":\n");
  97. for (Py_ssize_t i = 0; i < size; ++i) {
  98. msg.append(" ");
  99. msg.append(PyUnicode_AsUTF8(PyList_GetItem(result.get(), i)));
  100. }
  101. msg.pop_back(); // remove last \n
  102. succ = true;
  103. #undef DEF
  104. };
  105. run();
  106. }
  107. if (!succ) {
  108. PyObject* obj_str_py;
  109. if (obj && (obj_str_py = PyObject_Repr(obj))) {
  110. msg.append(" with message ");
  111. msg.append(PyUnicode_AsUTF8(obj_str_py));
  112. Py_DECREF(obj_str_py);
  113. } else {
  114. msg.append(" with unknown message");
  115. }
  116. }
  117. // throwing exception may cause abort due to unknown reasons; so we first
  118. // log the message
  119. mgb_log_error("caught exception from python callback: %s", msg.c_str());
  120. fflush(stdout);
  121. fflush(stderr);
  122. throw PyExceptionForward{etype, obj, trace, msg};
  123. }
  124. /* ============== namespace npy ============== */
  125. namespace {
  126. int to_mgb_supported_dtype_raw(int dtype) {
  127. if (dtype == NPY_INT64)
  128. return NPY_INT32;
  129. if (dtype == NPY_FLOAT64)
  130. return NPY_FLOAT32;
  131. return dtype;
  132. }
  133. #define FOREACH_NPY_DTYPE_PAIR(cb) \
  134. cb(Uint8, NPY_UINT8) \
  135. cb(Int8, NPY_INT8) \
  136. cb(Int16, NPY_INT16) \
  137. cb(Int32, NPY_INT32) \
  138. cb(Float16, NPY_FLOAT16) \
  139. cb(Float32, NPY_FLOAT32) \
  140. cb(Bool, NPY_BOOL)
  141. #define FOREACH_NPY_MGB_DTYPE_PAIR(cb) \
  142. FOREACH_NPY_DTYPE_PAIR(cb) \
  143. FOREACH_MGB_DTYPE_PAIR(cb)
  144. //! convert megbrain dtype to numpy dtype
  145. int dtype_mgb2np_raw(DType dtype) {
  146. mgb_assert(dtype.valid(), "attempt to convert from invalid dtype");
  147. switch (dtype.enumv()) {
  148. #define cb(_m, _n) \
  149. case DTypeEnum::_m: \
  150. return _n;
  151. FOREACH_NPY_MGB_DTYPE_PAIR(cb)
  152. #undef cb
  153. default:
  154. break;
  155. }
  156. throw ConversionError(ssprintf(
  157. "can not convert dtype %s to numpy dtype", dtype.name()));
  158. }
  159. struct PyArrayDescrDeleter {
  160. void operator()(PyArray_Descr* obj) {
  161. Py_XDECREF(obj);
  162. }
  163. };
  164. //! Convert MegBrain DType to NumPy DType descriptor, the caller receives a new
  165. //! reference to the descriptor.
  166. std::unique_ptr<PyArray_Descr, PyArrayDescrDeleter> dtype_mgb2np_descr(
  167. DType dtype) {
  168. PYTHON_GIL;
  169. mgb_assert(dtype.valid(), "attempt to convert from invalid dtype");
  170. auto build_mgb_dtype_dict =
  171. [](const char* name,
  172. const std::vector<std::pair<const char*, PyObject*>>& data) {
  173. PyObject* metadata = PyDict_New();
  174. PyObject* mgb_dtype_metadata = PyDict_New();
  175. PyDict_SetItemString(mgb_dtype_metadata, "name",
  176. PyUnicode_FromString(name));
  177. for (const auto& d : data) {
  178. PyDict_SetItemString(mgb_dtype_metadata, d.first, d.second);
  179. }
  180. PyDict_SetItemString(metadata, "mgb_dtype", mgb_dtype_metadata);
  181. return metadata;
  182. };
  183. if (dtype.has_param()) {
  184. PyArray_Descr* type_descr;
  185. switch (dtype.enumv()) {
  186. case DTypeEnum::Quantized4Asymm: {
  187. auto& param = dtype.param<dtype::Quantized4Asymm>();
  188. type_descr = PyArray_DescrNewFromType(NPY_UINT8);
  189. type_descr->metadata = build_mgb_dtype_dict(
  190. DTypeTrait<dtype::Quantized4Asymm>::name,
  191. {{"scale", PyFloat_FromDouble(param.scale)},
  192. {"zero_point", PyLong_FromLong(param.zero_point)}});
  193. break;
  194. }
  195. case DTypeEnum::QuantizedS4: {
  196. auto& param = dtype.param<dtype::QuantizedS4>();
  197. type_descr = PyArray_DescrNewFromType(NPY_INT8);
  198. type_descr->metadata = build_mgb_dtype_dict(
  199. DTypeTrait<dtype::QuantizedS4>::name,
  200. {{"scale", PyFloat_FromDouble(param.scale)}});
  201. break;
  202. }
  203. case DTypeEnum::Quantized8Asymm: {
  204. auto& param = dtype.param<dtype::Quantized8Asymm>();
  205. type_descr = PyArray_DescrNewFromType(NPY_UINT8);
  206. type_descr->metadata = build_mgb_dtype_dict(
  207. DTypeTrait<dtype::Quantized8Asymm>::name,
  208. {{"scale", PyFloat_FromDouble(param.scale)},
  209. {"zero_point", PyLong_FromLong(param.zero_point)}});
  210. break;
  211. }
  212. case DTypeEnum::QuantizedS8: {
  213. auto& param = dtype.param<dtype::QuantizedS8>();
  214. type_descr = PyArray_DescrNewFromType(NPY_INT8);
  215. type_descr->metadata = build_mgb_dtype_dict(
  216. DTypeTrait<dtype::QuantizedS8>::name,
  217. {{"scale", PyFloat_FromDouble(param.scale)}});
  218. break;
  219. }
  220. case DTypeEnum::QuantizedS32: {
  221. auto& param = dtype.param<dtype::QuantizedS32>();
  222. type_descr = PyArray_DescrNewFromType(NPY_INT32);
  223. type_descr->metadata = build_mgb_dtype_dict(
  224. DTypeTrait<dtype::QuantizedS32>::name,
  225. {{"scale", PyFloat_FromDouble(param.scale)}});
  226. break;
  227. }
  228. default:
  229. mgb_throw(ConversionError, "unhandled parameterized DType %s",
  230. dtype.name());
  231. }
  232. return std::unique_ptr<PyArray_Descr, PyArrayDescrDeleter>(type_descr);
  233. }
  234. PyArray_Descr* basic_descr = PyArray_DescrFromType(dtype_mgb2np_raw(dtype));
  235. mgb_assert(basic_descr != nullptr,
  236. "failed to convert expected dtype to numpy type descriptor");
  237. return std::unique_ptr<PyArray_Descr, PyArrayDescrDeleter>(basic_descr);
  238. }
  239. DType dtype_np2mgb_raw(int npt) {
  240. switch (npt) {
  241. #define cb(_m, _n) \
  242. case _n: \
  243. return dtype::_m();
  244. FOREACH_NPY_DTYPE_PAIR(cb)
  245. #undef cb
  246. }
  247. #define cb(_m, _n) \
  248. if (_n == npt) return dtype::_m();
  249. FOREACH_MGB_DTYPE_PAIR(cb)
  250. #undef cb
  251. PYTHON_GIL;
  252. std::string msg;
  253. auto py_obj = PyArray_TypeObjectFromType(npt);
  254. if (!py_obj) {
  255. msg = ssprintf("unknown numpy dtype enum %d", npt);
  256. } else {
  257. msg = ssprintf("unsupported numpy dtype %s",
  258. repr_pyobj(py_obj).c_str());
  259. }
  260. Py_DECREF(py_obj);
  261. throw ConversionError(msg);
  262. }
  263. DType dtype_np2mgb_descr(PyArray_Descr* descr) {
  264. PYTHON_GIL;
  265. auto handle_parameterized_dtype = [](PyObject* metadata) -> DType {
  266. mgb_assert(PyDict_Check(metadata),
  267. "Invalid parameterized DType metadata: should be a dict");
  268. PyObject* dtype_name_py = PyDict_GetItemString(metadata, "name");
  269. mgb_assert(
  270. PyUnicode_Check(dtype_name_py),
  271. "Invalid parameterized DType metadata: name should be a str");
  272. std::string dtype_name(PyUnicode_AsUTF8(dtype_name_py));
  273. if (dtype_name == "Quantized8Asymm") {
  274. PyObject* scale_py = PyDict_GetItemString(metadata, "scale");
  275. PyObject* zero_point_py =
  276. PyDict_GetItemString(metadata, "zero_point");
  277. mgb_assert(scale_py && zero_point_py,
  278. "Invalid Quantized8Asymm metadata: missing scale or "
  279. "zero_point.");
  280. mgb_assert(
  281. PyFloat_Check(scale_py),
  282. "Invalid Quantized8Asymm metadata: scale should be float");
  283. mgb_assert(PyLong_Check(zero_point_py),
  284. "Invalid Quantized8Asymm metadata: zero_point should be "
  285. "integer");
  286. auto zero_point = PyLong_AS_LONG(zero_point_py);
  287. mgb_assert(zero_point >= 0 && zero_point < 256,
  288. "Invalid Quantized8Asymm metadata: zero_point should be "
  289. "in [0, 256)");
  290. return dtype::Quantized8Asymm(
  291. static_cast<float>(PyFloat_AS_DOUBLE(scale_py)),
  292. static_cast<uint8_t>(zero_point));
  293. }
  294. if (dtype_name == "Quantized4Asymm") {
  295. PyObject* scale_py = PyDict_GetItemString(metadata, "scale");
  296. PyObject* zero_point_py =
  297. PyDict_GetItemString(metadata, "zero_point");
  298. mgb_assert(scale_py && zero_point_py,
  299. "Invalid Quantized4Asymm metadata: missing scale or "
  300. "zero_point.");
  301. mgb_assert(
  302. PyFloat_Check(scale_py),
  303. "Invalid Quantized4Asymm metadata: scale should be float");
  304. mgb_assert(PyLong_Check(zero_point_py),
  305. "Invalid Quantized4Asymm metadata: zero_point should be "
  306. "integer");
  307. auto zero_point = PyLong_AS_LONG(zero_point_py);
  308. mgb_assert(zero_point >= 0 && zero_point < 15,
  309. "Invalid Quantized4Asymm metadata: zero_point should be "
  310. "in [0, 15)");
  311. return dtype::Quantized4Asymm(
  312. static_cast<float>(PyFloat_AS_DOUBLE(scale_py)),
  313. static_cast<uint8_t>(zero_point));
  314. }
  315. if (dtype_name == "QuantizedS32" || dtype_name == "QuantizedS8" ||
  316. dtype_name == "QuantizedS4") {
  317. PyObject* scale_py = PyDict_GetItemString(metadata, "scale");
  318. mgb_assert(scale_py, "Invalid metadata: missing scale");
  319. mgb_assert(PyFloat_Check(scale_py),
  320. "Invalid metadata: scale should be float");
  321. float scale = static_cast<float>(PyFloat_AS_DOUBLE(scale_py));
  322. if (dtype_name == "QuantizedS32") {
  323. return dtype::QuantizedS32(scale);
  324. } else if (dtype_name == "QuantizedS8"){
  325. return dtype::QuantizedS8(scale);
  326. } else {
  327. return dtype::QuantizedS4(scale);
  328. }
  329. }
  330. throw ConversionError(
  331. ssprintf("Unknown parameterized DType: %s", dtype_name.c_str())
  332. .c_str());
  333. };
  334. PyObject* dtype_metadata;
  335. if (descr->metadata && PyDict_Check(descr->metadata) &&
  336. (dtype_metadata = PyDict_GetItemString(descr->metadata, "mgb_dtype"))) {
  337. return handle_parameterized_dtype(dtype_metadata);
  338. }
  339. return dtype_np2mgb_raw(descr->type_num);
  340. }
  341. HostTensorND lowbit_ndarray_to_host_tensor(
  342. CompNode comp_node, TensorLayout &layout, PyArrayObject *input) {
  343. auto src_ptr = reinterpret_cast<dt_byte*>(PyArray_DATA(input));
  344. if (!layout.ndim) {
  345. // numpy scalar
  346. mgb_assert(src_ptr, "can not convert from null numpy array");
  347. layout.init_contiguous_stride({1});
  348. } else {
  349. mgb_assert(layout.ndim && layout.ndim <= TensorShape::MAX_NDIM,
  350. "unsupported ndim %zu", layout.ndim);
  351. for (size_t i = 0; i < layout.ndim; ++ i) {
  352. layout.shape[i] = PyArray_SHAPE(input)[i];
  353. layout.stride[i] = PyArray_STRIDE(input, i);
  354. mgb_assert(layout.shape[i], "zero shape not supported");
  355. }
  356. mgb_assert(layout.is_contiguous());
  357. }
  358. HostTensorND ret{comp_node, layout};
  359. lowbit_memcpy_byte2compact(layout.dtype, ret.raw_ptr(), src_ptr,
  360. layout.total_nr_elems());
  361. return ret;
  362. }
  363. /*!
  364. * \brief convert a python object to tensor and try to borrow memory if the
  365. * original object is a contiguous numpy array
  366. * \param dtype see np2tensor
  367. * \return the megbrain tensor, and whether memory is borrowed
  368. */
  369. std::pair<HostTensorND, bool> np2tensor_try_borrow(
  370. PyObject *obj, const npy::Meth& meth, DType dtype) {
  371. auto dest_cn = meth.dest_cn_;
  372. mgb_assert(dest_cn.valid());
  373. PYTHON_GIL;
  374. PyArray_Descr* expected_descr = nullptr;
  375. if (dtype.valid()) {
  376. // The reference to expected_descr will be stealed later.
  377. expected_descr = dtype_mgb2np_descr(dtype).release();
  378. }
  379. // make result from PyArrayObject; its reference may be stolen
  380. auto make_from_arr = [&](PyArrayObject *input, bool allow_borrow) {
  381. TensorLayout layout;
  382. layout.dtype = dtype_np2mgb_descr(PyArray_DESCR(input));
  383. if (dtype.valid())
  384. mgb_assert(dtype == layout.dtype);
  385. layout.ndim = PyArray_NDIM(input);
  386. if (layout.dtype.is_low_bit()) {
  387. auto ret = lowbit_ndarray_to_host_tensor(dest_cn, layout, input);
  388. if (meth.dest_tensor_) {
  389. meth.dest_tensor_->copy_from(ret);
  390. ret = *meth.dest_tensor_;
  391. }
  392. return std::make_pair(ret, false);
  393. }
  394. auto data = reinterpret_cast<dt_byte*>(PyArray_DATA(input));
  395. if (!layout.ndim) {
  396. // numpy scalar
  397. mgb_assert(data, "can not convert from null numpy array");
  398. layout.init_contiguous_stride({1});
  399. } else {
  400. mgb_assert(layout.ndim && layout.ndim <= TensorShape::MAX_NDIM,
  401. "unsupported ndim %zu", layout.ndim);
  402. auto dsize = layout.dtype.size();
  403. bool is_empty = false;
  404. for (size_t i = 0; i < layout.ndim; ++ i) {
  405. layout.shape[i] = PyArray_SHAPE(input)[i];
  406. layout.stride[i] = PyArray_STRIDE(input, i);
  407. if (!layout.shape[i]) {
  408. is_empty = true;
  409. }
  410. mgb_assert(layout.stride[i] % dsize == 0,
  411. "bad stride %zd", layout.stride[i]);
  412. layout.stride[i] /= dsize;
  413. }
  414. mgb_assert(is_empty || layout.is_contiguous());
  415. }
  416. if (!meth.dest_tensor_ && allow_borrow) {
  417. Py_INCREF(input);
  418. PyObjRefKeeper ref_obj_cvt{reinterpret_cast<PyObject*>(input)};
  419. HostTensorStorage storage;
  420. auto input_ptr = ref_obj_cvt.make_shared(data);
  421. storage.reset(dest_cn, layout.span().high_byte, input_ptr);
  422. HostTensorND ret;
  423. ret.reset(storage, layout);
  424. return std::make_pair(ret, true);
  425. } else {
  426. auto storage = HostTensorStorage(dest_cn);
  427. storage.ensure_size(layout.span().dist_byte());
  428. memcpy(storage.ptr(), data, layout.span().dist_byte());
  429. HostTensorND ret{dest_cn, layout.dtype};
  430. if (meth.dest_tensor_) {
  431. meth.dest_tensor_->reset(storage, layout);
  432. return std::make_pair(*meth.dest_tensor_, false);
  433. } else {
  434. HostTensorND ret;
  435. ret.reset(storage, layout);
  436. return std::make_pair(ret, false);
  437. }
  438. }
  439. };
  440. PyArrayObject *obj_as_arr = nullptr;
  441. do {
  442. // check contiguous and dtype, and borrow mem if ok
  443. if (!PyArray_Check(obj))
  444. break;
  445. obj_as_arr = reinterpret_cast<PyArrayObject*>(obj);
  446. int typenum = PyArray_DTYPE(obj_as_arr)->type_num;
  447. // We have to check dtype.valid() and typenum first to avoid
  448. // accidentally trigger ConversionError on incompatible dtypes which can
  449. // be automatically converted into comptaible ones (e.g. float64).
  450. if (dtype.valid() &&
  451. (expected_descr->type_num != typenum ||
  452. dtype_np2mgb_descr(PyArray_DTYPE(obj_as_arr)) != dtype))
  453. break;
  454. if (typenum != to_mgb_supported_dtype_raw(typenum)) {
  455. mgb_assert(!dtype.valid() && expected_descr == nullptr);
  456. expected_descr =
  457. PyArray_DescrFromType(to_mgb_supported_dtype_raw(typenum));
  458. break;
  459. }
  460. if (PyArray_ISCARRAY_RO(obj_as_arr)) {
  461. return make_from_arr(obj_as_arr, true);
  462. }
  463. } while(0);
  464. constexpr auto NP_FLAGS = NPY_ARRAY_C_CONTIGUOUS | NPY_ARRAY_FORCECAST;
  465. PyObject *obj_cvt;
  466. if (obj_as_arr) {
  467. obj_cvt = PyArray_FromArray(obj_as_arr, expected_descr, NP_FLAGS);
  468. } else {
  469. obj_cvt = PyArray_FromAny(obj, expected_descr, 0, 0, NP_FLAGS, nullptr);
  470. }
  471. if (obj_cvt) {
  472. // convert to mgb supported dtype
  473. auto arr = reinterpret_cast<PyArrayObject*>(obj_cvt);
  474. int dt0 = PyArray_TYPE(arr), dt1 = to_mgb_supported_dtype_raw(dt0);
  475. if (dt0 != dt1) {
  476. mgb_assert(expected_descr == nullptr);
  477. expected_descr = PyArray_DescrFromType(dt1);
  478. mgb_assert(expected_descr);
  479. auto obj_cvt_new = PyArray_FromAny(
  480. obj_cvt, expected_descr, 0, 0, NP_FLAGS, nullptr);
  481. Py_DECREF(obj_cvt);
  482. obj_cvt = obj_cvt_new;
  483. }
  484. }
  485. if (!obj_cvt) {
  486. if (PyErr_Occurred()) {
  487. PyExceptionForward::throw_();
  488. }
  489. throw ConversionError(ssprintf("can not convert to numpy array from %s",
  490. repr_pyobj(obj).c_str()));
  491. }
  492. auto ret = make_from_arr(reinterpret_cast<PyArrayObject*>(obj_cvt), false);
  493. Py_DECREF(obj_cvt);
  494. return ret;
  495. }
  496. //! hold a reference to HostTensorND
  497. class HostTensorNDRefHolder final: public NonCopyableObj {
  498. HostTensorND m_val;
  499. static MemPool<HostTensorNDRefHolder> sm_mem_pool;
  500. friend class MemPool<HostTensorNDRefHolder>;
  501. HostTensorNDRefHolder(const HostTensorND &v):
  502. m_val{v}
  503. {
  504. }
  505. public:
  506. static HostTensorNDRefHolder* alloc(const HostTensorND &v) {
  507. return sm_mem_pool.alloc(v);
  508. }
  509. static void free(HostTensorNDRefHolder *p) {
  510. return sm_mem_pool.free(p);
  511. }
  512. };
  513. MemPool<HostTensorNDRefHolder> HostTensorNDRefHolder::sm_mem_pool;
  514. void ndarray_shared_from_tensor_py_capsule_dtor(PyObject *cap) {
  515. auto ptr = PyCapsule_GetPointer(cap, "HostTensorND");
  516. mgb_assert(ptr, "not a PyCapsule: %s", repr_pyobj(cap).c_str());
  517. HostTensorNDRefHolder::free(static_cast<HostTensorNDRefHolder*>(ptr));
  518. }
  519. } // anonymous namespace
  520. PyObject* npy::ndarray_from_tensor(
  521. const HostTensorND &val, ShareType share_type) {
  522. if (!val.layout().is_contiguous() && !val.shape().is_empty()) {
  523. mgb_assert(share_type != ShareType::MUST_SHARE);
  524. HostTensorND contig;
  525. contig.copy_from(val);
  526. return ndarray_from_tensor(contig, ShareType::TRY_SHARE);
  527. }
  528. PYTHON_GIL;
  529. npy_intp dims[TensorLayout::MAX_NDIM];
  530. for (size_t i = 0; i < val.layout().ndim; ++ i)
  531. dims[i] = val.shape()[i];
  532. PyObject* ret = nullptr;
  533. auto alloc_new_ret = [&]() {
  534. mgb_assert(!ret);
  535. ret = PyArray_NewFromDescr(
  536. &PyArray_Type, dtype_mgb2np_descr(val.dtype()).release(),
  537. val.layout().ndim, dims, nullptr, nullptr, 0, nullptr);
  538. mgb_assert(ret, "failed to allocate array");
  539. mgb_assert(PyArray_Check(ret));
  540. return PyArray_DATA(reinterpret_cast<PyArrayObject*>(ret));
  541. };
  542. if (val.dtype().is_low_bit()) {
  543. mgb_assert(share_type != ShareType::MUST_SHARE,
  544. "can not share memory for lowbit dtype");
  545. lowbit_memcpy_compact2byte(val.dtype(), alloc_new_ret(), val.raw_ptr(),
  546. val.layout().total_nr_elems());
  547. } else if (share_type == ShareType::MUST_UNSHARE) {
  548. memcpy(alloc_new_ret(), val.raw_ptr(), val.layout().span().dist_byte());
  549. } else {
  550. // share data
  551. ret = PyArray_NewFromDescr(
  552. &PyArray_Type, dtype_mgb2np_descr(val.dtype()).release(),
  553. val.layout().ndim, dims, nullptr,
  554. const_cast<dt_byte*>(val.raw_ptr()), 0, nullptr);
  555. mgb_assert(ret, "failed to alloc ndarray");
  556. auto capsule = PyCapsule_New(HostTensorNDRefHolder::alloc(val),
  557. "HostTensorND", ndarray_shared_from_tensor_py_capsule_dtor);
  558. mgb_assert(capsule, "failed to create PyCapsule");
  559. auto err = PyArray_SetBaseObject(
  560. reinterpret_cast<PyArrayObject*>(ret), capsule);
  561. mgb_assert(!err);
  562. }
  563. return ret;
  564. }
  565. HostTensorND npy::np2tensor(PyObject* obj, const Meth& meth, DType dtype) {
  566. auto ret_full = np2tensor_try_borrow(obj, meth, dtype);
  567. if (meth.must_borrow_) {
  568. mgb_assert(ret_full.second,
  569. "can not borrow from numpy array as contig array with dtype "
  570. "%s; src=%s",
  571. dtype.name(), repr_pyobj(obj).c_str());
  572. }
  573. return ret_full.first;
  574. }
  575. PyObject* npy::dtype_mgb2np(mgb::DType dtype) {
  576. PYTHON_GIL;
  577. // According to
  578. // https://docs.scipy.org/doc/numpy/reference/c-api.array.html#c.PyArray_TypeObjectFromType
  579. // the following is equivalent to PyArray_TypeObjectFromType for built-in
  580. // types.
  581. auto descr = dtype_mgb2np_descr(dtype);
  582. if (descr == nullptr) {
  583. return nullptr;
  584. }
  585. if (dtype.has_param()) {
  586. return reinterpret_cast<PyObject*>(descr.release());
  587. }
  588. PyObject* typeobj = reinterpret_cast<PyObject*>(descr->typeobj);
  589. Py_XINCREF(typeobj);
  590. return typeobj;
  591. }
  592. mgb::DType npy::dtype_np2mgb(PyObject *obj) {
  593. mgb_assert(obj && obj != Py_None,
  594. "can not convert null PyObject to numpy dtype");
  595. // see
  596. // http://stackoverflow.com/questions/8477122/numpy-c-api-convert-type-object-to-type-number
  597. PYTHON_GIL;
  598. PyArray_Descr* dtype;
  599. if(!PyArray_DescrConverter(obj, &dtype)) {
  600. throw ConversionError(ssprintf("can not convert to np.dtype from %s",
  601. repr_pyobj(obj).c_str()));
  602. }
  603. mgb::DType result = dtype_np2mgb_descr(dtype);
  604. Py_DECREF(dtype);
  605. return result;
  606. }
  607. PyObject* npy::to_mgb_supported_dtype(PyObject* dtype) {
  608. PYTHON_GIL;
  609. PyArray_Descr* descr;
  610. if (!PyArray_DescrConverter(dtype, &descr)) {
  611. throw ConversionError(ssprintf("can not convert to np.dtype from %s",
  612. repr_pyobj(dtype).c_str()));
  613. }
  614. mgb_assert(!descr->metadata,
  615. "unexpected metadata in dtype: "
  616. "dtype_obj=%s metadata=%s",
  617. repr_pyobj(dtype).c_str(), repr_pyobj(descr->metadata).c_str());
  618. int type_num = to_mgb_supported_dtype_raw(descr->type_num);
  619. return PyArray_TypeObjectFromType(type_num);
  620. }
  621. TensorShape npy::vec2shape(const std::vector<size_t> &vec) {
  622. TensorShape shape;
  623. mgb_assert(vec.size() <= TensorShape::MAX_NDIM,
  624. "dim too large: %zd (max %zd)",
  625. vec.size(), TensorShape::MAX_NDIM);
  626. shape.ndim = vec.size();
  627. for (size_t i = 0; i < vec.size(); i ++) {
  628. if (!vec[i]) {
  629. shape.ndim = 0;
  630. break;
  631. }
  632. shape[i] = vec[i];
  633. }
  634. mgb_assert(shape.ndim, "shape should not be empty");
  635. return shape;
  636. }

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台