You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

python_helper.cpp 32 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911
  1. /**
  2. * \file python_module/src/cpp/python_helper.cpp
  3. *
  4. * This file is part of MegBrain, a deep learning framework developed by Megvii.
  5. *
  6. * \brief helper utilities for python integration
  7. *
  8. * \copyright Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  9. *
  10. */
  11. #include "./python_helper.h"
  12. #include "megbrain/graph/exc_extra_info.h"
  13. #include "megbrain/graph/event.h"
  14. #include "megbrain/graph/cg.h"
  15. #include "megbrain/utils/mempool.h"
  16. #include "./numpy_incl.h"
  17. /*
  18. * demangle typeid, see
  19. * http://stackoverflow.com/questions/281818/unmangling-the-result-of-stdtype-infoname
  20. */
  21. #ifdef __GNUG__
  22. #include <cstdlib>
  23. #include <memory>
  24. #include <cxxabi.h>
  25. namespace {
  26. std::string demangle_typeid(const char* name) {
  27. int status = -4; // some arbitrary value to eliminate the compiler warning
  28. // enable c++11 by passing the flag -std=c++11 to g++
  29. std::unique_ptr<char, void(*)(void*)> res {
  30. abi::__cxa_demangle(name, nullptr, nullptr, &status),
  31. std::free
  32. };
  33. return (status==0) ? res.get() : name ;
  34. }
  35. }
  36. #else
  37. namespace {
  38. // does nothing if not g++
  39. std::string mgb::demangle_typeid(const char* name) {
  40. return name;
  41. }
  42. }
  43. #endif
  44. using namespace mgb;
  45. using namespace cg;
  46. PyStackExtracter* PyStackExtracter::ins = nullptr;
  47. namespace {
  48. std::string repr_pyobj(PyObject *obj) {
  49. if (!obj)
  50. return "<null PyObject>";
  51. PYTHON_GIL;
  52. auto str = PyObject_Repr(obj);
  53. if (!str)
  54. return ssprintf("<PyObject at %p (repr failed)>", obj);
  55. std::string ret{PyUnicode_AsUTF8(str)};
  56. Py_DECREF(str);
  57. return ret;
  58. }
  59. template<typename T>
  60. std::string typeid_name(const T &t) {
  61. return demangle_typeid(typeid(t).name());
  62. }
  63. } // anonymous namespace
  64. /* ============== OprPyTracker ============== */
  65. class OprPyTracker::TrackerStorage final : public UserDataContainer::UserData,
  66. public NonCopyableObj {
  67. MGB_TYPEINFO_OBJ_DECL;
  68. PyObject* m_cur_tracker = nullptr;
  69. size_t m_refcnt_to_add = 0;
  70. SyncEventConnecter::ReceiverHandler m_opr_insert_handler;
  71. ThinHashMap<OperatorNodeBase*, PyObject*> m_opr2tracker;
  72. public:
  73. explicit TrackerStorage(ComputingGraph& graph) {
  74. auto on_new_opr = [this](const event::OprInserted& ev) {
  75. if (!ev.is_dedup && !ev.exc) {
  76. if (m_cur_tracker) {
  77. ++m_refcnt_to_add;
  78. m_opr2tracker[ev.opr] = m_cur_tracker;
  79. }
  80. }
  81. };
  82. m_opr_insert_handler =
  83. graph.event().register_receiver<event::OprInserted>(on_new_opr);
  84. }
  85. ~TrackerStorage() {
  86. if (m_cur_tracker) {
  87. // manage refcnt of cur tracker
  88. disable();
  89. }
  90. PYTHON_GIL;
  91. for (auto&& i : m_opr2tracker) {
  92. Py_DecRef(i.second);
  93. }
  94. }
  95. //! get the instance
  96. static TrackerStorage& inst(ComputingGraph& graph) {
  97. auto make = [&graph]() {
  98. return std::make_shared<TrackerStorage>(graph);
  99. };
  100. return *graph.options()
  101. .user_data.get_user_data_or_create<TrackerStorage>(
  102. make);
  103. }
  104. //! get the tracker associated with an opr, or nullptr
  105. PyObject* get(OperatorNodeBase* opr) const {
  106. auto iter = m_opr2tracker.find(opr);
  107. return iter == m_opr2tracker.end() ? nullptr : iter->second;
  108. }
  109. void enable(PyObject* obj) {
  110. mgb_assert(!m_cur_tracker,
  111. "multiple calls to begin_set_tracker() on the same graph");
  112. m_cur_tracker = obj;
  113. }
  114. void disable() {
  115. mgb_assert(m_cur_tracker,
  116. "call end_set_tracker() before begin_set_tracker()");
  117. if (m_refcnt_to_add) {
  118. PYTHON_GIL;
  119. for (size_t i = 0; i < m_refcnt_to_add; ++i) {
  120. Py_IncRef(m_cur_tracker);
  121. }
  122. }
  123. m_cur_tracker = nullptr;
  124. }
  125. };
  126. MGB_TYPEINFO_OBJ_IMPL(OprPyTracker::TrackerStorage);
  127. void OprPyTracker::begin_set_tracker(ComputingGraph& graph, PyObject* obj) {
  128. TrackerStorage::inst(graph).enable(obj);
  129. }
  130. void OprPyTracker::end_set_tracker(ComputingGraph& graph) {
  131. TrackerStorage::inst(graph).disable();
  132. }
  133. OprPyTracker::TrackerResult OprPyTracker::get_tracker(mgb::MegBrainError& exc) {
  134. auto ptr = dynamic_cast<const OperatorNodeExcExtraInfo*>(exc.extra_info());
  135. if (!ptr)
  136. return {};
  137. return get_tracker(ptr->opr());
  138. }
  139. OprPyTracker::TrackerResult OprPyTracker::get_tracker(
  140. mgb::cg::OperatorNodeBase* opr) {
  141. TrackerResult ret;
  142. mgb_assert(opr);
  143. ret.exc_opr = opr;
  144. opr = cg::get_opr_root_source_opr(opr);
  145. ret.unopt_opr = opr;
  146. auto&& storage = TrackerStorage::inst(*opr->owner_graph());
  147. ret.tracker = storage.get(opr);
  148. {
  149. auto&& grad_info = opr->node_prop().attribute().grad_tracker;
  150. if (grad_info.valid()) {
  151. ret.opr_grad_src = cg::get_opr_root_source_opr(grad_info->orig_opr);
  152. ret.tracker_grad_src = storage.get(ret.opr_grad_src);
  153. }
  154. }
  155. return ret;
  156. }
  157. PyObject* OprPyTracker::TrackerResult::as_tuple(const char *leading_msg) const {
  158. std::string msg;
  159. if (leading_msg)
  160. msg = leading_msg;
  161. auto print_opr = [&](const char *otype, cg::OperatorNodeBase *opr) {
  162. if (!opr)
  163. return;
  164. msg += ssprintf("\n%s: id=%zu name=%s type=%s\n",
  165. otype, opr->id(), opr->cname(),
  166. typeid_name(*opr).c_str());
  167. msg += " input variables: \n";
  168. size_t idx = 0;
  169. for (auto i: opr->input()) {
  170. msg += ssprintf(" %zu: ", idx ++);
  171. msg += cg::dump_var_info({i});
  172. msg += "\n";
  173. }
  174. msg += " output variables: \n";
  175. idx = 0;
  176. for (auto i: opr->output()) {
  177. msg += ssprintf(" %zu: ", idx ++);
  178. msg += cg::dump_var_info({i});
  179. msg += "\n";
  180. }
  181. };
  182. print_opr("Associated operator", exc_opr);
  183. if (unopt_opr != exc_opr) {
  184. print_opr("Unoptimized equivalent of associated operator", unopt_opr);
  185. }
  186. print_opr("Associated operator created by taking grad of", opr_grad_src);
  187. PYTHON_GIL;
  188. PyObject *py_msg = PyUnicode_FromString(msg.c_str()),
  189. *py_tuple = PyTuple_Pack(3, py_msg,
  190. tracker ? tracker : Py_None,
  191. tracker_grad_src ? tracker_grad_src : Py_None);
  192. Py_DECREF(py_msg);
  193. return py_tuple;
  194. }
  195. std::string blame(mgb::cg::OperatorNodeBase* opr) {
  196. mgb_assert(PyMGBExceptionMaker::py_exc_class,
  197. "Python exception class is not set yet");
  198. PyObject* args = OprPyTracker::get_tracker(opr).as_tuple();
  199. PYTHON_GIL;
  200. PyObject* py_exc = PyObject_CallObject(PyMGBExceptionMaker::py_exc_class, args);
  201. Py_DECREF(args);
  202. mgb_assert(py_exc);
  203. PyObject* py_str = PyObject_Str(py_exc);
  204. Py_DECREF(py_exc);
  205. mgb_assert(py_str);
  206. int err = PyUnicode_READY(py_str);
  207. if (err) {
  208. Py_DECREF(py_str);
  209. mgb_assert(!err);
  210. }
  211. Py_ssize_t c_str_size;
  212. const char* c_str = PyUnicode_AsUTF8AndSize(py_str, &c_str_size);
  213. if (!c_str) {
  214. Py_DECREF(py_str);
  215. mgb_assert(c_str);
  216. }
  217. std::string ret(c_str, c_str_size);
  218. Py_DECREF(py_str);
  219. return ret;
  220. }
  221. /* ============== PyMGBExceptionMaker ============== */
  222. PyObject *PyMGBExceptionMaker::py_exc_class = nullptr;
  223. void PyMGBExceptionMaker::setup_py_exception(std::exception &exc) {
  224. mgb_assert(py_exc_class);
  225. if (auto cbexc = dynamic_cast<PyExceptionForward*>(&exc)) {
  226. cbexc->restore();
  227. return;
  228. }
  229. std::string msg;
  230. try {
  231. msg = ssprintf("MegBrain core throws exception: %s\n%s",
  232. typeid_name(exc).c_str(), exc.what());
  233. auto mgbexc = dynamic_cast<MegBrainError*>(&exc);
  234. OprPyTracker::TrackerResult tracker;
  235. if (mgbexc) {
  236. tracker = OprPyTracker::get_tracker(*mgbexc);
  237. }
  238. PYTHON_GIL;
  239. PyObject *py_exc_arg = tracker.as_tuple(msg.c_str());
  240. PyErr_SetObject(py_exc_class, py_exc_arg);
  241. Py_DECREF(py_exc_arg);
  242. } catch (std::exception &newexc) {
  243. auto newmsg = ssprintf(
  244. "caught exception during handling exception: %s\n%s\n"
  245. "original message: %s",
  246. typeid_name(newexc).c_str(), newexc.what(),
  247. msg.c_str());
  248. PyErr_SetString(PyExc_RuntimeError, newmsg.c_str());
  249. } catch (...) {
  250. auto newmsg = ssprintf(
  251. "caught unknown exception during handling exception\n"
  252. "original message: %s", msg.c_str());
  253. PyErr_SetString(PyExc_RuntimeError, newmsg.c_str());
  254. }
  255. }
  256. /* ============== PyExceptionForward ============== */
  257. PyExceptionForward::~PyExceptionForward() {
  258. PYTHON_GIL;
  259. PyObjRefKeeper::deleter(m_type);
  260. PyObjRefKeeper::deleter(m_value);
  261. PyObjRefKeeper::deleter(m_traceback);
  262. }
  263. void PyExceptionForward::restore() {
  264. PyErr_Restore(m_type, m_value, m_traceback);
  265. m_type = m_value = m_traceback = nullptr;
  266. }
  267. void PyExceptionForward::throw_() {
  268. PyObject *etype, *obj, *trace;
  269. PyErr_Fetch(&etype, &obj, &trace);
  270. PyErr_NormalizeException(&etype, &obj, &trace);
  271. std::string msg{"python exception"};
  272. bool succ = false;
  273. if (etype && obj && trace) {
  274. auto run = [&]() {
  275. #define DEF(name, expr) \
  276. PyObjRefKeeper name{expr}; \
  277. if (!name.get()) \
  278. return
  279. DEF(mod, PyImport_ImportModule("traceback"));
  280. DEF(result, PyObject_CallMethod(mod.get(), "format_exception",
  281. "(OOO)", etype, obj, trace));
  282. if (!PyList_Check(result.get()))
  283. return;
  284. auto size = PyList_Size(result.get());
  285. msg.append(":\n");
  286. for (Py_ssize_t i = 0; i < size; ++i) {
  287. msg.append(" ");
  288. msg.append(PyUnicode_AsUTF8(PyList_GetItem(result.get(), i)));
  289. }
  290. msg.pop_back(); // remove last \n
  291. succ = true;
  292. #undef DEF
  293. };
  294. run();
  295. }
  296. if (!succ) {
  297. PyObject* obj_str_py;
  298. if (obj && (obj_str_py = PyObject_Repr(obj))) {
  299. msg.append(" with message ");
  300. msg.append(PyUnicode_AsUTF8(obj_str_py));
  301. Py_DECREF(obj_str_py);
  302. } else {
  303. msg.append(" with unknown message");
  304. }
  305. }
  306. // throwing exception may cause abort due to unknown reasons; so we first
  307. // log the message
  308. mgb_log_error("caught exception from python callback: %s", msg.c_str());
  309. fflush(stdout);
  310. fflush(stderr);
  311. throw PyExceptionForward{etype, obj, trace, msg};
  312. }
  313. /* ============== namespace npy ============== */
  314. namespace {
  315. int to_mgb_supported_dtype_raw(int dtype) {
  316. if (dtype == NPY_INT64)
  317. return NPY_INT32;
  318. if (dtype == NPY_FLOAT64)
  319. return NPY_FLOAT32;
  320. return dtype;
  321. }
  322. #define FOREACH_NPY_DTYPE_PAIR(cb) \
  323. cb(Uint8, NPY_UINT8) \
  324. cb(Int8, NPY_INT8) \
  325. cb(Int16, NPY_INT16) \
  326. cb(Int32, NPY_INT32) \
  327. cb(Float16, NPY_FLOAT16) \
  328. cb(Float32, NPY_FLOAT32)
  329. #define FOREACH_NPY_MGB_DTYPE_PAIR(cb) \
  330. FOREACH_NPY_DTYPE_PAIR(cb) \
  331. FOREACH_MGB_DTYPE_PAIR(cb)
  332. //! convert megbrain dtype to numpy dtype
  333. int dtype_mgb2np_raw(DType dtype) {
  334. mgb_assert(dtype.valid(), "attempt to convert from invalid dtype");
  335. switch (dtype.enumv()) {
  336. #define cb(_m, _n) \
  337. case DTypeEnum::_m: \
  338. return _n;
  339. FOREACH_NPY_MGB_DTYPE_PAIR(cb)
  340. #undef cb
  341. default:
  342. break;
  343. }
  344. throw ConversionError(ssprintf(
  345. "can not convert dtype %s to numpy dtype", dtype.name()));
  346. }
  347. struct PyArrayDescrDeleter {
  348. void operator()(PyArray_Descr* obj) {
  349. Py_XDECREF(obj);
  350. }
  351. };
  352. //! Convert MegBrain DType to NumPy DType descriptor, the caller receives a new
  353. //! reference to the descriptor.
  354. std::unique_ptr<PyArray_Descr, PyArrayDescrDeleter> dtype_mgb2np_descr(
  355. DType dtype) {
  356. PYTHON_GIL;
  357. mgb_assert(dtype.valid(), "attempt to convert from invalid dtype");
  358. auto build_mgb_dtype_dict =
  359. [](const char* name,
  360. const std::vector<std::pair<const char*, PyObject*>>& data) {
  361. PyObject* metadata = PyDict_New();
  362. PyObject* mgb_dtype_metadata = PyDict_New();
  363. PyDict_SetItemString(mgb_dtype_metadata, "name",
  364. PyUnicode_FromString(name));
  365. for (const auto& d : data) {
  366. PyDict_SetItemString(mgb_dtype_metadata, d.first, d.second);
  367. }
  368. PyDict_SetItemString(metadata, "mgb_dtype", mgb_dtype_metadata);
  369. return metadata;
  370. };
  371. if (dtype.has_param()) {
  372. PyArray_Descr* type_descr;
  373. switch (dtype.enumv()) {
  374. case DTypeEnum::Quantized8Asymm: {
  375. auto& param = dtype.param<dtype::Quantized8Asymm>();
  376. type_descr = PyArray_DescrNewFromType(NPY_UINT8);
  377. type_descr->metadata = build_mgb_dtype_dict(
  378. DTypeTrait<dtype::Quantized8Asymm>::name,
  379. {{"scale", PyFloat_FromDouble(param.scale)},
  380. {"zero_point", PyLong_FromLong(param.zero_point)}});
  381. break;
  382. }
  383. case DTypeEnum::QuantizedS8: {
  384. auto& param = dtype.param<dtype::QuantizedS8>();
  385. type_descr = PyArray_DescrNewFromType(NPY_INT8);
  386. type_descr->metadata = build_mgb_dtype_dict(
  387. DTypeTrait<dtype::QuantizedS8>::name,
  388. {{"scale", PyFloat_FromDouble(param.scale)}});
  389. break;
  390. }
  391. case DTypeEnum::Quantized4Asymm: {
  392. auto& param = dtype.param<dtype::Quantized4Asymm>();
  393. type_descr = PyArray_DescrNewFromType(NPY_UINT8);
  394. type_descr->metadata = build_mgb_dtype_dict(
  395. DTypeTrait<dtype::Quantized4Asymm>::name,
  396. {{"scale", PyFloat_FromDouble(param.scale)},
  397. {"zero_point", PyLong_FromLong(param.zero_point)}});
  398. break;
  399. }
  400. case DTypeEnum::QuantizedS4: {
  401. auto& param = dtype.param<dtype::QuantizedS4>();
  402. type_descr = PyArray_DescrNewFromType(NPY_INT8);
  403. type_descr->metadata = build_mgb_dtype_dict(
  404. DTypeTrait<dtype::QuantizedS4>::name,
  405. {{"scale", PyFloat_FromDouble(param.scale)}});
  406. break;
  407. }
  408. case DTypeEnum::QuantizedS32: {
  409. auto& param = dtype.param<dtype::QuantizedS32>();
  410. type_descr = PyArray_DescrNewFromType(NPY_INT32);
  411. type_descr->metadata = build_mgb_dtype_dict(
  412. DTypeTrait<dtype::QuantizedS32>::name,
  413. {{"scale", PyFloat_FromDouble(param.scale)}});
  414. break;
  415. }
  416. default:
  417. mgb_throw(ConversionError, "unhandled parameterized DType %s",
  418. dtype.name());
  419. }
  420. return std::unique_ptr<PyArray_Descr, PyArrayDescrDeleter>(type_descr);
  421. }
  422. PyArray_Descr* basic_descr = PyArray_DescrFromType(dtype_mgb2np_raw(dtype));
  423. mgb_assert(basic_descr != nullptr,
  424. "failed to convert expected dtype to numpy type descriptor");
  425. return std::unique_ptr<PyArray_Descr, PyArrayDescrDeleter>(basic_descr);
  426. }
  427. DType dtype_np2mgb_raw(int npt) {
  428. switch (npt) {
  429. #define cb(_m, _n) \
  430. case _n: \
  431. return dtype::_m();
  432. FOREACH_NPY_DTYPE_PAIR(cb)
  433. #undef cb
  434. }
  435. #define cb(_m, _n) \
  436. if (_n == npt) return dtype::_m();
  437. FOREACH_MGB_DTYPE_PAIR(cb)
  438. #undef cb
  439. PYTHON_GIL;
  440. std::string msg;
  441. auto py_obj = PyArray_TypeObjectFromType(npt);
  442. if (!py_obj) {
  443. msg = ssprintf("unknown numpy dtype enum %d", npt);
  444. } else {
  445. msg = ssprintf("unsupported numpy dtype %s",
  446. repr_pyobj(py_obj).c_str());
  447. }
  448. Py_DECREF(py_obj);
  449. throw ConversionError(msg);
  450. }
  451. DType dtype_np2mgb_descr(PyArray_Descr* descr) {
  452. PYTHON_GIL;
  453. auto handle_parameterized_dtype = [](PyObject* metadata) -> DType {
  454. mgb_assert(PyDict_Check(metadata),
  455. "Invalid parameterized DType metadata: should be a dict");
  456. PyObject* dtype_name_py = PyDict_GetItemString(metadata, "name");
  457. mgb_assert(
  458. PyUnicode_Check(dtype_name_py),
  459. "Invalid parameterized DType metadata: name should be a str");
  460. std::string dtype_name(PyUnicode_AsUTF8(dtype_name_py));
  461. if (dtype_name == "Quantized8Asymm") {
  462. PyObject* scale_py = PyDict_GetItemString(metadata, "scale");
  463. PyObject* zero_point_py =
  464. PyDict_GetItemString(metadata, "zero_point");
  465. mgb_assert(scale_py && zero_point_py,
  466. "Invalid Quantized8Asymm metadata: missing scale or "
  467. "zero_point.");
  468. mgb_assert(
  469. PyFloat_Check(scale_py),
  470. "Invalid Quantized8Asymm metadata: scale should be float");
  471. mgb_assert(PyLong_Check(zero_point_py),
  472. "Invalid Quantized8Asymm metadata: zero_point should be "
  473. "integer");
  474. auto zero_point = PyLong_AS_LONG(zero_point_py);
  475. mgb_assert(zero_point >= 0 && zero_point < 256,
  476. "Invalid Quantized8Asymm metadata: zero_point should be "
  477. "in [0, 256)");
  478. return dtype::Quantized8Asymm(
  479. static_cast<float>(PyFloat_AS_DOUBLE(scale_py)),
  480. static_cast<uint8_t>(zero_point));
  481. }
  482. if (dtype_name == "Quantized4Asymm") {
  483. PyObject* scale_py = PyDict_GetItemString(metadata, "scale");
  484. PyObject* zero_point_py =
  485. PyDict_GetItemString(metadata, "zero_point");
  486. mgb_assert(scale_py && zero_point_py,
  487. "Invalid Quantized4Asymm metadata: missing scale or "
  488. "zero_point.");
  489. mgb_assert(
  490. PyFloat_Check(scale_py),
  491. "Invalid Quantized4Asymm metadata: scale should be float");
  492. mgb_assert(PyLong_Check(zero_point_py),
  493. "Invalid Quantized4Asymm metadata: zero_point should be "
  494. "integer");
  495. auto zero_point = PyLong_AS_LONG(zero_point_py);
  496. mgb_assert(zero_point >= 0 && zero_point < 15,
  497. "Invalid Quantized4Asymm metadata: zero_point should be "
  498. "in [0, 15)");
  499. return dtype::Quantized4Asymm(
  500. static_cast<float>(PyFloat_AS_DOUBLE(scale_py)),
  501. static_cast<uint8_t>(zero_point));
  502. }
  503. if (dtype_name == "QuantizedS32" || dtype_name == "QuantizedS8" ||
  504. dtype_name == "QuantizedS4") {
  505. PyObject* scale_py = PyDict_GetItemString(metadata, "scale");
  506. mgb_assert(scale_py, "Invalid metadata: missing scale");
  507. mgb_assert(PyFloat_Check(scale_py),
  508. "Invalid metadata: scale should be float");
  509. float scale = static_cast<float>(PyFloat_AS_DOUBLE(scale_py));
  510. if (dtype_name == "QuantizedS32") {
  511. return dtype::QuantizedS32(scale);
  512. } else if (dtype_name == "QuantizedS8"){
  513. return dtype::QuantizedS8(scale);
  514. } else {
  515. return dtype::QuantizedS4(scale);
  516. }
  517. }
  518. throw ConversionError(
  519. ssprintf("Unknown parameterized DType: %s", dtype_name.c_str())
  520. .c_str());
  521. };
  522. PyObject* dtype_metadata;
  523. if (descr->metadata && PyDict_Check(descr->metadata) &&
  524. (dtype_metadata = PyDict_GetItemString(descr->metadata, "mgb_dtype"))) {
  525. return handle_parameterized_dtype(dtype_metadata);
  526. }
  527. return dtype_np2mgb_raw(descr->type_num);
  528. }
  529. HostTensorND lowbit_ndarray_to_host_tensor(
  530. CompNode comp_node, TensorLayout &layout, PyArrayObject *input) {
  531. auto src_ptr = reinterpret_cast<dt_byte*>(PyArray_DATA(input));
  532. if (!layout.ndim) {
  533. // numpy scalar
  534. mgb_assert(src_ptr, "can not convert from null numpy array");
  535. layout.init_contiguous_stride({1});
  536. } else {
  537. mgb_assert(layout.ndim && layout.ndim <= TensorShape::MAX_NDIM,
  538. "unsupported ndim %zu", layout.ndim);
  539. for (size_t i = 0; i < layout.ndim; ++ i) {
  540. layout.shape[i] = PyArray_SHAPE(input)[i];
  541. layout.stride[i] = PyArray_STRIDE(input, i);
  542. mgb_assert(layout.shape[i], "zero shape not supported");
  543. }
  544. mgb_assert(layout.is_contiguous());
  545. }
  546. HostTensorND ret{comp_node, layout};
  547. lowbit_memcpy_byte2compact(layout.dtype, ret.raw_ptr(), src_ptr,
  548. layout.total_nr_elems());
  549. return ret;
  550. }
  551. /*!
  552. * \brief convert a python object to tensor and try to borrow memory if the
  553. * original object is a contiguous numpy array
  554. * \param dtype see np2tensor
  555. * \return the megbrain tensor, and whether memory is borrowed
  556. */
  557. std::pair<HostTensorND, bool> np2tensor_try_borrow(
  558. PyObject *obj, CompNode dest_cn, DType dtype) {
  559. mgb_assert(dest_cn.valid());
  560. PYTHON_GIL;
  561. PyArray_Descr* expected_descr = nullptr;
  562. if (dtype.valid()) {
  563. // The reference to expected_descr will be stealed later.
  564. expected_descr = dtype_mgb2np_descr(dtype).release();
  565. }
  566. // make result from PyArrayObject; its reference would be stolen
  567. auto make_from_arr = [&](PyArrayObject *input, bool is_borrow) {
  568. PyObjRefKeeper ref_obj_cvt{reinterpret_cast<PyObject*>(input)};
  569. TensorLayout layout;
  570. layout.dtype = dtype_np2mgb_descr(PyArray_DESCR(input));
  571. if (dtype.valid())
  572. mgb_assert(dtype == layout.dtype);
  573. layout.ndim = PyArray_NDIM(input);
  574. if (layout.dtype.is_low_bit()) {
  575. auto ret = lowbit_ndarray_to_host_tensor(dest_cn, layout, input);
  576. // decref(input) would be handled by ref_obj_cvt
  577. return std::make_pair(ret, false);
  578. }
  579. auto data = reinterpret_cast<dt_byte*>(PyArray_DATA(input));
  580. if (!layout.ndim) {
  581. // numpy scalar
  582. mgb_assert(data, "can not convert from null numpy array");
  583. layout.init_contiguous_stride({1});
  584. } else {
  585. mgb_assert(layout.ndim && layout.ndim <= TensorShape::MAX_NDIM,
  586. "unsupported ndim %zu", layout.ndim);
  587. auto dsize = layout.dtype.size();
  588. bool is_empty = false;
  589. for (size_t i = 0; i < layout.ndim; ++ i) {
  590. layout.shape[i] = PyArray_SHAPE(input)[i];
  591. layout.stride[i] = PyArray_STRIDE(input, i);
  592. if (!layout.shape[i]) {
  593. is_empty = true;
  594. }
  595. mgb_assert(layout.stride[i] % dsize == 0,
  596. "bad stride %zd", layout.stride[i]);
  597. layout.stride[i] /= dsize;
  598. }
  599. mgb_assert(is_empty || layout.is_contiguous());
  600. }
  601. HostTensorStorage storage;
  602. auto input_ptr = ref_obj_cvt.make_shared(data);
  603. storage.reset(dest_cn, layout.span().high_byte, input_ptr);
  604. HostTensorND ret;
  605. ret.reset(storage, layout);
  606. return std::make_pair(ret, is_borrow);
  607. };
  608. PyArrayObject *obj_as_arr = nullptr;
  609. do {
  610. // check contiguous and dtype, and borrow mem if ok
  611. if (!PyArray_Check(obj))
  612. break;
  613. obj_as_arr = reinterpret_cast<PyArrayObject*>(obj);
  614. int typenum = PyArray_DTYPE(obj_as_arr)->type_num;
  615. // We have to check dtype.valid() and typenum first to avoid
  616. // accidentally trigger ConversionError on incompatible dtypes which can
  617. // be automatically converted into comptaible ones (e.g. float64).
  618. if (dtype.valid() &&
  619. (expected_descr->type_num != typenum ||
  620. dtype_np2mgb_descr(PyArray_DTYPE(obj_as_arr)) != dtype))
  621. break;
  622. if (typenum != to_mgb_supported_dtype_raw(typenum)) {
  623. mgb_assert(!dtype.valid() && expected_descr == nullptr);
  624. expected_descr =
  625. PyArray_DescrFromType(to_mgb_supported_dtype_raw(typenum));
  626. break;
  627. }
  628. if (PyArray_ISCARRAY_RO(obj_as_arr)) {
  629. Py_INCREF(obj_as_arr);
  630. return make_from_arr(obj_as_arr, true);
  631. }
  632. } while(0);
  633. constexpr auto NP_FLAGS = NPY_ARRAY_C_CONTIGUOUS | NPY_ARRAY_FORCECAST;
  634. PyObject *obj_cvt;
  635. if (obj_as_arr) {
  636. obj_cvt = PyArray_FromArray(obj_as_arr, expected_descr, NP_FLAGS);
  637. } else {
  638. obj_cvt = PyArray_FromAny(obj, expected_descr, 0, 0, NP_FLAGS, nullptr);
  639. }
  640. if (obj_cvt) {
  641. // convert to mgb supported dtype
  642. auto arr = reinterpret_cast<PyArrayObject*>(obj_cvt);
  643. int dt0 = PyArray_TYPE(arr), dt1 = to_mgb_supported_dtype_raw(dt0);
  644. if (dt0 != dt1) {
  645. mgb_assert(expected_descr == nullptr);
  646. expected_descr = PyArray_DescrFromType(dt1);
  647. mgb_assert(expected_descr);
  648. auto obj_cvt_new = PyArray_FromAny(
  649. obj_cvt, expected_descr, 0, 0, NP_FLAGS, nullptr);
  650. Py_DECREF(obj_cvt);
  651. obj_cvt = obj_cvt_new;
  652. }
  653. }
  654. if (!obj_cvt) {
  655. if (PyErr_Occurred()) {
  656. PyExceptionForward::throw_();
  657. }
  658. throw ConversionError(ssprintf("can not convert to numpy array from %s",
  659. repr_pyobj(obj).c_str()));
  660. }
  661. return make_from_arr(reinterpret_cast<PyArrayObject*>(obj_cvt), false);
  662. }
  663. //! hold a reference to HostTensorND
  664. class HostTensorNDRefHolder final: public NonCopyableObj {
  665. HostTensorND m_val;
  666. static MemPool<HostTensorNDRefHolder> sm_mem_pool;
  667. friend class MemPool<HostTensorNDRefHolder>;
  668. HostTensorNDRefHolder(const HostTensorND &v):
  669. m_val{v}
  670. {
  671. }
  672. public:
  673. static HostTensorNDRefHolder* alloc(const HostTensorND &v) {
  674. return sm_mem_pool.alloc(v);
  675. }
  676. static void free(HostTensorNDRefHolder *p) {
  677. return sm_mem_pool.free(p);
  678. }
  679. };
  680. MemPool<HostTensorNDRefHolder> HostTensorNDRefHolder::sm_mem_pool;
  681. void ndarray_shared_from_tensor_py_capsule_dtor(PyObject *cap) {
  682. auto ptr = PyCapsule_GetPointer(cap, "HostTensorND");
  683. mgb_assert(ptr, "not a PyCapsule: %s", repr_pyobj(cap).c_str());
  684. HostTensorNDRefHolder::free(static_cast<HostTensorNDRefHolder*>(ptr));
  685. }
  686. } // anonymous namespace
  687. PyObject* npy::ndarray_from_tensor(
  688. const HostTensorND &val, ShareType share_type) {
  689. if (!val.layout().is_contiguous() && !val.shape().is_empty()) {
  690. mgb_assert(share_type != ShareType::MUST_SHARE);
  691. HostTensorND contig;
  692. contig.copy_from(val);
  693. return ndarray_from_tensor(contig, ShareType::TRY_SHARE);
  694. }
  695. PYTHON_GIL;
  696. npy_intp dims[TensorLayout::MAX_NDIM];
  697. for (size_t i = 0; i < val.layout().ndim; ++ i)
  698. dims[i] = val.shape()[i];
  699. PyObject* ret = nullptr;
  700. auto alloc_new_ret = [&]() {
  701. mgb_assert(!ret);
  702. ret = PyArray_NewFromDescr(
  703. &PyArray_Type, dtype_mgb2np_descr(val.dtype()).release(),
  704. val.layout().ndim, dims, nullptr, nullptr, 0, nullptr);
  705. mgb_assert(ret, "failed to allocate array");
  706. mgb_assert(PyArray_Check(ret));
  707. return PyArray_DATA(reinterpret_cast<PyArrayObject*>(ret));
  708. };
  709. if (val.dtype().is_low_bit()) {
  710. mgb_assert(share_type != ShareType::MUST_SHARE,
  711. "can not share memory for lowbit dtype");
  712. lowbit_memcpy_compact2byte(val.dtype(), alloc_new_ret(), val.raw_ptr(),
  713. val.layout().total_nr_elems());
  714. } else if (share_type == ShareType::MUST_UNSHARE) {
  715. memcpy(alloc_new_ret(), val.raw_ptr(), val.layout().span().dist_byte());
  716. } else {
  717. // share data
  718. ret = PyArray_NewFromDescr(
  719. &PyArray_Type, dtype_mgb2np_descr(val.dtype()).release(),
  720. val.layout().ndim, dims, nullptr,
  721. const_cast<dt_byte*>(val.raw_ptr()), 0, nullptr);
  722. mgb_assert(ret, "failed to alloc ndarray");
  723. auto capsule = PyCapsule_New(HostTensorNDRefHolder::alloc(val),
  724. "HostTensorND", ndarray_shared_from_tensor_py_capsule_dtor);
  725. mgb_assert(capsule, "failed to create PyCapsule");
  726. auto err = PyArray_SetBaseObject(
  727. reinterpret_cast<PyArrayObject*>(ret), capsule);
  728. mgb_assert(!err);
  729. }
  730. return ret;
  731. }
  732. HostTensorND npy::np2tensor(PyObject* obj, const Meth& meth, DType dtype) {
  733. auto ret_full = np2tensor_try_borrow(obj, meth.dest_cn_, dtype);
  734. if (meth.dest_tensor_) {
  735. meth.dest_tensor_->copy_from(ret_full.first);
  736. return *meth.dest_tensor_;
  737. }
  738. if (meth.must_borrow_) {
  739. mgb_assert(ret_full.second,
  740. "can not borrow from numpy array as contig array with dtype "
  741. "%s; src=%s",
  742. dtype.name(), repr_pyobj(obj).c_str());
  743. }
  744. return ret_full.first;
  745. }
  746. PyObject* npy::dtype_mgb2np(mgb::DType dtype) {
  747. PYTHON_GIL;
  748. // According to
  749. // https://docs.scipy.org/doc/numpy/reference/c-api.array.html#c.PyArray_TypeObjectFromType
  750. // the following is equivalent to PyArray_TypeObjectFromType for built-in
  751. // types.
  752. auto descr = dtype_mgb2np_descr(dtype);
  753. if (descr == nullptr) {
  754. return nullptr;
  755. }
  756. if (dtype.has_param()) {
  757. return reinterpret_cast<PyObject*>(descr.release());
  758. }
  759. PyObject* typeobj = reinterpret_cast<PyObject*>(descr->typeobj);
  760. Py_XINCREF(typeobj);
  761. return typeobj;
  762. }
  763. mgb::DType npy::dtype_np2mgb(PyObject *obj) {
  764. mgb_assert(obj && obj != Py_None,
  765. "can not convert null PyObject to numpy dtype");
  766. // see
  767. // http://stackoverflow.com/questions/8477122/numpy-c-api-convert-type-object-to-type-number
  768. PYTHON_GIL;
  769. PyArray_Descr* dtype;
  770. if(!PyArray_DescrConverter(obj, &dtype)) {
  771. throw ConversionError(ssprintf("can not convert to np.dtype from %s",
  772. repr_pyobj(obj).c_str()));
  773. }
  774. mgb::DType result = dtype_np2mgb_descr(dtype);
  775. Py_DECREF(dtype);
  776. return result;
  777. }
  778. PyObject* npy::to_mgb_supported_dtype(PyObject* dtype) {
  779. PYTHON_GIL;
  780. PyArray_Descr* descr;
  781. if (!PyArray_DescrConverter(dtype, &descr)) {
  782. throw ConversionError(ssprintf("can not convert to np.dtype from %s",
  783. repr_pyobj(dtype).c_str()));
  784. }
  785. mgb_assert(!descr->metadata,
  786. "unexpected metadata in dtype: "
  787. "dtype_obj=%s metadata=%s",
  788. repr_pyobj(dtype).c_str(), repr_pyobj(descr->metadata).c_str());
  789. int type_num = to_mgb_supported_dtype_raw(descr->type_num);
  790. return PyArray_TypeObjectFromType(type_num);
  791. }
  792. TensorShape npy::vec2shape(const std::vector<size_t> &vec) {
  793. TensorShape shape;
  794. mgb_assert(vec.size() <= TensorShape::MAX_NDIM,
  795. "dim too large: %zd (max %zd)",
  796. vec.size(), TensorShape::MAX_NDIM);
  797. shape.ndim = vec.size();
  798. for (size_t i = 0; i < vec.size(); i ++) {
  799. if (!vec[i]) {
  800. shape.ndim = 0;
  801. break;
  802. }
  803. shape[i] = vec[i];
  804. }
  805. mgb_assert(shape.ndim, "shape should not be empty");
  806. return shape;
  807. }
  808. void mgb_init_numpy() {
  809. import_array1( );
  810. }
  811. // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台