You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

tensor.cpp 36 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980
  1. /**
  2. * \file imperative/python/src/tensor.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "megbrain/dtype.h"
  12. #include "megbrain/common.h"
  13. #include "megbrain/imperative/ops/utility.h"
  14. #include "megbrain/imperative/ops/backward_graph.h"
  15. #include "./tensor.h"
  16. #include "./grad.h"
  17. #include "./trace.h"
  18. #include "./common.h"
  19. #include "./numpy_dtypes.h"
  20. #include "./graph_rt.h"
  21. #include "./helper.h"
  22. #include <pybind11/numpy.h>
  23. #include <pybind11/operators.h>
  24. #include <range/v3/all.hpp>
  25. #include <string>
  26. #include <unordered_map>
  27. namespace py = pybind11;
  28. namespace views = ranges::views;
  29. namespace mgb::imperative::python {
  30. interpreter::Interpreter::Channel* interpreter_for_py;
  31. PyObject *cpp_apply_with_tracing, *cpp_apply_const_with_tracing,
  32. *cpp_apply_compiled_mode, *cpp_apply_const_compiled_mode;
  33. PyObject *cpp_apply_backward_varnode;
  34. #define REGISTE_APPLY_FUNC(mode) \
  35. void set_##mode(py::object pyf) { \
  36. mode = pyf.ptr(); \
  37. }
  38. REGISTE_APPLY_FUNC(cpp_apply_with_tracing)
  39. REGISTE_APPLY_FUNC(cpp_apply_const_with_tracing)
  40. REGISTE_APPLY_FUNC(cpp_apply_compiled_mode)
  41. REGISTE_APPLY_FUNC(cpp_apply_const_compiled_mode)
  42. REGISTE_APPLY_FUNC(cpp_apply_backward_varnode)
  43. #undef REGISTE_APPLY_FUNC
  44. bool is_tracing = false;
  45. bool is_compiled = false;
  46. #define SET_UNSET_PROP(mode) \
  47. void set_##mode() { \
  48. is_##mode = true; \
  49. } \
  50. void unset_##mode() { \
  51. is_##mode = false; \
  52. } \
  53. SET_UNSET_PROP(tracing)
  54. SET_UNSET_PROP(compiled)
  55. #undef SET_UNSET_PROP
  56. bool skip_tracing = false;
  57. Tensor::flags_t ApplyContext::global_disable = 0;
  58. apply_result_t apply(ApplyContext& ctx) {
  59. // emulating scalar should be put to specific op's apply, e.g.,
  60. // elementwise, reduce, typecvt. Currently it's still handled at python
  61. // side. It could be move to C++ side if it has an impact on performance
  62. auto flags = ctx.flags & ~ApplyContext::global_disable;
  63. if (flags & Tensor::Flags::SCALAR) {
  64. // TODO: emulate scalar
  65. }
  66. if (flags & Tensor::Flags::GRAD) {
  67. return apply_grad(ctx);
  68. }
  69. if (auto* op = ctx.op->try_cast_final<GenericPyOp>()) {
  70. py::tuple pyin(ctx.nargs);
  71. for (size_t i = 0; i < ctx.nargs; ++i) {
  72. pyin[i] = TensorWrapper::make(ctx.pytype, ctx.args[i]->shared_from_this());
  73. }
  74. auto f = py::getattr(op->obj, "_default_rule");
  75. auto pyout = py::reinterpret_steal<py::object>(PyObject_Call(f.ptr(), pyin.ptr(), nullptr));
  76. if (!pyout) throw py::error_already_set();
  77. if (auto* tw = TensorWrapper::try_cast(pyout.ptr())) {
  78. return {tw->m_tensor};
  79. }
  80. apply_result_t ret;
  81. ret.reserve(py::len(pyout));
  82. for (auto&& i : pyout) {
  83. auto* tw = TensorWrapper::try_cast(i.ptr());
  84. mgb_assert(tw);
  85. ret.push_back(tw->m_tensor);
  86. }
  87. return ret;
  88. }
  89. if (flags & Tensor::Flags::TRACE) {
  90. return apply_trace(ctx);
  91. } else {
  92. SmallVector<interpreter::Interpreter::Handle> handles(ctx.nargs);
  93. for (size_t i = 0; i < ctx.nargs; ++i) {
  94. handles[i] = ctx.args[i]->m_handle.get();
  95. }
  96. auto output_handles = interpreter_for_py->apply_op(ctx.op, handles);
  97. apply_result_t outputs;
  98. outputs.reserve(output_handles.size());
  99. for (auto h : output_handles) {
  100. outputs.emplace_back(std::make_shared<Tensor>(h));
  101. }
  102. return outputs;
  103. }
  104. mgb_assert(0);
  105. }
  106. PyObject* py_apply(PyObject* self, PyObject*const* args, size_t nargs/* , PyObject* kwnames */) {
  107. try {
  108. // if (kwnames && PyTuple_GET_SIZE(kwnames)) {
  109. // PyErr_SetString(PyExc_TypeError, "keyword argument not allowed");
  110. // return nullptr;
  111. // }
  112. if (nargs < 2) {
  113. PyErr_SetString(PyExc_TypeError,
  114. "py_apply expects one Op and at least one tensor "
  115. "as argument");
  116. return nullptr;
  117. }
  118. auto* op = args[0];
  119. PyTypeObject* pytype = args[1]->ob_type;
  120. ++args;
  121. --nargs;
  122. ApplyContext ctx;
  123. ctx.flags = 0;
  124. ctx.op = py::handle(op).cast<std::shared_ptr<OpDef>>();
  125. SmallVector<Tensor*, 64> tensors(nargs);
  126. ctx.args = &tensors[0];
  127. ctx.nargs = nargs;
  128. ctx.pytype = pytype;
  129. if (ctx.op->same_type<BackwardGraph>()) {
  130. ctx.backward = true;
  131. }
  132. if (py::isinstance<PySymbolVar>(py::handle(args[0]))){
  133. SmallVector<cg::VarNode*> vinputs(nargs);
  134. for (size_t i = 0; i < nargs; ++i) {
  135. vinputs[i] = py::handle(args[i]).cast<PySymbolVar*>()->m_node;
  136. }
  137. auto op = ctx.op.get();
  138. auto rst = OpDef::apply_on_var_node(*op, vinputs);
  139. auto ret = pybind11::tuple(rst.size());
  140. auto typeobj = py::handle(args[0]).get_type();
  141. for (size_t i = 0; i<rst.size(); ++i) {
  142. ret[i] = typeobj(pybind11::cast(rst[i], pybind11::return_value_policy::automatic));
  143. }
  144. return ret.release().ptr();
  145. }
  146. for (size_t i = 0; i < nargs; ++i) {
  147. if (TensorWrapper* tw = TensorWrapper::try_cast(args[i])) {
  148. auto* t = tensors[i] = tw->m_tensor.get();
  149. ctx.flags |= t->m_flags;
  150. } else {
  151. PyErr_SetString(PyExc_TypeError,
  152. ssprintf("op %s expect type Tensor as inputs, got %s actually",
  153. ctx.op->make_name().c_str(), Py_TYPE(args[i])->tp_name).c_str());
  154. return nullptr;
  155. }
  156. }
  157. if (is_tracing) {
  158. ctx.flags |= Tensor::Flags::TRACE;
  159. }
  160. auto outputs = apply(ctx);
  161. size_t nout = outputs.size();
  162. auto ret = py::tuple(nout);
  163. for (size_t i = 0; i < nout; ++i) {
  164. ret[i] = TensorWrapper::make(pytype, std::move(outputs[i]));
  165. }
  166. return ret.release().ptr();
  167. } catch (std::exception& e) {
  168. PyErr_SetString(PyExc_RuntimeError, e.what());
  169. return nullptr;
  170. }
  171. }
  172. TensorWrapper::TensorWrapper(PyObject* args, PyObject* kwargs) {
  173. if (kwargs && PyDict_Size(kwargs)) {
  174. throw py::type_error("keyword argument not allowed");
  175. }
  176. auto nargs = PyTuple_Size(args);
  177. auto tup = py::reinterpret_borrow<py::tuple>(args);
  178. if (nargs == 0) {
  179. throw py::type_error("too few arguments");
  180. }
  181. if (auto* t = try_cast(tup[0].ptr())) {
  182. if (nargs > 1) {
  183. throw py::type_error("expect 1 argument");
  184. }
  185. m_tensor = t->m_tensor;
  186. } else {
  187. if (nargs == 1) {
  188. auto arg0 = PyTuple_GetItem(args, 0);
  189. // for lazy_eval_tensor
  190. if (strstr(arg0->ob_type->tp_name, "VarNode")) {
  191. if (PyObject_HasAttrString(arg0, "_node")) {
  192. arg0 = PyObject_GetAttrString(arg0, "_node");
  193. }
  194. m_tensor = std::make_shared<Tensor>(py::handle(arg0).cast<cg::VarNode *>());
  195. } else {
  196. // for DeviceTensorND
  197. if (strstr(arg0->ob_type->tp_name, "DeviceTensorND")) {
  198. auto dv = py::handle(arg0).cast<DeviceTensorND>();
  199. interpreter::Interpreter::Handle handle = interpreter_for_py->put(dv);
  200. m_tensor = std::make_shared<Tensor>(handle);
  201. } else {
  202. throw py::type_error("single argument is not tensor, varnode or devicetensor");
  203. }
  204. }
  205. } else {
  206. py::detail::loader_life_support life_sup; // FIXME!!!required to cast DType
  207. if (nargs != 5 && nargs != 6) {
  208. throw py::type_error("expect 5 or 6 arguments");
  209. }
  210. auto data = tup[0].cast<py::array>();
  211. DType dtype = tup[1].cast<DType>();
  212. CompNode cn = tup[2].cast<CompNode>();
  213. bool is_const = tup[3].cast<bool>();
  214. bool no_cache = nargs == 6 ? tup[4].cast<bool>() : false;
  215. std::string name;
  216. if (tup[nargs - 1].ptr() != Py_None) name = tup[nargs - 1].cast<std::string>();
  217. // const op
  218. if (is_const && is_tracing) {
  219. PyObject *pyf;
  220. if (is_compiled) {
  221. pyf = cpp_apply_const_compiled_mode;
  222. } else {
  223. pyf = cpp_apply_const_with_tracing;
  224. }
  225. auto py_ret = PyObject_Call(pyf, tup.ptr(), nullptr);
  226. if (!py_ret) throw py::error_already_set();
  227. auto py_list = py::reinterpret_steal<py::list>(py_ret);
  228. if (auto* t = try_cast(py_list[0].ptr())) {
  229. m_tensor = t->m_tensor;
  230. }
  231. return;
  232. }
  233. interpreter::Interpreter::Handle handle;
  234. constexpr auto size_threshhold = TensorShape::MAX_NDIM;
  235. if (data.size() > size_threshhold) {
  236. handle = interpreter_for_py->put(npy::np2tensor(data.ptr(), npy::Meth::borrow(cn), dtype), no_cache);
  237. } else {
  238. HostTensorND ret(cn);
  239. handle = interpreter_for_py->put(npy::np2tensor(data.ptr(), npy::Meth::copy_into(&ret), dtype), no_cache);
  240. }
  241. m_tensor = std::make_shared<Tensor>(handle);
  242. m_tensor->user_custom_name = name;
  243. if (data.ndim() == 0) {
  244. m_tensor->m_flags |= Tensor::Flags::SCALAR;
  245. }
  246. }
  247. }
  248. }
  249. #define REGISTE_TENSORWRAPPER_FUNC(type, member) \
  250. PyObject* TensorWrapper::member() { \
  251. return py::cast(m_tensor->m_trace_info.member).release().ptr(); \
  252. } \
  253. void TensorWrapper::set_##member(PyObject* dest) { \
  254. auto py_dest = py::reinterpret_borrow<py::object>(dest); \
  255. type real_dest = py_dest.cast<type>(); \
  256. m_tensor->m_trace_info.member = real_dest; \
  257. }
  258. REGISTE_TENSORWRAPPER_FUNC(int64_t, mixin_handle)
  259. REGISTE_TENSORWRAPPER_FUNC(bool, recording)
  260. #undef REGISTE_TENSORWRAPPER_FUNC
  261. PyObject* TensorWrapper::copied() {
  262. return py::cast(m_tensor->m_trace_info.copied).release().ptr();
  263. }
  264. #define REGISTE_TENSORWRAPPER_PYOBJECT_FUNC(member) \
  265. PyObject* TensorWrapper::member() { \
  266. if (m_tensor->m_trace_info.member) { \
  267. return m_tensor->m_trace_info.member; \
  268. } else { \
  269. Py_RETURN_NONE; \
  270. } \
  271. } \
  272. void TensorWrapper::set_##member(PyObject* dest) { \
  273. if (dest == Py_None) { \
  274. Py_XDECREF(m_tensor->m_trace_info.member); \
  275. m_tensor->m_trace_info.member = nullptr; \
  276. } else { \
  277. Py_INCREF(dest); \
  278. m_tensor->m_trace_info.member = dest; \
  279. } \
  280. }
  281. REGISTE_TENSORWRAPPER_PYOBJECT_FUNC(compiled_info)
  282. REGISTE_TENSORWRAPPER_PYOBJECT_FUNC(trace_mixin_info)
  283. #undef REGISTE_TENSORWRAPPER_PYOBJECT_FUNC
  284. #define SET_GET_NAME(member) \
  285. PyObject* TensorWrapper::member() { \
  286. return py::cast(m_tensor->member).release().ptr(); \
  287. } \
  288. void TensorWrapper::set_##member(PyObject* dest) { \
  289. auto py_dest = py::reinterpret_borrow<py::object>(dest); \
  290. m_tensor->member = py_dest.cast<std::string>(); \
  291. }
  292. SET_GET_NAME(user_custom_name)
  293. SET_GET_NAME(automatic_name)
  294. #undef SET_GET_NAME
  295. PyObject* TensorWrapper::handle() {
  296. return py::cast(m_tensor->m_handle).release().ptr();
  297. }
  298. void TensorWrapper::set_handle(PyObject* dest) {
  299. auto py_dest = py::reinterpret_borrow<py::object>(dest);
  300. SharedHandle real_dest = py_dest.cast<SharedHandle>();
  301. m_tensor->m_handle = std::move(real_dest);
  302. }
  303. PyObject* TensorWrapper::shape() {
  304. // if it's tracing compiled mode, get value from compiled_info
  305. if (m_tensor->m_trace_info.compiled_info != nullptr) {
  306. if (m_tensor->m_flags & Tensor::Flags::SCALAR) {
  307. return PyTuple_New(0);
  308. }
  309. PyObject *shp = PyObject_GetAttrString(m_tensor->m_trace_info.compiled_info, "shape");
  310. if (shp == Py_None) {
  311. throw TraceReadError("shape of this tensor is not read in trace");
  312. }
  313. return shp;
  314. }
  315. // inside trace, if tensor shape is useful for other operations, set shape_read = true
  316. if (m_tensor->m_trace_info.recording && !skip_tracing) {
  317. PyObject_SetAttrString(m_tensor->m_trace_info.trace_mixin_info, "shape_read", py::cast(true).release().ptr());
  318. }
  319. if (m_tensor->m_flags & Tensor::Flags::SCALAR) {
  320. return PyTuple_New(0);
  321. }
  322. TensorShape shape;
  323. if (m_tensor->m_var) { // get shape from m_var
  324. auto&& mgr = m_tensor->m_var->owner_graph()->static_infer_manager();
  325. auto *tshp = mgr.infer_shape_fallible(m_tensor->m_var);
  326. if (!tshp) {
  327. Py_RETURN_NONE;
  328. }
  329. shape = *tshp;
  330. } else {
  331. shape = m_tensor->shape();
  332. }
  333. if (!shape.ndim) {
  334. Py_RETURN_NONE;
  335. }
  336. py::tuple ret(shape.ndim);
  337. for (size_t i = 0; i < shape.ndim; ++i) {
  338. ret[i] = shape[i];
  339. }
  340. return ret.release().ptr();
  341. }
  342. PyObject* TensorWrapper::dtype() {
  343. if (m_tensor->m_var) {
  344. return py::cast(m_tensor->m_var->dtype()).release().ptr();
  345. }
  346. return py::cast(m_tensor->dtype()).release().ptr();
  347. }
  348. PyObject* TensorWrapper::device() {
  349. if (m_tensor->m_var) {
  350. return py::cast(m_tensor->m_var->comp_node()).release().ptr();
  351. }
  352. return py::cast(m_tensor->comp_node()).release().ptr();
  353. }
  354. PyObject* TensorWrapper::numpy() {
  355. if (m_tensor->m_trace_info.compiled_info != nullptr) {
  356. PyObject* np_val = PyObject_CallMethod(m_tensor->m_trace_info.compiled_info, "numpy", nullptr);
  357. if (!np_val) throw py::error_already_set();
  358. if (np_val == Py_None) {
  359. throw TraceReadError("value of this tensor is not read in trace");
  360. }
  361. if (m_tensor->m_flags & Tensor::Flags::SCALAR) {
  362. PyObject *np_scalar = PyArray_Squeeze(reinterpret_cast<PyArrayObject*>(np_val));
  363. Py_DECREF(np_val);
  364. return np_scalar;
  365. }
  366. return np_val;
  367. }
  368. if (m_tensor->m_trace_info.recording && !skip_tracing) {
  369. PyObject_SetAttrString(m_tensor->m_trace_info.trace_mixin_info, "value_read", py::cast(true).release().ptr());
  370. }
  371. if (m_tensor->m_handle.get() == nullptr && m_tensor->m_var != nullptr) {
  372. auto&& mgr = m_tensor->m_var->owner_graph()->static_infer_manager();
  373. auto&& type = mgr.get_infer_type(m_tensor->m_var);
  374. using InferType = cg::static_infer::InferType;
  375. if (!(type.value & (InferType::CONST | InferType::RT_STATIC))) {
  376. PyErr_SetString(PyExc_ValueError, "tensor invalid");
  377. return nullptr;
  378. }
  379. auto* val = mgr.infer_value_fallible(m_tensor->m_var);
  380. if (!val) {
  381. PyErr_SetString(PyExc_ValueError, "tensor invalid");
  382. return nullptr;
  383. }
  384. auto np_val = py::cast(*val).attr("numpy")();
  385. if (m_tensor->m_flags & Tensor::Flags::SCALAR) {
  386. return PyArray_Squeeze(reinterpret_cast<PyArrayObject*>(np_val.release().ptr()));
  387. }
  388. return np_val.release().ptr();
  389. }
  390. auto&& hv = [&]() {
  391. py::gil_scoped_release _;
  392. return interpreter_for_py->get_value(m_tensor->m_handle.get());
  393. }();
  394. auto arr = py::reinterpret_steal<py::array>(npy::ndarray_from_tensor(hv, npy::ShareType::TRY_SHARE));
  395. if (!arr) {
  396. PyErr_SetString(PyExc_ValueError, "tensor invalid");
  397. return nullptr;
  398. }
  399. if (m_tensor->m_flags & Tensor::Flags::SCALAR) {
  400. mgb_assert(PyArray_Check(arr.ptr()));
  401. return PyArray_Squeeze(reinterpret_cast<PyArrayObject*>(arr.ptr()));
  402. }
  403. return arr.release().ptr();
  404. }
  405. PyObject* TensorWrapper::varnode() {
  406. if (m_tensor->m_var) {
  407. return py::cast(m_tensor->m_var).release().ptr();
  408. }
  409. Py_RETURN_NONE;
  410. }
  411. void TensorWrapper::reset(PyObject* tensor) {
  412. TensorWrapper* t = TensorWrapper::try_cast(tensor);
  413. if (!t) {
  414. throw py::type_error("expect Tensor");
  415. }
  416. std::string user_custom_name = m_tensor->user_custom_name;
  417. std::string automatic_name = m_tensor->automatic_name;
  418. m_tensor = t->m_tensor;
  419. m_tensor->user_custom_name = user_custom_name;
  420. m_tensor->automatic_name = automatic_name;
  421. }
  422. void TensorWrapper::reset_varnode() {
  423. m_tensor->m_var = nullptr;
  424. }
  425. PyObject* TensorWrapper::detach() {
  426. PyObject* self = wrap_t::pycast(this);
  427. PyTypeObject* pytype = self->ob_type;
  428. std::shared_ptr<Tensor> new_tensor;
  429. if (m_tensor->m_handle.get()) {
  430. new_tensor = std::make_shared<Tensor>(m_tensor->m_handle);
  431. } else {
  432. new_tensor = std::make_shared<Tensor>(m_tensor->m_var);
  433. }
  434. new_tensor->m_trace_info = m_tensor->m_trace_info;
  435. new_tensor->m_flags = m_tensor->m_flags;
  436. auto ret = TensorWrapper::make(pytype, std::move(new_tensor));
  437. return ret.release().ptr();
  438. }
  439. PyObject* TensorWrapper::_dev_tensor(){
  440. if (m_tensor->m_trace_info.compiled_info != nullptr) {
  441. auto *dev_tensor = PyObject_CallMethod(m_tensor->m_trace_info.compiled_info, "_dev_tensor", nullptr);
  442. if (!dev_tensor) throw py::error_already_set();
  443. if (dev_tensor == Py_None) {
  444. throw TraceReadError("raw data of this tensor is not read in trace");
  445. }
  446. // set m_handle to make it a real tensor
  447. auto py_dev_tensor = py::reinterpret_borrow<py::object>(dev_tensor);
  448. auto sh = interpreter_for_py->put(py_dev_tensor.cast<DeviceTensorND>());
  449. m_tensor->m_handle = std::move(SharedHandle(sh));
  450. // compiled info is useless after m_handle is set
  451. Py_DECREF(m_tensor->m_trace_info.compiled_info);
  452. m_tensor->m_trace_info.compiled_info = nullptr;
  453. return dev_tensor;
  454. }
  455. if (m_tensor->m_trace_info.recording && !skip_tracing) {
  456. PyObject_SetAttrString(m_tensor->m_trace_info.trace_mixin_info, "data_read", py::cast(true).release().ptr());
  457. }
  458. auto dev_tensor = [&](){
  459. py::gil_scoped_release _;
  460. return interpreter_for_py->get_dev_tensor(m_tensor->m_handle.get());
  461. }();
  462. return py::cast(dev_tensor).release().ptr();
  463. }
  464. void TensorWrapper::_swap_out() {
  465. interpreter_for_py->swap_out(m_tensor->m_handle.get());
  466. }
  467. void TensorWrapper::_swap_in() {
  468. interpreter_for_py->swap_in(m_tensor->m_handle.get());
  469. }
  470. void TensorWrapper::_drop() {
  471. interpreter_for_py->drop(m_tensor->m_handle.get());
  472. }
  473. PyObject* TensorWrapper::isscalar() {
  474. if(m_tensor->m_flags & Tensor::Flags::SCALAR) {
  475. Py_RETURN_TRUE;
  476. } else {
  477. Py_RETURN_FALSE;
  478. }
  479. }
  480. void TensorWrapper::setscalar() {
  481. m_tensor->m_flags |= Tensor::Flags::SCALAR;
  482. }
  483. void TensorWrapper::unsetscalar() {
  484. m_tensor->m_flags &= ~Tensor::Flags::SCALAR;
  485. }
  486. struct TensorWeakRef {
  487. std::weak_ptr<Tensor> wptr;
  488. TensorWeakRef(const TensorWrapper& tw) : wptr(tw.m_tensor) {}
  489. py::object operator()() {
  490. if (auto p = wptr.lock()) {
  491. return TensorWrapper::make(p);
  492. }
  493. return py::none();
  494. }
  495. int _use_cnt() { return wptr.use_count(); }
  496. };
  497. /* ============== convert inputs ============== */
  498. // map numpy.dtype.kind to priority
  499. inline uint8_t category_priority(char c) {
  500. switch (c) {
  501. case 'f': return 3; // floating-point
  502. case 'i': return 2; // signed integer
  503. case 'u': return 2; // unsigned integer
  504. case 'b': return 1; // boolean
  505. default: return 0;
  506. }
  507. }
  508. // Returns the maximum value of the priority of each type in the list `types`.
  509. uint8_t max_priority(SmallVector<PyArray_Descr*> types) {
  510. if (types.size() == 0) {
  511. return 0;
  512. } else {
  513. uint8_t max_p = 0;
  514. for (auto&& desc: types) {
  515. max_p = std::max(max_p, category_priority(desc->kind));
  516. }
  517. return max_p;
  518. }
  519. }
  520. // Returns the data type with sufficient size to hold all types of
  521. // category `cat` in the list `types`.
  522. PyArray_Descr* promote_types(SmallVector<PyArray_Descr*> types, uint8_t cat) {
  523. // Return value: New reference
  524. SmallVector<PyArray_Descr*> used_types;
  525. for (auto&& desc: types) {
  526. auto&& v = category_priority(desc->kind);
  527. if (v == cat) {
  528. used_types.emplace_back(desc);
  529. }
  530. }
  531. mgb_assert(used_types.size() > 0, "size of used_types is 0");
  532. PyArray_Descr* res = used_types[0];
  533. Py_INCREF(res);
  534. for (size_t i = 1; i < used_types.size(); ++i) {
  535. PyArray_Descr* tmp = PyArray_PromoteTypes(used_types[i], res);
  536. Py_DECREF(res);
  537. res = tmp;
  538. }
  539. return res;
  540. }
  541. PyArray_Descr* scalar2dtype(PyObject* arg) {
  542. // Return value: New reference
  543. if (PyBool_Check(arg)) {
  544. auto&& descr = PyArray_DescrFromType(NPY_BOOL);
  545. return descr;
  546. }
  547. if (PyLong_CheckExact(arg)) {
  548. auto&& descr = PyArray_DescrFromType(NPY_INT32);
  549. return descr;
  550. }
  551. if (PyFloat_CheckExact(arg)) {
  552. auto&& descr = PyArray_DescrFromType(NPY_FLOAT32);
  553. return descr;
  554. }
  555. return nullptr;
  556. }
  557. PyArray_Descr* _dtype_promotion(PyObject*const* args, size_t nargs) {
  558. // Return value: New reference
  559. SmallVector<PyArray_Descr*> tensors;
  560. SmallVector<PyArray_Descr*> scalars;
  561. bool is_tuple = false;
  562. PyObject* tuple = nullptr;
  563. if (nargs == 1 && (PyTuple_Check(args[0]) || PyList_Check(args[0]))) {
  564. if (PyList_Check(args[0])) {
  565. tuple = PyList_AsTuple(args[0]);
  566. } else {
  567. tuple = args[0];
  568. Py_INCREF(tuple);
  569. }
  570. nargs = PyTuple_Size(tuple);
  571. is_tuple = true;
  572. }
  573. for (size_t i = 0; i < nargs; ++i) {
  574. PyObject* handle = is_tuple ? PyTuple_GetItem(tuple, i): args[i];
  575. if (handle == Py_None) continue;
  576. TensorWrapper* tw = TensorWrapper::try_cast(handle);
  577. if (tw) {
  578. mgb::DType type = tw->m_tensor->dtype();
  579. auto&& descr = npy::dtype_mgb2np_descr(type);
  580. Py_INCREF(descr.get());
  581. tensors.emplace_back(descr.get());
  582. }else{
  583. if (PyArray_Check(handle) || PyArray_CheckScalar(handle)) {
  584. auto&& descr = PyArray_DescrFromObject(handle, nullptr);
  585. tensors.emplace_back(descr);
  586. continue;
  587. }
  588. if (py::isinstance<PySymbolVar>(py::handle(handle))){
  589. auto var = py::handle(handle).cast<PySymbolVar*>();
  590. mgb::DType type = var->m_node->dtype();
  591. auto && descr = npy::dtype_mgb2np_descr(type);
  592. Py_INCREF(descr.get());
  593. tensors.emplace_back(descr.get());
  594. continue;
  595. }
  596. PyArray_Descr* descr = scalar2dtype(handle);
  597. if (descr) {
  598. scalars.emplace_back(descr);
  599. continue;
  600. }
  601. }
  602. }
  603. auto max_pri_scalars = max_priority(scalars);
  604. auto max_pri_tensors = max_priority(tensors);
  605. if (max_pri_scalars <= 0 && max_pri_tensors <= 0) {
  606. throw py::value_error("invalid input, no dtype avaliable");
  607. }
  608. PyArray_Descr* res;
  609. if (max_pri_scalars > max_pri_tensors) {
  610. res = promote_types(scalars, max_pri_scalars);
  611. }else{
  612. res = promote_types(tensors, max_pri_tensors);
  613. }
  614. for (auto *p: tensors) { Py_DECREF(p); }
  615. for (auto *p: scalars) { Py_DECREF(p); }
  616. Py_XDECREF(tuple);
  617. return res;
  618. }
  619. CompNode _get_device(PyObject*const* args, size_t nargs) {
  620. bool is_tuple = false;
  621. PyObject* tuple = nullptr;
  622. if (nargs == 1 && (PyTuple_Check(args[0]) || PyList_Check(args[0]))) {
  623. if (PyList_Check(args[0])) {
  624. tuple = PyList_AsTuple(args[0]);
  625. } else {
  626. tuple = args[0];
  627. Py_INCREF(tuple);
  628. }
  629. nargs = PyTuple_Size(tuple);
  630. is_tuple = true;
  631. }
  632. bool valid = false;
  633. CompNode cn;
  634. for (size_t i = 0; i < nargs; ++i) {
  635. PyObject* handle = is_tuple ? PyTuple_GetItem(tuple, i) : args[i];
  636. TensorWrapper* tw = TensorWrapper::try_cast(handle);
  637. bool is_symvar = py::isinstance<PySymbolVar>(py::handle(handle));
  638. if (tw || is_symvar) {
  639. if (!valid) {
  640. cn = tw ? tw->m_tensor->comp_node()
  641. : py::handle(handle)
  642. .cast<PySymbolVar*>()
  643. ->m_node->comp_node();
  644. valid = true;
  645. } else {
  646. CompNode cn1 = tw ? tw->m_tensor->comp_node()
  647. : py::handle(handle)
  648. .cast<PySymbolVar*>()
  649. ->m_node->comp_node();
  650. if (cn1 != cn) {
  651. throw py::value_error(ssprintf("ambiguous device: %s vs %s",
  652. cn.to_string().c_str(),
  653. cn1.to_string().c_str()));
  654. }
  655. }
  656. }
  657. }
  658. if (!valid) {
  659. mgb_assert(0, "expect at least 1 device");
  660. }
  661. Py_XDECREF(tuple);
  662. return cn;
  663. }
  664. // Returns the dtype that would result from performing an arithmetic
  665. // operation on the provided input tensors and scalars.
  666. PyObject* dtype_promotion(PyObject* self, PyObject*const* args, size_t nargs) {
  667. if (!nargs) {
  668. PyErr_SetString(PyExc_TypeError, "empty input is not allowed");
  669. return nullptr;
  670. }
  671. try {
  672. PyArray_Descr* res = _dtype_promotion(args, nargs);
  673. return py::cast(npy::dtype_np2mgb_descr(res)).release().ptr();
  674. } catch (std::exception& e) {
  675. PyErr_SetString(PyExc_RuntimeError, e.what());
  676. return nullptr;
  677. }
  678. }
  679. PyObject* get_device(PyObject* self, PyObject*const* args, size_t nargs) {
  680. if (!nargs) {
  681. PyErr_SetString(PyExc_TypeError, "empty input is not allowed");
  682. return nullptr;
  683. }
  684. try {
  685. CompNode cn = _get_device(args, nargs);
  686. return py::cast(cn).release().ptr();
  687. } catch (std::exception& e) {
  688. PyErr_SetString(PyExc_RuntimeError, e.what());
  689. return nullptr;
  690. }
  691. }
  692. #ifdef METH_FASTCALL
  693. #define MGE_PY_INTERFACE(NAME, FUNC) \
  694. { #NAME, (PyCFunction)FUNC, METH_FASTCALL, nullptr }
  695. #else
  696. #define WRAP_FUNC_PY35(FUNC) \
  697. PyObject* py35_##FUNC(PyObject* self, PyObject* args) { \
  698. auto* arr = &PyTuple_GET_ITEM(args, 0); \
  699. auto size = PyTuple_GET_SIZE(args); \
  700. return FUNC(self, arr, size); \
  701. }
  702. WRAP_FUNC_PY35(py_apply);
  703. WRAP_FUNC_PY35(dtype_promotion);
  704. WRAP_FUNC_PY35(get_device);
  705. #undef WRAP_FUNC_PY35
  706. #define MGE_PY_INTERFACE(NAME, FUNC) \
  707. { #NAME, (PyCFunction)py35_##FUNC, METH_VARARGS, nullptr }
  708. #endif
  709. void init_tensor(py::module m) {
  710. imperative::Tensor::static_initialize();
  711. static auto sl_interpreter_for_py = interpreter::Interpreter::inst().create_channel();
  712. interpreter_for_py = sl_interpreter_for_py.get();
  713. auto* tensor_type = TensorWrapper::wrap_t::type()
  714. .def<&TensorWrapper::numpy>("numpy")
  715. .def_getset<&TensorWrapper::shape>("shape")
  716. .def_getset<&TensorWrapper::dtype>("dtype")
  717. .def_getset<&TensorWrapper::device>("device")
  718. .def<&TensorWrapper::reset>("_reset")
  719. .def<&TensorWrapper::isscalar>("_isscalar")
  720. .def<&TensorWrapper::setscalar>("_setscalar")
  721. .def<&TensorWrapper::unsetscalar>("_unsetscalar")
  722. .def<&TensorWrapper::detach>("detach")
  723. .def<&TensorWrapper::_dev_tensor>("_dev_tensor")
  724. .def<&TensorWrapper::_swap_out>("_swap_out")
  725. .def<&TensorWrapper::_swap_in>("_swap_in")
  726. .def<&TensorWrapper::_drop>("_drop")
  727. .def<&TensorWrapper::reset_varnode>("_reset_varnode")
  728. .def<&TensorWrapper::_use_cnt>("_use_cnt")
  729. .def_getset<&TensorWrapper::varnode>("_varnode")
  730. .def_getset<&TensorWrapper::copied>("_copied")
  731. .def_getset<&TensorWrapper::mixin_handle, &TensorWrapper::set_mixin_handle>("_mixin_handle")
  732. .def_getset<&TensorWrapper::recording, &TensorWrapper::set_recording>("_recording")
  733. .def_getset<&TensorWrapper::handle, &TensorWrapper::set_handle>("_handle")
  734. .def_getset<&TensorWrapper::compiled_info, &TensorWrapper::set_compiled_info>("_compiled_info")
  735. .def_getset<&TensorWrapper::trace_mixin_info, &TensorWrapper::set_trace_mixin_info>("_trace_mixin_info")
  736. .def_getset<&TensorWrapper::user_custom_name, &TensorWrapper::set_user_custom_name>("c_name")
  737. .def_getset<&TensorWrapper::automatic_name, &TensorWrapper::set_automatic_name>("_name")
  738. .finalize();
  739. if (!tensor_type) throw py::error_already_set();
  740. py::setattr(m, "Tensor", tensor_type);
  741. py::class_<TensorWeakRef>(m, "TensorWeakRef")
  742. .def(py::init<const TensorWrapper&>())
  743. .def("__call__", &TensorWeakRef::operator())
  744. .def("_use_cnt", &TensorWeakRef::_use_cnt);
  745. py::class_<PySymbolVar, std::shared_ptr<PySymbolVar>>(m, "SymbolVar")
  746. .def_property_readonly(
  747. "dtype", [](PySymbolVar* v) { return v->m_node->dtype(); })
  748. .def_property("var", [](PySymbolVar* v) { return v->m_node; },
  749. [](PySymbolVar* s, cg::VarNode* v) { s->m_node = v; })
  750. .def_property_readonly(
  751. "device",
  752. [](PySymbolVar* v) { return v->m_node->comp_node(); })
  753. .def_property_readonly(
  754. "graph",
  755. [](PySymbolVar* v) { return v->m_node->owner_graph(); })
  756. .def_property_readonly(
  757. "shape",
  758. [](PySymbolVar* v) -> const TensorShape* {
  759. auto&& mgr = v->m_node->owner_graph()
  760. ->static_infer_manager();
  761. return mgr.infer_shape_fallible(v->m_node);
  762. })
  763. .def("_isscalar", [](PySymbolVar* v) { return v->is_scalar; })
  764. .def("_setscalar",
  765. [](PySymbolVar* v) { return v->is_scalar = true; })
  766. .def(py::init([](cg::VarNode* node) {
  767. return std::make_shared<PySymbolVar>(node);
  768. }),
  769. py::arg() = nullptr);
  770. static PyMethodDef method_defs[] = {
  771. MGE_PY_INTERFACE(apply, py_apply),
  772. MGE_PY_INTERFACE(dtype_promotion, dtype_promotion),
  773. MGE_PY_INTERFACE(get_device, get_device),
  774. {nullptr, nullptr, 0, nullptr}};
  775. for (auto&& def: method_defs) {
  776. if (def.ml_meth != nullptr) {
  777. auto* func = PyCFunction_NewEx(&def, nullptr, nullptr);
  778. if (!func) throw py::error_already_set();
  779. py::setattr(m, def.ml_name, func);
  780. }
  781. }
  782. m.def("set_option",
  783. [](std::string name, int value){ interpreter_for_py->set_option(name, value); });
  784. m.def("get_option",
  785. [](std::string name){ return interpreter_for_py->get_option(name); });
  786. m.def("_set_swap_flag",
  787. [](bool flag) { interpreter_for_py->set_option("enable_swap", flag); });
  788. m.def("_set_drop_flag",
  789. [](bool flag) { interpreter_for_py->set_option("enable_drop", flag); });
  790. m.def("config_async_level",
  791. [](int level) {
  792. mgb_assert(level >= 0 and level <= 2, "async_level should be 0, 1 or 2");
  793. interpreter_for_py->set_option("async_level", level);
  794. });
  795. m.def("get_async_level",
  796. []() { return interpreter_for_py->get_option("async_level"); });
  797. m.def("set_buffer_length",
  798. [](int length) {
  799. mgb_assert(length >= 0 and length < 100, "buffer_length should be in [0, 100)");
  800. interpreter_for_py->set_option("buffer_length", length);
  801. });
  802. m.def("push_scope",
  803. [](std::string name) { interpreter_for_py->push_scope(name); });
  804. m.def("pop_scope",
  805. [](std::string name) { interpreter_for_py->pop_scope(name); });
  806. m.def("start_profile",
  807. [](std::unordered_map<std::string, int> option) { return interpreter_for_py->start_profile(option); });
  808. m.def("stop_profile",
  809. [](std::string basename, std::string format) { interpreter_for_py->stop_profile(basename, format); });
  810. m.def("sync",
  811. []() {
  812. interpreter_for_py->sync();
  813. py_task_q.wait_all_task_finish();
  814. },
  815. py::call_guard<py::gil_scoped_release>());
  816. m.def("full_sync",
  817. []() {
  818. interpreter_for_py->sync();
  819. CompNode::sync_all();
  820. py_task_q.wait_all_task_finish();
  821. },
  822. py::call_guard<py::gil_scoped_release>());
  823. py::handle grad_key_type = GradKeyWrapper::wrap_t::type()
  824. .def<&GradKeyWrapper::attach>("attach")
  825. .def<&GradKeyWrapper::is_attached_to>("is_attached_to")
  826. .def_getset<&GradKeyWrapper::get_name, &GradKeyWrapper::set_name>("name")
  827. .finalize();
  828. if (!grad_key_type) throw py::error_already_set();
  829. py::setattr(m, "GradKey", grad_key_type);
  830. m.def("backward", &GradKeyWrapper::backward);
  831. m.def("set_cpp_apply_with_tracing", &set_cpp_apply_with_tracing);
  832. m.def("set_cpp_apply_const_with_tracing", &set_cpp_apply_const_with_tracing);
  833. m.def("set_cpp_apply_compiled_mode", &set_cpp_apply_compiled_mode);
  834. m.def("set_cpp_apply_const_compiled_mode", &set_cpp_apply_const_compiled_mode);
  835. m.def("set_cpp_apply_backward_varnode", &set_cpp_apply_backward_varnode);
  836. m.attr("skip_tracing") = &skip_tracing;
  837. py::class_<SharedHandle>(m, "SharedHandle")
  838. .def(py::init<const SharedHandle&>())
  839. .def("__eq__", [](SharedHandle &thish, SharedHandle &thath) {
  840. return (thish.get() == thath.get());
  841. })
  842. .def("__hash__", [](SharedHandle &sh) {
  843. return reinterpret_cast<int64_t>(sh.get());
  844. })
  845. ;
  846. m.def("set_tracing", &set_tracing);
  847. m.def("unset_tracing", &unset_tracing);
  848. m.def("set_compiled", &set_compiled);
  849. m.def("unset_compiled", &unset_compiled);
  850. }
  851. #undef MGE_PY_INTERFACE
  852. } // namespace mgb::imperative::python

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台