You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

tensor.cpp 63 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659
  1. /**
  2. * \file imperative/python/src/tensor.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "megbrain/common.h"
  12. #include "megbrain/dtype.h"
  13. #include "megbrain/imperative/ops/autogen.h"
  14. #include "megbrain/imperative/ops/backward_graph.h"
  15. #include "megbrain/imperative/ops/utility.h"
  16. #include "megbrain/imperative/profiler.h"
  17. #include "megbrain/imperative/transformations/eval.h"
  18. #include "megbrain/imperative/transformations/lazy.h"
  19. #include "megbrain/imperative/transformations/scalar.h"
  20. #include "megbrain/imperative/transformations/symbol.h"
  21. #include "megbrain/imperative/transformations/trace.h"
  22. #include "megbrain/imperative/utils/map.h"
  23. #include "megbrain/opr/io.h"
  24. #include "megbrain/plugin/profiler.h"
  25. #include "./common.h"
  26. #include "./grad.h"
  27. #include "./graph_rt.h"
  28. #include "./helper.h"
  29. #include "./module_trace.h"
  30. #include "./numpy_dtypes.h"
  31. #include "./tensor.h"
  32. #include "./transformation.h"
  33. #include <object.h>
  34. #include <pybind11/numpy.h>
  35. #include <pybind11/operators.h>
  36. #include <pybind11/pytypes.h>
  37. #include <pyerrors.h>
  38. #include <range/v3/all.hpp>
  39. #include <string>
  40. #include <unordered_map>
  41. #include "../../src/impl/mgb_cg_impl.h"
  42. namespace py = pybind11;
  43. namespace views = ranges::views;
  44. namespace mgb::imperative::python {
  45. namespace {
  46. WeakKeyMap<ValueWeakRef, py::object> module_trace_info_map;
  47. }
  48. interpreter::Interpreter::Channel* interpreter_for_py = nullptr;
  49. PyTypeObject* py_tensor_type = nullptr;
  50. PyObject *cpp_use_symbolic_shape, *cpp_astensor1d;
  51. #define REGISTE_APPLY_FUNC(mode) \
  52. void set_##mode(py::object pyf) { mode = pyf.ptr(); }
  53. REGISTE_APPLY_FUNC(cpp_use_symbolic_shape)
  54. REGISTE_APPLY_FUNC(cpp_astensor1d)
  55. #undef REGISTE_APPLY_FUNC
  56. PyObject* py_apply(
  57. PyObject* self, PyObject* const* args, size_t nargs /* , PyObject* kwnames */) {
  58. try {
  59. // if (kwnames && PyTuple_GET_SIZE(kwnames)) {
  60. // PyErr_SetString(PyExc_TypeError, "keyword argument not allowed");
  61. // return nullptr;
  62. // }
  63. if (nargs < 2) {
  64. PyErr_SetString(
  65. PyExc_TypeError,
  66. "py_apply expects one Op and at least one tensor "
  67. "as argument");
  68. return nullptr;
  69. }
  70. auto* py_op = args[0];
  71. ++args;
  72. --nargs;
  73. auto op = py::handle(py_op).cast<std::shared_ptr<OpDef>>();
  74. SmallVector<ValueRef, 64> tensors(nargs);
  75. if (py::isinstance<PySymbolVar>(py::handle(args[0]))) {
  76. // swap to a special context to reuse scalar handle
  77. TransformationContext symbol_var_context;
  78. Transformation::swap_context(symbol_var_context);
  79. CleanupGuard _{[&] { Transformation::swap_context(symbol_var_context); }};
  80. auto* graph =
  81. py::handle(args[0]).cast<PySymbolVar*>()->m_node->owner_graph();
  82. std::make_shared<SymbolTransformation>(graph)->register_at(
  83. Transformation::top());
  84. std::make_shared<ScalarTransformation>()->register_at(
  85. Transformation::top());
  86. SmallVector<ValueRef> inputs(nargs);
  87. for (size_t i = 0; i < nargs; ++i) {
  88. auto* py_input = py::handle(args[i]).cast<PySymbolVar*>();
  89. ValueRef input = SymbolValue::make(py_input->m_node);
  90. if (py_input->is_scalar) {
  91. input = ScalarValue::make(input);
  92. }
  93. inputs[i] = input;
  94. }
  95. auto outputs = imperative::apply(*op, inputs);
  96. auto ret = pybind11::tuple(outputs.size());
  97. auto typeobj = py::handle(args[0]).get_type();
  98. for (size_t i = 0; i < outputs.size(); ++i) {
  99. bool is_scalar = false;
  100. if (auto* scalar_value = outputs[i].as<ScalarValue>()) {
  101. outputs[i] = scalar_value->value();
  102. is_scalar = true;
  103. }
  104. auto* node = outputs[i].cast<SymbolValue>().node();
  105. ret[i] = typeobj(
  106. pybind11::cast(node, pybind11::return_value_policy::automatic));
  107. py::handle(ret[i]).cast<PySymbolVar*>()->is_scalar = is_scalar;
  108. }
  109. return ret.release().ptr();
  110. }
  111. for (size_t i = 0; i < nargs; ++i) {
  112. if (TensorWrapper* tw = TensorWrapper::try_cast(args[i])) {
  113. tensors[i] = tw->m_tensor->data();
  114. } else {
  115. PyErr_SetString(
  116. PyExc_TypeError,
  117. ssprintf(
  118. "op %s expect type Tensor as inputs, got %s actually",
  119. op->make_name().c_str(), Py_TYPE(args[i])->tp_name)
  120. .c_str());
  121. return nullptr;
  122. }
  123. }
  124. auto outputs = imperative::apply(ApplyOp(*op), {tensors.data(), nargs});
  125. size_t nout = outputs.size();
  126. auto ret = py::tuple(nout);
  127. for (size_t i = 0; i < nout; ++i) {
  128. ret[i] = TensorWrapper::make(py_tensor_type, std::move(outputs[i]));
  129. }
  130. return ret.release().ptr();
  131. }
  132. PYEXT17_TRANSLATE_EXC_RET(nullptr)
  133. }
  134. TensorWrapper::TensorWrapper(PyObject* args, PyObject* kwargs) {
  135. if (kwargs && PyDict_Size(kwargs)) {
  136. throw py::type_error("keyword argument not allowed");
  137. }
  138. auto nargs = PyTuple_Size(args);
  139. auto tup = py::reinterpret_borrow<py::tuple>(args);
  140. if (nargs == 0) {
  141. throw py::type_error("too few arguments");
  142. }
  143. if (auto* t = try_cast(tup[0].ptr())) {
  144. if (nargs > 1) {
  145. throw py::type_error("expect 1 argument");
  146. }
  147. m_tensor = t->m_tensor->copy();
  148. } else {
  149. if (nargs == 1) {
  150. auto arg0 = PyTuple_GetItem(args, 0);
  151. // for DeviceTensorND
  152. if (strstr(arg0->ob_type->tp_name, "DeviceTensorND")) {
  153. auto dv = py::handle(arg0).cast<DeviceTensorND>();
  154. m_tensor = std::make_shared<Tensor>(imperative::apply(
  155. CreateTensor(CreateTensor::Common, dv.comp_node(), dv.layout()),
  156. DeviceStorage::make(dv.storage()))[0]);
  157. } else {
  158. throw py::type_error(
  159. "single argument is not tensor, varnode or devicetensor");
  160. }
  161. } else {
  162. py::detail::loader_life_support life_sup; // FIXME!!!required to cast DType
  163. if (nargs != 5 && nargs != 6) {
  164. throw py::type_error("expect 5 or 6 arguments");
  165. }
  166. auto data = tup[0].cast<py::array>();
  167. DType dtype = tup[1].cast<DType>();
  168. CompNode cn = tup[2].cast<CompNode>();
  169. bool is_const = tup[3].cast<bool>();
  170. bool no_cache = nargs == 6 ? tup[4].cast<bool>() : false;
  171. std::string name;
  172. if (tup[nargs - 1].ptr() != Py_None)
  173. name = tup[nargs - 1].cast<std::string>();
  174. // const op
  175. {
  176. CreateTensor::Kind kind = is_const ? CreateTensor::Const
  177. : no_cache ? CreateTensor::Unique
  178. : CreateTensor::Common;
  179. HostTensorND ret(cn);
  180. ret = npy::np2tensor(data.ptr(), npy::Meth::copy_into(&ret), dtype);
  181. mgb_assert(
  182. ret.layout().is_empty() || ret.layout().is_contiguous(),
  183. "host value should be continuous");
  184. ValueShape shape;
  185. for (size_t i = 0; i < data.ndim(); ++i) {
  186. shape[shape.ndim++] = data.shape(i);
  187. }
  188. m_tensor = std::make_shared<Tensor>(imperative::apply(
  189. CreateTensor(kind, cn, ret.dtype(), shape),
  190. HostStorage::make(ret.storage()))[0]);
  191. }
  192. if (!name.empty()) {
  193. m_tensor->reset(
  194. imperative::apply(RenameValue(name), m_tensor->data())[0]);
  195. mgb_assert(
  196. ((std::string&)*m_tensor->data().name()) == name,
  197. "result name incorrect");
  198. }
  199. if (data.ndim() == 0) {
  200. mgb_assert(m_tensor->is_scalar(), "result should be scalar");
  201. }
  202. }
  203. }
  204. }
  205. PyObject* TensorWrapper::module_trace_info() {
  206. if (auto module_trace_info = module_trace_info_map.try_get(m_tensor->data())) {
  207. if (module_trace_info->ptr()) {
  208. return module_trace_info->inc_ref().ptr();
  209. }
  210. }
  211. PyErr_SetString(
  212. PyExc_AttributeError,
  213. "Has no attribute named \'_NodeMixin__node\', please "
  214. "set it first");
  215. return nullptr;
  216. }
  217. void TensorWrapper::set_module_trace_info(PyObject* obj) {
  218. // TODO: erase when obj == nullptr
  219. module_trace_info_map[m_tensor->data()] = py::reinterpret_borrow<py::object>(obj);
  220. }
  221. void TensorWrapper::_set_name(PyObject* dest) {
  222. auto py_dest = py::reinterpret_borrow<py::object>(dest);
  223. auto name = py_dest.cast<std::string>();
  224. m_tensor->set_name(name);
  225. }
  226. PyObject* TensorWrapper::_detail() {
  227. return py::str(m_tensor->data().unwrap().to_string()).release().ptr();
  228. }
  229. void TensorWrapper::_watch() {
  230. m_tensor->data().watch();
  231. }
  232. PyObject* TensorWrapper::shape() {
  233. auto shape = m_tensor->shape();
  234. if (!shape) {
  235. Py_RETURN_NONE;
  236. }
  237. py::tuple ret(shape->ndim);
  238. for (size_t i = 0; i < shape->ndim; ++i) {
  239. ret[i] = shape->at(i);
  240. }
  241. return ret.release().ptr();
  242. }
  243. PyObject* TensorWrapper::dtype() {
  244. return py::cast(m_tensor->dtype()).release().ptr();
  245. }
  246. PyObject* TensorWrapper::device() {
  247. return py::cast(m_tensor->comp_node()).release().ptr();
  248. }
  249. PyObject* TensorWrapper::numpy() {
  250. auto hv = m_tensor->numpy();
  251. if (!hv) {
  252. PyErr_SetString(PyExc_ValueError, "tensor invalid");
  253. return nullptr;
  254. }
  255. auto arr = py::reinterpret_steal<py::array>(
  256. npy::ndarray_from_tensor(hv->as_nd(true), npy::ShareType::TRY_SHARE));
  257. if (hv->shape().is_scalar()) {
  258. mgb_assert(PyArray_Check(arr.ptr()));
  259. return PyArray_Squeeze(reinterpret_cast<PyArrayObject*>(arr.ptr()));
  260. }
  261. return arr.release().ptr();
  262. }
  263. void TensorWrapper::reset(PyObject* tensor) {
  264. TensorWrapper* t = TensorWrapper::try_cast(tensor);
  265. if (!t) {
  266. throw py::type_error("expect Tensor");
  267. }
  268. m_tensor->reset(t->m_tensor->data());
  269. }
  270. PyObject* TensorWrapper::detach() {
  271. auto detached = imperative::apply(DetachGrad(), m_tensor->data())[0];
  272. return TensorWrapper::make(py_tensor_type, detached).release().ptr();
  273. }
  274. PyObject* TensorWrapper::_dev_tensor() {
  275. auto dv = m_tensor->data().dev_tensor();
  276. // TODO: handle scalar
  277. return py::cast(dv->as_nd(true)).release().ptr();
  278. }
  279. void TensorWrapper::_drop() {
  280. imperative::apply(DTRCommand(DTRCommand::Drop), m_tensor->data());
  281. }
  282. PyObject* TensorWrapper::isscalar() {
  283. if (m_tensor->is_scalar()) {
  284. Py_RETURN_TRUE;
  285. } else {
  286. Py_RETURN_FALSE;
  287. }
  288. }
  289. struct TensorWeakRef {
  290. std::weak_ptr<Tensor> wptr;
  291. TensorWeakRef(const TensorWrapper& tw) : wptr(tw.m_tensor) {}
  292. py::object operator()() {
  293. if (auto p = wptr.lock()) {
  294. return TensorWrapper::make(py_tensor_type, p);
  295. }
  296. return py::none();
  297. }
  298. int _use_cnt() { return wptr.use_count(); }
  299. };
  300. /* ============== convert inputs ============== */
  301. // map numpy.dtype.kind to priority
  302. inline uint8_t category_priority(char c) {
  303. switch (c) {
  304. case 'f':
  305. return 3; // floating-point
  306. case 'i':
  307. return 2; // signed integer
  308. case 'u':
  309. return 2; // unsigned integer
  310. case 'b':
  311. return 1; // boolean
  312. default:
  313. return 0;
  314. }
  315. }
  316. // Returns the maximum value of the priority of each type in the list `types`.
  317. uint8_t max_priority(SmallVector<PyArray_Descr*> types) {
  318. if (types.size() == 0) {
  319. return 0;
  320. } else {
  321. uint8_t max_p = 0;
  322. for (auto&& desc : types) {
  323. max_p = std::max(max_p, category_priority(desc->kind));
  324. }
  325. return max_p;
  326. }
  327. }
  328. // Returns the data type with sufficient size to hold all types of
  329. // category `cat` in the list `types`.
  330. PyArray_Descr* promote_types(SmallVector<PyArray_Descr*> types, uint8_t cat) {
  331. // Return value: New reference
  332. SmallVector<PyArray_Descr*> used_types;
  333. for (auto&& desc : types) {
  334. auto&& v = category_priority(desc->kind);
  335. if (v == cat) {
  336. used_types.emplace_back(desc);
  337. }
  338. }
  339. mgb_assert(used_types.size() > 0, "size of used_types is 0");
  340. PyArray_Descr* res = used_types[0];
  341. Py_INCREF(res);
  342. for (size_t i = 1; i < used_types.size(); ++i) {
  343. PyArray_Descr* tmp = PyArray_PromoteTypes(used_types[i], res);
  344. Py_DECREF(res);
  345. res = tmp;
  346. }
  347. return res;
  348. }
  349. PyArray_Descr* scalar2dtype(PyObject* arg) {
  350. // Return value: New reference
  351. if (PyBool_Check(arg)) {
  352. auto&& descr = PyArray_DescrFromType(NPY_BOOL);
  353. return descr;
  354. }
  355. if (PyLong_CheckExact(arg)) {
  356. auto&& descr = PyArray_DescrFromType(NPY_INT32);
  357. return descr;
  358. }
  359. if (PyFloat_CheckExact(arg)) {
  360. auto&& descr = PyArray_DescrFromType(NPY_FLOAT32);
  361. return descr;
  362. }
  363. return nullptr;
  364. }
  365. PyArray_Descr* _dtype_promotion(PyObject* const* args, size_t nargs) {
  366. // Return value: New reference
  367. SmallVector<PyArray_Descr*> tensors;
  368. SmallVector<PyArray_Descr*> scalars;
  369. bool is_tuple = false;
  370. PyObject* tuple = nullptr;
  371. if (nargs == 1 && (PyTuple_Check(args[0]) || PyList_Check(args[0]))) {
  372. if (PyList_Check(args[0])) {
  373. tuple = PyList_AsTuple(args[0]);
  374. } else {
  375. tuple = args[0];
  376. Py_INCREF(tuple);
  377. }
  378. nargs = PyTuple_Size(tuple);
  379. is_tuple = true;
  380. }
  381. for (size_t i = 0; i < nargs; ++i) {
  382. PyObject* handle = is_tuple ? PyTuple_GetItem(tuple, i) : args[i];
  383. if (handle == Py_None)
  384. continue;
  385. TensorWrapper* tw = TensorWrapper::try_cast(handle);
  386. if (tw) {
  387. mgb::DType type = tw->m_tensor->dtype();
  388. auto&& descr = npy::dtype_mgb2np_descr(type);
  389. Py_INCREF(descr.get());
  390. tensors.emplace_back(descr.get());
  391. } else {
  392. if (PyArray_Check(handle) || PyArray_CheckScalar(handle)) {
  393. auto&& descr = PyArray_DescrFromObject(handle, nullptr);
  394. tensors.emplace_back(descr);
  395. continue;
  396. }
  397. if (py::isinstance<PySymbolVar>(py::handle(handle))) {
  398. auto var = py::handle(handle).cast<PySymbolVar*>();
  399. mgb::DType type = var->m_node->dtype();
  400. auto&& descr = npy::dtype_mgb2np_descr(type);
  401. Py_INCREF(descr.get());
  402. tensors.emplace_back(descr.get());
  403. continue;
  404. }
  405. PyArray_Descr* descr = scalar2dtype(handle);
  406. if (descr) {
  407. scalars.emplace_back(descr);
  408. continue;
  409. }
  410. }
  411. }
  412. auto max_pri_scalars = max_priority(scalars);
  413. auto max_pri_tensors = max_priority(tensors);
  414. if (max_pri_scalars <= 0 && max_pri_tensors <= 0) {
  415. throw py::value_error("invalid input, no dtype avaliable");
  416. }
  417. PyArray_Descr* res;
  418. if (max_pri_scalars > max_pri_tensors) {
  419. res = promote_types(scalars, max_pri_scalars);
  420. } else {
  421. res = promote_types(tensors, max_pri_tensors);
  422. }
  423. for (auto* p : tensors) {
  424. Py_DECREF(p);
  425. }
  426. for (auto* p : scalars) {
  427. Py_DECREF(p);
  428. }
  429. Py_XDECREF(tuple);
  430. return res;
  431. }
  432. CompNode _get_device(PyObject* const* args, size_t nargs) {
  433. bool is_tuple = false;
  434. PyObject* tuple = nullptr;
  435. if (nargs == 1 && (PyTuple_Check(args[0]) || PyList_Check(args[0]))) {
  436. if (PyList_Check(args[0])) {
  437. tuple = PyList_AsTuple(args[0]);
  438. } else {
  439. tuple = args[0];
  440. Py_INCREF(tuple);
  441. }
  442. nargs = PyTuple_Size(tuple);
  443. is_tuple = true;
  444. }
  445. bool valid = false;
  446. CompNode cn;
  447. for (size_t i = 0; i < nargs; ++i) {
  448. PyObject* handle = is_tuple ? PyTuple_GetItem(tuple, i) : args[i];
  449. TensorWrapper* tw = TensorWrapper::try_cast(handle);
  450. bool is_symvar = py::isinstance<PySymbolVar>(py::handle(handle));
  451. if (tw || is_symvar) {
  452. if (!valid) {
  453. cn = tw ? tw->m_tensor->comp_node()
  454. : py::handle(handle).cast<PySymbolVar*>()->m_node->comp_node();
  455. valid = true;
  456. } else {
  457. CompNode cn1 = tw ? tw->m_tensor->comp_node()
  458. : py::handle(handle)
  459. .cast<PySymbolVar*>()
  460. ->m_node->comp_node();
  461. if (cn1 != cn) {
  462. throw py::value_error(ssprintf(
  463. "ambiguous device: %s vs %s", cn.to_string().c_str(),
  464. cn1.to_string().c_str()));
  465. }
  466. }
  467. }
  468. }
  469. if (!valid) {
  470. return CompNode::load(get_default_device());
  471. }
  472. Py_XDECREF(tuple);
  473. return cn;
  474. }
  475. bool is_scalar(PyObject* tensor) {
  476. if (py::isinstance<PySymbolVar>(py::handle(tensor))) {
  477. auto var = py::handle(tensor).cast<PySymbolVar*>();
  478. return var->is_scalar;
  479. }
  480. auto* tw = TensorWrapper::try_cast(tensor);
  481. if (tw) {
  482. return tw->m_tensor->is_scalar();
  483. }
  484. return PyArray_CheckAnyScalar(tensor);
  485. }
  486. bool is_bool_list(PyObject* arg) {
  487. if (!PyList_Check(arg)) {
  488. return false;
  489. }
  490. size_t sz = PyList_Size(arg);
  491. if (!sz) {
  492. return false;
  493. }
  494. for (size_t i = 0; i < sz; ++i) {
  495. PyObject* handle = PyList_GetItem(arg, i);
  496. if (!PyBool_Check(handle)) {
  497. return false;
  498. }
  499. }
  500. return true;
  501. }
  502. bool is_bool_dtype(PyObject* args) {
  503. if (!PyObject_HasAttrString(args, "dtype"))
  504. return false;
  505. PyObject* dobj = PyObject_GetAttrString(args, "dtype");
  506. PyArray_Descr* dtype;
  507. PyArray_DescrConverter(dobj, &dtype);
  508. bool ret = (dtype->kind == 'b');
  509. Py_XDECREF(dtype);
  510. Py_XDECREF(dobj);
  511. return ret;
  512. }
  513. py::object _Const(
  514. py::handle value, py::handle dtype, py::handle device, py::handle ref) {
  515. py::object val = py::reinterpret_borrow<py::object>(value);
  516. if (PyArray_Check(value.ptr())) {
  517. py::tuple strides =
  518. py::reinterpret_borrow<py::tuple>(getattr(value, "strides"));
  519. bool need_squeeze = false;
  520. for (size_t i = 0; i < strides.size(); ++i) {
  521. if (strides[i].cast<ptrdiff_t>() == 0) {
  522. need_squeeze = true;
  523. }
  524. }
  525. if (need_squeeze) {
  526. val = py::reinterpret_borrow<py::array>(value);
  527. val = val.attr("squeeze")();
  528. val = val.attr("reshape")(val.attr("shape"));
  529. }
  530. }
  531. if (py::isinstance<PySymbolVar>(ref)) {
  532. auto ref_var = ref.cast<PySymbolVar*>();
  533. auto* graph = ref_var->m_node->owner_graph();
  534. auto cn = device.cast<CompNode>();
  535. OperatorNodeConfig config(cn);
  536. auto hv = npy::np2tensor(
  537. val.ptr(), npy::Meth::borrow(cn), dtype.cast<mgb::DType>());
  538. auto typeobj = ref.get_type();
  539. return typeobj(opr::ImmutableTensor::make(*graph, hv, config).node());
  540. }
  541. py::tuple tup = py::make_tuple(val, dtype, device, true, false, py::none());
  542. return TensorWrapper::make(py_tensor_type, tup.ptr(), nullptr);
  543. }
  544. py::tuple _make_shape_tuple(py::handle shape) {
  545. py::list orig;
  546. py::list ret(0);
  547. auto solve_one = [&](py::handle val) {
  548. if (TensorWrapper::try_cast(val.ptr()) || py::isinstance<PySymbolVar>(val)) {
  549. py::object np = getattr(val, "numpy")();
  550. PyArrayObject* arr = (PyArrayObject*)np.ptr();
  551. PyObject* maybe_list = PyArray_ToList(arr);
  552. if (PyList_Check(maybe_list)) {
  553. py::list may = py::reinterpret_steal<py::list>(maybe_list);
  554. for (size_t i = 0; i < may.size(); ++i) {
  555. ret.append(may[i]);
  556. }
  557. } else {
  558. mgb_assert(PyLong_Check(maybe_list));
  559. ret.append(PyLong_AsLong(maybe_list));
  560. Py_XDECREF(maybe_list);
  561. }
  562. } else if (PyArray_Check(val.ptr())) {
  563. ret.append(PyArray_PyIntAsInt(val.ptr()));
  564. } else {
  565. ret.append(PyLong_AsLong(val.ptr()));
  566. }
  567. };
  568. if (PyArray_Check(shape.ptr()) && !PyArray_CheckAnyScalar(shape.ptr())) {
  569. orig = py::reinterpret_steal<py::list>(
  570. PyArray_ToList((PyArrayObject*)shape.ptr()));
  571. for (size_t i = 0; i < orig.size(); ++i) {
  572. solve_one(orig[i]);
  573. }
  574. } else if (PyList_Check(shape.ptr())) {
  575. orig = py::reinterpret_borrow<py::list>(shape);
  576. for (size_t i = 0; i < orig.size(); ++i) {
  577. solve_one(orig[i]);
  578. }
  579. } else if (PyTuple_Check(shape.ptr())) {
  580. py::tuple tup = py::reinterpret_borrow<py::tuple>(shape);
  581. for (size_t i = 0; i < tup.size(); ++i) {
  582. solve_one(tup[i]);
  583. }
  584. } else {
  585. solve_one(shape);
  586. }
  587. return py::reinterpret_steal<py::tuple>(PyList_AsTuple(ret.ptr()));
  588. }
  589. py::object _get_index(py::object tensor, py::object src) {
  590. if (!TensorWrapper::try_cast(tensor.ptr()) &&
  591. !py::isinstance<PySymbolVar>(tensor)) {
  592. auto get_const = [&](mgb::DType dtype) -> py::object {
  593. return _Const(tensor, py::cast(dtype), src.attr("device"), src);
  594. };
  595. if (is_bool_list(tensor.ptr()) || is_bool_dtype(tensor.ptr())) {
  596. tensor = get_const(dtype::Bool());
  597. } else {
  598. tensor = get_const(dtype::Int32());
  599. }
  600. if (!is_bool_dtype(tensor.ptr())) {
  601. return tensor;
  602. }
  603. } else {
  604. if (!is_bool_dtype(tensor.ptr())) {
  605. return tensor;
  606. }
  607. }
  608. static std::shared_ptr<OpDef> op = CondTake::make();
  609. std::vector<PyObject*> p;
  610. p.resize(3);
  611. py::object Op = py::cast(op);
  612. p[0] = Op.ptr();
  613. p[1] = tensor.ptr();
  614. p[2] = tensor.ptr();
  615. py::tuple ret =
  616. py::reinterpret_steal<py::object>(py_apply(NULL, p.data(), p.size()));
  617. return ret[1];
  618. }
  619. py::tuple _try_cond_take(py::handle tensor, py::handle index) {
  620. if (!hasattr(index, "dtype") || !hasattr(index, "shape")) {
  621. return py::tuple();
  622. }
  623. if (!is_bool_dtype(index.ptr()) ||
  624. _make_shape_tuple(getattr(index, "shape"))
  625. .not_equal(_make_shape_tuple(getattr(tensor, "shape")))) {
  626. return py::tuple();
  627. }
  628. py::object iobj;
  629. if (PyArray_Check(index.ptr())) {
  630. iobj =
  631. _Const(index, py::cast((mgb::DType)dtype::Bool()),
  632. getattr(tensor, "device"), tensor);
  633. } else {
  634. iobj = py::reinterpret_borrow<py::object>(index);
  635. }
  636. static std::shared_ptr<OpDef> op = CondTake::make();
  637. std::vector<PyObject*> p;
  638. p.resize(3);
  639. py::object Op = py::cast(op);
  640. p[0] = Op.ptr();
  641. p[1] = tensor.ptr();
  642. p[2] = iobj.ptr();
  643. py::tuple ret =
  644. py::reinterpret_steal<py::object>(py_apply(NULL, p.data(), p.size()));
  645. return ret;
  646. }
  647. py::tuple _remove_ellipsis(py::object tensor, py::tuple tuple_val) {
  648. size_t tuple_size = tuple_val.size();
  649. size_t ndim_sum = 0, cur_sum = 0;
  650. int pos = -1;
  651. bool has_unknown_ndim_bool_index = false;
  652. for (size_t i = 0; i < tuple_size; ++i) {
  653. py::object handle = tuple_val[i];
  654. if (handle.ptr() == Py_Ellipsis) {
  655. pos = static_cast<int>(i);
  656. for (size_t j = 0; j < i; ++j) {
  657. py::object t = tuple_val[j];
  658. if (t.ptr() == Py_Ellipsis) {
  659. throw py::index_error("only one ellipsis is allowed.");
  660. }
  661. }
  662. } else {
  663. size_t ndim_incr = 1;
  664. if (hasattr(handle, "dtype") && is_bool_dtype(handle.ptr()) &&
  665. hasattr(handle, "ndim")) {
  666. py::object ndim = getattr(handle, "ndim");
  667. if (PyLong_Check(ndim.ptr())) {
  668. ndim_incr = PyLong_AsLong(ndim.ptr());
  669. } else {
  670. has_unknown_ndim_bool_index = true;
  671. }
  672. }
  673. cur_sum += ndim_incr;
  674. }
  675. }
  676. if (pos == -1) {
  677. return tuple_val;
  678. } else {
  679. if (has_unknown_ndim_bool_index) {
  680. throw py::index_error(
  681. "does not support bool index with unknown shape when using "
  682. "Ellipsis.");
  683. }
  684. try {
  685. ndim_sum = getattr(tensor, "ndim").cast<size_t>();
  686. } catch (py::error_already_set& err) {
  687. throw py::index_error(
  688. "does not support Ellipsis when tensor's ndim is unknown.");
  689. }
  690. py::tuple ret(ndim_sum - cur_sum + tuple_size - 1);
  691. size_t idx = 0;
  692. for (size_t i = 0; i < tuple_size; ++i) {
  693. if (i == pos) {
  694. for (size_t j = cur_sum; j < ndim_sum; ++j) {
  695. ret[idx++] = PySlice_New(NULL, NULL, NULL);
  696. }
  697. } else {
  698. ret[idx++] = tuple_val[i];
  699. }
  700. }
  701. return ret;
  702. }
  703. }
  704. py::tuple _expand_bool_dim(py::object tensor, py::tuple tuple_val) {
  705. py::tuple cur_shape = _make_shape_tuple(py::handle(getattr(tensor, "shape")));
  706. py::list new_tuple_val(0);
  707. size_t offset = 0;
  708. size_t tdim = 0;
  709. for (size_t i = 0; i < tuple_val.size(); ++i) {
  710. py::handle k = tuple_val[i];
  711. if (is_bool_dtype(k.ptr())) {
  712. size_t ndim = getattr(k, "ndim").cast<size_t>();
  713. if (ndim > 1) {
  714. py::tuple ishape = _make_shape_tuple(py::handle(getattr(k, "shape")));
  715. for (size_t j = 0; j < ndim; ++j) {
  716. if (cur_shape[tdim + j - offset].cast<size_t>() !=
  717. ishape[j].cast<size_t>()) {
  718. std::string msg =
  719. "boolean index did not match tensor along dimension " +
  720. std::to_string(tdim + j) + "; dimension is " +
  721. std::to_string(
  722. cur_shape[tdim + j - offset].cast<size_t>()) +
  723. " but corresponding boolean dimension is " +
  724. std::to_string(ishape[j].cast<size_t>());
  725. throw py::index_error(msg.c_str());
  726. }
  727. }
  728. py::object new_k = getattr(k, "reshape")(-1);
  729. py::object kshape = getattr(new_k, "shape");
  730. py::list new_shape(0);
  731. PyObject* sym = PyObject_CallObject(cpp_use_symbolic_shape, nullptr);
  732. bool is_sym = (sym == Py_True);
  733. Py_XDECREF(sym);
  734. if (is_sym) {
  735. py::object tshape = getattr(tensor, "shape");
  736. for (size_t j = 0; j < i; ++j) {
  737. new_shape.append(tshape[py::int_(j)]);
  738. }
  739. new_shape.append(kshape[py::int_(0)]);
  740. for (size_t j = tdim + ndim - offset; j < cur_shape.size(); ++j) {
  741. new_shape.append(cur_shape[j]);
  742. }
  743. py::tuple args = py::make_tuple(new_shape);
  744. PyObject* shape_tensor =
  745. PyObject_CallObject(cpp_astensor1d, args.ptr());
  746. py::object reshape_func = getattr(tensor, "reshape");
  747. Py_INCREF(shape_tensor);
  748. PyObject* Args = PyTuple_New(1);
  749. PyTuple_SetItem(Args, 0, shape_tensor);
  750. PyObject* new_tensor =
  751. PyObject_CallObject(reshape_func.ptr(), Args);
  752. Py_XDECREF(Args);
  753. tensor = py::reinterpret_steal<py::object>(new_tensor);
  754. cur_shape = _make_shape_tuple(py::handle(shape_tensor));
  755. Py_XDECREF(shape_tensor);
  756. } else {
  757. for (size_t j = 0; j < i; ++j) {
  758. new_shape.append(cur_shape[j]);
  759. }
  760. new_shape.append(py::reinterpret_borrow<py::tuple>(kshape)[0]);
  761. for (size_t j = tdim + ndim - offset; j < cur_shape.size(); ++j) {
  762. new_shape.append(cur_shape[j]);
  763. }
  764. cur_shape = new_shape;
  765. tensor = getattr(tensor, "reshape")(cur_shape);
  766. }
  767. offset++;
  768. tdim += ndim;
  769. }
  770. new_tuple_val.append(k);
  771. } else {
  772. new_tuple_val.append(k);
  773. tdim++;
  774. }
  775. }
  776. return py::make_tuple(tensor, py::reinterpret_borrow<py::tuple>(new_tuple_val));
  777. }
  778. py::tuple _unpack_indexes(py::handle inp_hdl, py::handle idx_hdl) {
  779. py::object inp = py::reinterpret_borrow<py::object>(inp_hdl);
  780. py::tuple tuple_val;
  781. if (py::isinstance<py::tuple>(idx_hdl)) {
  782. tuple_val = py::reinterpret_borrow<py::tuple>(idx_hdl);
  783. } else {
  784. tuple_val = py::make_tuple(idx_hdl);
  785. }
  786. bool use_subtensor = true;
  787. bool need_remove_ellipsis = false;
  788. bool need_expand_bool_dim = false;
  789. size_t idx_ndim = 0;
  790. for (size_t i = 0; i < tuple_val.size(); ++i) {
  791. py::object k = tuple_val[i];
  792. if (k.ptr() == Py_None) {
  793. throw py::index_error("newaxis is not allowed here");
  794. } else if (k.ptr() == Py_Ellipsis) {
  795. need_remove_ellipsis = true;
  796. } else {
  797. if (is_bool_dtype(k.ptr()) && hasattr(k, "ndim")) {
  798. size_t ndim = getattr(k, "ndim").cast<size_t>();
  799. idx_ndim += ndim;
  800. if (ndim > 1) {
  801. need_expand_bool_dim = true;
  802. }
  803. } else {
  804. idx_ndim++;
  805. }
  806. }
  807. }
  808. try {
  809. size_t inp_ndim = getattr(inp, "ndim").cast<size_t>();
  810. if (idx_ndim > inp_ndim) {
  811. std::string msg = "too many indices for tensor: tensor is " +
  812. std::to_string(inp_ndim) + "-dimensional, but " +
  813. std::to_string(idx_ndim) + " were indexed";
  814. throw py::index_error(msg.c_str());
  815. }
  816. } catch (py::error_already_set& err) {
  817. ; // ignore
  818. }
  819. if (need_remove_ellipsis) {
  820. tuple_val = _remove_ellipsis(inp, tuple_val);
  821. }
  822. if (need_expand_bool_dim) {
  823. py::object shape = getattr(inp, "shape");
  824. if (shape.ptr() != Py_None) {
  825. py::tuple ret = _expand_bool_dim(inp, tuple_val);
  826. inp = ret[0];
  827. tuple_val = ret[1];
  828. }
  829. }
  830. py::list items;
  831. py::list tensors;
  832. int cur_axis = -1;
  833. for (size_t i = 0; i < tuple_val.size(); ++i) {
  834. py::object handle = tuple_val[i];
  835. cur_axis++;
  836. if (!is_scalar(handle.ptr()) && !PySlice_Check(handle.ptr())) {
  837. use_subtensor = false;
  838. }
  839. py::list item;
  840. item.append(cur_axis);
  841. auto push = [&](PyObject* v) {
  842. if (v == Py_None) {
  843. item.append(false);
  844. } else {
  845. item.append(true);
  846. tensors.append(_get_index(py::reinterpret_borrow<py::object>(v), inp));
  847. }
  848. };
  849. if (PySlice_Check(handle.ptr())) {
  850. PySliceObject* s = (PySliceObject*)handle.ptr();
  851. if (s->start == Py_None && s->stop == Py_None && s->step == Py_None) {
  852. continue;
  853. }
  854. push(s->start);
  855. push(s->stop);
  856. push(s->step);
  857. item.append(false);
  858. } else {
  859. for (size_t j = 0; j < 3; j++)
  860. item.append(false);
  861. push(handle.ptr());
  862. }
  863. items.append(item);
  864. }
  865. return py::make_tuple(inp, tensors, items, use_subtensor, need_expand_bool_dim);
  866. }
  867. py::object _getitem_cpp(py::handle inp_hdl, py::handle idx_hdl) {
  868. py::tuple try_res = _try_cond_take(inp_hdl, idx_hdl);
  869. if (try_res.size() == 2) {
  870. return try_res[0];
  871. }
  872. py::tuple up = _unpack_indexes(inp_hdl, idx_hdl);
  873. py::object tensor = py::reinterpret_borrow<py::object>(up[0]);
  874. py::list tensors = py::reinterpret_borrow<py::list>(up[1]);
  875. py::list py_items = py::reinterpret_borrow<py::list>(up[2]);
  876. std::vector<std::tuple<int8_t, bool, bool, bool, bool>> cpp_items;
  877. for (size_t i = 0; i < py_items.size(); ++i) {
  878. py::list item = py::reinterpret_borrow<py::list>(py_items[i]);
  879. cpp_items.push_back(
  880. {item[0].cast<int8_t>(), item[1].cast<bool>(), item[2].cast<bool>(),
  881. item[3].cast<bool>(), item[4].cast<bool>()});
  882. }
  883. static std::shared_ptr<OpDef> op;
  884. if (up[3].cast<bool>()) {
  885. op = Subtensor::make(cpp_items);
  886. } else {
  887. op = IndexingMultiAxisVec::make(cpp_items);
  888. }
  889. std::vector<PyObject*> p;
  890. p.resize(tensors.size() + 2);
  891. py::object Op = py::cast(op);
  892. p[0] = Op.ptr();
  893. p[1] = tensor.ptr();
  894. for (size_t i = 0; i < tensors.size(); ++i) {
  895. p[i + 2] = tensors[i].ptr();
  896. }
  897. py::tuple ret =
  898. py::reinterpret_steal<py::object>(py_apply(NULL, p.data(), p.size()));
  899. return ret[0];
  900. }
  901. py::object _setitem_cpp(py::handle inp_hdl, py::handle idx_hdl, py::handle val_hdl) {
  902. py::object org_shape = getattr(inp_hdl, "shape");
  903. py::object val = py::reinterpret_borrow<py::object>(val_hdl);
  904. if (!TensorWrapper::try_cast(val.ptr()) && !py::isinstance<PySymbolVar>(val)) {
  905. val =
  906. _Const(val_hdl, getattr(inp_hdl, "dtype"), getattr(inp_hdl, "device"),
  907. inp_hdl);
  908. }
  909. py::tuple up = _unpack_indexes(inp_hdl, idx_hdl);
  910. py::object tensor = py::reinterpret_borrow<py::object>(up[0]);
  911. py::list tensors = py::reinterpret_borrow<py::list>(up[1]);
  912. py::list py_items = py::reinterpret_borrow<py::list>(up[2]);
  913. std::vector<std::tuple<int8_t, bool, bool, bool, bool>> cpp_items;
  914. for (size_t i = 0; i < py_items.size(); ++i) {
  915. py::list item = py::reinterpret_borrow<py::list>(py_items[i]);
  916. cpp_items.push_back(
  917. {item[0].cast<int8_t>(), item[1].cast<bool>(), item[2].cast<bool>(),
  918. item[3].cast<bool>(), item[4].cast<bool>()});
  919. }
  920. static std::shared_ptr<OpDef> op, set_op;
  921. if (up[3].cast<bool>()) {
  922. op = Subtensor::make(cpp_items);
  923. } else {
  924. op = IndexingMultiAxisVec::make(cpp_items);
  925. }
  926. std::vector<PyObject*> p;
  927. p.resize(tensors.size() + 2);
  928. py::object Op = py::cast(op);
  929. p[0] = Op.ptr();
  930. p[1] = tensor.ptr();
  931. for (size_t i = 0; i < tensors.size(); ++i) {
  932. p[i + 2] = tensors[i].ptr();
  933. }
  934. py::tuple ret =
  935. py::reinterpret_steal<py::object>(py_apply(NULL, p.data(), p.size()));
  936. py::object tmp_result = ret[0];
  937. try {
  938. py::object value_tuple_shape = val.attr("_tuple_shape");
  939. py::object tmp_result_tuple_shape = tmp_result.attr("_tuple_shape");
  940. py::tuple value_shape = py::reinterpret_borrow<py::tuple>(value_tuple_shape);
  941. py::tuple tmp_result_shape =
  942. py::reinterpret_borrow<py::tuple>(tmp_result_tuple_shape);
  943. for (size_t i = 0; i < value_shape.size() && i < tmp_result_shape.size(); ++i) {
  944. size_t vs = value_shape[value_shape.size() - i - 1].cast<size_t>();
  945. size_t ts =
  946. tmp_result_shape[tmp_result_shape.size() - i - 1].cast<size_t>();
  947. if (vs != 1 && vs != ts) {
  948. std::string lhs = "", rhs = "";
  949. for (size_t j = 0; j < tmp_result_shape.size(); ++j) {
  950. lhs += std::to_string(tmp_result_shape[j].cast<size_t>());
  951. if (j)
  952. lhs += ",";
  953. }
  954. for (size_t j = 0; j < value_shape.size(); ++j) {
  955. rhs += std::to_string(value_shape[j].cast<size_t>());
  956. if (j)
  957. rhs += ",";
  958. }
  959. throw py::value_error(
  960. "cannot copy tensor with shape (" + rhs +
  961. ") to subtensor with shape (" + lhs + ")");
  962. }
  963. }
  964. } catch (py::error_already_set& err) {
  965. ;
  966. }
  967. py::object broadcast_func = getattr(val, "_broadcast");
  968. PyObject* Args = PyTuple_New(1);
  969. PyTuple_SetItem(Args, 0, getattr(tmp_result, "shape").release().ptr());
  970. PyObject* new_val = PyObject_CallObject(broadcast_func.ptr(), Args);
  971. Py_XDECREF(Args);
  972. val = py::reinterpret_steal<py::object>(new_val);
  973. if (up[3].cast<bool>()) {
  974. set_op = SetSubtensor::make(cpp_items);
  975. } else {
  976. set_op = IndexingSetMultiAxisVec::make(cpp_items);
  977. }
  978. std::vector<PyObject*> q;
  979. q.resize(tensors.size() + 3);
  980. py::object Set_Op = py::cast(set_op);
  981. q[0] = Set_Op.ptr();
  982. q[1] = tensor.ptr();
  983. q[2] = val.ptr();
  984. for (size_t i = 0; i < tensors.size(); ++i) {
  985. q[i + 3] = tensors[i].ptr();
  986. }
  987. py::tuple result =
  988. py::reinterpret_steal<py::object>(py_apply(NULL, q.data(), q.size()));
  989. py::object res = result[0];
  990. if (up[4].cast<bool>()) {
  991. py::object reshape_func = getattr(res, "reshape");
  992. PyObject* Args = PyTuple_New(1);
  993. PyTuple_SetItem(Args, 0, org_shape.release().ptr());
  994. PyObject* new_tensor = PyObject_CallObject(reshape_func.ptr(), Args);
  995. Py_XDECREF(Args);
  996. res = py::reinterpret_steal<py::object>(new_tensor);
  997. }
  998. return res;
  999. }
  1000. // Returns the dtype that would result from performing an arithmetic
  1001. // operation on the provided input tensors and scalars.
  1002. PyObject* dtype_promotion(PyObject* self, PyObject* const* args, size_t nargs) {
  1003. if (!nargs) {
  1004. PyErr_SetString(PyExc_TypeError, "empty input is not allowed");
  1005. return nullptr;
  1006. }
  1007. try {
  1008. PyArray_Descr* res = _dtype_promotion(args, nargs);
  1009. return py::cast(npy::dtype_np2mgb_descr(res)).release().ptr();
  1010. }
  1011. PYEXT17_TRANSLATE_EXC_RET(nullptr)
  1012. }
  1013. PyObject* get_device(PyObject* self, PyObject* const* args, size_t nargs) {
  1014. if (!nargs) {
  1015. PyErr_SetString(PyExc_TypeError, "empty input is not allowed");
  1016. return nullptr;
  1017. }
  1018. try {
  1019. CompNode cn = _get_device(args, nargs);
  1020. return py::cast(cn).release().ptr();
  1021. }
  1022. PYEXT17_TRANSLATE_EXC_RET(nullptr)
  1023. }
  1024. PyObject* make_shape_tuple(PyObject* self, PyObject* const* args, size_t nargs) {
  1025. try {
  1026. return _make_shape_tuple(py::handle(args[0])).release().ptr();
  1027. }
  1028. PYEXT17_TRANSLATE_EXC_RET(nullptr)
  1029. }
  1030. PyObject* getitem_cpp(PyObject* self, PyObject* const* args, size_t nargs) {
  1031. try {
  1032. return _getitem_cpp(py::handle(args[0]), py::handle(args[1])).release().ptr();
  1033. }
  1034. PYEXT17_TRANSLATE_EXC_RET(nullptr)
  1035. }
  1036. PyObject* setitem_cpp(PyObject* self, PyObject* const* args, size_t nargs) {
  1037. try {
  1038. return _setitem_cpp(
  1039. py::handle(args[0]), py::handle(args[1]), py::handle(args[2]))
  1040. .release()
  1041. .ptr();
  1042. }
  1043. PYEXT17_TRANSLATE_EXC_RET(nullptr)
  1044. }
  1045. #ifdef METH_FASTCALL
  1046. #define MGE_PY_INTERFACE(NAME, FUNC) \
  1047. { #NAME, (PyCFunction)FUNC, METH_FASTCALL, nullptr }
  1048. #else
  1049. #define WRAP_FUNC_PY35(FUNC) \
  1050. PyObject* py35_##FUNC(PyObject* self, PyObject* args) { \
  1051. auto* arr = &PyTuple_GET_ITEM(args, 0); \
  1052. auto size = PyTuple_GET_SIZE(args); \
  1053. return FUNC(self, arr, size); \
  1054. }
  1055. WRAP_FUNC_PY35(py_apply);
  1056. WRAP_FUNC_PY35(dtype_promotion);
  1057. WRAP_FUNC_PY35(get_device);
  1058. WRAP_FUNC_PY35(make_shape_tuple);
  1059. WRAP_FUNC_PY35(getitem_cpp);
  1060. WRAP_FUNC_PY35(setitem_cpp);
  1061. #undef WRAP_FUNC_PY35
  1062. #define MGE_PY_INTERFACE(NAME, FUNC) \
  1063. { #NAME, (PyCFunction)py35_##FUNC, METH_VARARGS, nullptr }
  1064. #endif
  1065. void init_tensor(py::module m) {
  1066. imperative::Tensor::static_initialize();
  1067. static auto& transformations = TransformationManager::get_instance();
  1068. using Segment = TransformationManager::Segment;
  1069. auto* channel = interpreter::Interpreter::inst().create_channel().release();
  1070. interpreter_for_py = channel;
  1071. transformations.register_at<Segment::Eval>(
  1072. std::make_shared<InterpreterTransformation>(
  1073. std::unique_ptr<interpreter::Interpreter::Channel>(channel)));
  1074. transformations.register_at<Segment::Scalar>(
  1075. std::make_shared<ScalarTransformation>());
  1076. static py::exception<interpreter::AsyncError> py_async_error(
  1077. m, "AsyncError", PyExc_RuntimeError);
  1078. py::register_exception_translator([](std::exception_ptr p) {
  1079. try {
  1080. if (p)
  1081. std::rethrow_exception(p);
  1082. } catch (const interpreter::AsyncError& e) {
  1083. pyext17::pybind11_translate_exception(e.nested_ptr());
  1084. if (PyErr_Occurred()) {
  1085. PyObject *exc, *val, *tb;
  1086. PyErr_Fetch(&exc, &val, &tb);
  1087. PyErr_NormalizeException(&exc, &val, &tb);
  1088. if (tb) {
  1089. PyException_SetTraceback(val, tb);
  1090. }
  1091. auto val2 = py_async_error.py::object::operator()(
  1092. "An async error is reported. See above for the actual cause."
  1093. " Hint: This is where it is reported, not where it happened."
  1094. " You may call `megengine.config.async_level = 0 "
  1095. "to get better error reporting.");
  1096. PyException_SetCause(
  1097. val2.ptr(), val); // PyException_SetCause steals reference
  1098. Py_XDECREF(exc);
  1099. Py_XDECREF(tb);
  1100. PyErr_Restore(
  1101. py_async_error.inc_ref().ptr(), val2.release().ptr(), nullptr);
  1102. } else {
  1103. py_async_error("Unkown async error");
  1104. }
  1105. }
  1106. });
  1107. auto* tensor_type =
  1108. TensorWrapper::wrap_t::type()
  1109. .def<&TensorWrapper::numpy>("numpy")
  1110. .def_getset<&TensorWrapper::shape>("shape")
  1111. .def_getset<&TensorWrapper::dtype>("dtype")
  1112. .def_getset<&TensorWrapper::device>("device")
  1113. .def<&TensorWrapper::reset>("_reset")
  1114. .def<&TensorWrapper::isscalar>("_isscalar")
  1115. .def<&TensorWrapper::detach>("detach")
  1116. // TODO: remove this
  1117. .def<&TensorWrapper::_dev_tensor>("_dev_tensor")
  1118. .def<&TensorWrapper::_drop>("_drop")
  1119. .def<&TensorWrapper::_use_cnt>("_use_cnt")
  1120. .def<&TensorWrapper::_detail>("_detail")
  1121. .def<&TensorWrapper::_set_name>("_set_name")
  1122. .def<&TensorWrapper::_watch>("_watch")
  1123. .def_getset<
  1124. &TensorWrapper::module_trace_info,
  1125. &TensorWrapper::set_module_trace_info>("_NodeMixin__node")
  1126. .finalize();
  1127. if (!tensor_type)
  1128. throw py::error_already_set();
  1129. py::setattr(m, "Tensor", tensor_type);
  1130. py::class_<TensorWeakRef>(m, "TensorWeakRef")
  1131. .def(py::init<const TensorWrapper&>())
  1132. .def("__call__", &TensorWeakRef::operator())
  1133. .def("_use_cnt", &TensorWeakRef::_use_cnt);
  1134. py::class_<PySymbolVar, std::shared_ptr<PySymbolVar>>(m, "SymbolVar")
  1135. .def_property_readonly(
  1136. "dtype", [](PySymbolVar* v) { return v->m_node->dtype(); })
  1137. .def_property(
  1138. "var", [](PySymbolVar* v) { return v->m_node; },
  1139. [](PySymbolVar* s, cg::VarNode* v) { s->m_node = v; })
  1140. .def_property_readonly(
  1141. "device", [](PySymbolVar* v) { return v->m_node->comp_node(); })
  1142. .def_property_readonly(
  1143. "graph", [](PySymbolVar* v) { return v->m_node->owner_graph(); })
  1144. .def_property_readonly(
  1145. "shape",
  1146. [](PySymbolVar* v) -> const TensorShape* {
  1147. auto&& mgr = v->m_node->owner_graph()->static_infer_manager();
  1148. return mgr.infer_shape_fallible(v->m_node);
  1149. })
  1150. .def("numpy",
  1151. [](PySymbolVar* v) {
  1152. auto&& mgr = v->m_node->owner_graph()->static_infer_manager();
  1153. auto&& type = mgr.get_infer_type(v->m_node);
  1154. using InferType = cg::static_infer::InferType;
  1155. if (!(type.value & (InferType::CONST | InferType::RT_STATIC))) {
  1156. throw py::value_error("value invalid!");
  1157. }
  1158. auto* val = mgr.infer_value_fallible(v->m_node);
  1159. if (!val) {
  1160. throw py::value_error("value invalid!");
  1161. }
  1162. auto np_val = py::cast(*val).attr("numpy")();
  1163. return np_val;
  1164. })
  1165. .def("_isscalar", [](PySymbolVar* v) { return v->is_scalar; })
  1166. .def(py::init([](cg::VarNode* node) {
  1167. return std::make_shared<PySymbolVar>(node);
  1168. }),
  1169. py::arg() = nullptr);
  1170. static PyMethodDef method_defs[] = {
  1171. MGE_PY_INTERFACE(apply, py_apply),
  1172. MGE_PY_INTERFACE(dtype_promotion, dtype_promotion),
  1173. MGE_PY_INTERFACE(get_device, get_device),
  1174. MGE_PY_INTERFACE(make_shape_tuple, make_shape_tuple),
  1175. MGE_PY_INTERFACE(getitem_cpp, getitem_cpp),
  1176. MGE_PY_INTERFACE(setitem_cpp, setitem_cpp),
  1177. {nullptr, nullptr, 0, nullptr}};
  1178. for (auto&& def : method_defs) {
  1179. if (def.ml_meth != nullptr) {
  1180. auto* func = PyCFunction_NewEx(&def, nullptr, nullptr);
  1181. if (!func)
  1182. throw py::error_already_set();
  1183. py::setattr(m, def.ml_name, func);
  1184. }
  1185. }
  1186. static constexpr auto sync_py_task_q = [] {
  1187. py::gil_scoped_release _;
  1188. py_task_q.wait_all_task_finish();
  1189. };
  1190. m.def("clear_candidates", [channel]() { channel->clear_candidates(); });
  1191. m.def("set_option", [channel](std::string name, size_t value) {
  1192. channel->set_option(name, value);
  1193. });
  1194. m.def("get_option",
  1195. [channel](std::string name) { return channel->get_option(name); });
  1196. m.def("push_scope", [channel](std::string name) {
  1197. Transformation::push_scope(name);
  1198. channel->push_scope(name);
  1199. });
  1200. m.def("pop_scope", [channel](std::string name) {
  1201. channel->pop_scope(name);
  1202. Transformation::pop_scope(name);
  1203. });
  1204. m.def("start_profile", [channel](imperative::Profiler::options_t options) {
  1205. channel->sync();
  1206. imperative::Profiler::load_options(std::move(options));
  1207. imperative::Profiler::start_profile();
  1208. channel->start_profile();
  1209. });
  1210. m.def("stop_profile", [channel]() -> std::function<void(std::string, std::string)> {
  1211. channel->stop_profile();
  1212. channel->sync();
  1213. imperative::Profiler::stop_profile();
  1214. auto results = std::make_shared<imperative::Profiler::bundle_t>(
  1215. imperative::Profiler::collect());
  1216. return [results = results](std::string basename, std::string format) mutable {
  1217. imperative::Profiler::dump_profile(basename, format, std::move(*results));
  1218. results = nullptr;
  1219. };
  1220. });
  1221. m.def("sync", [channel]() {
  1222. if (channel->check_available()) {
  1223. channel->sync();
  1224. }
  1225. sync_py_task_q();
  1226. });
  1227. m.def("full_sync", [channel]() {
  1228. if (channel->check_available()) {
  1229. channel->sync();
  1230. }
  1231. CompNode::sync_all();
  1232. CompNode::foreach ([](CompNode cn) {
  1233. auto err = cn.check_async_error();
  1234. mgb_assert(!err, "%s", err->what());
  1235. });
  1236. sync_py_task_q();
  1237. });
  1238. m.def("close", [channel]() {
  1239. channel->close();
  1240. sync_py_task_q();
  1241. });
  1242. py::handle grad_key_type =
  1243. GradKeyWrapper::wrap_t::type()
  1244. .def<&GradKeyWrapper::attach>("attach")
  1245. .def<&GradKeyWrapper::is_attached_to>("is_attached_to")
  1246. .def_getset<&GradKeyWrapper::get_name, &GradKeyWrapper::set_name>(
  1247. "name")
  1248. .def<&GradKeyWrapper::enter>("enter")
  1249. .def<&GradKeyWrapper::exit>("exit")
  1250. .def<&GradKeyWrapper::suppress>("suppress")
  1251. .def<&GradKeyWrapper::resume>("resume")
  1252. .finalize();
  1253. if (!grad_key_type)
  1254. throw py::error_already_set();
  1255. py::setattr(m, "GradKey", grad_key_type);
  1256. m.def("backward", &GradKeyWrapper::backward);
  1257. m.def("get_backward_closure", &GradKeyWrapper::get_backward_closure);
  1258. m.def("set_py_tensor_type", [](py::object type_obj) {
  1259. py_tensor_type = reinterpret_cast<PyTypeObject*>(type_obj.inc_ref().ptr());
  1260. });
  1261. /**
  1262. * \brief trace proxy
  1263. *
  1264. */
  1265. struct Trace {
  1266. bool symbolic = false;
  1267. bool no_exec = false;
  1268. bool capture_as_const = false;
  1269. bool profile = false;
  1270. bool record_input_shapes = false;
  1271. py::function options_visitor;
  1272. std::shared_ptr<TracingTransformation> tracing;
  1273. std::shared_ptr<CompiledTransformation> compiled;
  1274. std::shared_ptr<LazyEvalTransformation> lazy_eval;
  1275. std::pair<size_t, std::shared_ptr<GraphProfiler>> profiler;
  1276. std::optional<TraceResult> trace_result;
  1277. std::function<bool(py::object, py::object)> array_comparator;
  1278. bool compare_value(ValueRef lhs, ValueRef rhs) {
  1279. if (!lhs.shape()->eq(*rhs.shape())) {
  1280. return false;
  1281. }
  1282. HostTensorND lvalue = lhs.numpy()->as_nd(true);
  1283. HostTensorND rvalue = rhs.numpy()->as_nd(true);
  1284. auto larr = py::reinterpret_steal<py::array>(
  1285. npy::ndarray_from_tensor(lvalue, npy::ShareType::TRY_SHARE));
  1286. auto rarr = py::reinterpret_steal<py::array>(
  1287. npy::ndarray_from_tensor(rvalue, npy::ShareType::TRY_SHARE));
  1288. return array_comparator(larr, rarr);
  1289. }
  1290. void enter() {
  1291. auto& self = *this;
  1292. if (!self.trace_result) { // untraced
  1293. self.tracing = std::make_shared<TracingTransformation>(
  1294. self.capture_as_const, self.record_input_shapes);
  1295. if (self.symbolic) {
  1296. self.lazy_eval =
  1297. std::make_shared<LazyEvalTransformation>(self.no_exec);
  1298. self.options_visitor(py::cast(&self.lazy_eval->options()));
  1299. }
  1300. } else if (!self.compiled) { // traced but not compiled
  1301. using namespace std::placeholders;
  1302. self.compiled = std::make_shared<CompiledTransformation>(
  1303. *self.trace_result, self.record_input_shapes);
  1304. self.compiled->set_value_comparator(
  1305. std::bind(&Trace::compare_value, this, _1, _2));
  1306. self.options_visitor(py::cast(&self.compiled->options()));
  1307. self.compiled->compile();
  1308. }
  1309. // register transformations
  1310. if (self.compiled) {
  1311. if (self.profile) {
  1312. auto& current_graph = self.compiled->graph();
  1313. if (self.profiler.first != self.compiled->graph().id()) {
  1314. // graph changed
  1315. self.profiler = std::make_pair(
  1316. current_graph.id(),
  1317. std::make_shared<GraphProfiler>(&current_graph));
  1318. }
  1319. }
  1320. transformations.register_at<Segment::Trace>(self.compiled);
  1321. // start execute because InputCallback depends
  1322. self.compiled->execute();
  1323. } else if (self.tracing) {
  1324. transformations.register_at<Segment::Trace>(self.tracing);
  1325. if (self.lazy_eval) {
  1326. transformations.register_at<Segment::Eval>(self.lazy_eval);
  1327. }
  1328. } else {
  1329. mgb_throw(MegBrainError, "invalid state: neither tracing nor compiled");
  1330. }
  1331. }
  1332. void exit() {
  1333. auto& self = *this;
  1334. if (self.tracing) {
  1335. transformations.unregister<Segment::Trace>(self.tracing);
  1336. self.trace_result = self.tracing->get_result();
  1337. self.tracing.reset();
  1338. if (self.lazy_eval) {
  1339. auto lazy_eval = std::move(self.lazy_eval);
  1340. transformations.unregister<Segment::Eval>(lazy_eval);
  1341. lazy_eval->check_exception();
  1342. }
  1343. } else if (self.compiled) {
  1344. transformations.unregister<Segment::Trace>(self.compiled);
  1345. self.compiled->wait();
  1346. } else {
  1347. mgb_throw(MegBrainError, "invalid state: neither tracing nor compiled");
  1348. }
  1349. }
  1350. VarNodeArray dump(
  1351. std::shared_ptr<ComputingGraph> graph,
  1352. std::vector<std::tuple<std::string, std::string, TensorShape>> inputs,
  1353. std::vector<std::pair<std::string, std::string>> outputs,
  1354. bool prefer_input_names) {
  1355. auto& self = *this;
  1356. mgb_assert(self.trace_result);
  1357. // mark is like "arg_0", "kwarg_xxx", "output_0" ...
  1358. std::unordered_map<std::string, size_t> mark2var;
  1359. for (size_t i = 0; i < self.trace_result->vars.size(); ++i) {
  1360. auto& name = self.trace_result->vars[i].mark;
  1361. if (!name.empty()) {
  1362. mark2var[name] = i;
  1363. }
  1364. }
  1365. std::vector<std::tuple<size_t, std::string, TensorShape>> input_vars;
  1366. std::vector<std::pair<size_t, std::string>> output_vars;
  1367. for (auto&& [input_mark, input_name, input_shape] : inputs) {
  1368. mgb_assert(input_shape.ndim, "input shape invalid");
  1369. input_vars.push_back(
  1370. {mark2var.at(input_mark), input_name, input_shape});
  1371. }
  1372. for (auto&& [output_name, repr] : outputs) {
  1373. output_vars.push_back({mark2var.at(output_name), repr});
  1374. }
  1375. self.options_visitor(py::cast(&graph->options()));
  1376. auto vars = self.trace_result->dump(
  1377. *graph, input_vars, output_vars, prefer_input_names);
  1378. return vars;
  1379. }
  1380. };
  1381. py::class_<Trace>(m, "Trace")
  1382. .def(py::init<>())
  1383. .def_readwrite("record_input_shapes", &Trace::record_input_shapes)
  1384. .def_readwrite("array_comparator", &Trace::array_comparator)
  1385. .def_readwrite("profile", &Trace::profile)
  1386. .def_property_readonly(
  1387. "options",
  1388. [](Trace& self) {
  1389. if (self.compiled) {
  1390. return &self.compiled->options();
  1391. } else {
  1392. return (ComputingGraph::Options*)nullptr;
  1393. }
  1394. })
  1395. .def("get_profile",
  1396. [](Trace& self) -> py::object {
  1397. if (self.profiler.second && self.compiled) {
  1398. auto json = self.profiler.second->to_json_full(
  1399. self.compiled->graph().current_comp_seq());
  1400. return py::str(json->to_string());
  1401. } else {
  1402. return py::none();
  1403. }
  1404. })
  1405. .def_readwrite("symbolic", &Trace::symbolic)
  1406. .def_readwrite("capture_as_const", &Trace::capture_as_const)
  1407. .def_readwrite("no_exec", &Trace::no_exec)
  1408. .def_readwrite("options_visitor", &Trace::options_visitor)
  1409. .def("enter", &Trace::enter)
  1410. .def("exit", &Trace::exit)
  1411. .def("dump", &Trace::dump)
  1412. .def("begin_excluded_region",
  1413. [](Trace& self) {
  1414. mgb_assert(bool(self.tracing) ^ bool(self.compiled));
  1415. if (self.tracing) {
  1416. transformations.unregister<Segment::Trace>(self.tracing);
  1417. } else if (self.compiled) {
  1418. transformations.unregister<Segment::Trace>(self.compiled);
  1419. }
  1420. })
  1421. .def("end_excluded_region", [](Trace& self) {
  1422. mgb_assert(bool(self.tracing) ^ bool(self.compiled));
  1423. if (self.tracing) {
  1424. transformations.register_at<Segment::Trace>(self.tracing);
  1425. } else if (self.compiled) {
  1426. transformations.register_at<Segment::Trace>(self.compiled);
  1427. }
  1428. });
  1429. m.def("name_tensor", [](std::string name, py::object tensor) {
  1430. auto* tw = TensorWrapper::try_cast(tensor.ptr());
  1431. auto output = imperative::apply(TraceMarkVar(name), tw->m_tensor->data())[0];
  1432. tw->m_tensor->reset(output);
  1433. });
  1434. m.def("is_grad_attached", [](std::vector<py::object> tensors) -> bool {
  1435. SmallVector<ValueRef> values;
  1436. for (auto&& tensor : tensors) {
  1437. values.push_back(tensor.cast<TensorWrapper>().m_tensor->data());
  1438. }
  1439. auto outputs = imperative::apply(GetGradKey(), values);
  1440. if (outputs[0].is<GradKeyValue>()) {
  1441. return true;
  1442. } else {
  1443. return false;
  1444. }
  1445. });
  1446. m.def("get_grad_key", [](std::vector<py::object> tensors) -> py::object {
  1447. SmallVector<ValueRef> values;
  1448. for (auto&& tensor : tensors) {
  1449. values.push_back(tensor.cast<TensorWrapper>().m_tensor->data());
  1450. }
  1451. auto outputs = imperative::apply(GetGradKey(), values);
  1452. if (auto* grad_key_val = outputs[0].as<GradKeyValue>()) {
  1453. return py::reinterpret_borrow<py::object>(
  1454. GradKeyWrapper::wrap_t::pycast(GradKeyWrapper::get(*grad_key_val)));
  1455. } else {
  1456. return py::none();
  1457. }
  1458. });
  1459. m.def("set_grad", [](py::object py_key, py::function backward_fn,
  1460. std::vector<py::object> inputs,
  1461. std::vector<py::object> outputs) {
  1462. mgb_assert(GradKeyWrapper::wrap_t::type().isinstance(py_key.ptr()));
  1463. auto* key = reinterpret_cast<GradKeyWrapper::wrap_t*>(py_key.ptr())->inst();
  1464. GenericFunction generic_backward_fn =
  1465. [backward_fn](Span<ValueRef> output_grads) -> std::vector<ValueRef> {
  1466. py::list output_grad_tws;
  1467. for (auto&& output_grad : output_grads) {
  1468. if (output_grad) {
  1469. output_grad_tws.append(
  1470. TensorWrapper::make(py_tensor_type, output_grad));
  1471. } else {
  1472. output_grad_tws.append(py::none());
  1473. }
  1474. }
  1475. py::tuple input_grad_tws = backward_fn(*output_grad_tws);
  1476. std::vector<ValueRef> input_grads;
  1477. for (auto&& input_grad_tw : input_grad_tws) {
  1478. if (!input_grad_tw.is_none()) {
  1479. input_grads.push_back(
  1480. py::cast<TensorWrapper>(input_grad_tw).m_tensor->data());
  1481. } else {
  1482. input_grads.push_back({});
  1483. }
  1484. }
  1485. return input_grads;
  1486. };
  1487. SmallVector<ValueRef> values;
  1488. for (auto&& input : inputs) {
  1489. values.push_back(input.cast<TensorWrapper>().m_tensor->data());
  1490. }
  1491. for (auto&& output : outputs) {
  1492. values.push_back(output.cast<TensorWrapper>().m_tensor->data());
  1493. }
  1494. auto wrapped_output_values = imperative::apply(
  1495. SetGrad(key->m_key, generic_backward_fn, inputs.size()), values);
  1496. std::vector<py::object> wrapped_outputs;
  1497. mgb_assert(wrapped_output_values.size() == outputs.size());
  1498. for (auto&& output_value : wrapped_output_values) {
  1499. wrapped_outputs.push_back(
  1500. TensorWrapper::make(py_tensor_type, output_value));
  1501. }
  1502. return wrapped_outputs;
  1503. });
  1504. static py::function module_trace_hook;
  1505. static auto get_module_trace = [] {
  1506. static std::shared_ptr<ModuleTraceTransformation> module_trace_transformation;
  1507. if (!module_trace_transformation) {
  1508. mgb_assert(module_trace_hook);
  1509. module_trace_transformation =
  1510. std::make_shared<ModuleTraceTransformation>(module_trace_hook);
  1511. transformations.register_at<Segment::ModuleTrace>(
  1512. module_trace_transformation);
  1513. }
  1514. return module_trace_transformation;
  1515. };
  1516. m.def("set_cpp_use_symbolic_shape", &set_cpp_use_symbolic_shape);
  1517. m.def("set_cpp_astensor1d", &set_cpp_astensor1d);
  1518. m.def("set_module_tracing", [=] { get_module_trace()->enable(); });
  1519. m.def("unset_module_tracing", [=] { get_module_trace()->disable(); });
  1520. m.def("is_tracing_module", [=] { return get_module_trace()->enabled(); });
  1521. m.def("set_module_trace_hook",
  1522. [](py::function function) { module_trace_hook = function; });
  1523. m.def("begin_record_values", [] { Value::begin_record_values(); });
  1524. m.def("end_record_values", [] {
  1525. std::vector<std::pair<size_t, std::string>> reprs;
  1526. auto values = Value::end_record_values();
  1527. for (auto&& value : values) {
  1528. reprs.push_back({value.id(), value.to_string()});
  1529. }
  1530. return reprs;
  1531. });
  1532. py::register_exception<TraceError>(m, "TraceError");
  1533. }
  1534. #undef MGE_PY_INTERFACE
  1535. } // namespace mgb::imperative::python