You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

tensor.cpp 38 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024
  1. /**
  2. * \file imperative/python/src/tensor.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "megbrain/dtype.h"
  12. #include "megbrain/common.h"
  13. #include "megbrain/imperative/ops/utility.h"
  14. #include "megbrain/imperative/ops/backward_graph.h"
  15. #include "megbrain/imperative/ops/autogen.h"
  16. #include "megbrain/imperative/profiler.h"
  17. #include "megbrain/opr/io.h"
  18. #include "./tensor.h"
  19. #include "./grad.h"
  20. #include "./trace.h"
  21. #include "./common.h"
  22. #include "./numpy_dtypes.h"
  23. #include "./graph_rt.h"
  24. #include "./helper.h"
  25. #include <pybind11/numpy.h>
  26. #include <pybind11/operators.h>
  27. #include <range/v3/all.hpp>
  28. #include <string>
  29. #include <unordered_map>
  30. namespace py = pybind11;
  31. namespace views = ranges::views;
  32. namespace mgb::imperative::python {
  33. interpreter::Interpreter::Channel* interpreter_for_py;
  34. PyObject *cpp_apply_with_tracing, *cpp_apply_const_with_tracing;
  35. PyObject *cpp_apply_backward_varnode;
  36. std::shared_ptr<Tensor> make_const(imperative::TensorPtr value) {
  37. if (!(ApplyContext::global_enable & Tensor::Flags::TRACE)) {
  38. return std::make_shared<Tensor>(interpreter_for_py->put(value->dev_tensor()));
  39. }
  40. py::tuple tup(6);
  41. auto data = value->get_value();
  42. tup[0] = py::reinterpret_steal<py::array>(ndarray_from_tensor(data, npy::ShareType::MUST_SHARE));
  43. tup[1] = value->dtype();
  44. tup[2] = value->comp_node();
  45. tup[3] = true;
  46. tup[4] = false;
  47. tup[5] = py::none{};
  48. auto py_ret = PyObject_Call(cpp_apply_const_with_tracing, tup.ptr(), nullptr);
  49. if (!py_ret) throw py::error_already_set();
  50. auto py_list = py::reinterpret_steal<py::list>(py_ret);
  51. auto* tensor_wrapper = TensorWrapper::try_cast(py_list[0].ptr());
  52. auto tensor = tensor_wrapper->m_tensor;
  53. return tensor_wrapper->m_tensor;
  54. }
  55. #define REGISTE_APPLY_FUNC(mode) \
  56. void set_##mode(py::object pyf) { \
  57. mode = pyf.ptr(); \
  58. }
  59. REGISTE_APPLY_FUNC(cpp_apply_with_tracing)
  60. REGISTE_APPLY_FUNC(cpp_apply_const_with_tracing)
  61. REGISTE_APPLY_FUNC(cpp_apply_backward_varnode)
  62. #undef REGISTE_APPLY_FUNC
  63. Tensor::flags_t ApplyContext::global_disable = 0;
  64. Tensor::flags_t ApplyContext::global_enable = 0;
  65. void set_tracing() { ApplyContext::global_enable |= Tensor::Flags::TRACE; }
  66. void unset_tracing() { ApplyContext::global_enable &= ~Tensor::Flags::TRACE; }
  67. bool skip_tracing = false;
  68. apply_result_t apply(ApplyContext& ctx) {
  69. // emulating scalar should be put to specific op's apply, e.g.,
  70. // elementwise, reduce, typecvt. Currently it's still handled at python
  71. // side. It could be move to C++ side if it has an impact on performance
  72. auto flags = ctx.flags & ~ApplyContext::global_disable;
  73. flags = flags | ApplyContext::global_enable;
  74. if (flags & Tensor::Flags::SCALAR) {
  75. // TODO: emulate scalar
  76. }
  77. if (flags & Tensor::Flags::GRAD) {
  78. return apply_grad(ctx);
  79. }
  80. if (auto* op = ctx.op->try_cast_final<GenericPyOp>()) {
  81. py::tuple pyin(ctx.nargs);
  82. for (size_t i = 0; i < ctx.nargs; ++i) {
  83. pyin[i] = TensorWrapper::make(ctx.pytype, ctx.args[i]->shared_from_this());
  84. }
  85. auto f = py::getattr(op->obj, "_default_rule");
  86. auto pyout = py::reinterpret_steal<py::object>(PyObject_Call(f.ptr(), pyin.ptr(), nullptr));
  87. if (!pyout) throw py::error_already_set();
  88. if (auto* tw = TensorWrapper::try_cast(pyout.ptr())) {
  89. return {tw->m_tensor};
  90. }
  91. apply_result_t ret;
  92. ret.reserve(py::len(pyout));
  93. for (auto&& i : pyout) {
  94. auto* tw = TensorWrapper::try_cast(i.ptr());
  95. mgb_assert(tw);
  96. ret.push_back(tw->m_tensor);
  97. }
  98. return ret;
  99. }
  100. if (flags & Tensor::Flags::TRACE) {
  101. return apply_trace(ctx);
  102. } else {
  103. SmallVector<interpreter::Interpreter::Handle> handles(ctx.nargs);
  104. for (size_t i = 0; i < ctx.nargs; ++i) {
  105. handles[i] = ctx.args[i]->m_handle.get();
  106. }
  107. apply_result_t outputs;
  108. // fast copy without really applying
  109. if (ctx.op->same_type<FastpathCopy>()) {
  110. mgb_assert(ctx.nargs == 1);
  111. outputs.reserve(ctx.nargs);
  112. outputs.emplace_back(std::make_shared<Tensor>(ctx.args[0]->m_handle));
  113. return outputs;
  114. }
  115. auto output_handles = interpreter_for_py->apply_op(ctx.op, handles);
  116. outputs.reserve(output_handles.size());
  117. for (auto h : output_handles) {
  118. outputs.emplace_back(std::make_shared<Tensor>(h));
  119. }
  120. return outputs;
  121. }
  122. mgb_assert(0);
  123. }
  124. PyObject* py_apply(PyObject* self, PyObject*const* args, size_t nargs/* , PyObject* kwnames */) {
  125. try {
  126. // if (kwnames && PyTuple_GET_SIZE(kwnames)) {
  127. // PyErr_SetString(PyExc_TypeError, "keyword argument not allowed");
  128. // return nullptr;
  129. // }
  130. if (nargs < 2) {
  131. PyErr_SetString(PyExc_TypeError,
  132. "py_apply expects one Op and at least one tensor "
  133. "as argument");
  134. return nullptr;
  135. }
  136. auto* op = args[0];
  137. PyTypeObject* pytype = args[1]->ob_type;
  138. // check if pytype is Parameter(and all other python Tensor's derived class),
  139. // if yes, using it's tp_base(python Tensor)
  140. if (TensorWrapper::wrap_t::type().same_pytype(pytype->tp_base->tp_base)) {
  141. pytype = pytype->tp_base;
  142. }
  143. ++args;
  144. --nargs;
  145. ApplyContext ctx;
  146. ctx.flags = 0;
  147. ctx.op = py::handle(op).cast<std::shared_ptr<OpDef>>();
  148. SmallVector<Tensor*, 64> tensors(nargs);
  149. ctx.args = &tensors[0];
  150. ctx.nargs = nargs;
  151. ctx.pytype = pytype;
  152. if (py::isinstance<PySymbolVar>(py::handle(args[0]))){
  153. SmallVector<cg::VarNode*> vinputs(nargs);
  154. for (size_t i = 0; i < nargs; ++i) {
  155. vinputs[i] = py::handle(args[i]).cast<PySymbolVar*>()->m_node;
  156. }
  157. auto op = ctx.op.get();
  158. auto rst = OpDef::apply_on_var_node(*op, vinputs);
  159. auto ret = pybind11::tuple(rst.size());
  160. auto typeobj = py::handle(args[0]).get_type();
  161. for (size_t i = 0; i<rst.size(); ++i) {
  162. ret[i] = typeobj(pybind11::cast(rst[i], pybind11::return_value_policy::automatic));
  163. }
  164. return ret.release().ptr();
  165. }
  166. for (size_t i = 0; i < nargs; ++i) {
  167. if (TensorWrapper* tw = TensorWrapper::try_cast(args[i])) {
  168. auto* t = tensors[i] = tw->m_tensor.get();
  169. ctx.flags |= t->m_flags;
  170. } else {
  171. PyErr_SetString(PyExc_TypeError,
  172. ssprintf("op %s expect type Tensor as inputs, got %s actually",
  173. ctx.op->make_name().c_str(), Py_TYPE(args[i])->tp_name).c_str());
  174. return nullptr;
  175. }
  176. }
  177. auto outputs = apply(ctx);
  178. size_t nout = outputs.size();
  179. auto ret = py::tuple(nout);
  180. for (size_t i = 0; i < nout; ++i) {
  181. ret[i] = TensorWrapper::make(pytype, std::move(outputs[i]));
  182. }
  183. return ret.release().ptr();
  184. } catch (std::exception& e) {
  185. PyErr_SetString(PyExc_RuntimeError, e.what());
  186. return nullptr;
  187. }
  188. }
  189. TensorWrapper::TensorWrapper(PyObject* args, PyObject* kwargs) {
  190. if (kwargs && PyDict_Size(kwargs)) {
  191. throw py::type_error("keyword argument not allowed");
  192. }
  193. auto nargs = PyTuple_Size(args);
  194. auto tup = py::reinterpret_borrow<py::tuple>(args);
  195. if (nargs == 0) {
  196. throw py::type_error("too few arguments");
  197. }
  198. if (auto* t = try_cast(tup[0].ptr())) {
  199. if (nargs > 1) {
  200. throw py::type_error("expect 1 argument");
  201. }
  202. m_tensor = t->m_tensor;
  203. } else {
  204. if (nargs == 1) {
  205. auto arg0 = PyTuple_GetItem(args, 0);
  206. // for lazy_eval_tensor
  207. if (strstr(arg0->ob_type->tp_name, "VarNode")) {
  208. if (PyObject_HasAttrString(arg0, "_node")) {
  209. arg0 = PyObject_GetAttrString(arg0, "_node");
  210. }
  211. m_tensor = std::make_shared<Tensor>(py::handle(arg0).cast<cg::VarNode *>());
  212. } else {
  213. // for DeviceTensorND
  214. if (strstr(arg0->ob_type->tp_name, "DeviceTensorND")) {
  215. auto dv = py::handle(arg0).cast<DeviceTensorND>();
  216. interpreter::Interpreter::Handle handle = interpreter_for_py->put(dv);
  217. m_tensor = std::make_shared<Tensor>(handle);
  218. } else {
  219. throw py::type_error("single argument is not tensor, varnode or devicetensor");
  220. }
  221. }
  222. } else {
  223. py::detail::loader_life_support life_sup; // FIXME!!!required to cast DType
  224. if (nargs != 5 && nargs != 6) {
  225. throw py::type_error("expect 5 or 6 arguments");
  226. }
  227. auto data = tup[0].cast<py::array>();
  228. DType dtype = tup[1].cast<DType>();
  229. CompNode cn = tup[2].cast<CompNode>();
  230. bool is_const = tup[3].cast<bool>();
  231. bool no_cache = nargs == 6 ? tup[4].cast<bool>() : false;
  232. std::string name;
  233. if (tup[nargs - 1].ptr() != Py_None) name = tup[nargs - 1].cast<std::string>();
  234. // const op
  235. if (is_const && (ApplyContext::global_enable == Tensor::Flags::TRACE)) {
  236. auto py_ret = PyObject_Call(cpp_apply_const_with_tracing, tup.ptr(), nullptr);
  237. if (!py_ret) throw py::error_already_set();
  238. auto py_list = py::reinterpret_steal<py::list>(py_ret);
  239. if (auto* t = try_cast(py_list[0].ptr())) {
  240. m_tensor = t->m_tensor;
  241. }
  242. return;
  243. }
  244. interpreter::Interpreter::Handle handle;
  245. {
  246. HostTensorND ret(cn);
  247. handle = interpreter_for_py->put(npy::np2tensor(data.ptr(), npy::Meth::copy_into(&ret), dtype), no_cache);
  248. }
  249. m_tensor = std::make_shared<Tensor>(handle);
  250. m_tensor->user_custom_name = name;
  251. if (data.ndim() == 0) {
  252. m_tensor->m_flags |= Tensor::Flags::SCALAR;
  253. }
  254. }
  255. }
  256. }
  257. #define REGISTE_TENSORWRAPPER_FUNC(type, member) \
  258. PyObject* TensorWrapper::member() { \
  259. return py::cast(m_tensor->m_trace_info.member).release().ptr(); \
  260. } \
  261. void TensorWrapper::set_##member(PyObject* dest) { \
  262. auto py_dest = py::reinterpret_borrow<py::object>(dest); \
  263. type real_dest = py_dest.cast<type>(); \
  264. m_tensor->m_trace_info.member = real_dest; \
  265. }
  266. REGISTE_TENSORWRAPPER_FUNC(int64_t, mixin_handle)
  267. REGISTE_TENSORWRAPPER_FUNC(bool, recording)
  268. #undef REGISTE_TENSORWRAPPER_FUNC
  269. #define REGISTE_TENSORWRAPPER_PYOBJECT_FUNC(member) \
  270. PyObject* TensorWrapper::member() { \
  271. if (m_tensor->m_trace_info.member) { \
  272. return m_tensor->m_trace_info.member; \
  273. } else { \
  274. Py_RETURN_NONE; \
  275. } \
  276. } \
  277. void TensorWrapper::set_##member(PyObject* dest) { \
  278. if (dest == Py_None) { \
  279. Py_XDECREF(m_tensor->m_trace_info.member); \
  280. m_tensor->m_trace_info.member = nullptr; \
  281. } else { \
  282. Py_INCREF(dest); \
  283. m_tensor->m_trace_info.member = dest; \
  284. } \
  285. }
  286. REGISTE_TENSORWRAPPER_PYOBJECT_FUNC(compiled_info)
  287. REGISTE_TENSORWRAPPER_PYOBJECT_FUNC(trace_mixin_info)
  288. #undef REGISTE_TENSORWRAPPER_PYOBJECT_FUNC
  289. #define SET_GET_NAME(member) \
  290. PyObject* TensorWrapper::member() { \
  291. return py::cast(m_tensor->member).release().ptr(); \
  292. } \
  293. void TensorWrapper::set_##member(PyObject* dest) { \
  294. auto py_dest = py::reinterpret_borrow<py::object>(dest); \
  295. m_tensor->member = py_dest.cast<std::string>(); \
  296. }
  297. SET_GET_NAME(user_custom_name)
  298. SET_GET_NAME(automatic_name)
  299. #undef SET_GET_NAME
  300. PyObject* TensorWrapper::handle() {
  301. return py::cast(m_tensor->m_handle).release().ptr();
  302. }
  303. void TensorWrapper::set_handle(PyObject* dest) {
  304. auto py_dest = py::reinterpret_borrow<py::object>(dest);
  305. SharedHandle real_dest = py_dest.cast<SharedHandle>();
  306. m_tensor->m_handle = std::move(real_dest);
  307. }
  308. PyObject* TensorWrapper::shape() {
  309. // if it's tracing compiled mode, get value from compiled_info
  310. if (m_tensor->m_trace_info.compiled_info != nullptr) {
  311. if (m_tensor->m_flags & Tensor::Flags::SCALAR) {
  312. return PyTuple_New(0);
  313. }
  314. PyObject *shp = PyObject_GetAttrString(m_tensor->m_trace_info.compiled_info, "shape");
  315. if (shp == Py_None) {
  316. throw TraceReadError("shape of this tensor is not read in trace");
  317. }
  318. return shp;
  319. }
  320. // inside trace, if tensor shape is useful for other operations, set shape_read = true
  321. if (m_tensor->m_trace_info.recording && !skip_tracing) {
  322. PyObject_SetAttrString(m_tensor->m_trace_info.trace_mixin_info, "shape_read", py::cast(true).release().ptr());
  323. }
  324. if (m_tensor->m_flags & Tensor::Flags::SCALAR) {
  325. return PyTuple_New(0);
  326. }
  327. TensorShape shape;
  328. if (m_tensor->m_var) { // get shape from m_var
  329. auto&& mgr = m_tensor->m_var->owner_graph()->static_infer_manager();
  330. auto&& type = mgr.get_infer_type(m_tensor->m_var);
  331. using InferType = cg::static_infer::InferType;
  332. if (!(type.shape & (InferType::CONST | InferType::RT_STATIC))) {
  333. Py_RETURN_NONE;
  334. }
  335. auto *tshp = mgr.infer_shape_fallible(m_tensor->m_var);
  336. if (!tshp) {
  337. Py_RETURN_NONE;
  338. }
  339. shape = *tshp;
  340. } else {
  341. py::gil_scoped_release _;
  342. shape = m_tensor->shape();
  343. }
  344. if (!shape.ndim) {
  345. Py_RETURN_NONE;
  346. }
  347. py::tuple ret(shape.ndim);
  348. for (size_t i = 0; i < shape.ndim; ++i) {
  349. ret[i] = shape[i];
  350. }
  351. return ret.release().ptr();
  352. }
  353. PyObject* TensorWrapper::dtype() {
  354. if (m_tensor->m_var) {
  355. return py::cast(m_tensor->m_var->dtype()).release().ptr();
  356. }
  357. return py::cast(m_tensor->dtype()).release().ptr();
  358. }
  359. PyObject* TensorWrapper::device() {
  360. if (m_tensor->m_var) {
  361. return py::cast(m_tensor->m_var->comp_node()).release().ptr();
  362. }
  363. return py::cast(m_tensor->comp_node()).release().ptr();
  364. }
  365. PyObject* TensorWrapper::numpy() {
  366. if (m_tensor->m_trace_info.compiled_info != nullptr) {
  367. PyObject* np_val = PyObject_CallMethod(m_tensor->m_trace_info.compiled_info, "numpy", nullptr);
  368. if (!np_val) throw py::error_already_set();
  369. if (np_val == Py_None) {
  370. throw TraceReadError("value of this tensor is not read in trace");
  371. }
  372. if (m_tensor->m_flags & Tensor::Flags::SCALAR) {
  373. PyObject *np_scalar = PyArray_Squeeze(reinterpret_cast<PyArrayObject*>(np_val));
  374. Py_DECREF(np_val);
  375. return np_scalar;
  376. }
  377. return np_val;
  378. }
  379. if (m_tensor->m_trace_info.recording && !skip_tracing) {
  380. PyObject_SetAttrString(m_tensor->m_trace_info.trace_mixin_info, "value_read", py::cast(true).release().ptr());
  381. }
  382. if (m_tensor->m_handle.get() == nullptr && m_tensor->m_var != nullptr) {
  383. auto&& mgr = m_tensor->m_var->owner_graph()->static_infer_manager();
  384. auto&& type = mgr.get_infer_type(m_tensor->m_var);
  385. using InferType = cg::static_infer::InferType;
  386. if (!(type.value & (InferType::CONST | InferType::RT_STATIC))) {
  387. PyErr_SetString(PyExc_ValueError, "tensor invalid");
  388. return nullptr;
  389. }
  390. auto* val = mgr.infer_value_fallible(m_tensor->m_var);
  391. if (!val) {
  392. PyErr_SetString(PyExc_ValueError, "tensor invalid");
  393. return nullptr;
  394. }
  395. auto np_val = py::cast(*val).attr("numpy")();
  396. if (m_tensor->m_flags & Tensor::Flags::SCALAR) {
  397. return PyArray_Squeeze(reinterpret_cast<PyArrayObject*>(np_val.release().ptr()));
  398. }
  399. return np_val.release().ptr();
  400. }
  401. auto&& hv = [&]() {
  402. py::gil_scoped_release _;
  403. return interpreter_for_py->get_value(m_tensor->m_handle.get());
  404. }();
  405. auto arr = py::reinterpret_steal<py::array>(npy::ndarray_from_tensor(hv, npy::ShareType::TRY_SHARE));
  406. if (!arr) {
  407. PyErr_SetString(PyExc_ValueError, "tensor invalid");
  408. return nullptr;
  409. }
  410. if (m_tensor->m_flags & Tensor::Flags::SCALAR) {
  411. mgb_assert(PyArray_Check(arr.ptr()));
  412. return PyArray_Squeeze(reinterpret_cast<PyArrayObject*>(arr.ptr()));
  413. }
  414. return arr.release().ptr();
  415. }
  416. PyObject* TensorWrapper::varnode() {
  417. if (m_tensor->m_var) {
  418. return py::cast(m_tensor->m_var).release().ptr();
  419. }
  420. Py_RETURN_NONE;
  421. }
  422. void TensorWrapper::reset(PyObject* tensor) {
  423. TensorWrapper* t = TensorWrapper::try_cast(tensor);
  424. if (!t) {
  425. throw py::type_error("expect Tensor");
  426. }
  427. std::string user_custom_name = m_tensor->user_custom_name;
  428. std::string automatic_name = m_tensor->automatic_name;
  429. m_tensor = t->m_tensor;
  430. m_tensor->user_custom_name = user_custom_name;
  431. m_tensor->automatic_name = automatic_name;
  432. }
  433. void TensorWrapper::reset_varnode() {
  434. m_tensor->m_var = nullptr;
  435. }
  436. PyObject* TensorWrapper::detach() {
  437. PyObject* self = wrap_t::pycast(this);
  438. PyTypeObject* pytype = self->ob_type;
  439. std::shared_ptr<Tensor> new_tensor;
  440. if (m_tensor->m_handle.get()) {
  441. new_tensor = std::make_shared<Tensor>(m_tensor->m_handle);
  442. } else {
  443. new_tensor = std::make_shared<Tensor>(m_tensor->m_var);
  444. }
  445. new_tensor->m_trace_info = m_tensor->m_trace_info;
  446. new_tensor->m_flags = m_tensor->m_flags;
  447. auto ret = TensorWrapper::make(pytype, std::move(new_tensor));
  448. return ret.release().ptr();
  449. }
  450. PyObject* TensorWrapper::_dev_tensor(){
  451. if (m_tensor->m_trace_info.compiled_info != nullptr) {
  452. auto *dev_tensor = PyObject_CallMethod(m_tensor->m_trace_info.compiled_info, "_dev_tensor", nullptr);
  453. if (!dev_tensor) throw py::error_already_set();
  454. if (dev_tensor == Py_None) {
  455. throw TraceReadError("raw data of this tensor is not read in trace");
  456. }
  457. // set m_handle to make it a real tensor
  458. auto py_dev_tensor = py::reinterpret_borrow<py::object>(dev_tensor);
  459. auto sh = interpreter_for_py->put(py_dev_tensor.cast<DeviceTensorND>());
  460. m_tensor->m_handle = std::move(SharedHandle(sh));
  461. // compiled info is useless after m_handle is set
  462. Py_DECREF(m_tensor->m_trace_info.compiled_info);
  463. m_tensor->m_trace_info.compiled_info = nullptr;
  464. return dev_tensor;
  465. }
  466. if (m_tensor->m_trace_info.recording && !skip_tracing) {
  467. PyObject_SetAttrString(m_tensor->m_trace_info.trace_mixin_info, "data_read", py::cast(true).release().ptr());
  468. }
  469. auto dev_tensor = [&](){
  470. py::gil_scoped_release _;
  471. return interpreter_for_py->get_dev_tensor(m_tensor->m_handle.get());
  472. }();
  473. return py::cast(dev_tensor).release().ptr();
  474. }
  475. void TensorWrapper::_swap_out() {
  476. interpreter_for_py->swap_out(m_tensor->m_handle.get());
  477. }
  478. void TensorWrapper::_swap_in() {
  479. interpreter_for_py->swap_in(m_tensor->m_handle.get());
  480. }
  481. void TensorWrapper::_drop() {
  482. interpreter_for_py->drop(m_tensor->m_handle.get());
  483. }
  484. PyObject* TensorWrapper::isscalar() {
  485. if(m_tensor->m_flags & Tensor::Flags::SCALAR) {
  486. Py_RETURN_TRUE;
  487. } else {
  488. Py_RETURN_FALSE;
  489. }
  490. }
  491. void TensorWrapper::setscalar() {
  492. m_tensor->m_flags |= Tensor::Flags::SCALAR;
  493. }
  494. void TensorWrapper::unsetscalar() {
  495. m_tensor->m_flags &= ~Tensor::Flags::SCALAR;
  496. }
  497. struct TensorWeakRef {
  498. std::weak_ptr<Tensor> wptr;
  499. TensorWeakRef(const TensorWrapper& tw) : wptr(tw.m_tensor) {}
  500. py::object operator()() {
  501. if (auto p = wptr.lock()) {
  502. return TensorWrapper::make(p);
  503. }
  504. return py::none();
  505. }
  506. int _use_cnt() { return wptr.use_count(); }
  507. };
  508. /* ============== convert inputs ============== */
  509. // map numpy.dtype.kind to priority
  510. inline uint8_t category_priority(char c) {
  511. switch (c) {
  512. case 'f': return 3; // floating-point
  513. case 'i': return 2; // signed integer
  514. case 'u': return 2; // unsigned integer
  515. case 'b': return 1; // boolean
  516. default: return 0;
  517. }
  518. }
  519. // Returns the maximum value of the priority of each type in the list `types`.
  520. uint8_t max_priority(SmallVector<PyArray_Descr*> types) {
  521. if (types.size() == 0) {
  522. return 0;
  523. } else {
  524. uint8_t max_p = 0;
  525. for (auto&& desc: types) {
  526. max_p = std::max(max_p, category_priority(desc->kind));
  527. }
  528. return max_p;
  529. }
  530. }
  531. // Returns the data type with sufficient size to hold all types of
  532. // category `cat` in the list `types`.
  533. PyArray_Descr* promote_types(SmallVector<PyArray_Descr*> types, uint8_t cat) {
  534. // Return value: New reference
  535. SmallVector<PyArray_Descr*> used_types;
  536. for (auto&& desc: types) {
  537. auto&& v = category_priority(desc->kind);
  538. if (v == cat) {
  539. used_types.emplace_back(desc);
  540. }
  541. }
  542. mgb_assert(used_types.size() > 0, "size of used_types is 0");
  543. PyArray_Descr* res = used_types[0];
  544. Py_INCREF(res);
  545. for (size_t i = 1; i < used_types.size(); ++i) {
  546. PyArray_Descr* tmp = PyArray_PromoteTypes(used_types[i], res);
  547. Py_DECREF(res);
  548. res = tmp;
  549. }
  550. return res;
  551. }
  552. PyArray_Descr* scalar2dtype(PyObject* arg) {
  553. // Return value: New reference
  554. if (PyBool_Check(arg)) {
  555. auto&& descr = PyArray_DescrFromType(NPY_BOOL);
  556. return descr;
  557. }
  558. if (PyLong_CheckExact(arg)) {
  559. auto&& descr = PyArray_DescrFromType(NPY_INT32);
  560. return descr;
  561. }
  562. if (PyFloat_CheckExact(arg)) {
  563. auto&& descr = PyArray_DescrFromType(NPY_FLOAT32);
  564. return descr;
  565. }
  566. return nullptr;
  567. }
  568. PyArray_Descr* _dtype_promotion(PyObject*const* args, size_t nargs) {
  569. // Return value: New reference
  570. SmallVector<PyArray_Descr*> tensors;
  571. SmallVector<PyArray_Descr*> scalars;
  572. bool is_tuple = false;
  573. PyObject* tuple = nullptr;
  574. if (nargs == 1 && (PyTuple_Check(args[0]) || PyList_Check(args[0]))) {
  575. if (PyList_Check(args[0])) {
  576. tuple = PyList_AsTuple(args[0]);
  577. } else {
  578. tuple = args[0];
  579. Py_INCREF(tuple);
  580. }
  581. nargs = PyTuple_Size(tuple);
  582. is_tuple = true;
  583. }
  584. for (size_t i = 0; i < nargs; ++i) {
  585. PyObject* handle = is_tuple ? PyTuple_GetItem(tuple, i): args[i];
  586. if (handle == Py_None) continue;
  587. TensorWrapper* tw = TensorWrapper::try_cast(handle);
  588. if (tw) {
  589. mgb::DType type = tw->m_tensor->dtype();
  590. auto&& descr = npy::dtype_mgb2np_descr(type);
  591. Py_INCREF(descr.get());
  592. tensors.emplace_back(descr.get());
  593. }else{
  594. if (PyArray_Check(handle) || PyArray_CheckScalar(handle)) {
  595. auto&& descr = PyArray_DescrFromObject(handle, nullptr);
  596. tensors.emplace_back(descr);
  597. continue;
  598. }
  599. if (py::isinstance<PySymbolVar>(py::handle(handle))){
  600. auto var = py::handle(handle).cast<PySymbolVar*>();
  601. mgb::DType type = var->m_node->dtype();
  602. auto && descr = npy::dtype_mgb2np_descr(type);
  603. Py_INCREF(descr.get());
  604. tensors.emplace_back(descr.get());
  605. continue;
  606. }
  607. PyArray_Descr* descr = scalar2dtype(handle);
  608. if (descr) {
  609. scalars.emplace_back(descr);
  610. continue;
  611. }
  612. }
  613. }
  614. auto max_pri_scalars = max_priority(scalars);
  615. auto max_pri_tensors = max_priority(tensors);
  616. if (max_pri_scalars <= 0 && max_pri_tensors <= 0) {
  617. throw py::value_error("invalid input, no dtype avaliable");
  618. }
  619. PyArray_Descr* res;
  620. if (max_pri_scalars > max_pri_tensors) {
  621. res = promote_types(scalars, max_pri_scalars);
  622. }else{
  623. res = promote_types(tensors, max_pri_tensors);
  624. }
  625. for (auto *p: tensors) { Py_DECREF(p); }
  626. for (auto *p: scalars) { Py_DECREF(p); }
  627. Py_XDECREF(tuple);
  628. return res;
  629. }
  630. CompNode _get_device(PyObject*const* args, size_t nargs) {
  631. bool is_tuple = false;
  632. PyObject* tuple = nullptr;
  633. if (nargs == 1 && (PyTuple_Check(args[0]) || PyList_Check(args[0]))) {
  634. if (PyList_Check(args[0])) {
  635. tuple = PyList_AsTuple(args[0]);
  636. } else {
  637. tuple = args[0];
  638. Py_INCREF(tuple);
  639. }
  640. nargs = PyTuple_Size(tuple);
  641. is_tuple = true;
  642. }
  643. bool valid = false;
  644. CompNode cn;
  645. for (size_t i = 0; i < nargs; ++i) {
  646. PyObject* handle = is_tuple ? PyTuple_GetItem(tuple, i) : args[i];
  647. TensorWrapper* tw = TensorWrapper::try_cast(handle);
  648. bool is_symvar = py::isinstance<PySymbolVar>(py::handle(handle));
  649. if (tw || is_symvar) {
  650. if (!valid) {
  651. cn = tw ? tw->m_tensor->comp_node()
  652. : py::handle(handle)
  653. .cast<PySymbolVar*>()
  654. ->m_node->comp_node();
  655. valid = true;
  656. } else {
  657. CompNode cn1 = tw ? tw->m_tensor->comp_node()
  658. : py::handle(handle)
  659. .cast<PySymbolVar*>()
  660. ->m_node->comp_node();
  661. if (cn1 != cn) {
  662. throw py::value_error(ssprintf("ambiguous device: %s vs %s",
  663. cn.to_string().c_str(),
  664. cn1.to_string().c_str()));
  665. }
  666. }
  667. }
  668. }
  669. if (!valid) {
  670. return CompNode::load(get_default_device());
  671. }
  672. Py_XDECREF(tuple);
  673. return cn;
  674. }
  675. // Returns the dtype that would result from performing an arithmetic
  676. // operation on the provided input tensors and scalars.
  677. PyObject* dtype_promotion(PyObject* self, PyObject*const* args, size_t nargs) {
  678. if (!nargs) {
  679. PyErr_SetString(PyExc_TypeError, "empty input is not allowed");
  680. return nullptr;
  681. }
  682. try {
  683. PyArray_Descr* res = _dtype_promotion(args, nargs);
  684. return py::cast(npy::dtype_np2mgb_descr(res)).release().ptr();
  685. } catch (std::exception& e) {
  686. PyErr_SetString(PyExc_RuntimeError, e.what());
  687. return nullptr;
  688. }
  689. }
  690. PyObject* get_device(PyObject* self, PyObject*const* args, size_t nargs) {
  691. if (!nargs) {
  692. PyErr_SetString(PyExc_TypeError, "empty input is not allowed");
  693. return nullptr;
  694. }
  695. try {
  696. CompNode cn = _get_device(args, nargs);
  697. return py::cast(cn).release().ptr();
  698. } catch (std::exception& e) {
  699. PyErr_SetString(PyExc_RuntimeError, e.what());
  700. return nullptr;
  701. }
  702. }
  703. #ifdef METH_FASTCALL
  704. #define MGE_PY_INTERFACE(NAME, FUNC) \
  705. { #NAME, (PyCFunction)FUNC, METH_FASTCALL, nullptr }
  706. #else
  707. #define WRAP_FUNC_PY35(FUNC) \
  708. PyObject* py35_##FUNC(PyObject* self, PyObject* args) { \
  709. auto* arr = &PyTuple_GET_ITEM(args, 0); \
  710. auto size = PyTuple_GET_SIZE(args); \
  711. return FUNC(self, arr, size); \
  712. }
  713. WRAP_FUNC_PY35(py_apply);
  714. WRAP_FUNC_PY35(dtype_promotion);
  715. WRAP_FUNC_PY35(get_device);
  716. #undef WRAP_FUNC_PY35
  717. #define MGE_PY_INTERFACE(NAME, FUNC) \
  718. { #NAME, (PyCFunction)py35_##FUNC, METH_VARARGS, nullptr }
  719. #endif
  720. void init_tensor(py::module m) {
  721. imperative::Tensor::static_initialize();
  722. static auto sl_interpreter_for_py = interpreter::Interpreter::inst().create_channel();
  723. interpreter_for_py = sl_interpreter_for_py.get();
  724. auto* tensor_type = TensorWrapper::wrap_t::type()
  725. .def<&TensorWrapper::numpy>("numpy")
  726. .def_getset<&TensorWrapper::shape>("shape")
  727. .def_getset<&TensorWrapper::dtype>("dtype")
  728. .def_getset<&TensorWrapper::device>("device")
  729. .def<&TensorWrapper::reset>("_reset")
  730. .def<&TensorWrapper::isscalar>("_isscalar")
  731. .def<&TensorWrapper::setscalar>("_setscalar")
  732. .def<&TensorWrapper::unsetscalar>("_unsetscalar")
  733. .def<&TensorWrapper::detach>("detach")
  734. .def<&TensorWrapper::_dev_tensor>("_dev_tensor")
  735. .def<&TensorWrapper::_swap_out>("_swap_out")
  736. .def<&TensorWrapper::_swap_in>("_swap_in")
  737. .def<&TensorWrapper::_drop>("_drop")
  738. .def<&TensorWrapper::reset_varnode>("_reset_varnode")
  739. .def<&TensorWrapper::_use_cnt>("_use_cnt")
  740. .def_getset<&TensorWrapper::varnode>("_varnode")
  741. .def_getset<&TensorWrapper::mixin_handle, &TensorWrapper::set_mixin_handle>("_mixin_handle")
  742. .def_getset<&TensorWrapper::recording, &TensorWrapper::set_recording>("_recording")
  743. .def_getset<&TensorWrapper::handle, &TensorWrapper::set_handle>("_handle")
  744. .def_getset<&TensorWrapper::compiled_info, &TensorWrapper::set_compiled_info>("_compiled_info")
  745. .def_getset<&TensorWrapper::trace_mixin_info, &TensorWrapper::set_trace_mixin_info>("_trace_mixin_info")
  746. .def_getset<&TensorWrapper::user_custom_name, &TensorWrapper::set_user_custom_name>("c_name")
  747. .def_getset<&TensorWrapper::automatic_name, &TensorWrapper::set_automatic_name>("_name")
  748. .finalize();
  749. if (!tensor_type) throw py::error_already_set();
  750. py::setattr(m, "Tensor", tensor_type);
  751. py::class_<TensorWeakRef>(m, "TensorWeakRef")
  752. .def(py::init<const TensorWrapper&>())
  753. .def("__call__", &TensorWeakRef::operator())
  754. .def("_use_cnt", &TensorWeakRef::_use_cnt);
  755. py::class_<PySymbolVar, std::shared_ptr<PySymbolVar>>(m, "SymbolVar")
  756. .def_property_readonly(
  757. "dtype", [](PySymbolVar* v) { return v->m_node->dtype(); })
  758. .def_property("var", [](PySymbolVar* v) { return v->m_node; },
  759. [](PySymbolVar* s, cg::VarNode* v) { s->m_node = v; })
  760. .def_property_readonly(
  761. "device",
  762. [](PySymbolVar* v) { return v->m_node->comp_node(); })
  763. .def_property_readonly(
  764. "graph",
  765. [](PySymbolVar* v) { return v->m_node->owner_graph(); })
  766. .def_property_readonly(
  767. "shape",
  768. [](PySymbolVar* v) -> const TensorShape* {
  769. auto&& mgr = v->m_node->owner_graph()
  770. ->static_infer_manager();
  771. return mgr.infer_shape_fallible(v->m_node);
  772. })
  773. .def("numpy", [](PySymbolVar* v){
  774. auto&& mgr = v->m_node->owner_graph()->static_infer_manager();
  775. auto&& type = mgr.get_infer_type(v->m_node);
  776. using InferType = cg::static_infer::InferType;
  777. if (!(type.value & (InferType::CONST | InferType::RT_STATIC))) {
  778. throw py::value_error("value invalid!");
  779. }
  780. auto* val = mgr.infer_value_fallible(v->m_node);
  781. if (!val) {
  782. throw py::value_error("value invalid!");
  783. }
  784. auto np_val = py::cast(*val).attr("numpy")();
  785. if (v->is_scalar) {
  786. return py::object(py::array(np_val).squeeze());
  787. }
  788. return np_val;
  789. })
  790. .def("_isscalar", [](PySymbolVar* v) { return v->is_scalar; })
  791. .def("_setscalar",
  792. [](PySymbolVar* v) { return v->is_scalar = true; })
  793. .def(py::init([](cg::VarNode* node) {
  794. return std::make_shared<PySymbolVar>(node);
  795. }),
  796. py::arg() = nullptr);
  797. static PyMethodDef method_defs[] = {
  798. MGE_PY_INTERFACE(apply, py_apply),
  799. MGE_PY_INTERFACE(dtype_promotion, dtype_promotion),
  800. MGE_PY_INTERFACE(get_device, get_device),
  801. {nullptr, nullptr, 0, nullptr}};
  802. for (auto&& def: method_defs) {
  803. if (def.ml_meth != nullptr) {
  804. auto* func = PyCFunction_NewEx(&def, nullptr, nullptr);
  805. if (!func) throw py::error_already_set();
  806. py::setattr(m, def.ml_name, func);
  807. }
  808. }
  809. static constexpr auto sync_py_task_q = []{
  810. py_task_q.wait_all_task_finish();
  811. };
  812. m.def("set_option",
  813. [](std::string name, size_t value){ interpreter_for_py->set_option(name, value); });
  814. m.def("get_option",
  815. [](std::string name){ return interpreter_for_py->get_option(name); });
  816. m.def("_set_swap_flag",
  817. [](bool flag) { interpreter_for_py->set_option("enable_swap", flag); });
  818. m.def("_set_drop_flag",
  819. [](bool flag) { interpreter_for_py->set_option("enable_drop", flag); });
  820. m.def("config_async_level",
  821. [](int level) {
  822. mgb_assert(level >= 0 and level <= 2, "async_level should be 0, 1 or 2");
  823. interpreter_for_py->set_option("async_level", level);
  824. });
  825. m.def("get_async_level",
  826. []() { return interpreter_for_py->get_option("async_level"); });
  827. m.def("set_buffer_length",
  828. [](int length) {
  829. mgb_assert(length >= 0 and length < 100, "buffer_length should be in [0, 100)");
  830. interpreter_for_py->set_option("buffer_length", length);
  831. });
  832. m.def("push_scope",
  833. [](std::string name) { interpreter_for_py->push_scope(name); });
  834. m.def("pop_scope",
  835. [](std::string name) { interpreter_for_py->pop_scope(name); });
  836. m.def("start_profile",
  837. [](imperative::Profiler::options_t options) {
  838. interpreter_for_py->sync();
  839. imperative::Profiler::load_options(std::move(options));
  840. imperative::Profiler::start_profile();
  841. interpreter_for_py->start_profile();
  842. }, py::call_guard<py::gil_scoped_release>());
  843. m.def("stop_profile",
  844. []() -> std::function<void(std::string, std::string)> {
  845. interpreter_for_py->stop_profile();
  846. interpreter_for_py->sync();
  847. imperative::Profiler::stop_profile();
  848. auto results = imperative::Profiler::collect();
  849. auto options = imperative::Profiler::get_options();
  850. return [results=std::move(results), options=std::move(options)](std::string basename, std::string format){
  851. imperative::Profiler::dump_profile(basename, format, results, options);
  852. };
  853. }, py::call_guard<py::gil_scoped_release>());
  854. m.def("sync",
  855. []() {
  856. interpreter_for_py->sync();
  857. sync_py_task_q();
  858. }, py::call_guard<py::gil_scoped_release>());
  859. m.def("full_sync",
  860. []() {
  861. interpreter_for_py->sync();
  862. CompNode::sync_all();
  863. sync_py_task_q();
  864. }, py::call_guard<py::gil_scoped_release>());
  865. m.def("close",
  866. []() {
  867. interpreter_for_py->close();
  868. sync_py_task_q();
  869. }, py::call_guard<py::gil_scoped_release>());
  870. py::handle grad_key_type = GradKeyWrapper::wrap_t::type()
  871. .def<&GradKeyWrapper::attach>("attach")
  872. .def<&GradKeyWrapper::is_attached_to>("is_attached_to")
  873. .def_getset<&GradKeyWrapper::get_name, &GradKeyWrapper::set_name>("name")
  874. .def_getset<&GradKeyWrapper::get_priority, &GradKeyWrapper::set_priority>("priority")
  875. .finalize();
  876. if (!grad_key_type) throw py::error_already_set();
  877. py::setattr(m, "GradKey", grad_key_type);
  878. m.def("backward", &GradKeyWrapper::backward);
  879. m.def("set_cpp_apply_with_tracing", &set_cpp_apply_with_tracing);
  880. m.def("set_cpp_apply_const_with_tracing", &set_cpp_apply_const_with_tracing);
  881. m.def("set_cpp_apply_backward_varnode", &set_cpp_apply_backward_varnode);
  882. m.attr("skip_tracing") = &skip_tracing;
  883. py::class_<SharedHandle>(m, "SharedHandle")
  884. .def(py::init<const SharedHandle&>())
  885. .def("__eq__", [](SharedHandle &thish, SharedHandle &thath) {
  886. return (thish.get() == thath.get());
  887. })
  888. .def("__hash__", [](SharedHandle &sh) {
  889. return reinterpret_cast<int64_t>(sh.get());
  890. })
  891. ;
  892. m.def("set_tracing", &set_tracing);
  893. m.def("unset_tracing", &unset_tracing);
  894. m.def("set_allow_higher_order_directive", [](bool value){
  895. GradKey::allow_higher_order_directive = value;
  896. });
  897. }
  898. #undef MGE_PY_INTERFACE
  899. } // namespace mgb::imperative::python

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台