You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

tensor.cpp 42 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129
  1. /**
  2. * \file imperative/python/src/tensor.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "megbrain/common.h"
  12. #include "megbrain/dtype.h"
  13. #include "megbrain/imperative/ops/autogen.h"
  14. #include "megbrain/imperative/ops/backward_graph.h"
  15. #include "megbrain/imperative/ops/utility.h"
  16. #include "megbrain/imperative/profiler.h"
  17. #include "megbrain/opr/io.h"
  18. #include "./common.h"
  19. #include "./grad.h"
  20. #include "./graph_rt.h"
  21. #include "./helper.h"
  22. #include "./module_trace.h"
  23. #include "./numpy_dtypes.h"
  24. #include "./tensor.h"
  25. #include "./trace.h"
  26. #include <object.h>
  27. #include <pybind11/numpy.h>
  28. #include <pybind11/operators.h>
  29. #include <pybind11/pytypes.h>
  30. #include <pyerrors.h>
  31. #include <range/v3/all.hpp>
  32. #include <string>
  33. #include <unordered_map>
  34. namespace py = pybind11;
  35. namespace views = ranges::views;
  36. namespace mgb::imperative::python {
  37. interpreter::Interpreter::Channel* interpreter_for_py;
  38. PyObject *cpp_apply_with_tracing, *cpp_apply_const_with_tracing;
  39. PyObject* cpp_apply_backward_varnode;
  40. PyObject* cpp_apply_module_trace;
  41. std::shared_ptr<Tensor> make_const(imperative::TensorPtr value) {
  42. if (!(ApplyContext::global_enable & Tensor::Flags::TRACE)) {
  43. return std::make_shared<Tensor>(
  44. interpreter_for_py->put(value->dev_tensor(), value->get_value()));
  45. }
  46. py::tuple tup(6);
  47. auto data = value->get_value();
  48. tup[0] = py::reinterpret_steal<py::array>(
  49. ndarray_from_tensor(data, npy::ShareType::MUST_SHARE));
  50. tup[1] = value->dtype();
  51. tup[2] = value->comp_node();
  52. tup[3] = true;
  53. tup[4] = false;
  54. tup[5] = py::none{};
  55. auto py_ret = PyObject_Call(cpp_apply_const_with_tracing, tup.ptr(), nullptr);
  56. if (!py_ret)
  57. throw py::error_already_set();
  58. auto py_list = py::reinterpret_steal<py::list>(py_ret);
  59. auto* tensor_wrapper = TensorWrapper::try_cast(py_list[0].ptr());
  60. auto tensor = tensor_wrapper->m_tensor;
  61. return tensor_wrapper->m_tensor;
  62. }
  63. #define REGISTE_APPLY_FUNC(mode) \
  64. void set_##mode(py::object pyf) { mode = pyf.ptr(); }
  65. REGISTE_APPLY_FUNC(cpp_apply_with_tracing)
  66. REGISTE_APPLY_FUNC(cpp_apply_const_with_tracing)
  67. REGISTE_APPLY_FUNC(cpp_apply_backward_varnode)
  68. REGISTE_APPLY_FUNC(cpp_apply_module_trace)
  69. #undef REGISTE_APPLY_FUNC
  70. Tensor::flags_t ApplyContext::global_disable = 0;
  71. Tensor::flags_t ApplyContext::global_enable = 0;
  72. void set_tracing() {
  73. ApplyContext::global_enable |= Tensor::Flags::TRACE;
  74. }
  75. void unset_tracing() {
  76. ApplyContext::global_enable &= ~Tensor::Flags::TRACE;
  77. }
  78. void set_module_tracing() {
  79. ApplyContext::global_enable |= Tensor::Flags::MODULE_TRACE;
  80. }
  81. void unset_module_tracing() {
  82. ApplyContext::global_enable &= ~Tensor::Flags::MODULE_TRACE;
  83. }
  84. bool is_tracing_module() {
  85. return ApplyContext::global_enable & Tensor::Flags::MODULE_TRACE;
  86. }
  87. bool skip_tracing = false;
  88. apply_result_t apply(ApplyContext& ctx) {
  89. // emulating scalar should be put to specific op's apply, e.g.,
  90. // elementwise, reduce, typecvt. Currently it's still handled at python
  91. // side. It could be move to C++ side if it has an impact on performance
  92. auto flags = ctx.flags & ~ApplyContext::global_disable;
  93. flags = flags | ApplyContext::global_enable;
  94. if (flags & Tensor::Flags::SCALAR) {
  95. // TODO: emulate scalar
  96. }
  97. if (flags & Tensor::Flags::GRAD) {
  98. return apply_grad(ctx);
  99. }
  100. if (auto* op = ctx.op->try_cast_final<GenericPyOp>()) {
  101. py::tuple pyin(ctx.nargs);
  102. for (size_t i = 0; i < ctx.nargs; ++i) {
  103. pyin[i] = TensorWrapper::make(ctx.pytype, ctx.args[i]->shared_from_this());
  104. }
  105. auto f = py::getattr(op->obj, "_default_rule");
  106. auto pyout = py::reinterpret_steal<py::object>(
  107. PyObject_Call(f.ptr(), pyin.ptr(), nullptr));
  108. if (!pyout)
  109. throw py::error_already_set();
  110. if (auto* tw = TensorWrapper::try_cast(pyout.ptr())) {
  111. return {tw->m_tensor};
  112. }
  113. apply_result_t ret;
  114. ret.reserve(py::len(pyout));
  115. for (auto&& i : pyout) {
  116. auto* tw = TensorWrapper::try_cast(i.ptr());
  117. mgb_assert(tw);
  118. ret.push_back(tw->m_tensor);
  119. }
  120. return ret;
  121. }
  122. if (flags & Tensor::Flags::MODULE_TRACE) {
  123. return apply_module_trace(ctx);
  124. }
  125. if (flags & Tensor::Flags::TRACE) {
  126. return apply_trace(ctx);
  127. } else {
  128. SmallVector<interpreter::Interpreter::Handle> handles(ctx.nargs);
  129. for (size_t i = 0; i < ctx.nargs; ++i) {
  130. handles[i] = ctx.args[i]->m_handle.get();
  131. }
  132. apply_result_t outputs;
  133. // fast copy without really applying
  134. if (ctx.op->same_type<FastpathCopy>()) {
  135. mgb_assert(ctx.nargs == 1);
  136. outputs.reserve(ctx.nargs);
  137. outputs.emplace_back(std::make_shared<Tensor>(ctx.args[0]->m_handle));
  138. return outputs;
  139. }
  140. auto output_handles = interpreter_for_py->apply_op(ctx.op, handles);
  141. outputs.reserve(output_handles.size());
  142. for (auto h : output_handles) {
  143. outputs.emplace_back(std::make_shared<Tensor>(h));
  144. }
  145. return outputs;
  146. }
  147. mgb_assert(0);
  148. }
  149. PyObject* py_apply(
  150. PyObject* self, PyObject* const* args, size_t nargs /* , PyObject* kwnames */) {
  151. try {
  152. // if (kwnames && PyTuple_GET_SIZE(kwnames)) {
  153. // PyErr_SetString(PyExc_TypeError, "keyword argument not allowed");
  154. // return nullptr;
  155. // }
  156. if (nargs < 2) {
  157. PyErr_SetString(
  158. PyExc_TypeError,
  159. "py_apply expects one Op and at least one tensor "
  160. "as argument");
  161. return nullptr;
  162. }
  163. auto* op = args[0];
  164. PyTypeObject* pytype = args[1]->ob_type;
  165. // check if pytype is Parameter(and all other python Tensor's derived class),
  166. // if yes, using it's tp_base(python Tensor)
  167. if (TensorWrapper::wrap_t::type().same_pytype(pytype->tp_base->tp_base)) {
  168. pytype = pytype->tp_base;
  169. }
  170. ++args;
  171. --nargs;
  172. ApplyContext ctx;
  173. ctx.flags = 0;
  174. ctx.op = py::handle(op).cast<std::shared_ptr<OpDef>>();
  175. SmallVector<Tensor*, 64> tensors(nargs);
  176. ctx.args = &tensors[0];
  177. ctx.nargs = nargs;
  178. ctx.pytype = pytype;
  179. if (py::isinstance<PySymbolVar>(py::handle(args[0]))) {
  180. SmallVector<cg::VarNode*> vinputs(nargs);
  181. for (size_t i = 0; i < nargs; ++i) {
  182. vinputs[i] = py::handle(args[i]).cast<PySymbolVar*>()->m_node;
  183. }
  184. auto op = ctx.op.get();
  185. auto rst = OpDef::apply_on_var_node(*op, vinputs);
  186. auto ret = pybind11::tuple(rst.size());
  187. auto typeobj = py::handle(args[0]).get_type();
  188. for (size_t i = 0; i < rst.size(); ++i) {
  189. ret[i] = typeobj(pybind11::cast(
  190. rst[i], pybind11::return_value_policy::automatic));
  191. }
  192. return ret.release().ptr();
  193. }
  194. for (size_t i = 0; i < nargs; ++i) {
  195. if (TensorWrapper* tw = TensorWrapper::try_cast(args[i])) {
  196. auto* t = tensors[i] = tw->m_tensor.get();
  197. ctx.flags |= t->m_flags;
  198. } else {
  199. PyErr_SetString(
  200. PyExc_TypeError,
  201. ssprintf(
  202. "op %s expect type Tensor as inputs, got %s actually",
  203. ctx.op->make_name().c_str(), Py_TYPE(args[i])->tp_name)
  204. .c_str());
  205. return nullptr;
  206. }
  207. }
  208. auto outputs = apply(ctx);
  209. size_t nout = outputs.size();
  210. auto ret = py::tuple(nout);
  211. for (size_t i = 0; i < nout; ++i) {
  212. ret[i] = TensorWrapper::make(pytype, std::move(outputs[i]));
  213. }
  214. return ret.release().ptr();
  215. }
  216. PYEXT17_TRANSLATE_EXC_RET(nullptr)
  217. }
  218. TensorWrapper::TensorWrapper(PyObject* args, PyObject* kwargs) {
  219. if (kwargs && PyDict_Size(kwargs)) {
  220. throw py::type_error("keyword argument not allowed");
  221. }
  222. auto nargs = PyTuple_Size(args);
  223. auto tup = py::reinterpret_borrow<py::tuple>(args);
  224. if (nargs == 0) {
  225. throw py::type_error("too few arguments");
  226. }
  227. if (auto* t = try_cast(tup[0].ptr())) {
  228. if (nargs > 1) {
  229. throw py::type_error("expect 1 argument");
  230. }
  231. m_tensor = t->m_tensor;
  232. } else {
  233. if (nargs == 1) {
  234. auto arg0 = PyTuple_GetItem(args, 0);
  235. // for lazy_eval_tensor
  236. if (strstr(arg0->ob_type->tp_name, "VarNode")) {
  237. if (PyObject_HasAttrString(arg0, "_node")) {
  238. arg0 = PyObject_GetAttrString(arg0, "_node");
  239. }
  240. m_tensor =
  241. std::make_shared<Tensor>(py::handle(arg0).cast<cg::VarNode*>());
  242. } else {
  243. // for DeviceTensorND
  244. if (strstr(arg0->ob_type->tp_name, "DeviceTensorND")) {
  245. auto dv = py::handle(arg0).cast<DeviceTensorND>();
  246. interpreter::Interpreter::Handle handle =
  247. interpreter_for_py->put(dv, {});
  248. m_tensor = std::make_shared<Tensor>(handle);
  249. } else {
  250. throw py::type_error(
  251. "single argument is not tensor, varnode or devicetensor");
  252. }
  253. }
  254. } else {
  255. py::detail::loader_life_support life_sup; // FIXME!!!required to cast DType
  256. if (nargs != 5 && nargs != 6) {
  257. throw py::type_error("expect 5 or 6 arguments");
  258. }
  259. auto data = tup[0].cast<py::array>();
  260. DType dtype = tup[1].cast<DType>();
  261. CompNode cn = tup[2].cast<CompNode>();
  262. bool is_const = tup[3].cast<bool>();
  263. bool no_cache = nargs == 6 ? tup[4].cast<bool>() : false;
  264. std::string name;
  265. if (tup[nargs - 1].ptr() != Py_None)
  266. name = tup[nargs - 1].cast<std::string>();
  267. // const op
  268. if (is_const && (ApplyContext::global_enable == Tensor::Flags::TRACE)) {
  269. auto py_ret =
  270. PyObject_Call(cpp_apply_const_with_tracing, tup.ptr(), nullptr);
  271. if (!py_ret)
  272. throw py::error_already_set();
  273. auto py_list = py::reinterpret_steal<py::list>(py_ret);
  274. if (auto* t = try_cast(py_list[0].ptr())) {
  275. m_tensor = t->m_tensor;
  276. }
  277. return;
  278. }
  279. interpreter::Interpreter::Handle handle;
  280. {
  281. HostTensorND ret(cn);
  282. handle = interpreter_for_py->put(
  283. npy::np2tensor(data.ptr(), npy::Meth::copy_into(&ret), dtype),
  284. no_cache);
  285. }
  286. m_tensor = std::make_shared<Tensor>(handle);
  287. m_tensor->user_custom_name = name;
  288. if (data.ndim() == 0) {
  289. m_tensor->m_flags |= Tensor::Flags::SCALAR;
  290. }
  291. }
  292. }
  293. }
  294. #define REGISTE_TENSORWRAPPER_FUNC(type, member) \
  295. PyObject* TensorWrapper::member() { \
  296. return py::cast(m_tensor->m_trace_info.member).release().ptr(); \
  297. } \
  298. void TensorWrapper::set_##member(PyObject* dest) { \
  299. auto py_dest = py::reinterpret_borrow<py::object>(dest); \
  300. type real_dest = py_dest.cast<type>(); \
  301. m_tensor->m_trace_info.member = real_dest; \
  302. }
  303. REGISTE_TENSORWRAPPER_FUNC(int64_t, mixin_handle)
  304. REGISTE_TENSORWRAPPER_FUNC(bool, recording)
  305. #undef REGISTE_TENSORWRAPPER_FUNC
  306. PyObject* TensorWrapper::module_trace_info() {
  307. if (!m_tensor->m_module_trace_info.ptr()) {
  308. PyErr_SetString(
  309. PyExc_AttributeError,
  310. "Has no attribute named \'_NodeMixin__node\', please "
  311. "set it first");
  312. return nullptr;
  313. }
  314. return m_tensor->m_module_trace_info.inc_ref().ptr();
  315. }
  316. void TensorWrapper::set_module_trace_info(PyObject* obj) {
  317. m_tensor->m_module_trace_info = py::reinterpret_borrow<py::object>(obj);
  318. }
  319. #define REGISTE_TENSORWRAPPER_PYOBJECT_FUNC(member) \
  320. PyObject* TensorWrapper::member() { \
  321. if (m_tensor->m_trace_info.member) { \
  322. return m_tensor->m_trace_info.member; \
  323. } else { \
  324. Py_RETURN_NONE; \
  325. } \
  326. } \
  327. void TensorWrapper::set_##member(PyObject* dest) { \
  328. if (dest == Py_None) { \
  329. Py_XDECREF(m_tensor->m_trace_info.member); \
  330. m_tensor->m_trace_info.member = nullptr; \
  331. } else { \
  332. Py_INCREF(dest); \
  333. m_tensor->m_trace_info.member = dest; \
  334. } \
  335. }
  336. REGISTE_TENSORWRAPPER_PYOBJECT_FUNC(compiled_info)
  337. REGISTE_TENSORWRAPPER_PYOBJECT_FUNC(trace_mixin_info)
  338. #undef REGISTE_TENSORWRAPPER_PYOBJECT_FUNC
  339. #define SET_GET_NAME(member) \
  340. PyObject* TensorWrapper::member() { \
  341. return py::cast(m_tensor->member).release().ptr(); \
  342. } \
  343. void TensorWrapper::set_##member(PyObject* dest) { \
  344. auto py_dest = py::reinterpret_borrow<py::object>(dest); \
  345. m_tensor->member = py_dest.cast<std::string>(); \
  346. }
  347. SET_GET_NAME(user_custom_name)
  348. SET_GET_NAME(automatic_name)
  349. #undef SET_GET_NAME
  350. PyObject* TensorWrapper::handle() {
  351. return py::cast(m_tensor->m_handle).release().ptr();
  352. }
  353. void TensorWrapper::set_handle(PyObject* dest) {
  354. auto py_dest = py::reinterpret_borrow<py::object>(dest);
  355. SharedHandle real_dest = py_dest.cast<SharedHandle>();
  356. m_tensor->m_handle = std::move(real_dest);
  357. }
  358. PyObject* TensorWrapper::shape() {
  359. // if it's tracing compiled mode, get value from compiled_info
  360. if (m_tensor->m_trace_info.compiled_info != nullptr) {
  361. if (m_tensor->m_flags & Tensor::Flags::SCALAR) {
  362. return PyTuple_New(0);
  363. }
  364. PyObject* shp =
  365. PyObject_GetAttrString(m_tensor->m_trace_info.compiled_info, "shape");
  366. if (shp == Py_None) {
  367. throw TraceReadError("shape of this tensor is not read in trace");
  368. }
  369. return shp;
  370. }
  371. // inside trace, if tensor shape is useful for other operations, set shape_read = true
  372. if (m_tensor->m_trace_info.recording && !skip_tracing) {
  373. PyObject_SetAttrString(
  374. m_tensor->m_trace_info.trace_mixin_info, "shape_read",
  375. py::cast(true).release().ptr());
  376. }
  377. if (m_tensor->m_flags & Tensor::Flags::SCALAR) {
  378. return PyTuple_New(0);
  379. }
  380. TensorShape shape;
  381. if (m_tensor->m_var) { // get shape from m_var
  382. auto&& mgr = m_tensor->m_var->owner_graph()->static_infer_manager();
  383. auto&& type = mgr.get_infer_type(m_tensor->m_var);
  384. using InferType = cg::static_infer::InferType;
  385. if (!(type.shape & (InferType::CONST | InferType::RT_STATIC))) {
  386. Py_RETURN_NONE;
  387. }
  388. auto* tshp = mgr.infer_shape_fallible(m_tensor->m_var);
  389. if (!tshp) {
  390. Py_RETURN_NONE;
  391. }
  392. shape = *tshp;
  393. } else {
  394. py::gil_scoped_release _;
  395. shape = m_tensor->shape();
  396. }
  397. if (!shape.ndim) {
  398. Py_RETURN_NONE;
  399. }
  400. py::tuple ret(shape.ndim);
  401. for (size_t i = 0; i < shape.ndim; ++i) {
  402. ret[i] = shape[i];
  403. }
  404. return ret.release().ptr();
  405. }
  406. PyObject* TensorWrapper::dtype() {
  407. if (m_tensor->m_var) {
  408. return py::cast(m_tensor->m_var->dtype()).release().ptr();
  409. }
  410. return py::cast(m_tensor->dtype()).release().ptr();
  411. }
  412. PyObject* TensorWrapper::device() {
  413. if (m_tensor->m_var) {
  414. return py::cast(m_tensor->m_var->comp_node()).release().ptr();
  415. }
  416. return py::cast(m_tensor->comp_node()).release().ptr();
  417. }
  418. PyObject* TensorWrapper::numpy() {
  419. if (m_tensor->m_trace_info.compiled_info != nullptr) {
  420. PyObject* np_val = PyObject_CallMethod(
  421. m_tensor->m_trace_info.compiled_info, "numpy", nullptr);
  422. if (!np_val)
  423. throw py::error_already_set();
  424. if (np_val == Py_None) {
  425. throw TraceReadError("value of this tensor is not read in trace");
  426. }
  427. if (m_tensor->m_flags & Tensor::Flags::SCALAR) {
  428. PyObject* np_scalar =
  429. PyArray_Squeeze(reinterpret_cast<PyArrayObject*>(np_val));
  430. Py_DECREF(np_val);
  431. return np_scalar;
  432. }
  433. return np_val;
  434. }
  435. if (m_tensor->m_trace_info.recording && !skip_tracing) {
  436. PyObject_SetAttrString(
  437. m_tensor->m_trace_info.trace_mixin_info, "value_read",
  438. py::cast(true).release().ptr());
  439. }
  440. if (m_tensor->m_handle.get() == nullptr && m_tensor->m_var != nullptr) {
  441. auto&& mgr = m_tensor->m_var->owner_graph()->static_infer_manager();
  442. auto&& type = mgr.get_infer_type(m_tensor->m_var);
  443. using InferType = cg::static_infer::InferType;
  444. if (!(type.value & (InferType::CONST | InferType::RT_STATIC))) {
  445. PyErr_SetString(PyExc_ValueError, "tensor invalid");
  446. return nullptr;
  447. }
  448. auto* val = mgr.infer_value_fallible(m_tensor->m_var);
  449. if (!val) {
  450. PyErr_SetString(PyExc_ValueError, "tensor invalid");
  451. return nullptr;
  452. }
  453. auto np_val = py::cast(*val).attr("numpy")();
  454. if (m_tensor->m_flags & Tensor::Flags::SCALAR) {
  455. return PyArray_Squeeze(
  456. reinterpret_cast<PyArrayObject*>(np_val.release().ptr()));
  457. }
  458. return np_val.release().ptr();
  459. }
  460. auto&& hv = [&]() {
  461. py::gil_scoped_release _;
  462. return interpreter_for_py->get_value(m_tensor->m_handle.get());
  463. }();
  464. auto arr = py::reinterpret_steal<py::array>(
  465. npy::ndarray_from_tensor(hv, npy::ShareType::TRY_SHARE));
  466. if (!arr) {
  467. PyErr_SetString(PyExc_ValueError, "tensor invalid");
  468. return nullptr;
  469. }
  470. if (m_tensor->m_flags & Tensor::Flags::SCALAR) {
  471. mgb_assert(PyArray_Check(arr.ptr()));
  472. return PyArray_Squeeze(reinterpret_cast<PyArrayObject*>(arr.ptr()));
  473. }
  474. return arr.release().ptr();
  475. }
  476. PyObject* TensorWrapper::varnode() {
  477. if (m_tensor->m_var) {
  478. return py::cast(m_tensor->m_var).release().ptr();
  479. }
  480. Py_RETURN_NONE;
  481. }
  482. void TensorWrapper::reset(PyObject* tensor) {
  483. TensorWrapper* t = TensorWrapper::try_cast(tensor);
  484. if (!t) {
  485. throw py::type_error("expect Tensor");
  486. }
  487. std::string user_custom_name = m_tensor->user_custom_name;
  488. std::string automatic_name = m_tensor->automatic_name;
  489. auto module_trace_info = m_tensor->m_module_trace_info;
  490. m_tensor = t->m_tensor;
  491. m_tensor->m_module_trace_info = module_trace_info;
  492. m_tensor->user_custom_name = user_custom_name;
  493. m_tensor->automatic_name = automatic_name;
  494. }
  495. void TensorWrapper::reset_varnode() {
  496. m_tensor->m_var = nullptr;
  497. }
  498. PyObject* TensorWrapper::detach() {
  499. PyObject* self = wrap_t::pycast(this);
  500. PyTypeObject* pytype = self->ob_type;
  501. static std::shared_ptr<OpDef> op = std::shared_ptr<OpDef>(new FastpathCopy());
  502. auto new_tensor = python::apply(op, m_tensor)[0];
  503. new_tensor->m_grad_info_dict = {};
  504. auto ret = TensorWrapper::make(pytype, std::move(new_tensor));
  505. return ret.release().ptr();
  506. }
  507. PyObject* TensorWrapper::_dev_tensor() {
  508. if (m_tensor->m_trace_info.compiled_info != nullptr) {
  509. auto* dev_tensor = PyObject_CallMethod(
  510. m_tensor->m_trace_info.compiled_info, "_dev_tensor", nullptr);
  511. if (!dev_tensor)
  512. throw py::error_already_set();
  513. if (dev_tensor == Py_None) {
  514. throw TraceReadError("raw data of this tensor is not read in trace");
  515. }
  516. // set m_handle to make it a real tensor
  517. auto py_dev_tensor = py::reinterpret_borrow<py::object>(dev_tensor);
  518. auto sh = interpreter_for_py->put(py_dev_tensor.cast<DeviceTensorND>(), {});
  519. m_tensor->m_handle = std::move(SharedHandle(sh));
  520. // compiled info is useless after m_handle is set
  521. Py_DECREF(m_tensor->m_trace_info.compiled_info);
  522. m_tensor->m_trace_info.compiled_info = nullptr;
  523. return dev_tensor;
  524. }
  525. if (m_tensor->m_trace_info.recording && !skip_tracing) {
  526. PyObject_SetAttrString(
  527. m_tensor->m_trace_info.trace_mixin_info, "data_read",
  528. py::cast(true).release().ptr());
  529. }
  530. auto dev_tensor = [&]() {
  531. py::gil_scoped_release _;
  532. return interpreter_for_py->get_dev_tensor(m_tensor->m_handle.get());
  533. }();
  534. return py::cast(dev_tensor).release().ptr();
  535. }
  536. void TensorWrapper::_drop() {
  537. interpreter_for_py->drop(m_tensor->m_handle.get());
  538. }
  539. PyObject* TensorWrapper::isscalar() {
  540. if (m_tensor->m_flags & Tensor::Flags::SCALAR) {
  541. Py_RETURN_TRUE;
  542. } else {
  543. Py_RETURN_FALSE;
  544. }
  545. }
  546. void TensorWrapper::setscalar() {
  547. m_tensor->m_flags |= Tensor::Flags::SCALAR;
  548. }
  549. void TensorWrapper::unsetscalar() {
  550. m_tensor->m_flags &= ~Tensor::Flags::SCALAR;
  551. }
  552. struct TensorWeakRef {
  553. std::weak_ptr<Tensor> wptr;
  554. TensorWeakRef(const TensorWrapper& tw) : wptr(tw.m_tensor) {}
  555. py::object operator()() {
  556. if (auto p = wptr.lock()) {
  557. return TensorWrapper::make(p);
  558. }
  559. return py::none();
  560. }
  561. int _use_cnt() { return wptr.use_count(); }
  562. };
  563. /* ============== convert inputs ============== */
  564. // map numpy.dtype.kind to priority
  565. inline uint8_t category_priority(char c) {
  566. switch (c) {
  567. case 'f':
  568. return 3; // floating-point
  569. case 'i':
  570. return 2; // signed integer
  571. case 'u':
  572. return 2; // unsigned integer
  573. case 'b':
  574. return 1; // boolean
  575. default:
  576. return 0;
  577. }
  578. }
  579. // Returns the maximum value of the priority of each type in the list `types`.
  580. uint8_t max_priority(SmallVector<PyArray_Descr*> types) {
  581. if (types.size() == 0) {
  582. return 0;
  583. } else {
  584. uint8_t max_p = 0;
  585. for (auto&& desc : types) {
  586. max_p = std::max(max_p, category_priority(desc->kind));
  587. }
  588. return max_p;
  589. }
  590. }
  591. // Returns the data type with sufficient size to hold all types of
  592. // category `cat` in the list `types`.
  593. PyArray_Descr* promote_types(SmallVector<PyArray_Descr*> types, uint8_t cat) {
  594. // Return value: New reference
  595. SmallVector<PyArray_Descr*> used_types;
  596. for (auto&& desc : types) {
  597. auto&& v = category_priority(desc->kind);
  598. if (v == cat) {
  599. used_types.emplace_back(desc);
  600. }
  601. }
  602. mgb_assert(used_types.size() > 0, "size of used_types is 0");
  603. PyArray_Descr* res = used_types[0];
  604. Py_INCREF(res);
  605. for (size_t i = 1; i < used_types.size(); ++i) {
  606. PyArray_Descr* tmp = PyArray_PromoteTypes(used_types[i], res);
  607. Py_DECREF(res);
  608. res = tmp;
  609. }
  610. return res;
  611. }
  612. PyArray_Descr* scalar2dtype(PyObject* arg) {
  613. // Return value: New reference
  614. if (PyBool_Check(arg)) {
  615. auto&& descr = PyArray_DescrFromType(NPY_BOOL);
  616. return descr;
  617. }
  618. if (PyLong_CheckExact(arg)) {
  619. auto&& descr = PyArray_DescrFromType(NPY_INT32);
  620. return descr;
  621. }
  622. if (PyFloat_CheckExact(arg)) {
  623. auto&& descr = PyArray_DescrFromType(NPY_FLOAT32);
  624. return descr;
  625. }
  626. return nullptr;
  627. }
  628. PyArray_Descr* _dtype_promotion(PyObject* const* args, size_t nargs) {
  629. // Return value: New reference
  630. SmallVector<PyArray_Descr*> tensors;
  631. SmallVector<PyArray_Descr*> scalars;
  632. bool is_tuple = false;
  633. PyObject* tuple = nullptr;
  634. if (nargs == 1 && (PyTuple_Check(args[0]) || PyList_Check(args[0]))) {
  635. if (PyList_Check(args[0])) {
  636. tuple = PyList_AsTuple(args[0]);
  637. } else {
  638. tuple = args[0];
  639. Py_INCREF(tuple);
  640. }
  641. nargs = PyTuple_Size(tuple);
  642. is_tuple = true;
  643. }
  644. for (size_t i = 0; i < nargs; ++i) {
  645. PyObject* handle = is_tuple ? PyTuple_GetItem(tuple, i) : args[i];
  646. if (handle == Py_None)
  647. continue;
  648. TensorWrapper* tw = TensorWrapper::try_cast(handle);
  649. if (tw) {
  650. mgb::DType type = tw->m_tensor->dtype();
  651. auto&& descr = npy::dtype_mgb2np_descr(type);
  652. Py_INCREF(descr.get());
  653. tensors.emplace_back(descr.get());
  654. } else {
  655. if (PyArray_Check(handle) || PyArray_CheckScalar(handle)) {
  656. auto&& descr = PyArray_DescrFromObject(handle, nullptr);
  657. tensors.emplace_back(descr);
  658. continue;
  659. }
  660. if (py::isinstance<PySymbolVar>(py::handle(handle))) {
  661. auto var = py::handle(handle).cast<PySymbolVar*>();
  662. mgb::DType type = var->m_node->dtype();
  663. auto&& descr = npy::dtype_mgb2np_descr(type);
  664. Py_INCREF(descr.get());
  665. tensors.emplace_back(descr.get());
  666. continue;
  667. }
  668. PyArray_Descr* descr = scalar2dtype(handle);
  669. if (descr) {
  670. scalars.emplace_back(descr);
  671. continue;
  672. }
  673. }
  674. }
  675. auto max_pri_scalars = max_priority(scalars);
  676. auto max_pri_tensors = max_priority(tensors);
  677. if (max_pri_scalars <= 0 && max_pri_tensors <= 0) {
  678. throw py::value_error("invalid input, no dtype avaliable");
  679. }
  680. PyArray_Descr* res;
  681. if (max_pri_scalars > max_pri_tensors) {
  682. res = promote_types(scalars, max_pri_scalars);
  683. } else {
  684. res = promote_types(tensors, max_pri_tensors);
  685. }
  686. for (auto* p : tensors) {
  687. Py_DECREF(p);
  688. }
  689. for (auto* p : scalars) {
  690. Py_DECREF(p);
  691. }
  692. Py_XDECREF(tuple);
  693. return res;
  694. }
  695. CompNode _get_device(PyObject* const* args, size_t nargs) {
  696. bool is_tuple = false;
  697. PyObject* tuple = nullptr;
  698. if (nargs == 1 && (PyTuple_Check(args[0]) || PyList_Check(args[0]))) {
  699. if (PyList_Check(args[0])) {
  700. tuple = PyList_AsTuple(args[0]);
  701. } else {
  702. tuple = args[0];
  703. Py_INCREF(tuple);
  704. }
  705. nargs = PyTuple_Size(tuple);
  706. is_tuple = true;
  707. }
  708. bool valid = false;
  709. CompNode cn;
  710. for (size_t i = 0; i < nargs; ++i) {
  711. PyObject* handle = is_tuple ? PyTuple_GetItem(tuple, i) : args[i];
  712. TensorWrapper* tw = TensorWrapper::try_cast(handle);
  713. bool is_symvar = py::isinstance<PySymbolVar>(py::handle(handle));
  714. if (tw || is_symvar) {
  715. if (!valid) {
  716. cn = tw ? tw->m_tensor->comp_node()
  717. : py::handle(handle).cast<PySymbolVar*>()->m_node->comp_node();
  718. valid = true;
  719. } else {
  720. CompNode cn1 = tw ? tw->m_tensor->comp_node()
  721. : py::handle(handle)
  722. .cast<PySymbolVar*>()
  723. ->m_node->comp_node();
  724. if (cn1 != cn) {
  725. throw py::value_error(ssprintf(
  726. "ambiguous device: %s vs %s", cn.to_string().c_str(),
  727. cn1.to_string().c_str()));
  728. }
  729. }
  730. }
  731. }
  732. if (!valid) {
  733. return CompNode::load(get_default_device());
  734. }
  735. Py_XDECREF(tuple);
  736. return cn;
  737. }
  738. // Returns the dtype that would result from performing an arithmetic
  739. // operation on the provided input tensors and scalars.
  740. PyObject* dtype_promotion(PyObject* self, PyObject* const* args, size_t nargs) {
  741. if (!nargs) {
  742. PyErr_SetString(PyExc_TypeError, "empty input is not allowed");
  743. return nullptr;
  744. }
  745. try {
  746. PyArray_Descr* res = _dtype_promotion(args, nargs);
  747. return py::cast(npy::dtype_np2mgb_descr(res)).release().ptr();
  748. }
  749. PYEXT17_TRANSLATE_EXC_RET(nullptr)
  750. }
  751. PyObject* get_device(PyObject* self, PyObject* const* args, size_t nargs) {
  752. if (!nargs) {
  753. PyErr_SetString(PyExc_TypeError, "empty input is not allowed");
  754. return nullptr;
  755. }
  756. try {
  757. CompNode cn = _get_device(args, nargs);
  758. return py::cast(cn).release().ptr();
  759. }
  760. PYEXT17_TRANSLATE_EXC_RET(nullptr)
  761. }
  762. #ifdef METH_FASTCALL
  763. #define MGE_PY_INTERFACE(NAME, FUNC) \
  764. { #NAME, (PyCFunction)FUNC, METH_FASTCALL, nullptr }
  765. #else
  766. #define WRAP_FUNC_PY35(FUNC) \
  767. PyObject* py35_##FUNC(PyObject* self, PyObject* args) { \
  768. auto* arr = &PyTuple_GET_ITEM(args, 0); \
  769. auto size = PyTuple_GET_SIZE(args); \
  770. return FUNC(self, arr, size); \
  771. }
  772. WRAP_FUNC_PY35(py_apply);
  773. WRAP_FUNC_PY35(dtype_promotion);
  774. WRAP_FUNC_PY35(get_device);
  775. #undef WRAP_FUNC_PY35
  776. #define MGE_PY_INTERFACE(NAME, FUNC) \
  777. { #NAME, (PyCFunction)py35_##FUNC, METH_VARARGS, nullptr }
  778. #endif
  779. void init_tensor(py::module m) {
  780. imperative::Tensor::static_initialize();
  781. static auto sl_interpreter_for_py =
  782. interpreter::Interpreter::inst().create_channel();
  783. interpreter_for_py = sl_interpreter_for_py.get();
  784. static py::exception<interpreter::AsyncError> py_async_error(
  785. m, "AsyncError", PyExc_RuntimeError);
  786. py::register_exception_translator([](std::exception_ptr p) {
  787. try {
  788. if (p)
  789. std::rethrow_exception(p);
  790. } catch (const interpreter::AsyncError& e) {
  791. pyext17::pybind11_translate_exception(e.nested_ptr());
  792. if (PyErr_Occurred()) {
  793. PyObject *exc, *val, *tb;
  794. PyErr_Fetch(&exc, &val, &tb);
  795. PyErr_NormalizeException(&exc, &val, &tb);
  796. if (tb) {
  797. PyException_SetTraceback(val, tb);
  798. }
  799. auto val2 = py_async_error.py::object::operator()(
  800. "An async error is reported. See above for the actual cause."
  801. " Hint: This is where it is reported, not where it happened."
  802. " You may call `megengine.core.set_option('async_level', 0)` "
  803. "to get better error reporting.");
  804. PyException_SetCause(
  805. val2.ptr(), val); // PyException_SetCause steals reference
  806. Py_XDECREF(exc);
  807. Py_XDECREF(tb);
  808. PyErr_Restore(
  809. py_async_error.inc_ref().ptr(), val2.release().ptr(), nullptr);
  810. } else {
  811. py_async_error("Unkown async error");
  812. }
  813. }
  814. });
  815. auto* tensor_type =
  816. TensorWrapper::wrap_t::type()
  817. .def<&TensorWrapper::numpy>("numpy")
  818. .def_getset<&TensorWrapper::shape>("shape")
  819. .def_getset<&TensorWrapper::dtype>("dtype")
  820. .def_getset<&TensorWrapper::device>("device")
  821. .def<&TensorWrapper::reset>("_reset")
  822. .def<&TensorWrapper::isscalar>("_isscalar")
  823. .def<&TensorWrapper::setscalar>("_setscalar")
  824. .def<&TensorWrapper::unsetscalar>("_unsetscalar")
  825. .def<&TensorWrapper::detach>("detach")
  826. .def<&TensorWrapper::_dev_tensor>("_dev_tensor")
  827. .def<&TensorWrapper::_drop>("_drop")
  828. .def<&TensorWrapper::reset_varnode>("_reset_varnode")
  829. .def<&TensorWrapper::_use_cnt>("_use_cnt")
  830. .def_getset<&TensorWrapper::varnode>("_varnode")
  831. .def_getset<
  832. &TensorWrapper::mixin_handle,
  833. &TensorWrapper::set_mixin_handle>("_mixin_handle")
  834. .def_getset<
  835. &TensorWrapper::recording, &TensorWrapper::set_recording>(
  836. "_recording")
  837. .def_getset<&TensorWrapper::handle, &TensorWrapper::set_handle>(
  838. "_handle")
  839. .def_getset<
  840. &TensorWrapper::compiled_info,
  841. &TensorWrapper::set_compiled_info>("_compiled_info")
  842. .def_getset<
  843. &TensorWrapper::trace_mixin_info,
  844. &TensorWrapper::set_trace_mixin_info>("_trace_mixin_info")
  845. .def_getset<
  846. &TensorWrapper::user_custom_name,
  847. &TensorWrapper::set_user_custom_name>("c_name")
  848. .def_getset<
  849. &TensorWrapper::automatic_name,
  850. &TensorWrapper::set_automatic_name>("_name")
  851. .def_getset<
  852. &TensorWrapper::module_trace_info,
  853. &TensorWrapper::set_module_trace_info>("_NodeMixin__node")
  854. .finalize();
  855. if (!tensor_type)
  856. throw py::error_already_set();
  857. py::setattr(m, "Tensor", tensor_type);
  858. py::class_<TensorWeakRef>(m, "TensorWeakRef")
  859. .def(py::init<const TensorWrapper&>())
  860. .def("__call__", &TensorWeakRef::operator())
  861. .def("_use_cnt", &TensorWeakRef::_use_cnt);
  862. py::class_<PySymbolVar, std::shared_ptr<PySymbolVar>>(m, "SymbolVar")
  863. .def_property_readonly(
  864. "dtype", [](PySymbolVar* v) { return v->m_node->dtype(); })
  865. .def_property(
  866. "var", [](PySymbolVar* v) { return v->m_node; },
  867. [](PySymbolVar* s, cg::VarNode* v) { s->m_node = v; })
  868. .def_property_readonly(
  869. "device", [](PySymbolVar* v) { return v->m_node->comp_node(); })
  870. .def_property_readonly(
  871. "graph", [](PySymbolVar* v) { return v->m_node->owner_graph(); })
  872. .def_property_readonly(
  873. "shape",
  874. [](PySymbolVar* v) -> const TensorShape* {
  875. auto&& mgr = v->m_node->owner_graph()->static_infer_manager();
  876. return mgr.infer_shape_fallible(v->m_node);
  877. })
  878. .def("numpy",
  879. [](PySymbolVar* v) {
  880. auto&& mgr = v->m_node->owner_graph()->static_infer_manager();
  881. auto&& type = mgr.get_infer_type(v->m_node);
  882. using InferType = cg::static_infer::InferType;
  883. if (!(type.value & (InferType::CONST | InferType::RT_STATIC))) {
  884. throw py::value_error("value invalid!");
  885. }
  886. auto* val = mgr.infer_value_fallible(v->m_node);
  887. if (!val) {
  888. throw py::value_error("value invalid!");
  889. }
  890. auto np_val = py::cast(*val).attr("numpy")();
  891. if (v->is_scalar) {
  892. return py::object(py::array(np_val).squeeze());
  893. }
  894. return np_val;
  895. })
  896. .def("_isscalar", [](PySymbolVar* v) { return v->is_scalar; })
  897. .def("_setscalar", [](PySymbolVar* v) { return v->is_scalar = true; })
  898. .def(py::init([](cg::VarNode* node) {
  899. return std::make_shared<PySymbolVar>(node);
  900. }),
  901. py::arg() = nullptr);
  902. static PyMethodDef method_defs[] = {
  903. MGE_PY_INTERFACE(apply, py_apply),
  904. MGE_PY_INTERFACE(dtype_promotion, dtype_promotion),
  905. MGE_PY_INTERFACE(get_device, get_device),
  906. {nullptr, nullptr, 0, nullptr}};
  907. for (auto&& def : method_defs) {
  908. if (def.ml_meth != nullptr) {
  909. auto* func = PyCFunction_NewEx(&def, nullptr, nullptr);
  910. if (!func)
  911. throw py::error_already_set();
  912. py::setattr(m, def.ml_name, func);
  913. }
  914. }
  915. static constexpr auto sync_py_task_q = [] { py_task_q.wait_all_task_finish(); };
  916. m.def("set_option", [](std::string name, size_t value) {
  917. interpreter_for_py->set_option(name, value);
  918. });
  919. m.def("clear_candidates", []() { interpreter_for_py->clear_candidates(); });
  920. m.def("get_option",
  921. [](std::string name) { return interpreter_for_py->get_option(name); });
  922. m.def("_set_drop_flag",
  923. [](bool flag) { interpreter_for_py->set_option("enable_drop", flag); });
  924. m.def("config_async_level", [](int level) {
  925. mgb_assert(level >= 0 and level <= 2, "async_level should be 0, 1 or 2");
  926. interpreter_for_py->set_option("async_level", level);
  927. });
  928. m.def("get_async_level",
  929. []() { return interpreter_for_py->get_option("async_level"); });
  930. m.def("set_buffer_length", [](int length) {
  931. mgb_assert(length >= 0 and length < 100, "buffer_length should be in [0, 100)");
  932. interpreter_for_py->set_option("buffer_length", length);
  933. });
  934. m.def("push_scope", [](std::string name) { interpreter_for_py->push_scope(name); });
  935. m.def("pop_scope", [](std::string name) { interpreter_for_py->pop_scope(name); });
  936. m.def(
  937. "start_profile",
  938. [](imperative::Profiler::options_t options) {
  939. interpreter_for_py->sync();
  940. imperative::Profiler::load_options(std::move(options));
  941. imperative::Profiler::start_profile();
  942. interpreter_for_py->start_profile();
  943. },
  944. py::call_guard<py::gil_scoped_release>());
  945. m.def(
  946. "stop_profile",
  947. []() -> std::function<void(std::string, std::string)> {
  948. interpreter_for_py->stop_profile();
  949. interpreter_for_py->sync();
  950. imperative::Profiler::stop_profile();
  951. auto results = std::make_shared<imperative::Profiler::bundle_t>(
  952. imperative::Profiler::collect());
  953. return [results = results](
  954. std::string basename, std::string format) mutable {
  955. imperative::Profiler::dump_profile(
  956. basename, format, std::move(*results));
  957. results = nullptr;
  958. };
  959. },
  960. py::call_guard<py::gil_scoped_release>());
  961. m.def(
  962. "sync",
  963. []() {
  964. interpreter_for_py->sync();
  965. sync_py_task_q();
  966. },
  967. py::call_guard<py::gil_scoped_release>());
  968. m.def(
  969. "full_sync",
  970. []() {
  971. interpreter_for_py->sync();
  972. CompNode::sync_all();
  973. sync_py_task_q();
  974. },
  975. py::call_guard<py::gil_scoped_release>());
  976. m.def(
  977. "close",
  978. []() {
  979. interpreter_for_py->close();
  980. sync_py_task_q();
  981. },
  982. py::call_guard<py::gil_scoped_release>());
  983. py::handle grad_key_type =
  984. GradKeyWrapper::wrap_t::type()
  985. .def<&GradKeyWrapper::attach>("attach")
  986. .def<&GradKeyWrapper::is_attached_to>("is_attached_to")
  987. .def_getset<&GradKeyWrapper::get_name, &GradKeyWrapper::set_name>(
  988. "name")
  989. .def_getset<
  990. &GradKeyWrapper::get_priority,
  991. &GradKeyWrapper::set_priority>("priority")
  992. .finalize();
  993. if (!grad_key_type)
  994. throw py::error_already_set();
  995. py::setattr(m, "GradKey", grad_key_type);
  996. m.def("backward", &GradKeyWrapper::backward);
  997. m.def("set_cpp_apply_with_tracing", &set_cpp_apply_with_tracing);
  998. m.def("set_cpp_apply_const_with_tracing", &set_cpp_apply_const_with_tracing);
  999. m.def("set_cpp_apply_backward_varnode", &set_cpp_apply_backward_varnode);
  1000. m.def("set_cpp_apply_module_trace", &set_cpp_apply_module_trace);
  1001. m.attr("skip_tracing") = &skip_tracing;
  1002. py::class_<SharedHandle>(m, "SharedHandle")
  1003. .def(py::init<const SharedHandle&>())
  1004. .def("__eq__",
  1005. [](SharedHandle& thish, SharedHandle& thath) {
  1006. return (thish.get() == thath.get());
  1007. })
  1008. .def("__hash__",
  1009. [](SharedHandle& sh) { return reinterpret_cast<int64_t>(sh.get()); });
  1010. m.def("set_tracing", &set_tracing);
  1011. m.def("unset_tracing", &unset_tracing);
  1012. m.def("set_allow_higher_order_directive",
  1013. [](bool value) { GradKey::allow_higher_order_directive = value; });
  1014. m.def("set_module_tracing", &set_module_tracing);
  1015. m.def("unset_module_tracing", &unset_module_tracing);
  1016. m.def("is_tracing_module", &is_tracing_module);
  1017. }
  1018. #undef MGE_PY_INTERFACE
  1019. } // namespace mgb::imperative::python

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台