You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

tensor_utils.cpp 61 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710
  1. /**
  2. * \file imperative/python/src/tensor.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "megbrain/common.h"
  12. #include "megbrain/dtype.h"
  13. #include "megbrain/imperative/ops/autogen.h"
  14. #include "megbrain/imperative/ops/backward_graph.h"
  15. #include "megbrain/imperative/ops/utility.h"
  16. #include "megbrain/imperative/profiler.h"
  17. #include "megbrain/imperative/transformations/eval.h"
  18. #include "megbrain/imperative/transformations/lazy.h"
  19. #include "megbrain/imperative/transformations/scalar.h"
  20. #include "megbrain/imperative/transformations/symbol.h"
  21. #include "megbrain/imperative/transformations/trace.h"
  22. #include "megbrain/imperative/utils/map.h"
  23. #include "megbrain/opr/io.h"
  24. #include "megbrain/plugin/profiler.h"
  25. #include "./common.h"
  26. #include "./grad.h"
  27. #include "./graph_rt.h"
  28. #include "./helper.h"
  29. #include "./module_trace.h"
  30. #include "./numpy_dtypes.h"
  31. #include "./tensor.h"
  32. #include "./tensor_utils.h"
  33. #include "./transformation.h"
  34. #include <object.h>
  35. #include <pybind11/numpy.h>
  36. #include <pybind11/operators.h>
  37. #include <pybind11/pytypes.h>
  38. #include <pyerrors.h>
  39. #include <range/v3/all.hpp>
  40. #include <string>
  41. #include <unordered_map>
  42. #include "../../src/impl/mgb_cg_impl.h"
  43. namespace py = pybind11;
  44. namespace views = ranges::views;
  45. namespace mgb::imperative::python {
  46. /* ============== convert inputs ============== */
  47. // map numpy.dtype.kind to priority
  48. inline uint8_t category_priority(char c) {
  49. switch (c) {
  50. case 'f':
  51. return 3; // floating-point
  52. case 'i':
  53. return 2; // signed integer
  54. case 'u':
  55. return 2; // unsigned integer
  56. case 'b':
  57. return 1; // boolean
  58. default:
  59. return 0;
  60. }
  61. }
  62. // Returns the maximum value of the priority of each type in the list `types`.
  63. uint8_t max_priority(SmallVector<PyArray_Descr*> types) {
  64. if (types.size() == 0) {
  65. return 0;
  66. } else {
  67. uint8_t max_p = 0;
  68. for (auto&& desc : types) {
  69. max_p = std::max(max_p, category_priority(desc->kind));
  70. }
  71. return max_p;
  72. }
  73. }
  74. // Returns the data type with sufficient size to hold all types of
  75. // category `cat` in the list `types`.
  76. PyArray_Descr* promote_types(SmallVector<PyArray_Descr*> types, uint8_t cat) {
  77. // Return value: New reference
  78. SmallVector<PyArray_Descr*> used_types;
  79. for (auto&& desc : types) {
  80. auto&& v = category_priority(desc->kind);
  81. if (v == cat) {
  82. used_types.emplace_back(desc);
  83. }
  84. }
  85. mgb_assert(used_types.size() > 0, "size of used_types is 0");
  86. PyArray_Descr* res = used_types[0];
  87. Py_INCREF(res);
  88. for (size_t i = 1; i < used_types.size(); ++i) {
  89. PyArray_Descr* tmp = PyArray_PromoteTypes(used_types[i], res);
  90. Py_DECREF(res);
  91. res = tmp;
  92. }
  93. return res;
  94. }
  95. PyArray_Descr* scalar2dtype(PyObject* arg) {
  96. // Return value: New reference
  97. if (PyBool_Check(arg)) {
  98. auto&& descr = PyArray_DescrFromType(NPY_BOOL);
  99. return descr;
  100. }
  101. if (PyLong_CheckExact(arg)) {
  102. auto&& descr = PyArray_DescrFromType(NPY_INT32);
  103. return descr;
  104. }
  105. if (PyFloat_CheckExact(arg)) {
  106. auto&& descr = PyArray_DescrFromType(NPY_FLOAT32);
  107. return descr;
  108. }
  109. return nullptr;
  110. }
  111. PyArray_Descr* _dtype_promotion(PyObject* const* args, size_t nargs) {
  112. // Return value: New reference
  113. SmallVector<PyArray_Descr*> tensors;
  114. SmallVector<PyArray_Descr*> scalars;
  115. bool is_tuple = false;
  116. PyObject* tuple = nullptr;
  117. if (nargs == 1 && (PyTuple_Check(args[0]) || PyList_Check(args[0]))) {
  118. if (PyList_Check(args[0])) {
  119. tuple = PyList_AsTuple(args[0]);
  120. } else {
  121. tuple = args[0];
  122. Py_INCREF(tuple);
  123. }
  124. nargs = PyTuple_Size(tuple);
  125. is_tuple = true;
  126. }
  127. for (size_t i = 0; i < nargs; ++i) {
  128. PyObject* handle = is_tuple ? PyTuple_GetItem(tuple, i) : args[i];
  129. if (handle == Py_None)
  130. continue;
  131. TensorWrapper* tw = TensorWrapper::try_cast(handle);
  132. if (tw) {
  133. mgb::DType type = tw->m_tensor->dtype();
  134. auto&& descr = npy::dtype_mgb2np_descr(type);
  135. Py_INCREF(descr.get());
  136. tensors.emplace_back(descr.get());
  137. } else {
  138. if (PyArray_Check(handle) || PyArray_CheckScalar(handle)) {
  139. auto&& descr = PyArray_DescrFromObject(handle, nullptr);
  140. tensors.emplace_back(descr);
  141. continue;
  142. }
  143. if (py::isinstance<PySymbolVar>(py::handle(handle))) {
  144. auto var = py::handle(handle).cast<PySymbolVar*>();
  145. mgb::DType type = var->m_node->dtype();
  146. auto&& descr = npy::dtype_mgb2np_descr(type);
  147. Py_INCREF(descr.get());
  148. tensors.emplace_back(descr.get());
  149. continue;
  150. }
  151. PyArray_Descr* descr = scalar2dtype(handle);
  152. if (descr) {
  153. scalars.emplace_back(descr);
  154. continue;
  155. }
  156. }
  157. }
  158. auto max_pri_scalars = max_priority(scalars);
  159. auto max_pri_tensors = max_priority(tensors);
  160. if (max_pri_scalars <= 0 && max_pri_tensors <= 0) {
  161. throw py::value_error("invalid input, no dtype avaliable");
  162. }
  163. PyArray_Descr* res;
  164. if (max_pri_scalars > max_pri_tensors) {
  165. res = promote_types(scalars, max_pri_scalars);
  166. } else {
  167. res = promote_types(tensors, max_pri_tensors);
  168. }
  169. for (auto* p : tensors) {
  170. Py_DECREF(p);
  171. }
  172. for (auto* p : scalars) {
  173. Py_DECREF(p);
  174. }
  175. Py_XDECREF(tuple);
  176. return res;
  177. }
  178. CompNode _get_device(PyObject* const* args, size_t nargs) {
  179. bool is_tuple = false;
  180. PyObject* tuple = nullptr;
  181. if (nargs == 1 && (PyTuple_Check(args[0]) || PyList_Check(args[0]))) {
  182. if (PyList_Check(args[0])) {
  183. tuple = PyList_AsTuple(args[0]);
  184. } else {
  185. tuple = args[0];
  186. Py_INCREF(tuple);
  187. }
  188. nargs = PyTuple_Size(tuple);
  189. is_tuple = true;
  190. }
  191. bool valid = false;
  192. CompNode cn;
  193. for (size_t i = 0; i < nargs; ++i) {
  194. PyObject* handle = is_tuple ? PyTuple_GetItem(tuple, i) : args[i];
  195. TensorWrapper* tw = TensorWrapper::try_cast(handle);
  196. bool is_symvar = py::isinstance<PySymbolVar>(py::handle(handle));
  197. if (tw || is_symvar) {
  198. if (!valid) {
  199. cn = tw ? tw->m_tensor->comp_node()
  200. : py::handle(handle).cast<PySymbolVar*>()->m_node->comp_node();
  201. valid = true;
  202. } else {
  203. CompNode cn1 = tw ? tw->m_tensor->comp_node()
  204. : py::handle(handle)
  205. .cast<PySymbolVar*>()
  206. ->m_node->comp_node();
  207. if (cn1 != cn) {
  208. throw py::value_error(ssprintf(
  209. "ambiguous device: %s (from %s) vs %s (from %s)",
  210. cn.to_string().c_str(), cn.to_string_logical().c_str(),
  211. cn1.to_string().c_str(), cn1.to_string_logical().c_str()));
  212. }
  213. }
  214. }
  215. }
  216. if (!valid) {
  217. return CompNode::load(get_default_device());
  218. }
  219. Py_XDECREF(tuple);
  220. return cn;
  221. }
  222. // Returns the dtype that would result from performing an arithmetic
  223. // operation on the provided input tensors and scalars.
  224. PyObject* dtype_promotion(PyObject* self, PyObject* const* args, size_t nargs) {
  225. if (!nargs) {
  226. PyErr_SetString(PyExc_TypeError, "empty input is not allowed");
  227. return nullptr;
  228. }
  229. try {
  230. PyArray_Descr* res = _dtype_promotion(args, nargs);
  231. return py::cast(npy::dtype_np2mgb_descr(res)).release().ptr();
  232. }
  233. PYEXT17_TRANSLATE_EXC_RET(nullptr)
  234. }
  235. PyObject* get_device(PyObject* self, PyObject* const* args, size_t nargs) {
  236. if (!nargs) {
  237. PyErr_SetString(PyExc_TypeError, "empty input is not allowed");
  238. return nullptr;
  239. }
  240. try {
  241. CompNode cn = _get_device(args, nargs);
  242. return py::cast(cn).release().ptr();
  243. }
  244. PYEXT17_TRANSLATE_EXC_RET(nullptr)
  245. }
  246. bool is_scalar(PyObject* tensor) {
  247. if (py::isinstance<PySymbolVar>(py::handle(tensor))) {
  248. auto var = py::handle(tensor).cast<PySymbolVar*>();
  249. return var->is_scalar;
  250. }
  251. auto* tw = TensorWrapper::try_cast(tensor);
  252. if (tw) {
  253. return tw->m_tensor->is_scalar();
  254. }
  255. return PyArray_CheckAnyScalar(tensor);
  256. }
  257. bool is_bool_list(PyObject* arg) {
  258. if (!PyList_Check(arg)) {
  259. return false;
  260. }
  261. size_t sz = PyList_Size(arg);
  262. if (!sz) {
  263. return false;
  264. }
  265. for (size_t i = 0; i < sz; ++i) {
  266. PyObject* handle = PyList_GetItem(arg, i);
  267. if (!PyBool_Check(handle)) {
  268. return false;
  269. }
  270. }
  271. return true;
  272. }
  273. bool is_bool_dtype(PyObject* args) {
  274. if (!PyObject_HasAttrString(args, "dtype"))
  275. return false;
  276. PyObject* dobj = PyObject_GetAttrString(args, "dtype");
  277. PyArray_Descr* dtype;
  278. PyArray_DescrConverter(dobj, &dtype);
  279. bool ret = (dtype->kind == 'b');
  280. Py_XDECREF(dtype);
  281. Py_XDECREF(dobj);
  282. return ret;
  283. }
  284. py::object device2obj(py::handle device, bool mapping = false) {
  285. if (device.ptr() == Py_None) {
  286. return py::cast(CompNode::load(get_default_device()));
  287. } else if (py::isinstance<py::str>(device)) {
  288. if (mapping) {
  289. py::object dmap = getattr(
  290. py::reinterpret_borrow<py::object>((PyObject*)py_tensor_type),
  291. "dmap_callback");
  292. if (dmap.ptr() != Py_None) {
  293. return device2obj(dmap(device), false);
  294. }
  295. }
  296. return py::cast(CompNode::load(device.cast<std::string>()));
  297. } else if (py::isinstance<CompNode>(device)) {
  298. return py::reinterpret_borrow<py::object>(device);
  299. } else {
  300. return getattr(device, "_cn");
  301. }
  302. }
  303. py::object _Const(
  304. py::handle value, py::handle dtype, py::handle device, py::handle ref_hdl) {
  305. py::object val = py::reinterpret_borrow<py::object>(value);
  306. if (PyArray_Check(value.ptr())) {
  307. py::tuple strides =
  308. py::reinterpret_borrow<py::tuple>(getattr(value, "strides"));
  309. bool need_squeeze = false;
  310. for (size_t i = 0; i < strides.size(); ++i) {
  311. if (strides[i].cast<ptrdiff_t>() == 0) {
  312. need_squeeze = true;
  313. }
  314. }
  315. if (need_squeeze) {
  316. val = py::reinterpret_borrow<py::array>(value);
  317. py::object orig_shp = val.attr("shape");
  318. val = val.attr("squeeze")();
  319. val = val.attr("reshape")(orig_shp);
  320. }
  321. }
  322. py::object ref;
  323. if (py::isinstance<py::tuple>(ref_hdl)) {
  324. py::tuple tup = py::reinterpret_borrow<py::tuple>(ref_hdl);
  325. if (tup.size()) {
  326. ref = tup[0];
  327. } else {
  328. ref = py::none();
  329. }
  330. } else {
  331. ref = py::reinterpret_borrow<py::object>(ref_hdl);
  332. }
  333. if (py::isinstance<PySymbolVar>(ref)) {
  334. auto ref_var = ref.cast<PySymbolVar*>();
  335. auto* graph = ref_var->m_node->owner_graph();
  336. CompNode cn;
  337. if (device.ptr() == Py_None) {
  338. cn = ref_var->m_node->comp_node();
  339. } else {
  340. cn = device2obj(device).cast<CompNode>();
  341. }
  342. OperatorNodeConfig config(cn);
  343. auto hv = npy::np2tensor(
  344. val.ptr(), npy::Meth::borrow(cn), dtype.cast<mgb::DType>());
  345. auto typeobj = ref.get_type();
  346. return typeobj(opr::ImmutableTensor::make(*graph, hv, config).node());
  347. }
  348. py::object device_obj = device2obj(device, true);
  349. py::tuple tup = py::make_tuple(val, dtype, device_obj, true, false, py::none());
  350. return TensorWrapper::make(py_tensor_type, tup.ptr(), nullptr);
  351. }
  352. py::tuple _make_shape_tuple(py::handle shape) {
  353. py::list orig;
  354. py::list ret(0);
  355. auto solve_one = [&](py::handle val) {
  356. if (TensorWrapper::try_cast(val.ptr()) || py::isinstance<PySymbolVar>(val)) {
  357. py::object np = getattr(val, "numpy")();
  358. PyArrayObject* arr = (PyArrayObject*)np.ptr();
  359. PyObject* maybe_list = PyArray_ToList(arr);
  360. if (PyList_Check(maybe_list)) {
  361. py::list may = py::reinterpret_steal<py::list>(maybe_list);
  362. for (size_t i = 0; i < may.size(); ++i) {
  363. ret.append(may[i]);
  364. }
  365. } else {
  366. mgb_assert(PyLong_Check(maybe_list));
  367. ret.append(PyLong_AsLong(maybe_list));
  368. Py_XDECREF(maybe_list);
  369. }
  370. } else if (PyArray_Check(val.ptr())) {
  371. ret.append(PyArray_PyIntAsInt(val.ptr()));
  372. } else {
  373. ret.append(PyLong_AsLong(val.ptr()));
  374. }
  375. };
  376. if (PyArray_Check(shape.ptr()) && !PyArray_CheckAnyScalar(shape.ptr())) {
  377. orig = py::reinterpret_steal<py::list>(
  378. PyArray_ToList((PyArrayObject*)shape.ptr()));
  379. for (size_t i = 0; i < orig.size(); ++i) {
  380. solve_one(orig[i]);
  381. }
  382. } else if (PyList_Check(shape.ptr())) {
  383. orig = py::reinterpret_borrow<py::list>(shape);
  384. for (size_t i = 0; i < orig.size(); ++i) {
  385. solve_one(orig[i]);
  386. }
  387. } else if (PyTuple_Check(shape.ptr())) {
  388. py::tuple tup = py::reinterpret_borrow<py::tuple>(shape);
  389. for (size_t i = 0; i < tup.size(); ++i) {
  390. solve_one(tup[i]);
  391. }
  392. } else {
  393. solve_one(shape);
  394. }
  395. return py::reinterpret_steal<py::tuple>(PyList_AsTuple(ret.ptr()));
  396. }
  397. bool is_tensor_or_symbolvar(py::handle arg) {
  398. return bool(TensorWrapper::try_cast(arg.ptr())) || py::isinstance<PySymbolVar>(arg);
  399. }
  400. bool is_py_sequence(py::handle arg) {
  401. if (PyArray_Check(arg.ptr()) || TensorWrapper::try_cast(arg.ptr()) ||
  402. py::isinstance<PySymbolVar>(arg)) {
  403. return false;
  404. }
  405. return PySequence_Check(arg.ptr());
  406. }
  407. mgb::DType _get_dtype(py::handle tensor) {
  408. if (auto tw = TensorWrapper::try_cast(tensor.ptr())) {
  409. return tw->m_tensor->dtype();
  410. } else {
  411. auto var = tensor.cast<PySymbolVar*>();
  412. return var->m_node->dtype();
  413. }
  414. }
  415. py::object _astype_cpp(py::handle tensor, py::handle dtype_hdl) {
  416. PyArray_Descr* descr;
  417. if (!PyArray_DescrConverter(dtype_hdl.ptr(), &descr)) {
  418. throw py::value_error(ssprintf(
  419. "can not convert to numpy.dtype from %s",
  420. dtype_hdl.ptr()->ob_type->tp_name));
  421. }
  422. PyArray_Descr* cur = npy::dtype_mgb2np_descr(_get_dtype(tensor)).get();
  423. if (!dtype_equal(cur, descr)) {
  424. std::shared_ptr<OpDef> op = TypeCvt::make(npy::dtype_np2mgb_descr(descr));
  425. py::object Op = py::cast(op);
  426. PyObject* p[2] = {Op.ptr(), tensor.ptr()};
  427. py::tuple ret = py::reinterpret_steal<py::object>(py_apply(NULL, p, 2));
  428. return ret[0];
  429. } else {
  430. return py::reinterpret_borrow<py::object>(tensor);
  431. }
  432. }
  433. py::object _convert_single_value_cpp(
  434. py::handle value, py::handle dtype, py::handle device) {
  435. if (is_tensor_or_symbolvar(value)) {
  436. if (_get_dtype(value).category() != DTypeCategory::QUANTIZED) {
  437. return _astype_cpp(value, dtype);
  438. }
  439. } else {
  440. return _Const(value, dtype, device, py::none());
  441. }
  442. return py::reinterpret_borrow<py::object>(value);
  443. }
  444. py::object _convert_inputs_cpp(
  445. PyObject* const* args, size_t nargs, py::object dtype, py::object device) {
  446. ComputingGraph* graph = nullptr;
  447. py::handle typeobj;
  448. py::list lis;
  449. for (size_t i = 0; i < nargs; ++i) {
  450. py::handle h = py::handle(args[i]);
  451. lis.append(h);
  452. if (py::isinstance<PySymbolVar>(h)) {
  453. auto var = h.cast<PySymbolVar*>();
  454. auto g = var->m_node->owner_graph();
  455. if (!graph) {
  456. graph = g;
  457. typeobj = h.get_type();
  458. } else {
  459. mgb_assert(graph == g);
  460. }
  461. }
  462. }
  463. if (graph) {
  464. CompNode cn = device2obj(device).cast<CompNode>();
  465. for (size_t i = 0; i < nargs; ++i) {
  466. OperatorNodeConfig config(cn);
  467. auto hv = npy::np2tensor(
  468. lis[i].ptr(), npy::Meth::borrow(cn), dtype.cast<mgb::DType>());
  469. if (!py::isinstance<PySymbolVar>(lis[i])) {
  470. lis[i] = typeobj(opr::ImmutableTensor::make(*graph, hv, config).node());
  471. }
  472. }
  473. }
  474. auto convert = [&](py::object value) {
  475. if (value.is_none()) {
  476. return value;
  477. }
  478. return _convert_single_value_cpp(value, dtype, device);
  479. };
  480. for (size_t i = 0; i < lis.size(); ++i) {
  481. lis[i] = convert(lis[i]);
  482. }
  483. return py::reinterpret_steal<py::tuple>(PyList_AsTuple(lis.ptr()));
  484. }
  485. py::object _astensor1d_cpp(
  486. py::handle value, py::handle dtype, py::handle device, py::handle ref) {
  487. py::object ret;
  488. py::object device_obj = py::none();
  489. py::object ndim_obj = py::none();
  490. if (device.ptr() != Py_None) {
  491. device_obj = device2obj(device);
  492. }
  493. if (py::isinstance<PySymbolVar>(value)) {
  494. try {
  495. getattr(value, "ndim");
  496. } catch (py::error_already_set& err) {
  497. if (dtype.ptr() != Py_None) {
  498. ret = _astype_cpp(value, dtype);
  499. } else {
  500. ret = py::reinterpret_borrow<py::object>(value);
  501. }
  502. if (device.ptr() != Py_None) {
  503. std::shared_ptr<OpDef> op = Copy::make(device_obj.cast<CompNode>());
  504. py::object Op = py::cast(op);
  505. PyObject* p[2] = {Op.ptr(), ret.ptr()};
  506. py::tuple copy_ret =
  507. py::reinterpret_steal<py::object>(py_apply(NULL, p, 2));
  508. return copy_ret[0];
  509. }
  510. return ret;
  511. }
  512. }
  513. size_t ndim = 999;
  514. if (hasattr(value, "ndim")) {
  515. ndim = getattr(value, "ndim").cast<size_t>();
  516. if (ndim != 0 && ndim != 1) {
  517. throw py::value_error("ndim != 1 or 0, get : " + std::to_string(ndim));
  518. }
  519. if (!is_tensor_or_symbolvar(value)) {
  520. return _Const(value, dtype, device, ref);
  521. } else {
  522. return py::reinterpret_borrow<py::object>(value);
  523. }
  524. }
  525. if (!is_py_sequence(value)) {
  526. throw py::type_error();
  527. }
  528. py::list lis = py::reinterpret_steal<py::list>(PySequence_List(value.ptr()));
  529. bool need_concat = false;
  530. for (size_t i = 0; i < lis.size(); ++i) {
  531. if (is_tensor_or_symbolvar(lis[i])) {
  532. need_concat = true;
  533. break;
  534. }
  535. }
  536. if (!need_concat) {
  537. return _Const(value, dtype, device, ref);
  538. }
  539. if (lis.size() > 1) {
  540. std::vector<PyObject*> c_args(lis.size() + 1);
  541. for (size_t i = 0; i < lis.size(); ++i) {
  542. c_args[i] = lis[i].ptr();
  543. }
  544. c_args[lis.size()] = Py_None;
  545. py::tuple inp_tup = py::reinterpret_steal<py::tuple>(
  546. convert_inputs_cpp(NULL, c_args.data(), c_args.size()));
  547. if (device_obj.is_none()) {
  548. std::vector<PyObject*> inp(inp_tup.size());
  549. for (size_t i = 0; i < inp_tup.size(); ++i) {
  550. inp[i] = inp_tup[i].ptr();
  551. }
  552. device_obj = py::cast(_get_device(inp.data(), inp.size()));
  553. }
  554. std::shared_ptr<OpDef> op = Concat::make(0, device_obj.cast<CompNode>());
  555. py::object Op = py::cast(op);
  556. std::vector<PyObject*> p;
  557. p.resize(inp_tup.size() + 1);
  558. p[0] = Op.ptr();
  559. for (size_t i = 0; i < inp_tup.size(); ++i) {
  560. p[i + 1] = inp_tup[i].ptr();
  561. }
  562. py::tuple concat_ret =
  563. py::reinterpret_steal<py::object>(py_apply(NULL, p.data(), p.size()));
  564. ret = concat_ret[0];
  565. } else {
  566. ret = lis[0];
  567. }
  568. if (dtype.ptr() != Py_None) {
  569. return _astype_cpp(ret, dtype);
  570. } else {
  571. return ret;
  572. }
  573. }
  574. py::object _get_index(py::object tensor, py::object src) {
  575. if (!TensorWrapper::try_cast(tensor.ptr()) &&
  576. !py::isinstance<PySymbolVar>(tensor)) {
  577. auto get_const = [&](mgb::DType dtype) -> py::object {
  578. return _Const(tensor, py::cast(dtype), src.attr("device"), src);
  579. };
  580. if (is_bool_list(tensor.ptr()) || is_bool_dtype(tensor.ptr())) {
  581. tensor = get_const(dtype::Bool());
  582. } else {
  583. tensor = get_const(dtype::Int32());
  584. }
  585. if (!is_bool_dtype(tensor.ptr())) {
  586. return tensor;
  587. }
  588. } else {
  589. if (!is_bool_dtype(tensor.ptr())) {
  590. return tensor;
  591. }
  592. }
  593. std::shared_ptr<OpDef> op = CondTake::make();
  594. py::object Op = py::cast(op);
  595. PyObject* p[3] = {Op.ptr(), tensor.ptr(), tensor.ptr()};
  596. py::tuple ret = py::reinterpret_steal<py::object>(py_apply(NULL, p, 3));
  597. return ret[1];
  598. }
  599. py::tuple _try_cond_take(py::handle tensor, py::handle index) {
  600. if (!hasattr(index, "dtype") || !hasattr(index, "shape")) {
  601. return py::tuple();
  602. }
  603. if (!is_bool_dtype(index.ptr()) ||
  604. _make_shape_tuple(getattr(index, "shape"))
  605. .not_equal(_make_shape_tuple(getattr(tensor, "shape")))) {
  606. return py::tuple();
  607. }
  608. py::object iobj;
  609. if (PyArray_Check(index.ptr())) {
  610. iobj =
  611. _Const(index, py::cast((mgb::DType)dtype::Bool()),
  612. getattr(tensor, "device"), tensor);
  613. } else {
  614. iobj = py::reinterpret_borrow<py::object>(index);
  615. }
  616. std::shared_ptr<OpDef> op = CondTake::make();
  617. py::object Op = py::cast(op);
  618. PyObject* p[3] = {Op.ptr(), tensor.ptr(), iobj.ptr()};
  619. py::tuple ret = py::reinterpret_steal<py::object>(py_apply(NULL, p, 3));
  620. return ret;
  621. }
  622. py::tuple _remove_ellipsis(py::object tensor, py::tuple tuple_val) {
  623. size_t tuple_size = tuple_val.size();
  624. size_t ndim_sum = 0, cur_sum = 0;
  625. int pos = -1;
  626. bool has_unknown_ndim_bool_index = false;
  627. for (size_t i = 0; i < tuple_size; ++i) {
  628. py::object handle = tuple_val[i];
  629. if (handle.is_none()) {
  630. continue;
  631. } else if (handle.ptr() == Py_Ellipsis) {
  632. pos = static_cast<int>(i);
  633. for (size_t j = 0; j < i; ++j) {
  634. py::object t = tuple_val[j];
  635. if (t.ptr() == Py_Ellipsis) {
  636. throw py::index_error("only one ellipsis is allowed.");
  637. }
  638. }
  639. } else {
  640. size_t ndim_incr = 1;
  641. if (hasattr(handle, "dtype") && is_bool_dtype(handle.ptr()) &&
  642. hasattr(handle, "ndim")) {
  643. py::object ndim;
  644. try {
  645. ndim = getattr(handle, "ndim");
  646. } catch (py::error_already_set& err) {
  647. has_unknown_ndim_bool_index = true;
  648. }
  649. if (PyLong_Check(ndim.ptr())) {
  650. ndim_incr = PyLong_AsLong(ndim.ptr());
  651. } else {
  652. has_unknown_ndim_bool_index = true;
  653. }
  654. }
  655. cur_sum += ndim_incr;
  656. }
  657. }
  658. if (pos == -1) {
  659. return tuple_val;
  660. } else {
  661. if (has_unknown_ndim_bool_index) {
  662. throw py::index_error(
  663. "does not support bool index with unknown shape when using "
  664. "Ellipsis.");
  665. }
  666. try {
  667. ndim_sum = getattr(tensor, "ndim").cast<size_t>();
  668. } catch (py::error_already_set& err) {
  669. throw py::index_error(
  670. "does not support Ellipsis when tensor's ndim is unknown.");
  671. }
  672. py::tuple ret(ndim_sum - cur_sum + tuple_size - 1);
  673. size_t idx = 0;
  674. for (size_t i = 0; i < tuple_size; ++i) {
  675. if (i == pos) {
  676. for (size_t j = cur_sum; j < ndim_sum; ++j) {
  677. ret[idx++] = PySlice_New(NULL, NULL, NULL);
  678. }
  679. } else {
  680. ret[idx++] = tuple_val[i];
  681. }
  682. }
  683. return ret;
  684. }
  685. }
  686. py::object _reshape_cpp(py::handle inp_hdl, py::handle args);
  687. py::tuple _expand_bool_dim(py::object tensor, py::tuple tuple_val) {
  688. py::tuple cur_shape = _make_shape_tuple(py::handle(getattr(tensor, "shape")));
  689. py::list new_tuple_val(0);
  690. size_t offset = 0;
  691. size_t tdim = 0;
  692. size_t nonedim = 0;
  693. for (size_t i = 0; i < tuple_val.size(); ++i) {
  694. py::handle k = tuple_val[i];
  695. if (k.ptr() == Py_None) {
  696. nonedim++;
  697. new_tuple_val.append(k);
  698. continue;
  699. }
  700. if (is_bool_dtype(k.ptr())) {
  701. size_t ndim = getattr(k, "ndim").cast<size_t>();
  702. if (ndim > 1) {
  703. py::tuple ishape = _make_shape_tuple(py::handle(getattr(k, "shape")));
  704. for (size_t j = 0; j < ndim; ++j) {
  705. if (cur_shape[tdim + j - offset].cast<size_t>() !=
  706. ishape[j].cast<size_t>()) {
  707. std::string msg =
  708. "boolean index did not match tensor along "
  709. "dimension " +
  710. std::to_string(tdim + j) + "; dimension is " +
  711. std::to_string(
  712. cur_shape[tdim + j - offset].cast<size_t>()) +
  713. " but corresponding boolean dimension is " +
  714. std::to_string(ishape[j].cast<size_t>());
  715. throw py::index_error(msg.c_str());
  716. }
  717. }
  718. py::object new_k = getattr(k, "reshape")(-1);
  719. py::object kshape = getattr(new_k, "shape");
  720. py::list new_shape(0);
  721. PyObject* sym = PyObject_CallObject(cpp_use_symbolic_shape, nullptr);
  722. bool is_sym = (sym == Py_True);
  723. Py_XDECREF(sym);
  724. if (is_sym) {
  725. py::object tshape = getattr(tensor, "shape");
  726. for (size_t j = 0; j < i - nonedim; ++j) {
  727. new_shape.append(tshape[py::int_(j)]);
  728. }
  729. new_shape.append(kshape[py::int_(0)]);
  730. for (size_t j = tdim + ndim - offset; j < cur_shape.size(); ++j) {
  731. new_shape.append(cur_shape[j]);
  732. }
  733. py::object shape_tensor = _astensor1d_cpp(
  734. new_shape, py::none(), py::none(), py::none());
  735. tensor = _reshape_cpp(tensor, shape_tensor);
  736. cur_shape = _make_shape_tuple(shape_tensor);
  737. } else {
  738. for (size_t j = 0; j < i - nonedim; ++j) {
  739. new_shape.append(cur_shape[j]);
  740. }
  741. new_shape.append(py::reinterpret_borrow<py::tuple>(kshape)[0]);
  742. for (size_t j = tdim + ndim - offset; j < cur_shape.size(); ++j) {
  743. new_shape.append(cur_shape[j]);
  744. }
  745. cur_shape = new_shape;
  746. tensor = _reshape_cpp(tensor, cur_shape);
  747. }
  748. offset++;
  749. tdim += ndim;
  750. }
  751. new_tuple_val.append(k);
  752. } else {
  753. new_tuple_val.append(k);
  754. tdim++;
  755. }
  756. }
  757. return py::make_tuple(tensor, py::reinterpret_borrow<py::tuple>(new_tuple_val));
  758. }
  759. std::pair<size_t, bool> get_ndim_safe(py::handle tensor) {
  760. if (auto p = TensorWrapper::try_cast(tensor.ptr())) {
  761. return {p->m_tensor->shape()->ndim, true};
  762. }
  763. try {
  764. return {getattr(tensor, "ndim").cast<size_t>(), true};
  765. } catch (py::error_already_set& err) {
  766. return {0, false};
  767. }
  768. }
  769. py::tuple _unpack_indexes(py::handle inp_hdl, py::handle idx_hdl) {
  770. py::object inp = py::reinterpret_borrow<py::object>(inp_hdl);
  771. py::tuple tuple_val;
  772. if (py::isinstance<py::tuple>(idx_hdl)) {
  773. tuple_val = py::reinterpret_borrow<py::tuple>(idx_hdl);
  774. } else {
  775. tuple_val = py::make_tuple(idx_hdl);
  776. }
  777. bool use_subtensor = true;
  778. bool need_remove_ellipsis = false;
  779. bool need_expand_bool_dim = false;
  780. size_t idx_ndim = 0;
  781. for (size_t i = 0; i < tuple_val.size(); ++i) {
  782. py::object k = tuple_val[i];
  783. if (k.is_none()) {
  784. continue;
  785. } else if (k.ptr() == Py_Ellipsis) {
  786. need_remove_ellipsis = true;
  787. } else {
  788. if (is_bool_dtype(k.ptr()) && hasattr(k, "ndim")) {
  789. size_t ndim = get_ndim_safe(k).first;
  790. idx_ndim += ndim;
  791. if (ndim > 1) {
  792. need_expand_bool_dim = true;
  793. }
  794. } else {
  795. idx_ndim++;
  796. }
  797. }
  798. }
  799. try {
  800. size_t inp_ndim = getattr(inp, "ndim").cast<size_t>();
  801. if (idx_ndim > inp_ndim) {
  802. std::string msg = "too many indices for tensor: tensor is " +
  803. std::to_string(inp_ndim) + "-dimensional, but " +
  804. std::to_string(idx_ndim) + " were indexed";
  805. throw py::index_error(msg.c_str());
  806. }
  807. } catch (py::error_already_set& err) {
  808. ; // ignore
  809. }
  810. if (need_remove_ellipsis) {
  811. tuple_val = _remove_ellipsis(inp, tuple_val);
  812. }
  813. if (need_expand_bool_dim) {
  814. py::object shape = getattr(inp, "shape");
  815. if (shape.ptr() != Py_None) {
  816. py::tuple ret = _expand_bool_dim(inp, tuple_val);
  817. inp = ret[0];
  818. tuple_val = ret[1];
  819. }
  820. }
  821. std::vector<int32_t> axis;
  822. for (size_t i = 0; i < tuple_val.size(); ++i) {
  823. if (tuple_val[i].is_none()) {
  824. axis.push_back(i);
  825. }
  826. }
  827. if (axis.size()) {
  828. std::shared_ptr<OpDef> op = AddAxis::make(axis);
  829. py::object Op = py::cast(op);
  830. PyObject* p[2] = {Op.ptr(), inp.ptr()};
  831. py::tuple ret = py::reinterpret_steal<py::object>(py_apply(NULL, p, 2));
  832. inp = ret[0];
  833. }
  834. py::list items;
  835. py::list tensors;
  836. int cur_axis = -1;
  837. for (size_t i = 0; i < tuple_val.size(); ++i) {
  838. py::object handle = tuple_val[i];
  839. cur_axis++;
  840. if (handle.is_none()) {
  841. continue;
  842. }
  843. if (!is_scalar(handle.ptr()) && !PySlice_Check(handle.ptr())) {
  844. use_subtensor = false;
  845. }
  846. py::list item;
  847. item.append(cur_axis);
  848. auto push = [&](PyObject* v) {
  849. if (v == Py_None) {
  850. item.append(false);
  851. } else {
  852. item.append(true);
  853. tensors.append(_get_index(py::reinterpret_borrow<py::object>(v), inp));
  854. }
  855. };
  856. if (PySlice_Check(handle.ptr())) {
  857. PySliceObject* s = (PySliceObject*)handle.ptr();
  858. if (s->start == Py_None && s->stop == Py_None && s->step == Py_None) {
  859. continue;
  860. }
  861. push(s->start);
  862. push(s->stop);
  863. push(s->step);
  864. item.append(false);
  865. } else {
  866. for (size_t j = 0; j < 3; j++)
  867. item.append(false);
  868. push(handle.ptr());
  869. }
  870. items.append(item);
  871. }
  872. return py::make_tuple(inp, tensors, items, use_subtensor, need_expand_bool_dim);
  873. }
  874. py::object _expand_args(py::handle args) {
  875. if (!PyTuple_Check(args.ptr())) {
  876. return py::reinterpret_borrow<py::object>(args);
  877. }
  878. py::tuple args_tup = py::reinterpret_borrow<py::tuple>(args.ptr());
  879. if (args_tup.size() == 1 && (PySequence_Check(args_tup[0].ptr()) ||
  880. is_tensor_or_symbolvar(args_tup[0].ptr()))) {
  881. return py::reinterpret_borrow<py::object>(args_tup[0]);
  882. } else {
  883. return py::reinterpret_steal<py::list>(PySequence_List(args_tup.ptr()));
  884. }
  885. }
  886. std::tuple<std::vector<int32_t>, bool> tuple2vector(py::object shape) {
  887. std::vector<int32_t> shp;
  888. if (!PyTuple_Check(shape.ptr())) {
  889. return {shp, false};
  890. }
  891. py::tuple tup = py::reinterpret_borrow<py::tuple>(shape);
  892. for (size_t i = 0; i < tup.size(); ++i) {
  893. if (!PyLong_Check(tup[i].ptr())) {
  894. shp.clear();
  895. return {shp, false};
  896. } else {
  897. shp.push_back(tup[i].cast<int32_t>());
  898. }
  899. }
  900. return {shp, true};
  901. }
  902. bool enable_fastpath(py::handle inp) {
  903. auto&& tm_tr = TransformationManager::get_instance()
  904. .segments[TransformationManager::Segment::ModuleTrace];
  905. if (!TensorWrapper::try_cast(inp.ptr()) ||
  906. TransformationManager::get_instance()
  907. .segments[TransformationManager::Segment::Trace]
  908. .size() > 0 ||
  909. (tm_tr.size() > 0 &&
  910. reinterpret_cast<ModuleTraceTransformation*>(tm_tr[0].get())->enabled())) {
  911. return false;
  912. }
  913. return true;
  914. }
  915. py::object _broadcast_cpp(py::handle inp_hdl, py::handle args) {
  916. py::object shape_hdl = _expand_args(args);
  917. bool auto_infer = false;
  918. py::list lis;
  919. py::list new_shape;
  920. if (PyList_Check(shape_hdl.ptr()) || PyTuple_Check(shape_hdl.ptr())) {
  921. lis = py::reinterpret_steal<py::list>(PySequence_List(shape_hdl.ptr()));
  922. for (size_t i = 0; i < lis.size(); ++i) {
  923. if (lis[i].is_none()) {
  924. auto_infer = true;
  925. size_t right = lis.size() - i;
  926. py::object tshp = getattr(inp_hdl, "_tuple_shape");
  927. if (tshp.is_none()) {
  928. throw py::index_error("does not support `None` with unknown shape");
  929. }
  930. py::tuple inp_shape = py::reinterpret_borrow<py::tuple>(tshp);
  931. if (inp_shape.size() >= right) {
  932. if (enable_fastpath(inp_hdl)) {
  933. lis[i] = inp_shape[inp_shape.size() - right];
  934. }
  935. new_shape.append(inp_shape[inp_shape.size() - right]);
  936. } else {
  937. throw py::value_error("invalid broadcast shape");
  938. }
  939. } else {
  940. new_shape.append(lis[i]);
  941. if (PyLong_Check(lis[i].ptr())) {
  942. int32_t s = lis[i].cast<int32_t>();
  943. if (s < 0) {
  944. throw py::value_error(
  945. "expect shape[" + std::to_string(i) +
  946. "] >= 0 or use `None` to auto infer, got " +
  947. std::to_string(s));
  948. }
  949. }
  950. }
  951. }
  952. }
  953. if (auto_infer) {
  954. if (enable_fastpath(inp_hdl)) {
  955. shape_hdl = py::reinterpret_borrow<py::tuple>(lis);
  956. } else {
  957. shape_hdl = _astensor1d_cpp(
  958. new_shape, py::cast((mgb::DType)dtype::Int32()),
  959. getattr(inp_hdl, "device"), inp_hdl);
  960. }
  961. }
  962. py::object shape_tuple;
  963. try {
  964. shape_tuple = _make_shape_tuple(shape_hdl);
  965. } catch (py::error_already_set& err) {
  966. shape_tuple = py::reinterpret_borrow<py::object>(shape_hdl);
  967. }
  968. auto [shape, fastpath] = tuple2vector(shape_tuple);
  969. fastpath &= enable_fastpath(inp_hdl);
  970. std::shared_ptr<OpDef> op;
  971. std::vector<PyObject*> p;
  972. py::object shape_tensor;
  973. if (fastpath) {
  974. op = Broadcast::make(shape);
  975. p.resize(2);
  976. } else {
  977. op = Broadcast::make();
  978. shape_tensor = _astensor1d_cpp(
  979. shape_hdl, py::cast((mgb::DType)dtype::Int32()),
  980. getattr(inp_hdl, "device"), inp_hdl);
  981. p.resize(3);
  982. p[2] = shape_tensor.ptr();
  983. }
  984. py::object Op = py::cast(op);
  985. p[0] = Op.ptr();
  986. p[1] = inp_hdl.ptr();
  987. py::tuple ret =
  988. py::reinterpret_steal<py::object>(py_apply(NULL, p.data(), p.size()));
  989. return ret[0];
  990. }
  991. py::object _reshape_cpp(py::handle inp_hdl, py::handle args) {
  992. py::object shape_hdl = _expand_args(args);
  993. py::object shape_tuple;
  994. try {
  995. shape_tuple = _make_shape_tuple(shape_hdl);
  996. } catch (py::error_already_set& err) {
  997. shape_tuple = py::reinterpret_borrow<py::object>(shape_hdl);
  998. }
  999. int32_t unspec_axis = -1;
  1000. if (PyTuple_Check(shape_tuple.ptr())) {
  1001. py::tuple tup = py::reinterpret_borrow<py::tuple>(shape_tuple);
  1002. for (size_t i = 0; i < tup.size(); ++i) {
  1003. py::object obj = py::reinterpret_borrow<py::object>(tup[i]);
  1004. if (obj < py::int_(0)) {
  1005. if (obj.not_equal(py::int_(-1))) {
  1006. throw py::value_error(
  1007. "expect shape [" + std::to_string(i) + "] >= -1, got " +
  1008. repr(obj).cast<std::string>());
  1009. }
  1010. if (unspec_axis >= 0) {
  1011. throw py::value_error(
  1012. "multiple -1 in shape: " + std::to_string(unspec_axis) +
  1013. " & " + std::to_string(i));
  1014. }
  1015. unspec_axis = i;
  1016. }
  1017. }
  1018. }
  1019. auto [shape, fastpath] = tuple2vector(shape_tuple);
  1020. fastpath &= enable_fastpath(inp_hdl);
  1021. std::shared_ptr<OpDef> op;
  1022. std::vector<PyObject*> p;
  1023. py::object shape_tensor;
  1024. if (fastpath) {
  1025. if (unspec_axis >= 0) {
  1026. op = Reshape::make(unspec_axis, shape);
  1027. } else {
  1028. op = Reshape::make(::megdnn::param::OptionalAxisV1::INVALID_AXIS, shape);
  1029. }
  1030. p.resize(2);
  1031. } else {
  1032. shape.clear();
  1033. if (unspec_axis >= 0) {
  1034. op = Reshape::make(unspec_axis, shape);
  1035. } else {
  1036. op = Reshape::make();
  1037. }
  1038. shape_tensor = _astensor1d_cpp(
  1039. shape_hdl, py::cast((mgb::DType)dtype::Int32()),
  1040. getattr(inp_hdl, "device"), inp_hdl);
  1041. p.resize(3);
  1042. p[2] = shape_tensor.ptr();
  1043. }
  1044. py::object Op = py::cast(op);
  1045. p[0] = Op.ptr();
  1046. p[1] = inp_hdl.ptr();
  1047. py::tuple ret =
  1048. py::reinterpret_steal<py::object>(py_apply(NULL, p.data(), p.size()));
  1049. return ret[0];
  1050. }
  1051. py::object _adaptive_pool2d_cpp(
  1052. py::handle inp_hdl, py::handle shape_val_hdl, py::handle pool_mode_hdl) {
  1053. py::object shape_hdl = py::reinterpret_borrow<py::object>(shape_val_hdl);
  1054. py::list shps(0);
  1055. if (!PyTuple_Check(shape_val_hdl.ptr())) {
  1056. shps.append(PyLong_AsLong(shape_val_hdl.ptr()));
  1057. shps.append(PyLong_AsLong(shape_val_hdl.ptr()));
  1058. shape_hdl = py::reinterpret_borrow<py::object>(shps);
  1059. }
  1060. py::object shape_tuple;
  1061. try {
  1062. shape_tuple = _make_shape_tuple(shape_hdl);
  1063. } catch (py::error_already_set& err) {
  1064. shape_tuple = py::reinterpret_borrow<py::object>(shape_hdl);
  1065. }
  1066. auto mode_string = pool_mode_hdl.cast<std::string>();
  1067. ::megdnn::param::AdaptivePooling::Mode pool_mode =
  1068. ::megdnn::param::AdaptivePooling::Mode::MAX;
  1069. if (mode_string.compare(std::string("AVERAGE")) == 0) {
  1070. pool_mode = ::megdnn::param::AdaptivePooling::Mode::AVERAGE;
  1071. }
  1072. auto [shape, fastpath] = tuple2vector(shape_tuple);
  1073. fastpath &= enable_fastpath(inp_hdl);
  1074. std::shared_ptr<OpDef> op;
  1075. std::vector<PyObject*> p;
  1076. py::object shape_tensor;
  1077. op = AdaptivePooling::make(
  1078. pool_mode, ::megdnn::param::AdaptivePooling::Format::NCHW, shape);
  1079. if (fastpath) {
  1080. p.resize(2);
  1081. } else {
  1082. p.resize(3);
  1083. shape_tensor = _astensor1d_cpp(
  1084. shape_hdl, py::cast((mgb::DType)dtype::Int32()),
  1085. getattr(inp_hdl, "device"), inp_hdl);
  1086. p[2] = shape_tensor.ptr();
  1087. }
  1088. py::object Op = py::cast(op);
  1089. p[0] = Op.ptr();
  1090. p[1] = inp_hdl.ptr();
  1091. py::tuple ret =
  1092. py::reinterpret_steal<py::object>(py_apply(NULL, p.data(), p.size()));
  1093. return ret[0];
  1094. }
  1095. py::object _getitem_cpp(py::handle inp_hdl, py::handle idx_hdl) {
  1096. py::tuple try_res = _try_cond_take(inp_hdl, idx_hdl);
  1097. if (try_res.size() == 2) {
  1098. return try_res[0];
  1099. }
  1100. py::tuple up = _unpack_indexes(inp_hdl, idx_hdl);
  1101. py::object tensor = py::reinterpret_borrow<py::object>(up[0]);
  1102. py::list tensors = py::reinterpret_borrow<py::list>(up[1]);
  1103. py::list py_items = py::reinterpret_borrow<py::list>(up[2]);
  1104. std::vector<std::tuple<int8_t, bool, bool, bool, bool>> cpp_items;
  1105. for (size_t i = 0; i < py_items.size(); ++i) {
  1106. py::list item = py::reinterpret_borrow<py::list>(py_items[i]);
  1107. cpp_items.push_back(
  1108. {item[0].cast<int8_t>(), item[1].cast<bool>(), item[2].cast<bool>(),
  1109. item[3].cast<bool>(), item[4].cast<bool>()});
  1110. }
  1111. std::shared_ptr<OpDef> op;
  1112. if (up[3].cast<bool>()) {
  1113. op = Subtensor::make(cpp_items);
  1114. } else {
  1115. op = IndexingMultiAxisVec::make(cpp_items);
  1116. }
  1117. std::vector<PyObject*> p;
  1118. p.resize(tensors.size() + 2);
  1119. py::object Op = py::cast(op);
  1120. p[0] = Op.ptr();
  1121. p[1] = tensor.ptr();
  1122. for (size_t i = 0; i < tensors.size(); ++i) {
  1123. p[i + 2] = tensors[i].ptr();
  1124. }
  1125. py::tuple ret =
  1126. py::reinterpret_steal<py::object>(py_apply(NULL, p.data(), p.size()));
  1127. return ret[0];
  1128. }
  1129. py::object _setitem_cpp(py::handle inp_hdl, py::handle idx_hdl, py::handle val_hdl) {
  1130. py::object org_shape = getattr(inp_hdl, "shape");
  1131. py::object val = py::reinterpret_borrow<py::object>(val_hdl);
  1132. if (!TensorWrapper::try_cast(val.ptr()) && !py::isinstance<PySymbolVar>(val)) {
  1133. val =
  1134. _Const(val_hdl, getattr(inp_hdl, "dtype"), getattr(inp_hdl, "device"),
  1135. inp_hdl);
  1136. }
  1137. py::tuple up = _unpack_indexes(inp_hdl, idx_hdl);
  1138. py::object tensor = py::reinterpret_borrow<py::object>(up[0]);
  1139. py::list tensors = py::reinterpret_borrow<py::list>(up[1]);
  1140. py::list py_items = py::reinterpret_borrow<py::list>(up[2]);
  1141. std::vector<std::tuple<int8_t, bool, bool, bool, bool>> cpp_items;
  1142. for (size_t i = 0; i < py_items.size(); ++i) {
  1143. py::list item = py::reinterpret_borrow<py::list>(py_items[i]);
  1144. cpp_items.push_back(
  1145. {item[0].cast<int8_t>(), item[1].cast<bool>(), item[2].cast<bool>(),
  1146. item[3].cast<bool>(), item[4].cast<bool>()});
  1147. }
  1148. std::shared_ptr<OpDef> op, set_op;
  1149. if (up[3].cast<bool>()) {
  1150. op = Subtensor::make(cpp_items);
  1151. } else {
  1152. op = IndexingMultiAxisVec::make(cpp_items);
  1153. }
  1154. std::vector<PyObject*> p;
  1155. p.resize(tensors.size() + 2);
  1156. py::object Op = py::cast(op);
  1157. p[0] = Op.ptr();
  1158. p[1] = tensor.ptr();
  1159. for (size_t i = 0; i < tensors.size(); ++i) {
  1160. p[i + 2] = tensors[i].ptr();
  1161. }
  1162. py::tuple ret =
  1163. py::reinterpret_steal<py::object>(py_apply(NULL, p.data(), p.size()));
  1164. py::object tmp_result = ret[0];
  1165. try {
  1166. py::tuple value_shape =
  1167. py::reinterpret_borrow<py::tuple>(val.attr("_tuple_shape"));
  1168. py::tuple tmp_result_shape =
  1169. py::reinterpret_borrow<py::tuple>(tmp_result.attr("_tuple_shape"));
  1170. for (size_t i = 0; i < value_shape.size() && i < tmp_result_shape.size(); ++i) {
  1171. size_t vs = value_shape[value_shape.size() - i - 1].cast<size_t>();
  1172. size_t ts =
  1173. tmp_result_shape[tmp_result_shape.size() - i - 1].cast<size_t>();
  1174. if (vs != 1 && vs != ts) {
  1175. std::string lhs = "", rhs = "";
  1176. for (size_t j = 0; j < tmp_result_shape.size(); ++j) {
  1177. lhs += std::to_string(tmp_result_shape[j].cast<size_t>());
  1178. if (j)
  1179. lhs += ",";
  1180. }
  1181. for (size_t j = 0; j < value_shape.size(); ++j) {
  1182. rhs += std::to_string(value_shape[j].cast<size_t>());
  1183. if (j)
  1184. rhs += ",";
  1185. }
  1186. throw py::value_error(
  1187. "cannot copy tensor with shape (" + rhs +
  1188. ") to subtensor with shape (" + lhs + ")");
  1189. }
  1190. }
  1191. } catch (py::error_already_set& err) {
  1192. ;
  1193. }
  1194. val = _broadcast_cpp(val, getattr(tmp_result, "shape"));
  1195. if (up[3].cast<bool>()) {
  1196. set_op = SetSubtensor::make(cpp_items);
  1197. } else {
  1198. set_op = IndexingSetMultiAxisVec::make(cpp_items);
  1199. }
  1200. std::vector<PyObject*> q;
  1201. q.resize(tensors.size() + 3);
  1202. py::object Set_Op = py::cast(set_op);
  1203. q[0] = Set_Op.ptr();
  1204. q[1] = tensor.ptr();
  1205. q[2] = val.ptr();
  1206. for (size_t i = 0; i < tensors.size(); ++i) {
  1207. q[i + 3] = tensors[i].ptr();
  1208. }
  1209. py::tuple result =
  1210. py::reinterpret_steal<py::object>(py_apply(NULL, q.data(), q.size()));
  1211. py::object res = result[0];
  1212. if (up[4].cast<bool>()) {
  1213. res = _reshape_cpp(res, org_shape);
  1214. }
  1215. return res;
  1216. }
  1217. py::object _split_cpp(
  1218. py::handle inp_hdl, py::handle nsplits_or_sections_hdl, py::handle axis_hdl) {
  1219. py::object shape_obj = getattr(inp_hdl, "shape");
  1220. py::object n_total = shape_obj[axis_hdl];
  1221. int ndim = shape_obj.attr("__len__")().cast<int>();
  1222. int axis = axis_hdl.cast<int>();
  1223. if (axis >= ndim) {
  1224. throw py::value_error("Invalid axis " + std::to_string(axis));
  1225. }
  1226. int n_sections;
  1227. bool is_array;
  1228. if (is_py_sequence(nsplits_or_sections_hdl)) {
  1229. n_sections = PySequence_Length(nsplits_or_sections_hdl.ptr()) + 1;
  1230. is_array = true;
  1231. } else {
  1232. n_sections = getattr(nsplits_or_sections_hdl, "__int__")().cast<int>();
  1233. is_array = false;
  1234. }
  1235. py::list partitions;
  1236. std::shared_ptr<OpDef> op;
  1237. std::vector<PyObject*> p;
  1238. if (is_array) {
  1239. py::list div_points;
  1240. py::list sections = py::reinterpret_borrow<py::object>(nsplits_or_sections_hdl);
  1241. div_points.append(0);
  1242. for (size_t i = 0; i < sections.size(); ++i) {
  1243. div_points.append(sections[i]);
  1244. }
  1245. div_points.append(n_total);
  1246. for (size_t i = 1; i < div_points.size(); ++i) {
  1247. if (div_points[i - 1] > div_points[i]) {
  1248. throw py::value_error(
  1249. "Invalid nsplits_or_secions: " +
  1250. repr(nsplits_or_sections_hdl).cast<std::string>());
  1251. }
  1252. py::object pos = div_points[i] - div_points[i - 1];
  1253. if (is_tensor_or_symbolvar(pos)) {
  1254. partitions.append(pos);
  1255. } else {
  1256. partitions.append(
  1257. _Const(pos, py::cast((mgb::DType)dtype::Int32()),
  1258. getattr(inp_hdl, "device"), inp_hdl));
  1259. }
  1260. }
  1261. op = Split::make(axis, 0);
  1262. p.resize(partitions.size() + 2);
  1263. for (size_t i = 0; i < partitions.size(); ++i) {
  1264. p[i + 2] = partitions[i].ptr();
  1265. }
  1266. } else {
  1267. if (n_sections <= 0) {
  1268. throw py::value_error("Number sections must be larger than 0");
  1269. }
  1270. if (py::int_(n_sections) > n_total) {
  1271. throw py::value_error(
  1272. "The size " + repr(n_total).cast<std::string>() + " at dim " +
  1273. std::to_string(axis) + " cannot be split into " +
  1274. std::to_string(n_sections) + " sections");
  1275. }
  1276. op = Split::make(axis, n_sections);
  1277. p.resize(2);
  1278. }
  1279. py::object Op = py::cast(op);
  1280. p[0] = Op.ptr();
  1281. p[1] = inp_hdl.ptr();
  1282. return py::reinterpret_steal<py::object>(py_apply(NULL, p.data(), p.size()));
  1283. }
  1284. std::vector<int32_t> list2vector(py::handle li) {
  1285. std::vector<int32_t> axis;
  1286. if (is_py_sequence(li)) {
  1287. py::list tmp_list = py::reinterpret_steal<py::list>(PySequence_List(li.ptr()));
  1288. for (size_t i = 0; i < tmp_list.size(); ++i) {
  1289. axis.push_back(tmp_list[i].attr("__int__")().cast<int32_t>());
  1290. }
  1291. } else {
  1292. axis.push_back(getattr(li, "__int__")().cast<int32_t>());
  1293. }
  1294. return axis;
  1295. }
  1296. py::object _expand_dims_cpp(py::handle inp_hdl, py::handle axis_hdl) {
  1297. std::vector<int32_t> axis = list2vector(axis_hdl);
  1298. bool unknown_ndim = true;
  1299. size_t ndim = axis.size();
  1300. if (auto p = TensorWrapper::try_cast(inp_hdl.ptr())) {
  1301. auto&& shape = p->m_tensor->shape();
  1302. if (shape) {
  1303. unknown_ndim = false;
  1304. ndim += shape->ndim;
  1305. }
  1306. } else {
  1307. auto&& inp_ndim = get_ndim_safe(inp_hdl);
  1308. ndim += inp_ndim.first;
  1309. unknown_ndim &= !inp_ndim.second;
  1310. }
  1311. for (size_t i = 0; i < axis.size(); ++i) {
  1312. if (axis[i] < 0) {
  1313. if (unknown_ndim) {
  1314. throw py::index_error(
  1315. "Does not support negative index when tensor's ndim is "
  1316. "unknown");
  1317. }
  1318. axis[i] += static_cast<int32_t>(ndim);
  1319. }
  1320. }
  1321. if (!axis.size()) {
  1322. throw py::index_error("axis could not be empty");
  1323. }
  1324. std::sort(axis.begin(), axis.end());
  1325. std::shared_ptr<OpDef> op = AddAxis::make(axis = axis);
  1326. py::object Op = py::cast(op);
  1327. PyObject* p[2] = {Op.ptr(), inp_hdl.ptr()};
  1328. py::tuple ret = py::reinterpret_steal<py::object>(py_apply(NULL, p, 2));
  1329. return ret[0];
  1330. }
  1331. py::object _squeeze_cpp(py::handle inp_hdl, py::handle axis_hdl) {
  1332. std::vector<int32_t> axis;
  1333. size_t ndim;
  1334. if (axis_hdl.ptr() != Py_None) {
  1335. axis = list2vector(axis_hdl);
  1336. }
  1337. if (auto p = TensorWrapper::try_cast(inp_hdl.ptr())) {
  1338. auto&& shape = p->m_tensor->shape();
  1339. if (shape) {
  1340. ndim = shape->ndim;
  1341. if (axis_hdl.ptr() == Py_None) {
  1342. for (size_t i = 0; i < shape->ndim; ++i) {
  1343. if (shape->shape[i] == 1) {
  1344. axis.push_back(i);
  1345. }
  1346. }
  1347. }
  1348. }
  1349. } else {
  1350. py::tuple shape =
  1351. py::reinterpret_borrow<py::tuple>(getattr(inp_hdl, "_tuple_shape"));
  1352. ndim = shape.size();
  1353. if (axis_hdl.ptr() == Py_None) {
  1354. for (size_t i = 0; i < shape.size(); ++i) {
  1355. if (shape[i].cast<size_t>() == 1) {
  1356. axis.push_back(i);
  1357. }
  1358. }
  1359. }
  1360. }
  1361. for (size_t i = 0; i < axis.size(); ++i) {
  1362. if (axis[i] < 0) {
  1363. axis[i] += static_cast<int32_t>(ndim);
  1364. }
  1365. }
  1366. std::sort(axis.begin(), axis.end());
  1367. for (size_t i = 0; i < axis.size(); ++i) {
  1368. axis[i] -= static_cast<int32_t>(i);
  1369. }
  1370. std::shared_ptr<OpDef> op = RemoveAxis::make(axis = axis);
  1371. py::object Op = py::cast(op);
  1372. PyObject* p[2] = {Op.ptr(), inp_hdl.ptr()};
  1373. py::tuple ret = py::reinterpret_steal<py::object>(py_apply(NULL, p, 2));
  1374. return ret[0];
  1375. }
  1376. py::object _transpose_cpp(py::handle inp_hdl, py::handle args) {
  1377. py::object obj = _expand_args(args);
  1378. py::list lis;
  1379. if (!is_tensor_or_symbolvar(obj.ptr()) && PySequence_Check(obj.ptr())) {
  1380. lis = py::reinterpret_steal<py::list>(PySequence_List(obj.ptr()));
  1381. } else {
  1382. py::object np = getattr(obj, "numpy")();
  1383. PyArrayObject* arr = (PyArrayObject*)np.ptr();
  1384. PyObject* maybe_list = PyArray_ToList(arr);
  1385. if (PyList_Check(maybe_list)) {
  1386. lis = py::reinterpret_steal<py::list>(maybe_list);
  1387. }
  1388. }
  1389. if (get_ndim_safe(inp_hdl).first == 0) {
  1390. if (lis.size() != 0) {
  1391. throw py::index_error(
  1392. "transpose for scalar does not accept additional args");
  1393. }
  1394. return getattr(inp_hdl, "to")(getattr(inp_hdl, "device"));
  1395. }
  1396. std::vector<int32_t> pattern;
  1397. if (!lis.size()) {
  1398. size_t ndim = getattr(inp_hdl, "ndim").cast<size_t>();
  1399. for (size_t i = 0; i < ndim; ++i) {
  1400. pattern.push_back(ndim - i - 1);
  1401. }
  1402. } else {
  1403. for (size_t i = 0; i < lis.size(); ++i) {
  1404. if (PyLong_Check(lis[i].ptr())) {
  1405. pattern.push_back(lis[i].cast<int32_t>());
  1406. } else {
  1407. if (lis[i].cast<std::string>() == "x") {
  1408. pattern.push_back(-1);
  1409. }
  1410. }
  1411. }
  1412. }
  1413. std::shared_ptr<OpDef> op = Dimshuffle::make(pattern);
  1414. py::object Op = py::cast(op);
  1415. PyObject* p[2] = {Op.ptr(), inp_hdl.ptr()};
  1416. py::tuple ret = py::reinterpret_steal<py::object>(py_apply(NULL, p, 2));
  1417. return ret[0];
  1418. }
  1419. py::object _matmul_cpp(
  1420. py::handle inp1, py::handle inp2, py::handle dim1, py::handle dim2,
  1421. py::handle transpose_a, py::handle transpose_b, py::handle compute_mode,
  1422. py::handle profile, py::handle determistic) {
  1423. ::megdnn::param::MatrixMul::ComputeMode mode =
  1424. ::megdnn::param::MatrixMul::ComputeMode::DEFAULT;
  1425. if (compute_mode.cast<std::string>().compare(std::string("float32")) == 0) {
  1426. mode = ::megdnn::param::MatrixMul::ComputeMode::FLOAT32;
  1427. }
  1428. ::megdnn::param::ExecutionPolicy::Strategy cstrategy =
  1429. static_cast<::megdnn::param::ExecutionPolicy::Strategy>(0);
  1430. if (profile.cast<bool>()) {
  1431. cstrategy |= ::megdnn::param::ExecutionPolicy::Strategy::PROFILE;
  1432. } else {
  1433. cstrategy |= ::megdnn::param::ExecutionPolicy::Strategy::HEURISTIC;
  1434. }
  1435. if (determistic.cast<bool>()) {
  1436. cstrategy |= ::megdnn::param::ExecutionPolicy::Strategy::REPRODUCIBLE;
  1437. }
  1438. std::shared_ptr<OpDef> op = MatrixMul::make(
  1439. transpose_a.cast<bool>(), transpose_b.cast<bool>(), mode,
  1440. ::megdnn::param::MatrixMul::Format::DEFAULT, cstrategy, UINT64_MAX,
  1441. dim1.cast<uint32_t>(), dim2.cast<uint32_t>());
  1442. py::object Op = py::cast(op);
  1443. PyObject* p[3] = {Op.ptr(), inp1.ptr(), inp2.ptr()};
  1444. py::tuple ret = py::reinterpret_steal<py::object>(py_apply(NULL, p, 3));
  1445. return ret[0];
  1446. }
  1447. py::object _batched_matmul_cpp(
  1448. py::handle inp1, py::handle inp2, py::handle dim1, py::handle dim2,
  1449. py::handle transpose_a, py::handle transpose_b, py::handle compute_mode,
  1450. py::handle profile, py::handle determistic) {
  1451. ::megdnn::param::MatrixMul::ComputeMode mode =
  1452. ::megdnn::param::MatrixMul::ComputeMode::DEFAULT;
  1453. if (compute_mode.cast<std::string>().compare(std::string("float32")) == 0) {
  1454. mode = ::megdnn::param::MatrixMul::ComputeMode::FLOAT32;
  1455. }
  1456. ::megdnn::param::ExecutionPolicy::Strategy cstrategy =
  1457. static_cast<::megdnn::param::ExecutionPolicy::Strategy>(0);
  1458. if (profile.cast<bool>()) {
  1459. cstrategy |= ::megdnn::param::ExecutionPolicy::Strategy::PROFILE;
  1460. } else {
  1461. cstrategy |= ::megdnn::param::ExecutionPolicy::Strategy::HEURISTIC;
  1462. }
  1463. if (determistic.cast<bool>()) {
  1464. cstrategy |= ::megdnn::param::ExecutionPolicy::Strategy::REPRODUCIBLE;
  1465. }
  1466. std::shared_ptr<OpDef> op = BatchedMatrixMul::make(
  1467. transpose_a.cast<bool>(), transpose_b.cast<bool>(), mode,
  1468. ::megdnn::param::MatrixMul::Format::DEFAULT, cstrategy, UINT64_MAX,
  1469. dim1.cast<uint32_t>(), dim2.cast<uint32_t>());
  1470. py::object Op = py::cast(op);
  1471. PyObject* p[3] = {Op.ptr(), inp1.ptr(), inp2.ptr()};
  1472. py::tuple ret = py::reinterpret_steal<py::object>(py_apply(NULL, p, 3));
  1473. return ret[0];
  1474. }
  1475. py::object _pixel_shuffle_cpp(py::handle inp, py::handle val, py::handle func) {
  1476. if (enable_fastpath(inp) && PyLong_Check(val.ptr())) {
  1477. std::shared_ptr<OpDef> op = PixelShuffle::make(val.cast<int32_t>());
  1478. py::object Op = py::cast(op);
  1479. PyObject* p[2] = {Op.ptr(), inp.ptr()};
  1480. py::tuple ret = py::reinterpret_steal<py::object>(py_apply(NULL, p, 2));
  1481. return ret[0];
  1482. } else {
  1483. // fallback to traceable subgraph implement
  1484. return func(inp, val);
  1485. }
  1486. }
  1487. PyObject* make_shape_tuple(PyObject* self, PyObject* const* args, size_t nargs) {
  1488. try {
  1489. return _make_shape_tuple(args[0]).release().ptr();
  1490. }
  1491. PYEXT17_TRANSLATE_EXC_RET(nullptr)
  1492. }
  1493. PyObject* getitem_cpp(PyObject* self, PyObject* const* args, size_t nargs) {
  1494. try {
  1495. return _getitem_cpp(args[0], args[1]).release().ptr();
  1496. }
  1497. PYEXT17_TRANSLATE_EXC_RET(nullptr)
  1498. }
  1499. PyObject* setitem_cpp(PyObject* self, PyObject* const* args, size_t nargs) {
  1500. try {
  1501. return _setitem_cpp(args[0], args[1], args[2]).release().ptr();
  1502. }
  1503. PYEXT17_TRANSLATE_EXC_RET(nullptr)
  1504. }
  1505. PyObject* split_cpp(PyObject* self, PyObject* const* args, size_t nargs) {
  1506. try {
  1507. return _split_cpp(args[0], args[1], args[2]).release().ptr();
  1508. }
  1509. PYEXT17_TRANSLATE_EXC_RET(nullptr)
  1510. }
  1511. PyObject* expand_dims_cpp(PyObject* self, PyObject* const* args, size_t nargs) {
  1512. try {
  1513. return _expand_dims_cpp(args[0], args[1]).release().ptr();
  1514. }
  1515. PYEXT17_TRANSLATE_EXC_RET(nullptr)
  1516. }
  1517. PyObject* squeeze_cpp(PyObject* self, PyObject* const* args, size_t nargs) {
  1518. try {
  1519. return _squeeze_cpp(args[0], args[1]).release().ptr();
  1520. }
  1521. PYEXT17_TRANSLATE_EXC_RET(nullptr)
  1522. }
  1523. PyObject* transpose_cpp(PyObject* self, PyObject* const* args, size_t nargs) {
  1524. try {
  1525. return _transpose_cpp(args[0], args[1]).release().ptr();
  1526. }
  1527. PYEXT17_TRANSLATE_EXC_RET(nullptr)
  1528. }
  1529. PyObject* broadcast_cpp(PyObject* self, PyObject* const* args, size_t nargs) {
  1530. try {
  1531. return _broadcast_cpp(args[0], args[1]).release().ptr();
  1532. }
  1533. PYEXT17_TRANSLATE_EXC_RET(nullptr)
  1534. }
  1535. PyObject* reshape_cpp(PyObject* self, PyObject* const* args, size_t nargs) {
  1536. try {
  1537. return _reshape_cpp(args[0], args[1]).release().ptr();
  1538. }
  1539. PYEXT17_TRANSLATE_EXC_RET(nullptr)
  1540. }
  1541. PyObject* adaptive_pool2d_cpp(PyObject* self, PyObject* const* args, size_t nargs) {
  1542. try {
  1543. return _adaptive_pool2d_cpp(args[0], args[1], args[2]).release().ptr();
  1544. }
  1545. PYEXT17_TRANSLATE_EXC_RET(nullptr)
  1546. }
  1547. PyObject* pixel_shuffle_cpp(PyObject* self, PyObject* const* args, size_t nargs) {
  1548. try {
  1549. return _pixel_shuffle_cpp(args[0], args[1], args[2]).release().ptr();
  1550. }
  1551. PYEXT17_TRANSLATE_EXC_RET(nullptr)
  1552. }
  1553. PyObject* Const(PyObject* self, PyObject* const* args, size_t nargs) {
  1554. try {
  1555. return _Const(args[0], args[1], args[2], args[3]).release().ptr();
  1556. }
  1557. PYEXT17_TRANSLATE_EXC_RET(nullptr)
  1558. }
  1559. PyObject* astype_cpp(PyObject* self, PyObject* const* args, size_t nargs) {
  1560. try {
  1561. return _astype_cpp(args[0], args[1]).release().ptr();
  1562. }
  1563. PYEXT17_TRANSLATE_EXC_RET(nullptr)
  1564. }
  1565. PyObject* matmul_cpp(PyObject* self, PyObject* const* args, size_t nargs) {
  1566. try {
  1567. return _matmul_cpp(
  1568. args[0], args[1], args[2], args[3], args[4], args[5], args[6],
  1569. args[7], args[8])
  1570. .release()
  1571. .ptr();
  1572. }
  1573. PYEXT17_TRANSLATE_EXC_RET(nullptr)
  1574. }
  1575. PyObject* batched_matmul_cpp(PyObject* self, PyObject* const* args, size_t nargs) {
  1576. try {
  1577. return _batched_matmul_cpp(
  1578. args[0], args[1], args[2], args[3], args[4], args[5], args[6],
  1579. args[7], args[8])
  1580. .release()
  1581. .ptr();
  1582. }
  1583. PYEXT17_TRANSLATE_EXC_RET(nullptr)
  1584. }
  1585. PyObject* convert_single_value_cpp(
  1586. PyObject* self, PyObject* const* args, size_t nargs) {
  1587. try {
  1588. return _convert_single_value_cpp(args[0], args[1], args[2]).release().ptr();
  1589. }
  1590. PYEXT17_TRANSLATE_EXC_RET(nullptr)
  1591. }
  1592. PyObject* convert_inputs_cpp(PyObject* self, PyObject* const* args, size_t nargs) {
  1593. try {
  1594. py::object dtype = py::reinterpret_steal<py::object>(
  1595. dtype_promotion(self, args, nargs - 1));
  1596. py::object device;
  1597. if (args[nargs - 1] == Py_None) {
  1598. device = py::reinterpret_steal<py::object>(
  1599. get_device(self, args, nargs - 1));
  1600. } else {
  1601. device = py::reinterpret_borrow<py::object>(args[nargs - 1]);
  1602. }
  1603. return _convert_inputs_cpp(args, nargs - 1, dtype, device).release().ptr();
  1604. }
  1605. PYEXT17_TRANSLATE_EXC_RET(nullptr)
  1606. }
  1607. PyObject* astensor1d_cpp(PyObject* self, PyObject* const* args, size_t nargs) {
  1608. try {
  1609. return _astensor1d_cpp(args[0], args[1], args[2], args[3]).release().ptr();
  1610. }
  1611. PYEXT17_TRANSLATE_EXC_RET(nullptr)
  1612. }
  1613. } // namespace mgb::imperative::python