You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

tensor.cpp 42 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085
  1. /**
  2. * \file imperative/python/src/tensor.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "megbrain/common.h"
  12. #include "megbrain/dtype.h"
  13. #include "megbrain/imperative/ops/autogen.h"
  14. #include "megbrain/imperative/ops/backward_graph.h"
  15. #include "megbrain/imperative/ops/utility.h"
  16. #include "megbrain/imperative/profiler.h"
  17. #include "megbrain/imperative/transformations/eval.h"
  18. #include "megbrain/imperative/transformations/lazy.h"
  19. #include "megbrain/imperative/transformations/scalar.h"
  20. #include "megbrain/imperative/transformations/symbol.h"
  21. #include "megbrain/imperative/transformations/trace.h"
  22. #include "megbrain/imperative/utils/map.h"
  23. #include "megbrain/opr/io.h"
  24. #include "megbrain/plugin/profiler.h"
  25. #include "./common.h"
  26. #include "./grad.h"
  27. #include "./graph_rt.h"
  28. #include "./helper.h"
  29. #include "./module_trace.h"
  30. #include "./numpy_dtypes.h"
  31. #include "./tensor.h"
  32. #include "./transformation.h"
  33. #include <object.h>
  34. #include <pybind11/numpy.h>
  35. #include <pybind11/operators.h>
  36. #include <pybind11/pytypes.h>
  37. #include <pyerrors.h>
  38. #include <range/v3/all.hpp>
  39. #include <string>
  40. #include <unordered_map>
  41. #include "../../src/impl/mgb_cg_impl.h"
  42. namespace py = pybind11;
  43. namespace views = ranges::views;
  44. namespace mgb::imperative::python {
  45. namespace {
  46. WeakKeyMap<ValueWeakRef, py::object> module_trace_info_map;
  47. }
  48. interpreter::Interpreter::Channel* interpreter_for_py = nullptr;
  49. PyTypeObject* py_tensor_type = nullptr;
  50. PyObject* py_apply(
  51. PyObject* self, PyObject* const* args, size_t nargs /* , PyObject* kwnames */) {
  52. try {
  53. // if (kwnames && PyTuple_GET_SIZE(kwnames)) {
  54. // PyErr_SetString(PyExc_TypeError, "keyword argument not allowed");
  55. // return nullptr;
  56. // }
  57. if (nargs < 2) {
  58. PyErr_SetString(
  59. PyExc_TypeError,
  60. "py_apply expects one Op and at least one tensor "
  61. "as argument");
  62. return nullptr;
  63. }
  64. auto* py_op = args[0];
  65. ++args;
  66. --nargs;
  67. auto op = py::handle(py_op).cast<std::shared_ptr<OpDef>>();
  68. SmallVector<ValueRef, 64> tensors(nargs);
  69. if (py::isinstance<PySymbolVar>(py::handle(args[0]))) {
  70. // swap to a special context to reuse scalar handle
  71. TransformationContext symbol_var_context;
  72. Transformation::swap_context(symbol_var_context);
  73. CleanupGuard _{[&] { Transformation::swap_context(symbol_var_context); }};
  74. auto* graph =
  75. py::handle(args[0]).cast<PySymbolVar*>()->m_node->owner_graph();
  76. std::make_shared<SymbolTransformation>(graph)->register_at(
  77. Transformation::top());
  78. std::make_shared<ScalarTransformation>()->register_at(
  79. Transformation::top());
  80. SmallVector<ValueRef> inputs(nargs);
  81. for (size_t i = 0; i < nargs; ++i) {
  82. auto* py_input = py::handle(args[i]).cast<PySymbolVar*>();
  83. ValueRef input = SymbolValue::make(py_input->m_node);
  84. if (py_input->is_scalar) {
  85. input = ScalarValue::make(input);
  86. }
  87. inputs[i] = input;
  88. }
  89. auto outputs = imperative::apply(*op, inputs);
  90. auto ret = pybind11::tuple(outputs.size());
  91. auto typeobj = py::handle(args[0]).get_type();
  92. for (size_t i = 0; i < outputs.size(); ++i) {
  93. bool is_scalar = false;
  94. if (auto* scalar_value = outputs[i].as<ScalarValue>()) {
  95. outputs[i] = scalar_value->value();
  96. is_scalar = true;
  97. }
  98. auto* node = outputs[i].cast<SymbolValue>().node();
  99. ret[i] = typeobj(
  100. pybind11::cast(node, pybind11::return_value_policy::automatic));
  101. py::handle(ret[i]).cast<PySymbolVar*>()->is_scalar = is_scalar;
  102. }
  103. return ret.release().ptr();
  104. }
  105. for (size_t i = 0; i < nargs; ++i) {
  106. if (TensorWrapper* tw = TensorWrapper::try_cast(args[i])) {
  107. tensors[i] = tw->m_tensor->data();
  108. } else {
  109. PyErr_SetString(
  110. PyExc_TypeError,
  111. ssprintf(
  112. "op %s expect type Tensor as inputs, got %s actually",
  113. op->make_name().c_str(), Py_TYPE(args[i])->tp_name)
  114. .c_str());
  115. return nullptr;
  116. }
  117. }
  118. auto outputs = imperative::apply(ApplyOp(*op), {tensors.data(), nargs});
  119. size_t nout = outputs.size();
  120. auto ret = py::tuple(nout);
  121. for (size_t i = 0; i < nout; ++i) {
  122. ret[i] = TensorWrapper::make(py_tensor_type, std::move(outputs[i]));
  123. }
  124. return ret.release().ptr();
  125. }
  126. PYEXT17_TRANSLATE_EXC_RET(nullptr)
  127. }
  128. TensorWrapper::TensorWrapper(PyObject* args, PyObject* kwargs) {
  129. if (kwargs && PyDict_Size(kwargs)) {
  130. throw py::type_error("keyword argument not allowed");
  131. }
  132. auto nargs = PyTuple_Size(args);
  133. auto tup = py::reinterpret_borrow<py::tuple>(args);
  134. if (nargs == 0) {
  135. throw py::type_error("too few arguments");
  136. }
  137. if (auto* t = try_cast(tup[0].ptr())) {
  138. if (nargs > 1) {
  139. throw py::type_error("expect 1 argument");
  140. }
  141. m_tensor = t->m_tensor->copy();
  142. } else {
  143. if (nargs == 1) {
  144. auto arg0 = PyTuple_GetItem(args, 0);
  145. // for DeviceTensorND
  146. if (strstr(arg0->ob_type->tp_name, "DeviceTensorND")) {
  147. auto dv = py::handle(arg0).cast<DeviceTensorND>();
  148. m_tensor = std::make_shared<Tensor>(imperative::apply(
  149. CreateTensor(CreateTensor::Common, dv.comp_node(), dv.layout()),
  150. DeviceStorage::make(dv.storage()))[0]);
  151. } else {
  152. throw py::type_error(
  153. "single argument is not tensor, varnode or devicetensor");
  154. }
  155. } else {
  156. py::detail::loader_life_support life_sup; // FIXME!!!required to cast DType
  157. if (nargs != 5 && nargs != 6) {
  158. throw py::type_error("expect 5 or 6 arguments");
  159. }
  160. auto data = tup[0].cast<py::array>();
  161. DType dtype = tup[1].cast<DType>();
  162. CompNode cn = tup[2].cast<CompNode>();
  163. bool is_const = tup[3].cast<bool>();
  164. bool no_cache = nargs == 6 ? tup[4].cast<bool>() : false;
  165. std::string name;
  166. if (tup[nargs - 1].ptr() != Py_None)
  167. name = tup[nargs - 1].cast<std::string>();
  168. // const op
  169. {
  170. CreateTensor::Kind kind = is_const ? CreateTensor::Const
  171. : no_cache ? CreateTensor::Unique
  172. : CreateTensor::Common;
  173. HostTensorND ret(cn);
  174. ret = npy::np2tensor(data.ptr(), npy::Meth::copy_into(&ret), dtype);
  175. mgb_assert(
  176. ret.layout().is_empty() || ret.layout().is_contiguous(),
  177. "host value should be continuous");
  178. ValueShape shape;
  179. for (size_t i = 0; i < data.ndim(); ++i) {
  180. shape[shape.ndim++] = data.shape(i);
  181. }
  182. m_tensor = std::make_shared<Tensor>(imperative::apply(
  183. CreateTensor(kind, cn, ret.dtype(), shape),
  184. HostStorage::make(ret.storage()))[0]);
  185. }
  186. if (!name.empty()) {
  187. m_tensor->reset(
  188. imperative::apply(RenameValue(name), m_tensor->data())[0]);
  189. mgb_assert(
  190. ((std::string&)*m_tensor->data().name()) == name,
  191. "result name incorrect");
  192. }
  193. if (data.ndim() == 0) {
  194. mgb_assert(m_tensor->is_scalar(), "result should be scalar");
  195. }
  196. }
  197. }
  198. }
  199. PyObject* TensorWrapper::module_trace_info() {
  200. if (auto module_trace_info = module_trace_info_map.try_get(m_tensor->data())) {
  201. return module_trace_info->inc_ref().ptr();
  202. } else {
  203. PyErr_SetString(
  204. PyExc_AttributeError,
  205. "Has no attribute named \'_NodeMixin__node\', please "
  206. "set it first");
  207. return nullptr;
  208. }
  209. }
  210. void TensorWrapper::set_module_trace_info(PyObject* obj) {
  211. module_trace_info_map[m_tensor->data()] = py::reinterpret_borrow<py::object>(obj);
  212. }
  213. void TensorWrapper::_set_name(PyObject* dest) {
  214. auto py_dest = py::reinterpret_borrow<py::object>(dest);
  215. auto name = py_dest.cast<std::string>();
  216. m_tensor->set_name(name);
  217. }
  218. PyObject* TensorWrapper::_detail() {
  219. return py::str(m_tensor->data().unwrap().to_string()).release().ptr();
  220. }
  221. void TensorWrapper::_watch() {
  222. m_tensor->data().watch();
  223. }
  224. PyObject* TensorWrapper::shape() {
  225. auto shape = m_tensor->shape();
  226. if (!shape) {
  227. Py_RETURN_NONE;
  228. }
  229. py::tuple ret(shape->ndim);
  230. for (size_t i = 0; i < shape->ndim; ++i) {
  231. ret[i] = shape->at(i);
  232. }
  233. return ret.release().ptr();
  234. }
  235. PyObject* TensorWrapper::dtype() {
  236. return py::cast(m_tensor->dtype()).release().ptr();
  237. }
  238. PyObject* TensorWrapper::device() {
  239. return py::cast(m_tensor->comp_node()).release().ptr();
  240. }
  241. PyObject* TensorWrapper::numpy() {
  242. auto hv = m_tensor->numpy();
  243. // if (!hv) {
  244. // PyErr_SetString(PyExc_ValueError, "tensor invalid");
  245. // return nullptr;
  246. // }
  247. auto arr = py::reinterpret_steal<py::array>(
  248. npy::ndarray_from_tensor(hv->as_nd(true), npy::ShareType::TRY_SHARE));
  249. if (!arr) {
  250. PyErr_SetString(PyExc_ValueError, "tensor invalid");
  251. return nullptr;
  252. }
  253. if (hv->shape().is_scalar()) {
  254. mgb_assert(PyArray_Check(arr.ptr()));
  255. return PyArray_Squeeze(reinterpret_cast<PyArrayObject*>(arr.ptr()));
  256. }
  257. return arr.release().ptr();
  258. }
  259. void TensorWrapper::reset(PyObject* tensor) {
  260. TensorWrapper* t = TensorWrapper::try_cast(tensor);
  261. if (!t) {
  262. throw py::type_error("expect Tensor");
  263. }
  264. m_tensor->reset(t->m_tensor->data());
  265. }
  266. PyObject* TensorWrapper::detach() {
  267. auto detached = imperative::apply(DetachGrad(), m_tensor->data())[0];
  268. return TensorWrapper::make(py_tensor_type, detached).release().ptr();
  269. }
  270. PyObject* TensorWrapper::_dev_tensor() {
  271. auto dv = m_tensor->data().dev_tensor();
  272. // TODO: handle scalar
  273. return py::cast(dv->as_nd(true)).release().ptr();
  274. }
  275. void TensorWrapper::_drop() {
  276. imperative::apply(DTRCommand(DTRCommand::Drop), m_tensor->data());
  277. }
  278. PyObject* TensorWrapper::isscalar() {
  279. if (m_tensor->is_scalar()) {
  280. Py_RETURN_TRUE;
  281. } else {
  282. Py_RETURN_FALSE;
  283. }
  284. }
  285. struct TensorWeakRef {
  286. std::weak_ptr<Tensor> wptr;
  287. TensorWeakRef(const TensorWrapper& tw) : wptr(tw.m_tensor) {}
  288. py::object operator()() {
  289. if (auto p = wptr.lock()) {
  290. return TensorWrapper::make(py_tensor_type, p);
  291. }
  292. return py::none();
  293. }
  294. int _use_cnt() { return wptr.use_count(); }
  295. };
  296. /* ============== convert inputs ============== */
  297. // map numpy.dtype.kind to priority
  298. inline uint8_t category_priority(char c) {
  299. switch (c) {
  300. case 'f':
  301. return 3; // floating-point
  302. case 'i':
  303. return 2; // signed integer
  304. case 'u':
  305. return 2; // unsigned integer
  306. case 'b':
  307. return 1; // boolean
  308. default:
  309. return 0;
  310. }
  311. }
  312. // Returns the maximum value of the priority of each type in the list `types`.
  313. uint8_t max_priority(SmallVector<PyArray_Descr*> types) {
  314. if (types.size() == 0) {
  315. return 0;
  316. } else {
  317. uint8_t max_p = 0;
  318. for (auto&& desc : types) {
  319. max_p = std::max(max_p, category_priority(desc->kind));
  320. }
  321. return max_p;
  322. }
  323. }
  324. // Returns the data type with sufficient size to hold all types of
  325. // category `cat` in the list `types`.
  326. PyArray_Descr* promote_types(SmallVector<PyArray_Descr*> types, uint8_t cat) {
  327. // Return value: New reference
  328. SmallVector<PyArray_Descr*> used_types;
  329. for (auto&& desc : types) {
  330. auto&& v = category_priority(desc->kind);
  331. if (v == cat) {
  332. used_types.emplace_back(desc);
  333. }
  334. }
  335. mgb_assert(used_types.size() > 0, "size of used_types is 0");
  336. PyArray_Descr* res = used_types[0];
  337. Py_INCREF(res);
  338. for (size_t i = 1; i < used_types.size(); ++i) {
  339. PyArray_Descr* tmp = PyArray_PromoteTypes(used_types[i], res);
  340. Py_DECREF(res);
  341. res = tmp;
  342. }
  343. return res;
  344. }
  345. PyArray_Descr* scalar2dtype(PyObject* arg) {
  346. // Return value: New reference
  347. if (PyBool_Check(arg)) {
  348. auto&& descr = PyArray_DescrFromType(NPY_BOOL);
  349. return descr;
  350. }
  351. if (PyLong_CheckExact(arg)) {
  352. auto&& descr = PyArray_DescrFromType(NPY_INT32);
  353. return descr;
  354. }
  355. if (PyFloat_CheckExact(arg)) {
  356. auto&& descr = PyArray_DescrFromType(NPY_FLOAT32);
  357. return descr;
  358. }
  359. return nullptr;
  360. }
  361. PyArray_Descr* _dtype_promotion(PyObject* const* args, size_t nargs) {
  362. // Return value: New reference
  363. SmallVector<PyArray_Descr*> tensors;
  364. SmallVector<PyArray_Descr*> scalars;
  365. bool is_tuple = false;
  366. PyObject* tuple = nullptr;
  367. if (nargs == 1 && (PyTuple_Check(args[0]) || PyList_Check(args[0]))) {
  368. if (PyList_Check(args[0])) {
  369. tuple = PyList_AsTuple(args[0]);
  370. } else {
  371. tuple = args[0];
  372. Py_INCREF(tuple);
  373. }
  374. nargs = PyTuple_Size(tuple);
  375. is_tuple = true;
  376. }
  377. for (size_t i = 0; i < nargs; ++i) {
  378. PyObject* handle = is_tuple ? PyTuple_GetItem(tuple, i) : args[i];
  379. if (handle == Py_None)
  380. continue;
  381. TensorWrapper* tw = TensorWrapper::try_cast(handle);
  382. if (tw) {
  383. mgb::DType type = tw->m_tensor->dtype();
  384. auto&& descr = npy::dtype_mgb2np_descr(type);
  385. Py_INCREF(descr.get());
  386. tensors.emplace_back(descr.get());
  387. } else {
  388. if (PyArray_Check(handle) || PyArray_CheckScalar(handle)) {
  389. auto&& descr = PyArray_DescrFromObject(handle, nullptr);
  390. tensors.emplace_back(descr);
  391. continue;
  392. }
  393. if (py::isinstance<PySymbolVar>(py::handle(handle))) {
  394. auto var = py::handle(handle).cast<PySymbolVar*>();
  395. mgb::DType type = var->m_node->dtype();
  396. auto&& descr = npy::dtype_mgb2np_descr(type);
  397. Py_INCREF(descr.get());
  398. tensors.emplace_back(descr.get());
  399. continue;
  400. }
  401. PyArray_Descr* descr = scalar2dtype(handle);
  402. if (descr) {
  403. scalars.emplace_back(descr);
  404. continue;
  405. }
  406. }
  407. }
  408. auto max_pri_scalars = max_priority(scalars);
  409. auto max_pri_tensors = max_priority(tensors);
  410. if (max_pri_scalars <= 0 && max_pri_tensors <= 0) {
  411. throw py::value_error("invalid input, no dtype avaliable");
  412. }
  413. PyArray_Descr* res;
  414. if (max_pri_scalars > max_pri_tensors) {
  415. res = promote_types(scalars, max_pri_scalars);
  416. } else {
  417. res = promote_types(tensors, max_pri_tensors);
  418. }
  419. for (auto* p : tensors) {
  420. Py_DECREF(p);
  421. }
  422. for (auto* p : scalars) {
  423. Py_DECREF(p);
  424. }
  425. Py_XDECREF(tuple);
  426. return res;
  427. }
  428. CompNode _get_device(PyObject* const* args, size_t nargs) {
  429. bool is_tuple = false;
  430. PyObject* tuple = nullptr;
  431. if (nargs == 1 && (PyTuple_Check(args[0]) || PyList_Check(args[0]))) {
  432. if (PyList_Check(args[0])) {
  433. tuple = PyList_AsTuple(args[0]);
  434. } else {
  435. tuple = args[0];
  436. Py_INCREF(tuple);
  437. }
  438. nargs = PyTuple_Size(tuple);
  439. is_tuple = true;
  440. }
  441. bool valid = false;
  442. CompNode cn;
  443. for (size_t i = 0; i < nargs; ++i) {
  444. PyObject* handle = is_tuple ? PyTuple_GetItem(tuple, i) : args[i];
  445. TensorWrapper* tw = TensorWrapper::try_cast(handle);
  446. bool is_symvar = py::isinstance<PySymbolVar>(py::handle(handle));
  447. if (tw || is_symvar) {
  448. if (!valid) {
  449. cn = tw ? tw->m_tensor->comp_node()
  450. : py::handle(handle).cast<PySymbolVar*>()->m_node->comp_node();
  451. valid = true;
  452. } else {
  453. CompNode cn1 = tw ? tw->m_tensor->comp_node()
  454. : py::handle(handle)
  455. .cast<PySymbolVar*>()
  456. ->m_node->comp_node();
  457. if (cn1 != cn) {
  458. throw py::value_error(ssprintf(
  459. "ambiguous device: %s vs %s", cn.to_string().c_str(),
  460. cn1.to_string().c_str()));
  461. }
  462. }
  463. }
  464. }
  465. if (!valid) {
  466. return CompNode::load(get_default_device());
  467. }
  468. Py_XDECREF(tuple);
  469. return cn;
  470. }
  471. // Returns the dtype that would result from performing an arithmetic
  472. // operation on the provided input tensors and scalars.
  473. PyObject* dtype_promotion(PyObject* self, PyObject* const* args, size_t nargs) {
  474. if (!nargs) {
  475. PyErr_SetString(PyExc_TypeError, "empty input is not allowed");
  476. return nullptr;
  477. }
  478. try {
  479. PyArray_Descr* res = _dtype_promotion(args, nargs);
  480. return py::cast(npy::dtype_np2mgb_descr(res)).release().ptr();
  481. }
  482. PYEXT17_TRANSLATE_EXC_RET(nullptr)
  483. }
  484. PyObject* get_device(PyObject* self, PyObject* const* args, size_t nargs) {
  485. if (!nargs) {
  486. PyErr_SetString(PyExc_TypeError, "empty input is not allowed");
  487. return nullptr;
  488. }
  489. try {
  490. CompNode cn = _get_device(args, nargs);
  491. return py::cast(cn).release().ptr();
  492. }
  493. PYEXT17_TRANSLATE_EXC_RET(nullptr)
  494. }
  495. #ifdef METH_FASTCALL
  496. #define MGE_PY_INTERFACE(NAME, FUNC) \
  497. { #NAME, (PyCFunction)FUNC, METH_FASTCALL, nullptr }
  498. #else
  499. #define WRAP_FUNC_PY35(FUNC) \
  500. PyObject* py35_##FUNC(PyObject* self, PyObject* args) { \
  501. auto* arr = &PyTuple_GET_ITEM(args, 0); \
  502. auto size = PyTuple_GET_SIZE(args); \
  503. return FUNC(self, arr, size); \
  504. }
  505. WRAP_FUNC_PY35(py_apply);
  506. WRAP_FUNC_PY35(dtype_promotion);
  507. WRAP_FUNC_PY35(get_device);
  508. #undef WRAP_FUNC_PY35
  509. #define MGE_PY_INTERFACE(NAME, FUNC) \
  510. { #NAME, (PyCFunction)py35_##FUNC, METH_VARARGS, nullptr }
  511. #endif
  512. void init_tensor(py::module m) {
  513. imperative::Tensor::static_initialize();
  514. static auto& transformations = TransformationManager::get_instance();
  515. using Segment = TransformationManager::Segment;
  516. auto* channel = interpreter::Interpreter::inst().create_channel().release();
  517. interpreter_for_py = channel;
  518. transformations.register_at<Segment::Eval>(
  519. std::make_shared<InterpreterTransformation>(
  520. std::unique_ptr<interpreter::Interpreter::Channel>(channel)));
  521. transformations.register_at<Segment::Scalar>(
  522. std::make_shared<ScalarTransformation>());
  523. static py::exception<interpreter::AsyncError> py_async_error(
  524. m, "AsyncError", PyExc_RuntimeError);
  525. py::register_exception_translator([](std::exception_ptr p) {
  526. try {
  527. if (p)
  528. std::rethrow_exception(p);
  529. } catch (const interpreter::AsyncError& e) {
  530. pyext17::pybind11_translate_exception(e.nested_ptr());
  531. if (PyErr_Occurred()) {
  532. PyObject *exc, *val, *tb;
  533. PyErr_Fetch(&exc, &val, &tb);
  534. PyErr_NormalizeException(&exc, &val, &tb);
  535. if (tb) {
  536. PyException_SetTraceback(val, tb);
  537. }
  538. auto val2 = py_async_error.py::object::operator()(
  539. "An async error is reported. See above for the actual cause."
  540. " Hint: This is where it is reported, not where it happened."
  541. " You may call `megengine.core.set_option('async_level', 0)` "
  542. "to get better error reporting.");
  543. PyException_SetCause(
  544. val2.ptr(), val); // PyException_SetCause steals reference
  545. Py_XDECREF(exc);
  546. Py_XDECREF(tb);
  547. PyErr_Restore(
  548. py_async_error.inc_ref().ptr(), val2.release().ptr(), nullptr);
  549. } else {
  550. py_async_error("Unkown async error");
  551. }
  552. }
  553. });
  554. auto* tensor_type =
  555. TensorWrapper::wrap_t::type()
  556. .def<&TensorWrapper::numpy>("numpy")
  557. .def_getset<&TensorWrapper::shape>("shape")
  558. .def_getset<&TensorWrapper::dtype>("dtype")
  559. .def_getset<&TensorWrapper::device>("device")
  560. .def<&TensorWrapper::reset>("_reset")
  561. .def<&TensorWrapper::isscalar>("_isscalar")
  562. .def<&TensorWrapper::detach>("detach")
  563. // TODO: remove this
  564. .def<&TensorWrapper::_dev_tensor>("_dev_tensor")
  565. .def<&TensorWrapper::_drop>("_drop")
  566. .def<&TensorWrapper::_use_cnt>("_use_cnt")
  567. .def<&TensorWrapper::_detail>("_detail")
  568. .def<&TensorWrapper::_set_name>("_set_name")
  569. .def<&TensorWrapper::_watch>("_watch")
  570. .def_getset<
  571. &TensorWrapper::module_trace_info,
  572. &TensorWrapper::set_module_trace_info>("_NodeMixin__node")
  573. .finalize();
  574. if (!tensor_type)
  575. throw py::error_already_set();
  576. py::setattr(m, "Tensor", tensor_type);
  577. py::class_<TensorWeakRef>(m, "TensorWeakRef")
  578. .def(py::init<const TensorWrapper&>())
  579. .def("__call__", &TensorWeakRef::operator())
  580. .def("_use_cnt", &TensorWeakRef::_use_cnt);
  581. py::class_<PySymbolVar, std::shared_ptr<PySymbolVar>>(m, "SymbolVar")
  582. .def_property_readonly(
  583. "dtype", [](PySymbolVar* v) { return v->m_node->dtype(); })
  584. .def_property(
  585. "var", [](PySymbolVar* v) { return v->m_node; },
  586. [](PySymbolVar* s, cg::VarNode* v) { s->m_node = v; })
  587. .def_property_readonly(
  588. "device", [](PySymbolVar* v) { return v->m_node->comp_node(); })
  589. .def_property_readonly(
  590. "graph", [](PySymbolVar* v) { return v->m_node->owner_graph(); })
  591. .def_property_readonly(
  592. "shape",
  593. [](PySymbolVar* v) -> const TensorShape* {
  594. auto&& mgr = v->m_node->owner_graph()->static_infer_manager();
  595. return mgr.infer_shape_fallible(v->m_node);
  596. })
  597. .def("numpy",
  598. [](PySymbolVar* v) {
  599. auto&& mgr = v->m_node->owner_graph()->static_infer_manager();
  600. auto&& type = mgr.get_infer_type(v->m_node);
  601. using InferType = cg::static_infer::InferType;
  602. if (!(type.value & (InferType::CONST | InferType::RT_STATIC))) {
  603. throw py::value_error("value invalid!");
  604. }
  605. auto* val = mgr.infer_value_fallible(v->m_node);
  606. if (!val) {
  607. throw py::value_error("value invalid!");
  608. }
  609. auto np_val = py::cast(*val).attr("numpy")();
  610. return np_val;
  611. })
  612. .def("_isscalar", [](PySymbolVar* v) { return v->is_scalar; })
  613. .def(py::init([](cg::VarNode* node) {
  614. return std::make_shared<PySymbolVar>(node);
  615. }),
  616. py::arg() = nullptr);
  617. static PyMethodDef method_defs[] = {
  618. MGE_PY_INTERFACE(apply, py_apply),
  619. MGE_PY_INTERFACE(dtype_promotion, dtype_promotion),
  620. MGE_PY_INTERFACE(get_device, get_device),
  621. {nullptr, nullptr, 0, nullptr}};
  622. for (auto&& def : method_defs) {
  623. if (def.ml_meth != nullptr) {
  624. auto* func = PyCFunction_NewEx(&def, nullptr, nullptr);
  625. if (!func)
  626. throw py::error_already_set();
  627. py::setattr(m, def.ml_name, func);
  628. }
  629. }
  630. static constexpr auto sync_py_task_q = [] {
  631. py::gil_scoped_release _;
  632. py_task_q.wait_all_task_finish();
  633. };
  634. m.def("set_option", [channel](std::string name, size_t value) {
  635. channel->set_option(name, value);
  636. });
  637. m.def("clear_candidates", [channel]() { channel->clear_candidates(); });
  638. m.def("get_option",
  639. [channel](std::string name) { return channel->get_option(name); });
  640. m.def("_set_drop_flag",
  641. [channel](bool flag) { channel->set_option("enable_drop", flag); });
  642. m.def("config_async_level", [channel](int level) {
  643. mgb_assert(level >= 0 and level <= 2, "async_level should be 0, 1 or 2");
  644. channel->set_option("async_level", level);
  645. });
  646. m.def("get_async_level",
  647. [channel]() { return channel->get_option("async_level"); });
  648. m.def("set_buffer_length", [channel](int length) {
  649. mgb_assert(length >= 0 and length < 100, "buffer_length should be in [0, 100)");
  650. channel->set_option("buffer_length", length);
  651. });
  652. m.def("push_scope", [channel](std::string name) {
  653. Transformation::push_scope(name);
  654. channel->push_scope(name);
  655. });
  656. m.def("pop_scope", [channel](std::string name) {
  657. channel->pop_scope(name);
  658. Transformation::pop_scope(name);
  659. });
  660. m.def("start_profile", [channel](imperative::Profiler::options_t options) {
  661. channel->sync();
  662. imperative::Profiler::load_options(std::move(options));
  663. imperative::Profiler::start_profile();
  664. channel->start_profile();
  665. });
  666. m.def("stop_profile", [channel]() -> std::function<void(std::string, std::string)> {
  667. channel->stop_profile();
  668. channel->sync();
  669. imperative::Profiler::stop_profile();
  670. auto results = std::make_shared<imperative::Profiler::bundle_t>(
  671. imperative::Profiler::collect());
  672. return [results = results](std::string basename, std::string format) mutable {
  673. imperative::Profiler::dump_profile(basename, format, std::move(*results));
  674. results = nullptr;
  675. };
  676. });
  677. m.def("sync", [channel]() {
  678. if (channel->check_available()) {
  679. channel->sync();
  680. }
  681. sync_py_task_q();
  682. });
  683. m.def("full_sync", [channel]() {
  684. if (channel->check_available()) {
  685. channel->sync();
  686. }
  687. CompNode::sync_all();
  688. CompNode::foreach ([](CompNode cn) {
  689. auto err = cn.check_async_error();
  690. mgb_assert(!err, "%s", err->what());
  691. });
  692. sync_py_task_q();
  693. });
  694. m.def("close", [channel]() {
  695. channel->close();
  696. sync_py_task_q();
  697. });
  698. py::handle grad_key_type =
  699. GradKeyWrapper::wrap_t::type()
  700. .def<&GradKeyWrapper::attach>("attach")
  701. .def<&GradKeyWrapper::is_attached_to>("is_attached_to")
  702. .def_getset<&GradKeyWrapper::get_name, &GradKeyWrapper::set_name>(
  703. "name")
  704. .def<&GradKeyWrapper::enter>("enter")
  705. .def<&GradKeyWrapper::exit>("exit")
  706. .def<&GradKeyWrapper::suppress>("suppress")
  707. .def<&GradKeyWrapper::resume>("resume")
  708. .finalize();
  709. if (!grad_key_type)
  710. throw py::error_already_set();
  711. py::setattr(m, "GradKey", grad_key_type);
  712. m.def("backward", &GradKeyWrapper::backward);
  713. m.def("get_backward_closure", &GradKeyWrapper::get_backward_closure);
  714. m.def("set_py_tensor_type", [](py::object type_obj) {
  715. py_tensor_type = reinterpret_cast<PyTypeObject*>(type_obj.inc_ref().ptr());
  716. });
  717. /**
  718. * \brief trace proxy
  719. *
  720. */
  721. struct Trace {
  722. bool symbolic = false;
  723. bool no_exec = false;
  724. bool capture_as_const = false;
  725. bool profile = false;
  726. bool record_input_shapes = false;
  727. py::function options_visitor;
  728. std::shared_ptr<TracingTransformation> tracing;
  729. std::shared_ptr<CompiledTransformation> compiled;
  730. std::shared_ptr<LazyEvalTransformation> lazy_eval;
  731. std::pair<size_t, std::shared_ptr<GraphProfiler>> profiler;
  732. std::optional<TraceResult> trace_result;
  733. std::function<bool(py::object, py::object)> array_comparator;
  734. bool compare_value(ValueRef lhs, ValueRef rhs) {
  735. if (!lhs.shape()->eq(*rhs.shape())) {
  736. return false;
  737. }
  738. HostTensorND lvalue = lhs.numpy()->as_nd(true);
  739. HostTensorND rvalue = rhs.numpy()->as_nd(true);
  740. auto larr = py::reinterpret_steal<py::array>(
  741. npy::ndarray_from_tensor(lvalue, npy::ShareType::TRY_SHARE));
  742. auto rarr = py::reinterpret_steal<py::array>(
  743. npy::ndarray_from_tensor(rvalue, npy::ShareType::TRY_SHARE));
  744. return array_comparator(larr, rarr);
  745. }
  746. void enter() {
  747. auto& self = *this;
  748. if (!self.trace_result) { // untraced
  749. self.tracing = std::make_shared<TracingTransformation>(
  750. self.capture_as_const, self.record_input_shapes);
  751. if (self.symbolic) {
  752. self.lazy_eval =
  753. std::make_shared<LazyEvalTransformation>(self.no_exec);
  754. self.options_visitor(py::cast(&self.lazy_eval->options()));
  755. }
  756. } else if (!self.compiled) { // traced but not compiled
  757. using namespace std::placeholders;
  758. self.compiled = std::make_shared<CompiledTransformation>(
  759. *self.trace_result, self.record_input_shapes);
  760. self.compiled->set_value_comparator(
  761. std::bind(&Trace::compare_value, this, _1, _2));
  762. self.options_visitor(py::cast(&self.compiled->options()));
  763. self.compiled->compile();
  764. }
  765. // register transformations
  766. if (self.compiled) {
  767. if (self.profile) {
  768. auto& current_graph = self.compiled->graph();
  769. if (self.profiler.first != self.compiled->graph().id()) {
  770. // graph changed
  771. self.profiler = std::make_pair(
  772. current_graph.id(),
  773. std::make_shared<GraphProfiler>(&current_graph));
  774. }
  775. }
  776. transformations.register_at<Segment::Trace>(self.compiled);
  777. // start execute because InputCallback depends
  778. self.compiled->execute();
  779. } else if (self.tracing) {
  780. transformations.register_at<Segment::Trace>(self.tracing);
  781. if (self.lazy_eval) {
  782. transformations.register_at<Segment::Eval>(self.lazy_eval);
  783. }
  784. } else {
  785. mgb_throw(MegBrainError, "invalid state: neither tracing nor compiled");
  786. }
  787. }
  788. void exit() {
  789. auto& self = *this;
  790. if (self.tracing) {
  791. transformations.unregister<Segment::Trace>(self.tracing);
  792. self.trace_result = self.tracing->get_result();
  793. self.tracing.reset();
  794. if (self.lazy_eval) {
  795. auto lazy_eval = std::move(self.lazy_eval);
  796. transformations.unregister<Segment::Eval>(lazy_eval);
  797. lazy_eval->check_exception();
  798. }
  799. } else if (self.compiled) {
  800. transformations.unregister<Segment::Trace>(self.compiled);
  801. self.compiled->wait();
  802. } else {
  803. mgb_throw(MegBrainError, "invalid state: neither tracing nor compiled");
  804. }
  805. }
  806. VarNodeArray dump(
  807. std::shared_ptr<ComputingGraph> graph,
  808. std::vector<std::tuple<std::string, std::string, TensorShape>> inputs,
  809. std::vector<std::pair<std::string, std::string>> outputs,
  810. bool prefer_input_names) {
  811. auto& self = *this;
  812. mgb_assert(self.trace_result);
  813. // mark is like "arg_0", "kwarg_xxx", "output_0" ...
  814. std::unordered_map<std::string, size_t> mark2var;
  815. for (size_t i = 0; i < self.trace_result->vars.size(); ++i) {
  816. auto& name = self.trace_result->vars[i].mark;
  817. if (!name.empty()) {
  818. mark2var[name] = i;
  819. }
  820. }
  821. std::vector<std::tuple<size_t, std::string, TensorShape>> input_vars;
  822. std::vector<std::pair<size_t, std::string>> output_vars;
  823. for (auto&& [input_mark, input_name, input_shape] : inputs) {
  824. mgb_assert(input_shape.ndim, "input shape invalid");
  825. input_vars.push_back(
  826. {mark2var.at(input_mark), input_name, input_shape});
  827. }
  828. for (auto&& [output_name, repr] : outputs) {
  829. output_vars.push_back({mark2var.at(output_name), repr});
  830. }
  831. self.options_visitor(py::cast(&graph->options()));
  832. auto vars = self.trace_result->dump(
  833. *graph, input_vars, output_vars, prefer_input_names);
  834. return vars;
  835. }
  836. };
  837. py::class_<Trace>(m, "Trace")
  838. .def(py::init<>())
  839. .def_readwrite("record_input_shapes", &Trace::record_input_shapes)
  840. .def_readwrite("array_comparator", &Trace::array_comparator)
  841. .def_readwrite("profile", &Trace::profile)
  842. .def_property_readonly(
  843. "options",
  844. [](Trace& self) {
  845. if (self.compiled) {
  846. return &self.compiled->options();
  847. } else {
  848. return (ComputingGraph::Options*)nullptr;
  849. }
  850. })
  851. .def("get_profile",
  852. [](Trace& self) -> py::object {
  853. if (self.profiler.second && self.compiled) {
  854. auto json = self.profiler.second->to_json_full(
  855. self.compiled->graph().current_comp_seq());
  856. return py::str(json->to_string());
  857. } else {
  858. return py::none();
  859. }
  860. })
  861. .def_readwrite("symbolic", &Trace::symbolic)
  862. .def_readwrite("capture_as_const", &Trace::capture_as_const)
  863. .def_readwrite("no_exec", &Trace::no_exec)
  864. .def_readwrite("options_visitor", &Trace::options_visitor)
  865. .def("enter", &Trace::enter)
  866. .def("exit", &Trace::exit)
  867. .def("dump", &Trace::dump)
  868. .def("begin_excluded_region",
  869. [](Trace& self) {
  870. mgb_assert(bool(self.tracing) ^ bool(self.compiled));
  871. if (self.tracing) {
  872. transformations.unregister<Segment::Trace>(self.tracing);
  873. } else if (self.compiled) {
  874. transformations.unregister<Segment::Trace>(self.compiled);
  875. }
  876. })
  877. .def("end_excluded_region", [](Trace& self) {
  878. mgb_assert(bool(self.tracing) ^ bool(self.compiled));
  879. if (self.tracing) {
  880. transformations.register_at<Segment::Trace>(self.tracing);
  881. } else if (self.compiled) {
  882. transformations.register_at<Segment::Trace>(self.compiled);
  883. }
  884. });
  885. m.def("name_tensor", [](std::string name, py::object tensor) {
  886. auto* tw = TensorWrapper::try_cast(tensor.ptr());
  887. auto output = imperative::apply(TraceMarkVar(name), tw->m_tensor->data())[0];
  888. tw->m_tensor->reset(output);
  889. });
  890. m.def("is_grad_attached", [](std::vector<py::object> tensors) -> bool {
  891. SmallVector<ValueRef> values;
  892. for (auto&& tensor : tensors) {
  893. values.push_back(tensor.cast<TensorWrapper>().m_tensor->data());
  894. }
  895. auto outputs = imperative::apply(GetGradKey(), values);
  896. if (outputs[0].is<GradKeyValue>()) {
  897. return true;
  898. } else {
  899. return false;
  900. }
  901. });
  902. m.def("get_grad_key", [](std::vector<py::object> tensors) -> py::object {
  903. SmallVector<ValueRef> values;
  904. for (auto&& tensor : tensors) {
  905. values.push_back(tensor.cast<TensorWrapper>().m_tensor->data());
  906. }
  907. auto outputs = imperative::apply(GetGradKey(), values);
  908. if (auto* grad_key_val = outputs[0].as<GradKeyValue>()) {
  909. return py::reinterpret_borrow<py::object>(
  910. GradKeyWrapper::wrap_t::pycast(GradKeyWrapper::get(*grad_key_val)));
  911. } else {
  912. return py::none();
  913. }
  914. });
  915. m.def("set_grad", [](py::object py_key, py::function backward_fn,
  916. std::vector<py::object> inputs,
  917. std::vector<py::object> outputs) {
  918. mgb_assert(GradKeyWrapper::wrap_t::type().isinstance(py_key.ptr()));
  919. auto* key = reinterpret_cast<GradKeyWrapper::wrap_t*>(py_key.ptr())->inst();
  920. GenericFunction generic_backward_fn =
  921. [backward_fn](Span<ValueRef> output_grads) -> std::vector<ValueRef> {
  922. py::list output_grad_tws;
  923. for (auto&& output_grad : output_grads) {
  924. if (output_grad) {
  925. output_grad_tws.append(
  926. TensorWrapper::make(py_tensor_type, output_grad));
  927. } else {
  928. output_grad_tws.append(py::none());
  929. }
  930. }
  931. py::tuple input_grad_tws = backward_fn(*output_grad_tws);
  932. std::vector<ValueRef> input_grads;
  933. for (auto&& input_grad_tw : input_grad_tws) {
  934. if (!input_grad_tw.is_none()) {
  935. input_grads.push_back(
  936. py::cast<TensorWrapper>(input_grad_tw).m_tensor->data());
  937. } else {
  938. input_grads.push_back({});
  939. }
  940. }
  941. return input_grads;
  942. };
  943. SmallVector<ValueRef> values;
  944. for (auto&& input : inputs) {
  945. values.push_back(input.cast<TensorWrapper>().m_tensor->data());
  946. }
  947. for (auto&& output : outputs) {
  948. values.push_back(output.cast<TensorWrapper>().m_tensor->data());
  949. }
  950. auto wrapped_output_values = imperative::apply(
  951. SetGrad(key->m_key, generic_backward_fn, inputs.size()), values);
  952. std::vector<py::object> wrapped_outputs;
  953. mgb_assert(wrapped_output_values.size() == outputs.size());
  954. for (auto&& output_value : wrapped_output_values) {
  955. wrapped_outputs.push_back(
  956. TensorWrapper::make(py_tensor_type, output_value));
  957. }
  958. return wrapped_outputs;
  959. });
  960. static py::function module_trace_hook;
  961. static std::shared_ptr<ModuleTraceTransformation> module_trace_transformation;
  962. static int module_tracing = 0;
  963. m.def("set_module_tracing", [=] {
  964. if (!module_trace_transformation) {
  965. mgb_assert(module_trace_hook);
  966. module_trace_transformation =
  967. std::make_shared<ModuleTraceTransformation>(module_trace_hook);
  968. }
  969. if (++module_tracing == 1) {
  970. transformations.register_at<TransformationManager::ModuleTrace>(
  971. module_trace_transformation);
  972. }
  973. });
  974. m.def("unset_module_tracing", [=] {
  975. if (--module_tracing == 0) {
  976. transformations.unregister<TransformationManager::ModuleTrace>(
  977. module_trace_transformation);
  978. }
  979. });
  980. m.def("is_tracing_module", [=] { return module_tracing > 0; });
  981. m.def("set_module_trace_hook",
  982. [](py::function function) { module_trace_hook = function; });
  983. m.def("begin_record_values", [] { Value::begin_record_values(); });
  984. m.def("end_record_values", [] {
  985. std::vector<std::pair<size_t, std::string>> reprs;
  986. auto values = Value::end_record_values();
  987. for (auto&& value : values) {
  988. reprs.push_back({value.id(), value.to_string()});
  989. }
  990. return reprs;
  991. });
  992. py::register_exception<TraceError>(m, "TraceError");
  993. }
  994. #undef MGE_PY_INTERFACE
  995. } // namespace mgb::imperative::python