You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

tensor.cpp 30 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867
  1. /**
  2. * \file imperative/python/src/tensor.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "megbrain/dtype.h"
  12. #include "megbrain/common.h"
  13. #include "megbrain/imperative/ops/utility.h"
  14. #include "./tensor.h"
  15. #include "./grad.h"
  16. #include "./trace.h"
  17. #include "./common.h"
  18. #include "./numpy_dtypes.h"
  19. #include "./graph_rt.h"
  20. #include "./helper.h"
  21. #include <pybind11/numpy.h>
  22. #include <pybind11/operators.h>
  23. #include <range/v3/all.hpp>
  24. #include <unordered_map>
  25. namespace py = pybind11;
  26. namespace views = ranges::views;
  27. namespace mgb::imperative::python {
  28. interpreter::Interpreter::Channel* interpreter_for_py;
  29. PyObject *cpp_apply_with_tracing, *cpp_apply_const_with_tracing,
  30. *cpp_apply_compiled_mode, *cpp_apply_const_compiled_mode;
  31. PyObject *cpp_apply_backward_varnode;
  32. #define REGISTE_APPLY_FUNC(mode) \
  33. void set_##mode(py::object pyf) { \
  34. mode = pyf.ptr(); \
  35. }
  36. REGISTE_APPLY_FUNC(cpp_apply_with_tracing)
  37. REGISTE_APPLY_FUNC(cpp_apply_const_with_tracing)
  38. REGISTE_APPLY_FUNC(cpp_apply_compiled_mode)
  39. REGISTE_APPLY_FUNC(cpp_apply_const_compiled_mode)
  40. REGISTE_APPLY_FUNC(cpp_apply_backward_varnode)
  41. #undef REGISTE_APPLY_FUNC
  42. bool is_tracing = false;
  43. bool is_compiled = false;
  44. #define SET_UNSET_PROP(mode) \
  45. void set_##mode() { \
  46. is_##mode = true; \
  47. } \
  48. void unset_##mode() { \
  49. is_##mode = false; \
  50. } \
  51. SET_UNSET_PROP(tracing)
  52. SET_UNSET_PROP(compiled)
  53. #undef SET_UNSET_PROP
  54. bool skip_tracing = false;
  55. Tensor::flags_t ApplyContext::global_disable = 0;
  56. apply_result_t apply(ApplyContext& ctx) {
  57. // emulating scalar should be put to specific op's apply, e.g.,
  58. // elementwise, reduce, typecvt. Currently it's still handled at python
  59. // side. It could be move to C++ side if it has an impact on performance
  60. auto flags = ctx.flags & ~ApplyContext::global_disable;
  61. if (flags & Tensor::Flags::SCALAR) {
  62. // TODO: emulate scalar
  63. }
  64. if (flags & Tensor::Flags::GRAD) {
  65. return apply_grad(ctx);
  66. }
  67. if (auto* op = ctx.op->try_cast_final<GenericPyOp>()) {
  68. py::tuple pyin(ctx.nargs);
  69. for (size_t i = 0; i < ctx.nargs; ++i) {
  70. pyin[i] = TensorWrapper::make(ctx.pytype, ctx.args[i]->shared_from_this());
  71. }
  72. auto f = py::getattr(op->obj, "_default_rule");
  73. auto pyout = py::reinterpret_steal<py::object>(PyObject_Call(f.ptr(), pyin.ptr(), nullptr));
  74. if (!pyout) throw py::error_already_set();
  75. if (auto* tw = TensorWrapper::try_cast(pyout.ptr())) {
  76. return {tw->m_tensor};
  77. }
  78. apply_result_t ret;
  79. ret.reserve(py::len(pyout));
  80. for (auto&& i : pyout) {
  81. auto* tw = TensorWrapper::try_cast(i.ptr());
  82. mgb_assert(tw);
  83. ret.push_back(tw->m_tensor);
  84. }
  85. return ret;
  86. }
  87. if (flags & Tensor::Flags::TRACE) {
  88. return apply_trace(ctx);
  89. } else {
  90. SmallVector<interpreter::Interpreter::Handle> handles(ctx.nargs);
  91. for (size_t i = 0; i < ctx.nargs; ++i) {
  92. handles[i] = ctx.args[i]->m_handle.get();
  93. }
  94. auto output_handles = interpreter_for_py->apply_op(ctx.op, handles);
  95. apply_result_t outputs;
  96. outputs.reserve(output_handles.size());
  97. for (auto h : output_handles) {
  98. outputs.emplace_back(std::make_shared<Tensor>(h));
  99. }
  100. return outputs;
  101. }
  102. mgb_assert(0);
  103. }
  104. PyObject* py_apply(PyObject* self, PyObject*const* args, size_t nargs/* , PyObject* kwnames */) {
  105. try {
  106. // if (kwnames && PyTuple_GET_SIZE(kwnames)) {
  107. // PyErr_SetString(PyExc_TypeError, "keyword argument not allowed");
  108. // return nullptr;
  109. // }
  110. if (nargs < 2) {
  111. PyErr_SetString(PyExc_TypeError,
  112. "py_apply expects one Op and at least one tensor "
  113. "as argument");
  114. return nullptr;
  115. }
  116. auto* op = args[0];
  117. PyTypeObject* pytype = args[1]->ob_type;
  118. ++args;
  119. --nargs;
  120. ApplyContext ctx;
  121. ctx.flags = 0;
  122. ctx.op = py::handle(op).cast<std::shared_ptr<OpDef>>();
  123. SmallVector<Tensor*, 64> tensors(nargs);
  124. ctx.args = &tensors[0];
  125. ctx.nargs = nargs;
  126. ctx.pytype = pytype;
  127. if (strstr(op->ob_type->tp_name, "BackwardGraph")) {
  128. ctx.backward = true;
  129. }
  130. for (size_t i = 0; i < nargs; ++i) {
  131. if (TensorWrapper* tw = TensorWrapper::try_cast(args[i])) {
  132. auto* t = tensors[i] = tw->m_tensor.get();
  133. ctx.flags |= t->m_flags;
  134. } else {
  135. PyErr_SetString(PyExc_TypeError, "expect Tensor");
  136. return nullptr;
  137. }
  138. }
  139. if (is_tracing) {
  140. ctx.flags |= Tensor::Flags::TRACE;
  141. }
  142. auto outputs = apply(ctx);
  143. size_t nout = outputs.size();
  144. auto ret = py::tuple(nout);
  145. for (size_t i = 0; i < nout; ++i) {
  146. ret[i] = TensorWrapper::make(pytype, std::move(outputs[i]));
  147. }
  148. return ret.release().ptr();
  149. } catch (std::exception& e) {
  150. PyErr_SetString(PyExc_RuntimeError, e.what());
  151. return nullptr;
  152. }
  153. }
  154. TensorWrapper::TensorWrapper(PyObject* args, PyObject* kwargs) {
  155. if (kwargs && PyDict_Size(kwargs)) {
  156. throw py::type_error("keyword argument not allowed");
  157. }
  158. auto nargs = PyTuple_Size(args);
  159. auto tup = py::reinterpret_borrow<py::tuple>(args);
  160. if (nargs == 0) {
  161. throw py::type_error("too few arguments");
  162. }
  163. if (auto* t = try_cast(tup[0].ptr())) {
  164. if (nargs > 1) {
  165. throw py::type_error("expect 1 argument");
  166. }
  167. m_tensor = t->m_tensor;
  168. } else {
  169. if (nargs == 1) {
  170. auto arg0 = PyTuple_GetItem(args, 0);
  171. // for lazy_eval_tensor
  172. if (strstr(arg0->ob_type->tp_name, "VarNode")) {
  173. if (PyObject_HasAttrString(arg0, "_node")) {
  174. arg0 = PyObject_GetAttrString(arg0, "_node");
  175. }
  176. m_tensor = std::make_shared<Tensor>(py::handle(arg0).cast<cg::VarNode *>());
  177. } else {
  178. // for DeviceTensorND
  179. if (strstr(arg0->ob_type->tp_name, "DeviceTensorND")) {
  180. auto dv = py::handle(arg0).cast<DeviceTensorND>();
  181. interpreter::Interpreter::Handle handle = interpreter_for_py->put(dv);
  182. m_tensor = std::make_shared<Tensor>(handle);
  183. } else {
  184. throw py::type_error("single argument is not tensor, varnode or devicetensor");
  185. }
  186. }
  187. } else {
  188. py::detail::loader_life_support life_sup; // FIXME!!!required to cast DType
  189. if (nargs != 4 && nargs != 5) {
  190. throw py::type_error("expect 4 or 5 arguments");
  191. }
  192. auto data = tup[0].cast<py::array>();
  193. DType dtype = tup[1].cast<DType>();
  194. CompNode cn = tup[2].cast<CompNode>();
  195. bool is_const = tup[3].cast<bool>();
  196. bool no_cache = nargs == 5 ? tup[4].cast<bool>() : false;
  197. // const op
  198. if (is_const && is_tracing) {
  199. PyObject *pyf;
  200. if (is_compiled) {
  201. pyf = cpp_apply_const_compiled_mode;
  202. } else {
  203. pyf = cpp_apply_const_with_tracing;
  204. }
  205. auto ret = py::reinterpret_steal<py::object>(
  206. PyObject_Call(pyf, tup.ptr(), nullptr));
  207. auto py_ret = py::reinterpret_borrow<py::list>(ret);
  208. if (auto* t = try_cast(py_ret[0].ptr())) {
  209. m_tensor = t->m_tensor;
  210. }
  211. return;
  212. }
  213. interpreter::Interpreter::Handle handle;
  214. constexpr auto size_threshhold = TensorShape::MAX_NDIM;
  215. if (data.size() > size_threshhold) {
  216. handle = interpreter_for_py->put(npy::np2tensor(data.ptr(), npy::Meth::borrow(cn), dtype), no_cache);
  217. } else {
  218. HostTensorND ret(cn);
  219. handle = interpreter_for_py->put(npy::np2tensor(data.ptr(), npy::Meth::copy_into(&ret), dtype), no_cache);
  220. }
  221. m_tensor = std::make_shared<Tensor>(handle);
  222. if (data.ndim() == 0) {
  223. m_tensor->m_flags |= Tensor::Flags::SCALAR;
  224. }
  225. }
  226. }
  227. }
  228. #define REGISTE_TENSORWRAPPER_FUNC(type, member) \
  229. PyObject* TensorWrapper::member() { \
  230. return py::cast(m_tensor->m_trace_info.member).release().ptr(); \
  231. } \
  232. void TensorWrapper::set_##member(PyObject* dest) { \
  233. auto py_dest = py::reinterpret_borrow<py::object>(dest); \
  234. type real_dest = py_dest.cast<type>(); \
  235. m_tensor->m_trace_info.member = real_dest; \
  236. }
  237. REGISTE_TENSORWRAPPER_FUNC(int64_t, mixin_handle)
  238. REGISTE_TENSORWRAPPER_FUNC(bool, recording)
  239. #undef REGISTE_TENSORWRAPPER_FUNC
  240. PyObject* TensorWrapper::copied() {
  241. return py::cast(m_tensor->m_trace_info.copied).release().ptr();
  242. }
  243. #define REGISTE_TENSORWRAPPER_PYOBJECT_FUNC(member) \
  244. PyObject* TensorWrapper::member() { \
  245. if (m_tensor->m_trace_info.member) { \
  246. return m_tensor->m_trace_info.member; \
  247. } else { \
  248. Py_RETURN_NONE; \
  249. } \
  250. } \
  251. void TensorWrapper::set_##member(PyObject* dest) { \
  252. if (dest == Py_None) { \
  253. Py_XDECREF(m_tensor->m_trace_info.member); \
  254. m_tensor->m_trace_info.member = nullptr; \
  255. } else { \
  256. Py_INCREF(dest); \
  257. m_tensor->m_trace_info.member = dest; \
  258. } \
  259. }
  260. REGISTE_TENSORWRAPPER_PYOBJECT_FUNC(compiled_info)
  261. REGISTE_TENSORWRAPPER_PYOBJECT_FUNC(trace_mixin_info)
  262. #undef REGISTE_TENSORWRAPPER_PYOBJECT_FUNC
  263. PyObject* TensorWrapper::handle() {
  264. return py::cast(m_tensor->m_handle).release().ptr();
  265. }
  266. void TensorWrapper::set_handle(PyObject* dest) {
  267. auto py_dest = py::reinterpret_borrow<py::object>(dest);
  268. SharedHandle real_dest = py_dest.cast<SharedHandle>();
  269. m_tensor->m_handle = std::move(real_dest);
  270. }
  271. PyObject* TensorWrapper::shape() {
  272. // if it's tracing compiled mode, get value from compiled_info
  273. if (m_tensor->m_trace_info.compiled_info != nullptr) {
  274. if (m_tensor->m_flags & Tensor::Flags::SCALAR) {
  275. return PyTuple_New(0);
  276. }
  277. PyObject *shp = PyObject_GetAttrString(m_tensor->m_trace_info.compiled_info, "shape");
  278. if (shp == Py_None) {
  279. throw TraceReadError("shape of this tensor is not read in trace");
  280. }
  281. return shp;
  282. }
  283. // inside trace, if tensor shape is useful for other operations, set shape_read = true
  284. if (m_tensor->m_trace_info.recording && !skip_tracing) {
  285. PyObject_SetAttrString(m_tensor->m_trace_info.trace_mixin_info, "shape_read", py::cast(true).release().ptr());
  286. }
  287. if (m_tensor->m_flags & Tensor::Flags::SCALAR) {
  288. return PyTuple_New(0);
  289. }
  290. TensorShape shape;
  291. if (m_tensor->m_var) { // get shape from m_var
  292. auto&& mgr = m_tensor->m_var->owner_graph()->static_infer_manager();
  293. auto *tshp = mgr.infer_shape_fallible(m_tensor->m_var);
  294. if (!tshp) {
  295. Py_RETURN_NONE;
  296. }
  297. shape = *tshp;
  298. } else {
  299. shape = m_tensor->shape();
  300. }
  301. if (!shape.ndim) {
  302. Py_RETURN_NONE;
  303. }
  304. py::tuple ret(shape.ndim);
  305. for (size_t i = 0; i < shape.ndim; ++i) {
  306. ret[i] = shape[i];
  307. }
  308. return ret.release().ptr();
  309. }
  310. PyObject* TensorWrapper::dtype() {
  311. if (m_tensor->m_var) {
  312. return py::cast(m_tensor->m_var->dtype()).release().ptr();
  313. }
  314. return py::cast(m_tensor->dtype()).release().ptr();
  315. }
  316. PyObject* TensorWrapper::device() {
  317. if (m_tensor->m_var) {
  318. return py::cast(m_tensor->m_var->comp_node()).release().ptr();
  319. }
  320. return py::cast(m_tensor->comp_node()).release().ptr();
  321. }
  322. PyObject* TensorWrapper::numpy() {
  323. if (m_tensor->m_trace_info.compiled_info != nullptr) {
  324. PyObject* np_val = PyObject_CallMethod(m_tensor->m_trace_info.compiled_info, "numpy", nullptr);
  325. if (np_val == Py_None) {
  326. throw TraceReadError("value of this tensor is not read in trace");
  327. }
  328. if (m_tensor->m_flags & Tensor::Flags::SCALAR) {
  329. PyObject *np_scalar = PyArray_Squeeze(reinterpret_cast<PyArrayObject*>(np_val));
  330. Py_DECREF(np_val);
  331. return np_scalar;
  332. }
  333. return np_val;
  334. }
  335. if (m_tensor->m_trace_info.recording && !skip_tracing) {
  336. PyObject_SetAttrString(m_tensor->m_trace_info.trace_mixin_info, "value_read", py::cast(true).release().ptr());
  337. }
  338. if (m_tensor->m_handle.get() == nullptr && m_tensor->m_var != nullptr) {
  339. auto&& mgr = m_tensor->m_var->owner_graph()->static_infer_manager();
  340. auto&& type = mgr.get_infer_type(m_tensor->m_var);
  341. using InferType = cg::static_infer::InferType;
  342. if (!(type.value & (InferType::CONST | InferType::RT_STATIC))) {
  343. PyErr_SetString(PyExc_ValueError, "tensor invalid");
  344. return nullptr;
  345. }
  346. auto* val = mgr.infer_value_fallible(m_tensor->m_var);
  347. if (!val) {
  348. PyErr_SetString(PyExc_ValueError, "tensor invalid");
  349. return nullptr;
  350. }
  351. auto np_val = py::cast(*val).attr("numpy")();
  352. if (m_tensor->m_flags & Tensor::Flags::SCALAR) {
  353. return PyArray_Squeeze(reinterpret_cast<PyArrayObject*>(np_val.release().ptr()));
  354. }
  355. return np_val.release().ptr();
  356. }
  357. auto&& hv = [&]() {
  358. py::gil_scoped_release _;
  359. return interpreter_for_py->get_value(m_tensor->m_handle.get());
  360. }();
  361. auto arr = py::reinterpret_steal<py::array>(npy::ndarray_from_tensor(hv, npy::ShareType::TRY_SHARE));
  362. if (!arr) {
  363. PyErr_SetString(PyExc_ValueError, "tensor invalid");
  364. return nullptr;
  365. }
  366. if (m_tensor->m_flags & Tensor::Flags::SCALAR) {
  367. mgb_assert(PyArray_Check(arr.ptr()));
  368. return PyArray_Squeeze(reinterpret_cast<PyArrayObject*>(arr.ptr()));
  369. }
  370. return arr.release().ptr();
  371. }
  372. PyObject* TensorWrapper::varnode() {
  373. if (m_tensor->m_var) {
  374. return py::cast(m_tensor->m_var).release().ptr();
  375. }
  376. Py_RETURN_NONE;
  377. }
  378. void TensorWrapper::reset(PyObject* tensor) {
  379. TensorWrapper* t = TensorWrapper::try_cast(tensor);
  380. if (!t) {
  381. throw py::type_error("expect Tensor");
  382. }
  383. m_tensor = t->m_tensor;
  384. }
  385. void TensorWrapper::reset_varnode() {
  386. m_tensor->m_var = nullptr;
  387. }
  388. PyObject* TensorWrapper::detach() {
  389. PyObject* self = wrap_t::pycast(this);
  390. PyTypeObject* pytype = self->ob_type;
  391. std::shared_ptr<Tensor> new_tensor;
  392. if (m_tensor->m_handle.get()) {
  393. new_tensor = std::make_shared<Tensor>(m_tensor->m_handle);
  394. } else {
  395. new_tensor = std::make_shared<Tensor>(m_tensor->m_var);
  396. }
  397. new_tensor->m_trace_info = m_tensor->m_trace_info;
  398. auto ret = TensorWrapper::make(pytype, std::move(new_tensor));
  399. return ret.release().ptr();
  400. }
  401. PyObject* TensorWrapper::_dev_tensor(){
  402. if (m_tensor->m_trace_info.compiled_info != nullptr) {
  403. auto *dev_tensor = PyObject_CallMethod(m_tensor->m_trace_info.compiled_info, "_dev_tensor", nullptr);
  404. if (dev_tensor == Py_None) {
  405. throw TraceReadError("raw data of this tensor is not read in trace");
  406. }
  407. // set m_handle to make it a real tensor
  408. auto py_dev_tensor = py::reinterpret_borrow<py::object>(dev_tensor);
  409. auto sh = interpreter_for_py->put(py_dev_tensor.cast<DeviceTensorND>());
  410. m_tensor->m_handle = std::move(SharedHandle(sh));
  411. // compiled info is useless after m_handle is set
  412. Py_DECREF(m_tensor->m_trace_info.compiled_info);
  413. m_tensor->m_trace_info.compiled_info = nullptr;
  414. return dev_tensor;
  415. }
  416. if (m_tensor->m_trace_info.recording && !skip_tracing) {
  417. PyObject_SetAttrString(m_tensor->m_trace_info.trace_mixin_info, "data_read", py::cast(true).release().ptr());
  418. }
  419. auto dev_tensor = [&](){
  420. py::gil_scoped_release _;
  421. return interpreter_for_py->get_dev_tensor(m_tensor->m_handle.get());
  422. }();
  423. return py::cast(dev_tensor).release().ptr();
  424. }
  425. void TensorWrapper::_swap_out() {
  426. interpreter_for_py->swap_out(m_tensor->m_handle.get());
  427. }
  428. void TensorWrapper::_swap_in() {
  429. interpreter_for_py->swap_in(m_tensor->m_handle.get());
  430. }
  431. void TensorWrapper::_drop() {
  432. interpreter_for_py->drop(m_tensor->m_handle.get());
  433. }
  434. PyObject* TensorWrapper::isscalar() {
  435. if(m_tensor->m_flags & Tensor::Flags::SCALAR) {
  436. Py_RETURN_TRUE;
  437. } else {
  438. Py_RETURN_FALSE;
  439. }
  440. }
  441. void TensorWrapper::setscalar() {
  442. m_tensor->m_flags |= Tensor::Flags::SCALAR;
  443. }
  444. struct TensorWeakRef {
  445. std::weak_ptr<Tensor> wptr;
  446. TensorWeakRef(const TensorWrapper& tw) : wptr(tw.m_tensor) {}
  447. py::object operator()() {
  448. if (auto p = wptr.lock()) {
  449. return TensorWrapper::make(p);
  450. }
  451. return py::none();
  452. }
  453. int _use_cnt() { return wptr.use_count(); }
  454. };
  455. /* ============== convert inputs ============== */
  456. // map numpy.dtype.kind to priority
  457. inline uint8_t category_priority(char c) {
  458. switch (c) {
  459. case 'f': return 3; // floating-point
  460. case 'i': return 2; // signed integer
  461. case 'u': return 2; // unsigned integer
  462. case 'b': return 1; // boolean
  463. default: return 0;
  464. }
  465. }
  466. // Returns the maximum value of the priority of each type in the list `types`.
  467. uint8_t max_priority(SmallVector<PyArray_Descr*> types) {
  468. if (types.size() == 0) {
  469. return 0;
  470. } else {
  471. uint8_t max_p = 0;
  472. for (auto&& desc: types) {
  473. max_p = std::max(max_p, category_priority(desc->kind));
  474. }
  475. return max_p;
  476. }
  477. }
  478. // Returns the data type with sufficient size to hold all types of
  479. // category `cat` in the list `types`.
  480. PyArray_Descr* promote_types(SmallVector<PyArray_Descr*> types, uint8_t cat) {
  481. // Return value: New reference
  482. SmallVector<PyArray_Descr*> used_types;
  483. for (auto&& desc: types) {
  484. auto&& v = category_priority(desc->kind);
  485. if (v == cat) {
  486. used_types.emplace_back(desc);
  487. }
  488. }
  489. mgb_assert(used_types.size() > 0, "size of used_types is 0");
  490. PyArray_Descr* res = used_types[0];
  491. Py_INCREF(res);
  492. for (size_t i = 1; i < used_types.size(); ++i) {
  493. PyArray_Descr* tmp = PyArray_PromoteTypes(used_types[i], res);
  494. Py_DECREF(res);
  495. res = tmp;
  496. }
  497. return res;
  498. }
  499. PyArray_Descr* scalar2dtype(PyObject* arg) {
  500. // Return value: New reference
  501. if (PyBool_Check(arg)) {
  502. auto&& descr = PyArray_DescrFromType(NPY_BOOL);
  503. return descr;
  504. }
  505. if (PyLong_CheckExact(arg)) {
  506. auto&& descr = PyArray_DescrFromType(NPY_INT32);
  507. return descr;
  508. }
  509. if (PyFloat_CheckExact(arg)) {
  510. auto&& descr = PyArray_DescrFromType(NPY_FLOAT32);
  511. return descr;
  512. }
  513. return nullptr;
  514. }
  515. PyArray_Descr* _dtype_promotion(PyObject*const* args, size_t nargs) {
  516. // Return value: New reference
  517. SmallVector<PyArray_Descr*> tensors;
  518. SmallVector<PyArray_Descr*> scalars;
  519. bool is_tuple = false;
  520. PyObject* tuple = nullptr;
  521. if (nargs == 1 && (PyTuple_Check(args[0]) || PyList_Check(args[0]))) {
  522. if (PyList_Check(args[0])) {
  523. tuple = PyList_AsTuple(args[0]);
  524. } else {
  525. tuple = args[0];
  526. Py_INCREF(tuple);
  527. }
  528. nargs = PyTuple_Size(tuple);
  529. is_tuple = true;
  530. }
  531. for (size_t i = 0; i < nargs; ++i) {
  532. PyObject* handle = is_tuple ? PyTuple_GetItem(tuple, i): args[i];
  533. if (handle == Py_None) continue;
  534. TensorWrapper* tw = TensorWrapper::try_cast(handle);
  535. if (tw) {
  536. mgb::DType type = tw->m_tensor->dtype();
  537. auto&& descr = npy::dtype_mgb2np_descr(type);
  538. Py_INCREF(descr.get());
  539. tensors.emplace_back(descr.get());
  540. }else{
  541. if (PyArray_Check(handle) || PyArray_CheckScalar(handle)) {
  542. auto&& descr = PyArray_DescrFromObject(handle, nullptr);
  543. tensors.emplace_back(descr);
  544. continue;
  545. }
  546. PyArray_Descr* descr = scalar2dtype(handle);
  547. if (descr) {
  548. scalars.emplace_back(descr);
  549. continue;
  550. }
  551. }
  552. }
  553. auto max_pri_scalars = max_priority(scalars);
  554. auto max_pri_tensors = max_priority(tensors);
  555. if (max_pri_scalars <= 0 && max_pri_tensors <= 0) {
  556. throw py::value_error("invalid input, no dtype avaliable");
  557. }
  558. PyArray_Descr* res;
  559. if (max_pri_scalars > max_pri_tensors) {
  560. res = promote_types(scalars, max_pri_scalars);
  561. }else{
  562. res = promote_types(tensors, max_pri_tensors);
  563. }
  564. for (auto *p: tensors) { Py_DECREF(p); }
  565. for (auto *p: scalars) { Py_DECREF(p); }
  566. Py_XDECREF(tuple);
  567. return res;
  568. }
  569. CompNode _get_device(PyObject*const* args, size_t nargs) {
  570. bool is_tuple = false;
  571. PyObject* tuple = nullptr;
  572. if (nargs == 1 && (PyTuple_Check(args[0]) || PyList_Check(args[0]))) {
  573. if (PyList_Check(args[0])) {
  574. tuple = PyList_AsTuple(args[0]);
  575. } else {
  576. tuple = args[0];
  577. Py_INCREF(tuple);
  578. }
  579. nargs = PyTuple_Size(tuple);
  580. is_tuple = true;
  581. }
  582. bool valid = false;
  583. CompNode cn;
  584. for (size_t i = 0; i < nargs; ++i) {
  585. PyObject* handle = is_tuple ? PyTuple_GetItem(tuple, i): args[i];
  586. TensorWrapper* tw = TensorWrapper::try_cast(handle);
  587. if (tw) {
  588. if (!valid) {
  589. cn = tw->m_tensor->comp_node();
  590. valid = true;
  591. } else {
  592. CompNode cn1 = tw->m_tensor->comp_node();
  593. if (cn1 != cn) {
  594. throw py::value_error(ssprintf("ambiguous device: %s vs %s",
  595. cn.to_string().c_str(), cn1.to_string().c_str()));
  596. }
  597. }
  598. }
  599. }
  600. if (!valid) {
  601. mgb_assert(0, "expect at least 1 device");
  602. }
  603. Py_XDECREF(tuple);
  604. return cn;
  605. }
  606. // Returns the dtype that would result from performing an arithmetic
  607. // operation on the provided input tensors and scalars.
  608. PyObject* dtype_promotion(PyObject* self, PyObject*const* args, size_t nargs) {
  609. if (!nargs) {
  610. PyErr_SetString(PyExc_TypeError, "empty input is not allowed");
  611. return nullptr;
  612. }
  613. try {
  614. PyArray_Descr* res = _dtype_promotion(args, nargs);
  615. return py::cast(npy::dtype_np2mgb_descr(res)).release().ptr();
  616. } catch (std::exception& e) {
  617. PyErr_SetString(PyExc_RuntimeError, e.what());
  618. return nullptr;
  619. }
  620. }
  621. PyObject* get_device(PyObject* self, PyObject*const* args, size_t nargs) {
  622. if (!nargs) {
  623. PyErr_SetString(PyExc_TypeError, "empty input is not allowed");
  624. return nullptr;
  625. }
  626. try {
  627. CompNode cn = _get_device(args, nargs);
  628. return py::cast(cn).release().ptr();
  629. } catch (std::exception& e) {
  630. PyErr_SetString(PyExc_RuntimeError, e.what());
  631. return nullptr;
  632. }
  633. }
  634. #ifdef METH_FASTCALL
  635. #define MGE_PY_INTERFACE(NAME, FUNC) \
  636. { #NAME, (PyCFunction)FUNC, METH_FASTCALL, nullptr }
  637. #else
  638. #define WRAP_FUNC_PY35(FUNC) \
  639. PyObject* py35_##FUNC(PyObject* self, PyObject* args) { \
  640. auto* arr = &PyTuple_GET_ITEM(args, 0); \
  641. auto size = PyTuple_GET_SIZE(args); \
  642. return FUNC(self, arr, size); \
  643. }
  644. WRAP_FUNC_PY35(py_apply);
  645. WRAP_FUNC_PY35(dtype_promotion);
  646. WRAP_FUNC_PY35(get_device);
  647. #undef WRAP_FUNC_PY35
  648. #define MGE_PY_INTERFACE(NAME, FUNC) \
  649. { #NAME, (PyCFunction)py35_##FUNC, METH_VARARGS, nullptr }
  650. #endif
  651. void init_tensor(py::module m) {
  652. imperative::Tensor::static_initialize();
  653. static auto sl_interpreter_for_py = interpreter::Interpreter::inst().create_channel();
  654. interpreter_for_py = sl_interpreter_for_py.get();
  655. auto* tensor_type = TensorWrapper::wrap_t::type()
  656. .def<&TensorWrapper::numpy>("numpy")
  657. .def_getset<&TensorWrapper::shape>("shape")
  658. .def_getset<&TensorWrapper::dtype>("dtype")
  659. .def_getset<&TensorWrapper::device>("device")
  660. .def<&TensorWrapper::reset>("_reset")
  661. .def<&TensorWrapper::isscalar>("isscalar")
  662. .def<&TensorWrapper::setscalar>("setscalar")
  663. .def<&TensorWrapper::detach>("detach")
  664. .def<&TensorWrapper::_dev_tensor>("_dev_tensor")
  665. .def<&TensorWrapper::_swap_out>("_swap_out")
  666. .def<&TensorWrapper::_swap_in>("_swap_in")
  667. .def<&TensorWrapper::_drop>("_drop")
  668. .def<&TensorWrapper::reset_varnode>("_reset_varnode")
  669. .def<&TensorWrapper::_use_cnt>("_use_cnt")
  670. .def_getset<&TensorWrapper::varnode>("_varnode")
  671. .def_getset<&TensorWrapper::copied>("_copied")
  672. .def_getset<&TensorWrapper::mixin_handle, &TensorWrapper::set_mixin_handle>("_mixin_handle")
  673. .def_getset<&TensorWrapper::recording, &TensorWrapper::set_recording>("_recording")
  674. .def_getset<&TensorWrapper::handle, &TensorWrapper::set_handle>("_handle")
  675. .def_getset<&TensorWrapper::compiled_info, &TensorWrapper::set_compiled_info>("_compiled_info")
  676. .def_getset<&TensorWrapper::trace_mixin_info, &TensorWrapper::set_trace_mixin_info>("_trace_mixin_info")
  677. .finalize();
  678. if (!tensor_type) throw py::error_already_set();
  679. py::setattr(m, "Tensor", tensor_type);
  680. py::class_<TensorWeakRef>(m, "TensorWeakRef")
  681. .def(py::init<const TensorWrapper&>())
  682. .def("__call__", &TensorWeakRef::operator())
  683. .def("_use_cnt", &TensorWeakRef::_use_cnt);
  684. static PyMethodDef method_defs[] = {
  685. MGE_PY_INTERFACE(apply, py_apply),
  686. MGE_PY_INTERFACE(dtype_promotion, dtype_promotion),
  687. MGE_PY_INTERFACE(get_device, get_device),
  688. {nullptr, nullptr, 0, nullptr}};
  689. for (auto&& def: method_defs) {
  690. if (def.ml_meth != nullptr) {
  691. auto* func = PyCFunction_NewEx(&def, nullptr, nullptr);
  692. if (!func) throw py::error_already_set();
  693. py::setattr(m, def.ml_name, func);
  694. }
  695. }
  696. m.def("_set_swap_flag",
  697. [](bool flag) { interpreter_for_py->set_swap_flag(flag); });
  698. m.def("_set_drop_flag",
  699. [](bool flag) { interpreter_for_py->set_drop_flag(flag); });
  700. m.def("config_async_level",
  701. [](int level) { interpreter_for_py->config_async_level(level); });
  702. m.def("get_async_level",
  703. []() { return interpreter_for_py->get_async_level(); });
  704. m.def("set_buffer_length",
  705. [](int length) { interpreter_for_py->set_buffer_length(length); });
  706. m.def("sync",
  707. []() {
  708. interpreter_for_py->sync();
  709. py_task_q.wait_all_task_finish();
  710. },
  711. py::call_guard<py::gil_scoped_release>());
  712. m.def("full_sync",
  713. []() {
  714. interpreter_for_py->sync();
  715. CompNode::sync_all();
  716. py_task_q.wait_all_task_finish();
  717. },
  718. py::call_guard<py::gil_scoped_release>());
  719. py::handle grad_key_type = GradKeyWrapper::wrap_t::type()
  720. .def<&GradKeyWrapper::attach>("attach")
  721. .def<&GradKeyWrapper::is_attached_to>("is_attached_to")
  722. .def_getset<&GradKeyWrapper::get_name, &GradKeyWrapper::set_name>("name")
  723. .finalize();
  724. if (!grad_key_type) throw py::error_already_set();
  725. py::setattr(m, "GradKey", grad_key_type);
  726. m.def("backward", &GradKeyWrapper::backward);
  727. m.def("set_cpp_apply_with_tracing", &set_cpp_apply_with_tracing);
  728. m.def("set_cpp_apply_const_with_tracing", &set_cpp_apply_const_with_tracing);
  729. m.def("set_cpp_apply_compiled_mode", &set_cpp_apply_compiled_mode);
  730. m.def("set_cpp_apply_const_compiled_mode", &set_cpp_apply_const_compiled_mode);
  731. m.def("set_cpp_apply_backward_varnode", &set_cpp_apply_backward_varnode);
  732. m.attr("skip_tracing") = &skip_tracing;
  733. py::class_<SharedHandle>(m, "SharedHandle")
  734. .def(py::init<const SharedHandle&>())
  735. .def("__eq__", [](SharedHandle &thish, SharedHandle &thath) {
  736. return (thish.get() == thath.get());
  737. })
  738. .def("__hash__", [](SharedHandle &sh) {
  739. return reinterpret_cast<int64_t>(sh.get());
  740. })
  741. ;
  742. m.def("set_tracing", &set_tracing);
  743. m.def("unset_tracing", &unset_tracing);
  744. m.def("set_compiled", &set_compiled);
  745. m.def("unset_compiled", &unset_compiled);
  746. }
  747. #undef MGE_PY_INTERFACE
  748. } // namespace mgb::imperative::python

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台