You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

tensor.cpp 44 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085
  1. /**
  2. * \file imperative/python/src/tensor.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "megbrain/common.h"
  12. #include "megbrain/dtype.h"
  13. #include "megbrain/imperative/ops/autogen.h"
  14. #include "megbrain/imperative/ops/backward_graph.h"
  15. #include "megbrain/imperative/ops/utility.h"
  16. #include "megbrain/imperative/profiler.h"
  17. #include "megbrain/imperative/transformations/dim_expansion.h"
  18. #include "megbrain/imperative/transformations/dtype_promote.h"
  19. #include "megbrain/imperative/transformations/eval.h"
  20. #include "megbrain/imperative/transformations/lazy.h"
  21. #include "megbrain/imperative/transformations/scalar.h"
  22. #include "megbrain/imperative/transformations/symbol.h"
  23. #include "megbrain/imperative/transformations/trace.h"
  24. #include "megbrain/imperative/utils/map.h"
  25. #include "megbrain/imperative/utils/stats.h"
  26. #include "megbrain/opr/io.h"
  27. #include "megbrain/plugin/profiler.h"
  28. #include "./common.h"
  29. #include "./grad.h"
  30. #include "./graph_rt.h"
  31. #include "./helper.h"
  32. #include "./module_trace.h"
  33. #include "./numpy_dtypes.h"
  34. #include "./tensor.h"
  35. #include "./tensor_utils.h"
  36. #include "./transformation.h"
  37. #include <object.h>
  38. #include <pybind11/numpy.h>
  39. #include <pybind11/operators.h>
  40. #include <pybind11/pytypes.h>
  41. #include <pyerrors.h>
  42. #include <range/v3/all.hpp>
  43. #include <string>
  44. #include <unordered_map>
  45. #include "../../src/impl/mgb_cg_impl.h"
  46. namespace py = pybind11;
  47. namespace views = ranges::views;
  48. namespace mgb::imperative::python {
  49. namespace {
  50. WeakKeyMap<ValueWeakRef, py::object> module_trace_info_map;
  51. struct SymbolVarContext {
  52. TransformationContext context;
  53. std::shared_ptr<SymbolTransformation> symbol_tsf;
  54. std::shared_ptr<ScalarTransformation> scalar_tsf;
  55. std::shared_ptr<DTypePromoteTransformation> dtype_promote_tsf;
  56. std::shared_ptr<DimExpansionTransformation> dim_expansion_tsf;
  57. SymbolVarContext(cg::ComputingGraph* graph) {
  58. symbol_tsf = std::make_shared<SymbolTransformation>(graph);
  59. scalar_tsf = std::make_shared<ScalarTransformation>();
  60. dtype_promote_tsf = std::make_shared<DTypePromoteTransformation>();
  61. dim_expansion_tsf = std::make_shared<DimExpansionTransformation>();
  62. Transformation::swap_context(context);
  63. }
  64. void init() {
  65. symbol_tsf->register_at(Transformation::top());
  66. scalar_tsf->register_at(Transformation::top());
  67. dtype_promote_tsf->register_at(Transformation::top());
  68. dim_expansion_tsf->register_at(Transformation::top());
  69. }
  70. ValueRef symvar2val(py::handle py_symbol_var) {
  71. auto* symbol_var = py_symbol_var.cast<PySymbolVar*>();
  72. ValueRef value = symbol_tsf->value_type().make(symbol_var->m_node);
  73. if (symbol_var->is_scalar) {
  74. value = scalar_tsf->value_type().make(value);
  75. }
  76. return value;
  77. }
  78. py::object val2symvar(py::handle typeobj, ValueRef value) {
  79. bool is_scalar = false;
  80. if (auto* scalar_value = value.as(scalar_tsf->value_type())) {
  81. value = scalar_value->value();
  82. is_scalar = true;
  83. }
  84. auto* node = value.cast(symbol_tsf->value_type()).node();
  85. auto py_symbol_var =
  86. typeobj(pybind11::cast(node, pybind11::return_value_policy::automatic));
  87. py_symbol_var.cast<PySymbolVar*>()->is_scalar = is_scalar;
  88. return py_symbol_var;
  89. }
  90. ~SymbolVarContext() { Transformation::swap_context(context); }
  91. };
  92. } // namespace
  93. interpreter::Interpreter::Channel* interpreter_for_py = nullptr;
  94. PyTypeObject* py_tensor_type = nullptr;
  95. PyObject* cpp_use_symbolic_shape;
  96. #define REGISTE_APPLY_FUNC(mode) \
  97. void set_##mode(py::object pyf) { mode = pyf.ptr(); }
  98. REGISTE_APPLY_FUNC(cpp_use_symbolic_shape)
  99. #undef REGISTE_APPLY_FUNC
  100. PyArray_Descr* _dtype_promotion(PyObject* const* args, size_t nargs);
  101. CompNode _get_device(PyObject* const* args, size_t nargs);
  102. PyObject* py_apply(
  103. PyObject* self, PyObject* const* args, size_t nargs /* , PyObject* kwnames */) {
  104. try {
  105. // if (kwnames && PyTuple_GET_SIZE(kwnames)) {
  106. // PyErr_SetString(PyExc_TypeError, "keyword argument not allowed");
  107. // return nullptr;
  108. // }
  109. if (nargs < 2) {
  110. PyErr_SetString(
  111. PyExc_TypeError,
  112. "py_apply expects one Op and at least one tensor "
  113. "as argument");
  114. return nullptr;
  115. }
  116. auto* py_op = args[0];
  117. ++args;
  118. --nargs;
  119. auto op = py::handle(py_op).cast<std::shared_ptr<OpDef>>();
  120. SmallVector<ValueRef, 8> tensors(nargs);
  121. SmallVector<bool, 8> is_symbol_var(nargs, false);
  122. ComputingGraph* cg = nullptr;
  123. for (size_t i = 0; i < nargs; ++i) {
  124. if ((!TensorWrapper::try_cast(args[i])) &&
  125. py::isinstance<PySymbolVar>(py::handle(args[i]))) {
  126. is_symbol_var[i] = true;
  127. ComputingGraph* cur_cg =
  128. py::handle(args[i]).cast<PySymbolVar*>()->m_node->owner_graph();
  129. if (cg == nullptr) {
  130. cg = cur_cg;
  131. } else {
  132. mgb_assert(cg == cur_cg);
  133. }
  134. }
  135. }
  136. mgb::CompNode target_cn;
  137. mgb::DType target_dtype;
  138. auto convert_pyinput_to_tensor = [&](size_t i) -> ValueRef {
  139. if (!target_dtype.valid()) {
  140. target_dtype = npy::dtype_np2mgb_descr(_dtype_promotion(args, nargs));
  141. target_cn = _get_device(args, nargs);
  142. }
  143. HostTensorND ht(target_cn);
  144. ht = npy::np2tensor(args[i], npy::Meth::copy_into(&ht), target_dtype);
  145. if (PyArray_Check(args[i]) || PyList_Check(args[i])) { // non scaler
  146. return imperative::apply(
  147. CreateTensor(CreateTensor::Const, target_cn, ht.layout()),
  148. HostStorage::make(ht.storage()))[0];
  149. } else { // scaler
  150. return imperative::apply(
  151. CreateTensor(CreateTensor::Const, target_cn, target_dtype, {}),
  152. HostStorage::make(ht.storage()))[0];
  153. }
  154. };
  155. if (cg != nullptr) {
  156. // swap to a special context to reuse scalar handle
  157. size_t symbol_var_idx = 8;
  158. SymbolVarContext context(cg);
  159. context.init();
  160. for (size_t i = 0; i < nargs; ++i) {
  161. if (is_symbol_var[i]) {
  162. symbol_var_idx = i;
  163. tensors[i] = context.symvar2val(args[i]);
  164. } else {
  165. tensors[i] = convert_pyinput_to_tensor(i);
  166. }
  167. }
  168. auto outputs = imperative::apply(*op, tensors);
  169. auto ret = pybind11::tuple(outputs.size());
  170. auto typeobj = py::handle(args[symbol_var_idx]).get_type();
  171. for (size_t i = 0; i < outputs.size(); ++i) {
  172. ret[i] = context.val2symvar(typeobj, outputs[i]);
  173. }
  174. return ret.release().ptr();
  175. }
  176. for (size_t i = 0; i < nargs; ++i) {
  177. if (TensorWrapper* tw = TensorWrapper::try_cast(args[i])) {
  178. tensors[i] = tw->m_tensor->data();
  179. } else if (
  180. DTypePromoteCfg::convert_input_enabled &&
  181. op->same_type<Elemwise>()) {
  182. tensors[i] = convert_pyinput_to_tensor(i);
  183. } else {
  184. PyErr_SetString(PyExc_TypeError, "py_apply expects tensor as inputs");
  185. return nullptr;
  186. }
  187. }
  188. auto outputs = [&] { return imperative::apply(*op, tensors); }();
  189. size_t nout = outputs.size();
  190. auto ret = py::tuple(nout);
  191. for (size_t i = 0; i < nout; ++i) {
  192. ret[i] = TensorWrapper::make(py_tensor_type, std::move(outputs[i]));
  193. }
  194. return ret.release().ptr();
  195. }
  196. PYEXT17_TRANSLATE_EXC_RET(nullptr)
  197. }
  198. TensorWrapper::TensorWrapper(PyObject* args, PyObject* kwargs) {
  199. if (kwargs && PyDict_Size(kwargs)) {
  200. throw py::type_error("keyword argument not allowed");
  201. }
  202. auto nargs = PyTuple_Size(args);
  203. auto tup = py::reinterpret_borrow<py::tuple>(args);
  204. if (nargs == 0) {
  205. throw py::type_error("too few arguments");
  206. }
  207. if (auto* t = try_cast(tup[0].ptr())) {
  208. if (nargs > 1) {
  209. throw py::type_error("expect 1 argument");
  210. }
  211. m_tensor = t->m_tensor->copy();
  212. } else {
  213. if (nargs == 1) {
  214. auto arg0 = PyTuple_GetItem(args, 0);
  215. // for DeviceTensorND
  216. if (strstr(arg0->ob_type->tp_name, "DeviceTensorND")) {
  217. auto dv = py::handle(arg0).cast<DeviceTensorND>();
  218. m_tensor = std::make_shared<Tensor>(imperative::apply(
  219. CreateTensor(CreateTensor::Common, dv.comp_node(), dv.layout()),
  220. DeviceStorage::make(dv.storage()))[0]);
  221. } else {
  222. throw py::type_error(
  223. "single argument is not tensor, varnode or devicetensor");
  224. }
  225. } else {
  226. py::detail::loader_life_support life_sup; // FIXME!!!required to cast DType
  227. if (nargs != 5 && nargs != 6) {
  228. throw py::type_error("expect 5 or 6 arguments");
  229. }
  230. auto data = tup[0].cast<py::array>();
  231. DType dtype = tup[1].cast<DType>();
  232. CompNode cn = tup[2].cast<CompNode>();
  233. bool is_const = tup[3].cast<bool>();
  234. bool no_cache = nargs == 6 ? tup[4].cast<bool>() : false;
  235. std::string name;
  236. if (tup[nargs - 1].ptr() != Py_None)
  237. name = tup[nargs - 1].cast<std::string>();
  238. // const op
  239. {
  240. CreateTensor::Kind kind = is_const ? CreateTensor::Const
  241. : no_cache ? CreateTensor::Unique
  242. : CreateTensor::Common;
  243. HostTensorND ret(cn);
  244. ret = npy::np2tensor(data.ptr(), npy::Meth::copy_into(&ret), dtype);
  245. mgb_assert(
  246. ret.layout().is_empty() || ret.layout().is_contiguous(),
  247. "host value should be continuous");
  248. ValueShape shape;
  249. for (size_t i = 0; i < data.ndim(); ++i) {
  250. shape[shape.ndim++] = data.shape(i);
  251. }
  252. m_tensor = std::make_shared<Tensor>(imperative::apply(
  253. CreateTensor(kind, cn, ret.dtype(), shape),
  254. HostStorage::make(ret.storage()))[0]);
  255. }
  256. if (!name.empty()) {
  257. m_tensor->reset(
  258. imperative::apply(RenameValue(name), m_tensor->data())[0]);
  259. }
  260. }
  261. }
  262. mgb_assert(m_tensor->data());
  263. }
  264. PyObject* TensorWrapper::module_trace_info() {
  265. if (auto module_trace_info = module_trace_info_map.try_get(m_tensor->data())) {
  266. if (module_trace_info->ptr()) {
  267. return module_trace_info->inc_ref().ptr();
  268. }
  269. }
  270. PyErr_SetString(
  271. PyExc_AttributeError,
  272. "Has no attribute named \'_NodeMixin__node\', please "
  273. "set it first");
  274. return nullptr;
  275. }
  276. void TensorWrapper::set_module_trace_info(PyObject* obj) {
  277. // TODO: erase when obj == nullptr
  278. module_trace_info_map[m_tensor->data()] = py::reinterpret_borrow<py::object>(obj);
  279. }
  280. void TensorWrapper::_set_name(PyObject* dest) {
  281. auto py_dest = py::reinterpret_borrow<py::object>(dest);
  282. auto name = py_dest.cast<std::string>();
  283. m_tensor->set_name(name);
  284. }
  285. PyObject* TensorWrapper::_detail() {
  286. return py::str(m_tensor->data().unwrap().to_string()).release().ptr();
  287. }
  288. void TensorWrapper::_watch() {
  289. m_tensor->data().watch();
  290. }
  291. PyObject* TensorWrapper::shape() {
  292. auto shape = m_tensor->shape();
  293. if (!shape) {
  294. Py_RETURN_NONE;
  295. }
  296. py::tuple ret(shape->ndim);
  297. for (size_t i = 0; i < shape->ndim; ++i) {
  298. ret[i] = shape->at(i);
  299. }
  300. return ret.release().ptr();
  301. }
  302. PyObject* TensorWrapper::dtype() {
  303. return py::cast(m_tensor->dtype()).release().ptr();
  304. }
  305. PyObject* TensorWrapper::device() {
  306. return py::cast(m_tensor->comp_node()).release().ptr();
  307. }
  308. PyObject* TensorWrapper::numpy() {
  309. auto hv = m_tensor->numpy();
  310. if (!hv) {
  311. PyErr_SetString(PyExc_ValueError, "tensor invalid");
  312. return nullptr;
  313. }
  314. auto arr = py::reinterpret_steal<py::array>(
  315. npy::ndarray_from_tensor(hv->as_nd(true), npy::ShareType::TRY_SHARE));
  316. if (hv->shape().is_scalar()) {
  317. mgb_assert(PyArray_Check(arr.ptr()));
  318. return PyArray_Squeeze(reinterpret_cast<PyArrayObject*>(arr.ptr()));
  319. }
  320. return arr.release().ptr();
  321. }
  322. void TensorWrapper::reset(PyObject* tensor) {
  323. TensorWrapper* t = TensorWrapper::try_cast(tensor);
  324. if (!t) {
  325. throw py::type_error("expect Tensor");
  326. }
  327. m_tensor->reset(t->m_tensor->data());
  328. }
  329. PyObject* TensorWrapper::detach() {
  330. auto detached = imperative::apply(DetachGrad(), m_tensor->data())[0];
  331. return TensorWrapper::make(py_tensor_type, detached).release().ptr();
  332. }
  333. PyObject* TensorWrapper::_dev_tensor() {
  334. auto dv = m_tensor->data().dev_tensor();
  335. // TODO: handle scalar
  336. return py::cast(dv->as_nd(true)).release().ptr();
  337. }
  338. void TensorWrapper::_drop() {
  339. imperative::apply(DTRCommand(DTRCommand::Drop), m_tensor->data());
  340. }
  341. PyObject* TensorWrapper::isscalar() {
  342. if (m_tensor->is_scalar()) {
  343. Py_RETURN_TRUE;
  344. } else {
  345. Py_RETURN_FALSE;
  346. }
  347. }
  348. struct TensorWeakRef {
  349. std::weak_ptr<Tensor> wptr;
  350. TensorWeakRef(const TensorWrapper& tw) : wptr(tw.m_tensor) {}
  351. py::object operator()() {
  352. if (auto p = wptr.lock()) {
  353. return TensorWrapper::make(py_tensor_type, p);
  354. }
  355. return py::none();
  356. }
  357. int _use_cnt() { return wptr.use_count(); }
  358. };
  359. #ifdef METH_FASTCALL
  360. #define MGE_PY_INTERFACE(NAME, FUNC) \
  361. { #NAME, (PyCFunction)FUNC, METH_FASTCALL, nullptr }
  362. #else
  363. #define WRAP_FUNC_PY35(FUNC) \
  364. PyObject* py35_##FUNC(PyObject* self, PyObject* args) { \
  365. auto* arr = &PyTuple_GET_ITEM(args, 0); \
  366. auto size = PyTuple_GET_SIZE(args); \
  367. return FUNC(self, arr, size); \
  368. }
  369. WRAP_FUNC_PY35(py_apply);
  370. WRAP_FUNC_PY35(dtype_promotion);
  371. WRAP_FUNC_PY35(get_device);
  372. WRAP_FUNC_PY35(make_shape_tuple);
  373. WRAP_FUNC_PY35(getitem_cpp);
  374. WRAP_FUNC_PY35(setitem_cpp);
  375. WRAP_FUNC_PY35(split_cpp);
  376. WRAP_FUNC_PY35(expand_dims_cpp);
  377. WRAP_FUNC_PY35(squeeze_cpp);
  378. WRAP_FUNC_PY35(transpose_cpp);
  379. WRAP_FUNC_PY35(broadcast_cpp);
  380. WRAP_FUNC_PY35(reshape_cpp);
  381. WRAP_FUNC_PY35(adaptive_pool2d_cpp);
  382. WRAP_FUNC_PY35(Const);
  383. WRAP_FUNC_PY35(astype_cpp);
  384. WRAP_FUNC_PY35(matmul_cpp);
  385. WRAP_FUNC_PY35(batched_matmul_cpp);
  386. WRAP_FUNC_PY35(convert_single_value_cpp);
  387. WRAP_FUNC_PY35(convert_inputs_cpp);
  388. WRAP_FUNC_PY35(astensor1d_cpp);
  389. #undef WRAP_FUNC_PY35
  390. #define MGE_PY_INTERFACE(NAME, FUNC) \
  391. { #NAME, (PyCFunction)py35_##FUNC, METH_VARARGS, nullptr }
  392. #endif
  393. void init_tensor(py::module m) {
  394. imperative::Tensor::static_initialize();
  395. static auto& transformations = TransformationManager::get_instance();
  396. using Segment = TransformationManager::Segment;
  397. using Channel = interpreter::Interpreter::Channel;
  398. auto* channel =
  399. imperative::ResourceManager::create_global<std::unique_ptr<Channel>>(
  400. interpreter::Interpreter::inst().create_channel())
  401. ->get();
  402. interpreter_for_py = channel;
  403. MGB_MARK_USED_VAR(
  404. transformations
  405. .register_at<Segment::Eval>(
  406. std::make_shared<InterpreterTransformation>(
  407. std::shared_ptr<Channel>(channel, [](Channel*) {})))
  408. .release());
  409. MGB_MARK_USED_VAR(transformations
  410. .register_at<Segment::Scalar>(
  411. std::make_shared<ScalarTransformation>())
  412. .release());
  413. MGB_MARK_USED_VAR(transformations
  414. .register_at<Segment::DTypePromote>(
  415. std::make_shared<DTypePromoteTransformation>())
  416. .release());
  417. MGB_MARK_USED_VAR(transformations
  418. .register_at<Segment::DimExpansion>(
  419. std::make_shared<DimExpansionTransformation>())
  420. .release());
  421. static py::exception<interpreter::AsyncError> py_async_error(
  422. m, "AsyncError", PyExc_RuntimeError);
  423. py::register_exception_translator([](std::exception_ptr p) {
  424. try {
  425. if (p)
  426. std::rethrow_exception(p);
  427. } catch (const interpreter::AsyncError& e) {
  428. pyext17::pybind11_translate_exception(e.nested_ptr());
  429. if (PyErr_Occurred()) {
  430. PyObject *exc, *val, *tb;
  431. PyErr_Fetch(&exc, &val, &tb);
  432. PyErr_NormalizeException(&exc, &val, &tb);
  433. if (tb) {
  434. PyException_SetTraceback(val, tb);
  435. }
  436. auto val2 = py_async_error.py::object::operator()(
  437. "An async error is reported. See above for the actual cause."
  438. " Hint: This is where it is reported, not where it happened."
  439. " You may call `megengine.config.async_level = 0 "
  440. "to get better error reporting.");
  441. PyException_SetCause(
  442. val2.ptr(), val); // PyException_SetCause steals reference
  443. Py_XDECREF(exc);
  444. Py_XDECREF(tb);
  445. PyErr_Restore(
  446. py_async_error.inc_ref().ptr(), val2.release().ptr(), nullptr);
  447. } else {
  448. py_async_error("Unkown async error");
  449. }
  450. }
  451. });
  452. auto* tensor_type =
  453. TensorWrapper::wrap_t::type()
  454. .def<&TensorWrapper::numpy>("numpy")
  455. .def_getset<&TensorWrapper::shape>("shape")
  456. .def_getset<&TensorWrapper::dtype>("dtype")
  457. .def_getset<&TensorWrapper::device>("device")
  458. .def<&TensorWrapper::reset>("_reset")
  459. .def<&TensorWrapper::isscalar>("_isscalar")
  460. .def<&TensorWrapper::detach>("detach")
  461. // TODO: remove this
  462. .def<&TensorWrapper::_dev_tensor>("_dev_tensor")
  463. .def<&TensorWrapper::_drop>("_drop")
  464. .def<&TensorWrapper::_use_cnt>("_use_cnt")
  465. .def<&TensorWrapper::_detail>("_detail")
  466. .def<&TensorWrapper::_set_name>("_set_name")
  467. .def<&TensorWrapper::_watch>("_watch")
  468. .def_getset<
  469. &TensorWrapper::module_trace_info,
  470. &TensorWrapper::set_module_trace_info>("_NodeMixin__node")
  471. .finalize();
  472. if (!tensor_type)
  473. throw py::error_already_set();
  474. py::setattr(m, "Tensor", tensor_type);
  475. py::class_<TensorWeakRef>(m, "TensorWeakRef")
  476. .def(py::init<const TensorWrapper&>())
  477. .def("__call__", &TensorWeakRef::operator())
  478. .def("_use_cnt", &TensorWeakRef::_use_cnt);
  479. py::class_<PySymbolVar, std::shared_ptr<PySymbolVar>>(m, "SymbolVar")
  480. .def_property_readonly(
  481. "dtype", [](PySymbolVar* v) { return v->m_node->dtype(); })
  482. .def_property(
  483. "var", [](PySymbolVar* v) { return v->m_node; },
  484. [](PySymbolVar* s, cg::VarNode* v) { s->m_node = v; })
  485. .def_property_readonly(
  486. "device", [](PySymbolVar* v) { return v->m_node->comp_node(); })
  487. .def_property_readonly(
  488. "graph", [](PySymbolVar* v) { return v->m_node->owner_graph(); })
  489. .def_property_readonly(
  490. "shape",
  491. [](PySymbolVar* v) -> const TensorShape* {
  492. auto&& mgr = v->m_node->owner_graph()->static_infer_manager();
  493. return mgr.infer_shape_fallible(v->m_node);
  494. })
  495. .def("numpy",
  496. [](PySymbolVar* v) {
  497. auto&& mgr = v->m_node->owner_graph()->static_infer_manager();
  498. auto&& type = mgr.get_infer_type(v->m_node);
  499. using InferType = cg::static_infer::InferType;
  500. if (!(type.value & (InferType::CONST | InferType::RT_STATIC))) {
  501. throw py::value_error("value invalid!");
  502. }
  503. auto* val = mgr.infer_value_fallible(v->m_node);
  504. if (!val) {
  505. throw py::value_error("value invalid!");
  506. }
  507. auto np_val = py::cast(*val).attr("numpy")();
  508. return np_val;
  509. })
  510. .def("_isscalar", [](PySymbolVar* v) { return v->is_scalar; })
  511. .def(py::init([](cg::VarNode* node) {
  512. return std::make_shared<PySymbolVar>(node);
  513. }),
  514. py::arg() = nullptr);
  515. static PyMethodDef method_defs[] = {
  516. MGE_PY_INTERFACE(apply, py_apply),
  517. MGE_PY_INTERFACE(dtype_promotion, dtype_promotion),
  518. MGE_PY_INTERFACE(get_device, get_device),
  519. MGE_PY_INTERFACE(make_shape_tuple, make_shape_tuple),
  520. MGE_PY_INTERFACE(getitem_cpp, getitem_cpp),
  521. MGE_PY_INTERFACE(setitem_cpp, setitem_cpp),
  522. MGE_PY_INTERFACE(split_cpp, split_cpp),
  523. MGE_PY_INTERFACE(expand_dims_cpp, expand_dims_cpp),
  524. MGE_PY_INTERFACE(squeeze_cpp, squeeze_cpp),
  525. MGE_PY_INTERFACE(transpose_cpp, transpose_cpp),
  526. MGE_PY_INTERFACE(broadcast_cpp, broadcast_cpp),
  527. MGE_PY_INTERFACE(reshape_cpp, reshape_cpp),
  528. MGE_PY_INTERFACE(adaptive_pool2d_cpp, adaptive_pool2d_cpp),
  529. MGE_PY_INTERFACE(Const, Const),
  530. MGE_PY_INTERFACE(astype_cpp, astype_cpp),
  531. MGE_PY_INTERFACE(matmul_cpp, matmul_cpp),
  532. MGE_PY_INTERFACE(batched_matmul_cpp, batched_matmul_cpp),
  533. MGE_PY_INTERFACE(convert_single_value_cpp, convert_single_value_cpp),
  534. MGE_PY_INTERFACE(convert_inputs_cpp, convert_inputs_cpp),
  535. MGE_PY_INTERFACE(astensor1d_cpp, astensor1d_cpp),
  536. {nullptr, nullptr, 0, nullptr}};
  537. for (auto&& def : method_defs) {
  538. if (def.ml_meth != nullptr) {
  539. auto* func = PyCFunction_NewEx(&def, nullptr, nullptr);
  540. if (!func)
  541. throw py::error_already_set();
  542. py::setattr(m, def.ml_name, func);
  543. }
  544. }
  545. static constexpr auto sync_py_task_q = [] {
  546. py::gil_scoped_release _;
  547. py_task_q.wait_all_task_finish();
  548. };
  549. m.def("clear_candidates", [channel]() { channel->clear_candidates(); });
  550. m.def("set_option", [channel](std::string name, size_t value) {
  551. channel->set_option(name, value);
  552. });
  553. m.def("get_option",
  554. [channel](std::string name) { return channel->get_option(name); });
  555. m.def("push_scope", [channel](std::string name) {
  556. Transformation::push_scope(name);
  557. channel->push_scope(name);
  558. });
  559. m.def("pop_scope", [channel](std::string name) {
  560. channel->pop_scope(name);
  561. Transformation::pop_scope(name);
  562. });
  563. m.def("start_profile", [channel](imperative::Profiler::options_t options) {
  564. channel->sync();
  565. imperative::Profiler::load_options(std::move(options));
  566. imperative::Profiler::start_profile();
  567. channel->start_profile();
  568. });
  569. m.def("stop_profile", [channel]() -> std::function<void(std::string, std::string)> {
  570. channel->stop_profile();
  571. channel->sync();
  572. imperative::Profiler::stop_profile();
  573. auto results = std::make_shared<imperative::Profiler::bundle_t>(
  574. imperative::Profiler::collect());
  575. return [results = results](std::string basename, std::string format) mutable {
  576. imperative::Profiler::dump_profile(basename, format, std::move(*results));
  577. results = nullptr;
  578. };
  579. });
  580. m.def("sync", [channel]() {
  581. if (channel->check_available()) {
  582. channel->sync();
  583. }
  584. sync_py_task_q();
  585. });
  586. m.def("full_sync", [channel]() {
  587. if (channel->check_available()) {
  588. channel->sync();
  589. }
  590. CompNode::sync_all();
  591. CompNode::foreach ([](CompNode cn) {
  592. auto err = cn.check_async_error();
  593. mgb_assert(!err, "%s", err->what());
  594. });
  595. sync_py_task_q();
  596. });
  597. m.def("close", [channel]() {
  598. channel->close();
  599. sync_py_task_q();
  600. });
  601. py::handle grad_key_type =
  602. GradKeyWrapper::wrap_t::type()
  603. .def<&GradKeyWrapper::attach>("attach")
  604. .def<&GradKeyWrapper::is_attached_to>("is_attached_to")
  605. .def_getset<&GradKeyWrapper::get_name, &GradKeyWrapper::set_name>(
  606. "name")
  607. .def<&GradKeyWrapper::enter>("enter")
  608. .def<&GradKeyWrapper::exit>("exit")
  609. .def<&GradKeyWrapper::suppress>("suppress")
  610. .def<&GradKeyWrapper::resume>("resume")
  611. .finalize();
  612. if (!grad_key_type)
  613. throw py::error_already_set();
  614. py::setattr(m, "GradKey", grad_key_type);
  615. m.def("backward", &GradKeyWrapper::backward);
  616. m.def("get_backward_closure", &GradKeyWrapper::get_backward_closure);
  617. m.def("set_py_tensor_type", [](py::object type_obj) {
  618. py_tensor_type = reinterpret_cast<PyTypeObject*>(type_obj.inc_ref().ptr());
  619. });
  620. /**
  621. * \brief trace proxy
  622. *
  623. */
  624. struct Trace {
  625. bool symbolic = false;
  626. bool no_exec = false;
  627. bool capture_as_const = false;
  628. bool profile = false;
  629. bool record_input_shapes = false;
  630. py::function options_visitor;
  631. std::shared_ptr<TracingTransformation> tracing;
  632. std::shared_ptr<CompiledTransformation> compiled;
  633. std::shared_ptr<LazyEvalTransformation> lazy_eval;
  634. std::pair<size_t, std::shared_ptr<GraphProfiler>> profiler;
  635. std::optional<TraceResult> trace_result;
  636. std::function<bool(py::object, py::object)> array_comparator;
  637. std::unique_ptr<CleanupGuard<>> tracing_guard;
  638. std::unique_ptr<CleanupGuard<>> compiled_guard;
  639. std::unique_ptr<CleanupGuard<>> lazy_eval_guard;
  640. bool compare_value(ValueRef lhs, ValueRef rhs) {
  641. auto lvalue = lhs.cast_ref<HostValue>();
  642. auto rvalue = rhs.cast_ref<HostValue>();
  643. if (lvalue->shape() != rvalue->shape()) {
  644. return false;
  645. }
  646. if (lvalue->shape().total_nr_elems() == 1) {
  647. return lvalue->item() == rvalue->item();
  648. }
  649. HostTensorND lnd = lvalue->as_nd(true);
  650. HostTensorND rnd = rvalue->as_nd(true);
  651. auto larr = py::reinterpret_steal<py::array>(
  652. npy::ndarray_from_tensor(lnd, npy::ShareType::TRY_SHARE));
  653. auto rarr = py::reinterpret_steal<py::array>(
  654. npy::ndarray_from_tensor(rnd, npy::ShareType::TRY_SHARE));
  655. return array_comparator(larr, rarr);
  656. }
  657. void enter() {
  658. auto& self = *this;
  659. if (!self.trace_result) { // untraced
  660. self.tracing = std::make_shared<TracingTransformation>(
  661. self.capture_as_const, self.record_input_shapes);
  662. if (self.symbolic) {
  663. self.lazy_eval =
  664. std::make_shared<LazyEvalTransformation>(self.no_exec);
  665. self.options_visitor(py::cast(&self.lazy_eval->options()));
  666. }
  667. } else if (!self.compiled) { // traced but not compiled
  668. using namespace std::placeholders;
  669. self.compiled = std::make_shared<CompiledTransformation>(
  670. *self.trace_result, self.record_input_shapes);
  671. self.compiled->set_value_comparator(
  672. std::bind(&Trace::compare_value, this, _1, _2));
  673. self.options_visitor(py::cast(&self.compiled->options()));
  674. self.compiled->compile();
  675. }
  676. // register transformations
  677. if (self.compiled) {
  678. if (self.profile) {
  679. auto& current_graph = self.compiled->graph();
  680. if (self.profiler.first != self.compiled->graph().id()) {
  681. // graph changed
  682. self.profiler = std::make_pair(
  683. current_graph.id(),
  684. std::make_shared<GraphProfiler>(&current_graph));
  685. }
  686. }
  687. compiled_guard =
  688. transformations.register_at<Segment::Trace>(self.compiled);
  689. // start execute because InputCallback depends
  690. self.compiled->execute();
  691. } else if (self.tracing) {
  692. tracing_guard =
  693. transformations.register_at<Segment::Trace>(self.tracing);
  694. if (self.lazy_eval) {
  695. lazy_eval_guard =
  696. transformations.register_at<Segment::Eval>(self.lazy_eval);
  697. }
  698. } else {
  699. mgb_throw(MegBrainError, "invalid state: neither tracing nor compiled");
  700. }
  701. }
  702. void exit() {
  703. auto& self = *this;
  704. if (self.tracing) {
  705. tracing_guard.reset();
  706. self.trace_result = self.tracing->get_result();
  707. self.tracing.reset();
  708. if (self.lazy_eval) {
  709. auto lazy_eval = std::move(self.lazy_eval);
  710. lazy_eval_guard.reset();
  711. lazy_eval->check_exception();
  712. }
  713. } else if (self.compiled) {
  714. compiled_guard.reset();
  715. self.compiled->wait();
  716. } else {
  717. mgb_throw(MegBrainError, "invalid state: neither tracing nor compiled");
  718. }
  719. }
  720. VarNodeArray dump(
  721. std::shared_ptr<ComputingGraph> graph,
  722. std::vector<std::tuple<std::string, std::string, TensorShape>> inputs,
  723. std::vector<std::pair<std::string, std::string>> outputs,
  724. bool prefer_input_names) {
  725. auto& self = *this;
  726. mgb_assert(self.trace_result);
  727. // mark is like "arg_0", "kwarg_xxx", "output_0" ...
  728. std::unordered_map<std::string, size_t> mark2var;
  729. for (size_t i = 0; i < self.trace_result->vars.size(); ++i) {
  730. auto& name = self.trace_result->vars[i].mark;
  731. if (!name.empty()) {
  732. mark2var[name] = i;
  733. }
  734. }
  735. std::vector<std::tuple<size_t, std::string, TensorShape>> input_vars;
  736. std::vector<std::pair<size_t, std::string>> output_vars;
  737. for (auto&& [input_mark, input_name, input_shape] : inputs) {
  738. mgb_assert(input_shape.ndim, "input shape invalid");
  739. input_vars.push_back(
  740. {mark2var.at(input_mark), input_name, input_shape});
  741. }
  742. for (auto&& [output_name, repr] : outputs) {
  743. output_vars.push_back({mark2var.at(output_name), repr});
  744. }
  745. self.options_visitor(py::cast(&graph->options()));
  746. auto vars = self.trace_result->dump(
  747. *graph, input_vars, output_vars, prefer_input_names);
  748. return vars;
  749. }
  750. };
  751. py::class_<Trace>(m, "Trace")
  752. .def(py::init<>())
  753. .def_readwrite("record_input_shapes", &Trace::record_input_shapes)
  754. .def_readwrite("array_comparator", &Trace::array_comparator)
  755. .def_readwrite("profile", &Trace::profile)
  756. .def_property_readonly(
  757. "options",
  758. [](Trace& self) {
  759. if (self.compiled) {
  760. return &self.compiled->options();
  761. } else {
  762. return (ComputingGraph::Options*)nullptr;
  763. }
  764. })
  765. .def("get_profile",
  766. [](Trace& self) -> py::object {
  767. if (self.profiler.second && self.compiled) {
  768. auto json = self.profiler.second->to_json_full(
  769. self.compiled->graph().current_comp_seq());
  770. return py::str(json->to_string());
  771. } else {
  772. return py::none();
  773. }
  774. })
  775. .def_readwrite("symbolic", &Trace::symbolic)
  776. .def_readwrite("capture_as_const", &Trace::capture_as_const)
  777. .def_readwrite("no_exec", &Trace::no_exec)
  778. .def_readwrite("options_visitor", &Trace::options_visitor)
  779. .def("enter", &Trace::enter)
  780. .def("exit", &Trace::exit)
  781. .def("dump", &Trace::dump)
  782. .def("begin_excluded_region",
  783. [](Trace& self) {
  784. mgb_assert(bool(self.tracing) ^ bool(self.compiled));
  785. if (self.tracing) {
  786. self.tracing_guard.reset();
  787. } else if (self.compiled) {
  788. self.compiled_guard.reset();
  789. }
  790. })
  791. .def("end_excluded_region", [](Trace& self) {
  792. mgb_assert(bool(self.tracing) ^ bool(self.compiled));
  793. if (self.tracing) {
  794. self.tracing_guard =
  795. transformations.register_at<Segment::Trace>(self.tracing);
  796. } else if (self.compiled) {
  797. self.compiled_guard =
  798. transformations.register_at<Segment::Trace>(self.compiled);
  799. }
  800. });
  801. m.def("reduce_to_scalar", [](py::object op, py::object tensor) -> py::object {
  802. auto reduce_to_scalar = [](const OpDef& op, const ValueRef& input) {
  803. auto make_scalar_shape = [&](CompNode device) {
  804. return imperative::apply(
  805. CreateTensor(CreateTensor::Const, device, dtype::Int32(), {0}),
  806. HostStorage::make(device))[0];
  807. };
  808. return imperative::apply(op, input, make_scalar_shape(*input.device()))[0];
  809. };
  810. if (py::isinstance<PySymbolVar>(tensor)) {
  811. auto* graph = tensor.cast<PySymbolVar*>()->m_node->owner_graph();
  812. SymbolVarContext context(graph);
  813. context.init();
  814. auto output = reduce_to_scalar(
  815. *op.cast<std::shared_ptr<OpDef>>(), context.symvar2val(tensor));
  816. auto typeobj = tensor.get_type();
  817. return context.val2symvar(typeobj, output);
  818. } else {
  819. auto* tw = TensorWrapper::try_cast(tensor.ptr());
  820. auto output = reduce_to_scalar(
  821. *op.cast<std::shared_ptr<OpDef>>(), tw->m_tensor->data());
  822. return TensorWrapper::make(py_tensor_type, output);
  823. }
  824. });
  825. m.def("name_tensor", [](std::string name, py::object tensor) {
  826. auto* tw = TensorWrapper::try_cast(tensor.ptr());
  827. auto output = imperative::apply(TraceMarkVar(name), tw->m_tensor->data())[0];
  828. tw->m_tensor->reset(output);
  829. });
  830. m.def("is_grad_attached", [](std::vector<py::object> tensors) -> bool {
  831. SmallVector<ValueRef> values(tensors.size());
  832. for (size_t i = 0; i < tensors.size(); ++i) {
  833. values[i] = tensors[i].cast<TensorWrapper>().m_tensor->data();
  834. }
  835. auto outputs = imperative::apply(GetGradKey(), values);
  836. if (outputs[0].is<GradKeyValue>()) {
  837. return true;
  838. } else {
  839. return false;
  840. }
  841. });
  842. m.def("get_grad_key", [](std::vector<py::object> tensors) -> py::object {
  843. SmallVector<ValueRef> values(tensors.size());
  844. for (size_t i = 0; i < tensors.size(); ++i) {
  845. values[i] = tensors[i].cast<TensorWrapper>().m_tensor->data();
  846. }
  847. auto output = imperative::apply(GetGradKey(), values)[0];
  848. if (!output) {
  849. return py::none();
  850. }
  851. return py::reinterpret_borrow<py::object>(GradKeyWrapper::wrap_t::pycast(
  852. GradKeyWrapper::get(output.cast<GradKeyValue>())));
  853. });
  854. m.def("set_grad", [](py::function backward_fn, std::vector<py::object> inputs,
  855. std::vector<py::object> outputs) {
  856. GenericFunction generic_backward_fn =
  857. [backward_fn](Span<ValueRef> output_grads) -> ValueRefList {
  858. py::list output_grad_tws;
  859. for (auto&& output_grad : output_grads) {
  860. if (output_grad) {
  861. output_grad_tws.append(
  862. TensorWrapper::make(py_tensor_type, output_grad));
  863. } else {
  864. output_grad_tws.append(py::none());
  865. }
  866. }
  867. py::tuple input_grad_tws = backward_fn(*output_grad_tws);
  868. ValueRefList input_grads(input_grad_tws.size());
  869. for (size_t i = 0; i < input_grad_tws.size(); ++i) {
  870. auto input_grad_tw = input_grad_tws[i];
  871. if (!input_grad_tw.is_none()) {
  872. input_grads[i] =
  873. py::cast<TensorWrapper>(input_grad_tw).m_tensor->data();
  874. } else {
  875. input_grads[i] = {};
  876. }
  877. }
  878. return input_grads;
  879. };
  880. SmallVector<ValueRef> values(inputs.size() + outputs.size());
  881. for (size_t i = 0; i < inputs.size(); ++i) {
  882. values[i] = inputs[i].cast<TensorWrapper>().m_tensor->data();
  883. }
  884. for (size_t i = 0; i < outputs.size(); ++i) {
  885. values[i + inputs.size()] =
  886. outputs[i].cast<TensorWrapper>().m_tensor->data();
  887. }
  888. auto wrapped_output_values =
  889. imperative::apply(SetGrad(generic_backward_fn, inputs.size()), values);
  890. std::vector<py::object> wrapped_outputs;
  891. mgb_assert(wrapped_output_values.size() == outputs.size());
  892. for (auto&& output_value : wrapped_output_values) {
  893. wrapped_outputs.push_back(
  894. TensorWrapper::make(py_tensor_type, output_value));
  895. }
  896. return wrapped_outputs;
  897. });
  898. static py::function module_trace_hook;
  899. static auto get_module_trace = [] {
  900. static std::shared_ptr<ModuleTraceTransformation> module_trace_transformation;
  901. if (!module_trace_transformation) {
  902. mgb_assert(module_trace_hook);
  903. module_trace_transformation =
  904. std::make_shared<ModuleTraceTransformation>(module_trace_hook);
  905. MGB_MARK_USED_VAR(transformations
  906. .register_at<Segment::ModuleTrace>(
  907. module_trace_transformation)
  908. .release());
  909. }
  910. return module_trace_transformation;
  911. };
  912. m.def("set_cpp_use_symbolic_shape", &set_cpp_use_symbolic_shape);
  913. m.def("set_module_tracing", [=] { get_module_trace()->enable(); });
  914. m.def("unset_module_tracing", [=] { get_module_trace()->disable(); });
  915. m.def("is_tracing_module", [=] { return get_module_trace()->enabled(); });
  916. m.def("set_module_trace_hook", [](py::function function) {
  917. module_trace_hook = function;
  918. module_trace_hook.inc_ref();
  919. });
  920. auto atexit = py::module::import("atexit");
  921. atexit.attr("register")(py::cpp_function([]() { module_trace_hook = {}; }));
  922. m.def("begin_record_values", [] { Value::begin_record_values(); });
  923. m.def("end_record_values", [] {
  924. std::vector<std::pair<size_t, std::string>> reprs;
  925. auto values = Value::end_record_values();
  926. for (auto&& value : values) {
  927. reprs.push_back({value.id(), value.to_string()});
  928. }
  929. return reprs;
  930. });
  931. m.def("print_stats", [] { imperative::Stats::print(); });
  932. m.def("reset_stats", [] { imperative::Stats::reset(); });
  933. m.def("_get_convert_inputs",
  934. []() -> bool { return DTypePromoteCfg::convert_input_enabled; });
  935. m.def("_set_convert_inputs", [](bool flag) -> bool {
  936. bool ret = DTypePromoteCfg::convert_input_enabled;
  937. DTypePromoteCfg::convert_input_enabled = flag;
  938. return ret;
  939. });
  940. m.def("_get_amp_dtype_autocast",
  941. []() -> bool { return DTypePromoteCfg::amp_dtype_autocast_enabled; });
  942. m.def("_set_amp_dtype_autocast", [](bool flag) -> bool {
  943. bool ret = DTypePromoteCfg::amp_dtype_autocast_enabled;
  944. DTypePromoteCfg::amp_dtype_autocast_enabled = flag;
  945. return ret;
  946. });
  947. static auto get_amp_prec_dtype = [](bool is_high) -> std::string {
  948. DType& target = is_high ? DTypePromoteCfg::amp_high_prec_dtype
  949. : DTypePromoteCfg::amp_low_prec_dtype;
  950. mgb_assert(target.category() == DTypeCategory::FLOAT);
  951. std::string ret = target.name();
  952. transform(ret.begin(), ret.end(), ret.begin(), ::tolower);
  953. return ret;
  954. };
  955. static auto set_amp_prec_dtype = [](bool is_high,
  956. std::string dtype_name) -> std::string {
  957. DType& target = is_high ? DTypePromoteCfg::amp_high_prec_dtype
  958. : DTypePromoteCfg::amp_low_prec_dtype;
  959. std::string ret = target.name();
  960. if (dtype_name == "float32") {
  961. target = dtype::Float32();
  962. } else if (dtype_name == "float16") {
  963. target = dtype::Float16();
  964. } else if (dtype_name == "bfloat16") {
  965. target = dtype::BFloat16();
  966. } else {
  967. mgb_assert(
  968. false, "casted type of amp should be float, but you give %s\n",
  969. dtype_name.c_str());
  970. }
  971. transform(ret.begin(), ret.end(), ret.begin(), ::tolower);
  972. return ret;
  973. };
  974. m.def("_get_amp_high_prec_dtype",
  975. []() -> std::string { return get_amp_prec_dtype(true); });
  976. m.def("_set_amp_high_prec_dtype", [](std::string dtype_name) -> std::string {
  977. return set_amp_prec_dtype(true, dtype_name);
  978. });
  979. m.def("_get_amp_low_prec_dtype",
  980. []() -> std::string { return get_amp_prec_dtype(false); });
  981. m.def("_set_amp_low_prec_dtype", [](std::string dtype_name) -> std::string {
  982. return set_amp_prec_dtype(false, dtype_name);
  983. });
  984. py::register_exception<TraceError>(m, "TraceError");
  985. }
  986. #undef MGE_PY_INTERFACE
  987. } // namespace mgb::imperative::python