You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

tensor.cpp 55 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439
  1. /**
  2. * \file imperative/python/src/tensor.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "megbrain/common.h"
  12. #include "megbrain/dtype.h"
  13. #include "megbrain/imperative/ops/autogen.h"
  14. #include "megbrain/imperative/ops/backward_graph.h"
  15. #include "megbrain/imperative/ops/utility.h"
  16. #include "megbrain/imperative/profiler.h"
  17. #include "megbrain/imperative/transformations/dim_expansion.h"
  18. #include "megbrain/imperative/transformations/dtype_promote.h"
  19. #include "megbrain/imperative/transformations/eval.h"
  20. #include "megbrain/imperative/transformations/lazy.h"
  21. #include "megbrain/imperative/transformations/scalar.h"
  22. #include "megbrain/imperative/transformations/symbol.h"
  23. #include "megbrain/imperative/transformations/trace.h"
  24. #include "megbrain/imperative/utils/map.h"
  25. #include "megbrain/opr/io.h"
  26. #include "megbrain/plugin/profiler.h"
  27. #include "megbrain/utils/stats.h"
  28. #include "megdnn/algorithm_cache.h"
  29. #include "./common.h"
  30. #include "./grad.h"
  31. #include "./graph_rt.h"
  32. #include "./helper.h"
  33. #include "./module_trace.h"
  34. #include "./numpy_dtypes.h"
  35. #include "./tensor.h"
  36. #include "./tensor_utils.h"
  37. #include "./transformation.h"
  38. #include <object.h>
  39. #include <pybind11/numpy.h>
  40. #include <pybind11/operators.h>
  41. #include <pybind11/pytypes.h>
  42. #include <pyerrors.h>
  43. #include <iterator>
  44. #include <range/v3/all.hpp>
  45. #include <string>
  46. #include <unordered_map>
  47. #include "../../src/impl/mgb_cg_impl.h"
  48. namespace py = pybind11;
  49. namespace views = ranges::views;
  50. namespace mgb::imperative::python {
  51. namespace {
  52. WeakKeyMap<ValueWeakRef, py::object> module_trace_info_map;
  53. struct SymbolVarContext {
  54. TransformationContext context;
  55. std::shared_ptr<SymbolTransformation> symbol_tsf;
  56. std::shared_ptr<ScalarTransformation> scalar_tsf;
  57. std::shared_ptr<DTypePromoteTransformation> dtype_promote_tsf;
  58. std::shared_ptr<DimExpansionTransformation> dim_expansion_tsf;
  59. SymbolVarContext(cg::ComputingGraph* graph) {
  60. symbol_tsf = std::make_shared<SymbolTransformation>(graph);
  61. scalar_tsf = std::make_shared<ScalarTransformation>();
  62. dtype_promote_tsf = std::make_shared<DTypePromoteTransformation>();
  63. dim_expansion_tsf = std::make_shared<DimExpansionTransformation>();
  64. Transformation::swap_context(context);
  65. }
  66. void init() {
  67. symbol_tsf->register_at(Transformation::top());
  68. scalar_tsf->register_at(Transformation::top());
  69. dtype_promote_tsf->register_at(Transformation::top());
  70. dim_expansion_tsf->register_at(Transformation::top());
  71. }
  72. ValueRef symvar2val(py::handle py_symbol_var) {
  73. auto* symbol_var = py_symbol_var.cast<PySymbolVar*>();
  74. ValueRef value = symbol_tsf->value_type().make(symbol_var->m_node);
  75. if (symbol_var->is_scalar) {
  76. value = scalar_tsf->value_type().make(value);
  77. }
  78. return value;
  79. }
  80. py::object val2symvar(py::handle typeobj, ValueRef value) {
  81. bool is_scalar = false;
  82. if (auto* scalar_value = value.as(scalar_tsf->value_type())) {
  83. value = scalar_value->value();
  84. is_scalar = true;
  85. }
  86. auto* node = value.cast(symbol_tsf->value_type()).node();
  87. auto py_symbol_var =
  88. typeobj(pybind11::cast(node, pybind11::return_value_policy::automatic));
  89. py_symbol_var.cast<PySymbolVar*>()->is_scalar = is_scalar;
  90. return py_symbol_var;
  91. }
  92. ~SymbolVarContext() { Transformation::swap_context(context); }
  93. };
  94. } // namespace
  95. interpreter::Interpreter::Channel* interpreter_for_py = nullptr;
  96. PyTypeObject* py_tensor_type = nullptr;
  97. pybind11::handle py_device_type = nullptr;
  98. PyObject* cpp_use_symbolic_shape;
  99. #define REGISTE_APPLY_FUNC(mode) \
  100. void set_##mode(py::object pyf) { mode = pyf.ptr(); }
  101. REGISTE_APPLY_FUNC(cpp_use_symbolic_shape)
  102. #undef REGISTE_APPLY_FUNC
  103. PyArray_Descr* _dtype_promotion(PyObject* const* args, size_t nargs);
  104. CompNode _get_device(PyObject* const* args, size_t nargs);
  105. PyObject* py_apply(
  106. PyObject* self, PyObject* const* args, size_t nargs /* , PyObject* kwnames */) {
  107. try {
  108. // if (kwnames && PyTuple_GET_SIZE(kwnames)) {
  109. // PyErr_SetString(PyExc_TypeError, "keyword argument not allowed");
  110. // return nullptr;
  111. // }
  112. if (nargs < 2) {
  113. PyErr_SetString(
  114. PyExc_TypeError,
  115. "py_apply expects one Op and at least one tensor "
  116. "as argument");
  117. return nullptr;
  118. }
  119. auto* py_op = args[0];
  120. ++args;
  121. --nargs;
  122. auto op = py::handle(py_op).cast<std::shared_ptr<OpDef>>();
  123. SmallVector<ValueRef, 8> tensors(nargs);
  124. SmallVector<bool, 8> is_symbol_var(nargs, false);
  125. ComputingGraph* cg = nullptr;
  126. for (size_t i = 0; i < nargs; ++i) {
  127. if ((!TensorWrapper::try_cast(args[i])) &&
  128. py::isinstance<PySymbolVar>(py::handle(args[i]))) {
  129. is_symbol_var[i] = true;
  130. ComputingGraph* cur_cg =
  131. py::handle(args[i]).cast<PySymbolVar*>()->m_node->owner_graph();
  132. if (cg == nullptr) {
  133. cg = cur_cg;
  134. } else {
  135. mgb_assert(cg == cur_cg);
  136. }
  137. }
  138. }
  139. mgb::CompNode target_cn;
  140. mgb::DType target_dtype;
  141. auto convert_pyinput_to_tensor = [&](size_t i) -> ValueRef {
  142. if (!target_dtype.valid()) {
  143. target_dtype = npy::dtype_np2mgb_descr(_dtype_promotion(args, nargs));
  144. target_cn = _get_device(args, nargs);
  145. }
  146. HostTensorND ht(target_cn);
  147. ht = npy::np2tensor(args[i], npy::Meth::copy_into(&ht), target_dtype);
  148. if (PyArray_Check(args[i]) || PyList_Check(args[i])) { // non scaler
  149. // py_tuple is not allowed here because of tracing
  150. return imperative::apply(
  151. CreateTensor(CreateTensor::Const, target_cn, ht.layout()),
  152. HostStorage::make(ht.storage()))[0];
  153. } else { // scaler
  154. return imperative::apply(
  155. CreateTensor(CreateTensor::Const, target_cn, target_dtype, {}),
  156. HostStorage::make(ht.storage()))[0];
  157. }
  158. };
  159. if (cg != nullptr) {
  160. // swap to a special context to reuse scalar handle
  161. size_t symbol_var_idx = 8;
  162. SymbolVarContext context(cg);
  163. context.init();
  164. for (size_t i = 0; i < nargs; ++i) {
  165. if (is_symbol_var[i]) {
  166. symbol_var_idx = i;
  167. tensors[i] = context.symvar2val(args[i]);
  168. } else if (
  169. DTypePromoteCfg::convert_input_enabled &&
  170. op->same_type<Elemwise>()) {
  171. tensors[i] = convert_pyinput_to_tensor(i);
  172. } else {
  173. PyErr_SetString(
  174. PyExc_TypeError, "py_apply expects tensor as inputs");
  175. return nullptr;
  176. }
  177. }
  178. auto outputs = imperative::apply(*op, tensors);
  179. auto ret = pybind11::tuple(outputs.size());
  180. auto typeobj = py::handle(args[symbol_var_idx]).get_type();
  181. for (size_t i = 0; i < outputs.size(); ++i) {
  182. ret[i] = context.val2symvar(typeobj, outputs[i]);
  183. }
  184. return ret.release().ptr();
  185. }
  186. for (size_t i = 0; i < nargs; ++i) {
  187. if (TensorWrapper* tw = TensorWrapper::try_cast(args[i])) {
  188. tensors[i] = tw->m_tensor->data();
  189. } else if (
  190. DTypePromoteCfg::convert_input_enabled &&
  191. op->same_type<Elemwise>()) {
  192. tensors[i] = convert_pyinput_to_tensor(i);
  193. } else {
  194. PyErr_SetString(PyExc_TypeError, "py_apply expects tensor as inputs");
  195. return nullptr;
  196. }
  197. }
  198. auto outputs = [&] { return imperative::apply(*op, tensors); }();
  199. size_t nout = outputs.size();
  200. auto ret = py::tuple(nout);
  201. for (size_t i = 0; i < nout; ++i) {
  202. ret[i] = TensorWrapper::make(py_tensor_type, std::move(outputs[i]));
  203. }
  204. return ret.release().ptr();
  205. }
  206. PYEXT17_TRANSLATE_EXC_RET(nullptr)
  207. }
  208. namespace {
  209. template <typename T>
  210. py::handle py_type() {
  211. if constexpr (std::is_same_v<T, py::int_>) {
  212. return (PyObject*)&PyLong_Type;
  213. } else if constexpr (std::is_same_v<T, py::float_>) {
  214. return (PyObject*)&PyFloat_Type;
  215. } else if constexpr (std::is_same_v<T, py::tuple>) {
  216. return (PyObject*)&PyTuple_Type;
  217. } else if constexpr (std::is_same_v<T, py::list>) {
  218. return (PyObject*)&PyList_Type;
  219. } else {
  220. static_assert(std::is_same_v<T, T>);
  221. }
  222. }
  223. template <typename T>
  224. auto scalar2storage(T val, CompNode cn, DType dtype) {
  225. using max_ctype_t = DTypeScalar::max_ctype;
  226. DTypeScalar scalar(dtype);
  227. scalar.set_retain_dtype(val);
  228. HostTensorStorage storage(cn);
  229. auto* raw_ptr = reinterpret_cast<dt_byte*>(new max_ctype_t());
  230. std::shared_ptr<dt_byte> raw_storage = {
  231. raw_ptr, [](dt_byte* ptr) { delete reinterpret_cast<max_ctype_t*>(ptr); }};
  232. storage.only_reset_raw_storage(cn, dtype.size(), raw_storage, 0);
  233. std::memcpy(storage.ptr(), scalar.storage(), dtype.size());
  234. return HostStorage::make(std::move(storage));
  235. }
  236. template <typename ctype>
  237. auto vec2storage(Span<DTypeScalar> vec, CompNode cn, DType dtype) {
  238. mgb_assert(vec.size() <= MEGDNN_MAX_NDIM);
  239. // TODO: use storage cache and modify ConstTensorCache to return (Host, Device)
  240. auto* raw_ptr = new ctype[MEGDNN_MAX_NDIM];
  241. for (size_t i = 0; i < vec.size(); ++i) {
  242. raw_ptr[i] = vec[i].get_cast<ctype>();
  243. }
  244. mgb_assert(sizeof(ctype) == dtype.size());
  245. std::shared_ptr<dt_byte> raw_storage = {
  246. reinterpret_cast<dt_byte*>(raw_ptr),
  247. [](dt_byte* ptr) { delete[] reinterpret_cast<ctype*>(ptr); }};
  248. HostTensorStorage storage(cn);
  249. storage.only_reset_raw_storage(cn, sizeof(ctype) * vec.size(), raw_storage, 0);
  250. return HostStorage::make(std::move(storage));
  251. }
  252. struct HostTensorArgs {
  253. ValueShape shape;
  254. DType dtype;
  255. HostStorage::ref_t storage;
  256. HostTensorND as_tensor_nd() const {
  257. HostTensorND ret(CompNode::default_cpu(), shape.as_tensor_shape(), dtype);
  258. ret.only_reset_raw_storage(*storage);
  259. return ret;
  260. }
  261. };
  262. template <typename seq_type, typename ctype>
  263. bool pyseq2hval(seq_type obj, CompNode cn, DType dtype, HostTensorArgs& ret) {
  264. auto size = obj.size();
  265. if (size > MEGDNN_MAX_NDIM) {
  266. return false;
  267. }
  268. ctype items[size];
  269. for (size_t i = 0; i < size; ++i) {
  270. py::handle item = obj[i];
  271. if (item.get_type().is(py_type<py::int_>())) {
  272. items[i] = (ctype)(dt_int32)item.template cast<py::int_>();
  273. } else if (item.get_type().is(py_type<py::float_>())) {
  274. items[i] = (ctype)(dt_float32)item.template cast<py::float_>();
  275. } else {
  276. return false;
  277. }
  278. }
  279. mgb_assert(sizeof(ctype) == dtype.size());
  280. auto* raw_ptr = new ctype[size];
  281. std::shared_ptr<dt_byte> raw_storage = {
  282. reinterpret_cast<dt_byte*>(raw_ptr),
  283. [](dt_byte* ptr) { delete[] reinterpret_cast<ctype*>(ptr); }};
  284. HostTensorStorage storage(cn);
  285. storage.only_reset_raw_storage(cn, sizeof(ctype) * size, raw_storage, 0);
  286. std::memcpy(storage.ptr(), items, sizeof(ctype) * size);
  287. ret.dtype = dtype;
  288. ret.shape = {size};
  289. ret.storage = HostStorage::make(std::move(storage));
  290. return true;
  291. }
  292. template <typename seq_type>
  293. bool pyseq2hval(seq_type obj, CompNode cn, HostTensorArgs& ret) {
  294. auto size = obj.size();
  295. if (size > MEGDNN_MAX_NDIM) {
  296. return false;
  297. }
  298. DTypeScalar items[size];
  299. DType dtype;
  300. for (size_t i = 0; i < size; ++i) {
  301. auto&& item = obj[i];
  302. if (item.get_type().is(py_type<py::int_>())) {
  303. items[i] = (dt_int32)item.template cast<py::int_>();
  304. if (!dtype.valid()) {
  305. dtype = dtype::Int32();
  306. } else if (dtype != dtype::Int32() && dtype != dtype::Float32()) {
  307. return false;
  308. }
  309. } else if (item.get_type().is(py_type<py::float_>())) {
  310. items[i] = (dt_float32)item.template cast<py::float_>();
  311. if (!dtype.valid()) {
  312. dtype = dtype::Float32();
  313. } else if (dtype == dtype::Int32()) {
  314. dtype = dtype::Float32();
  315. } else if (dtype != dtype::Float32()) {
  316. return false;
  317. }
  318. } else {
  319. return false;
  320. }
  321. }
  322. if (!dtype.valid()) {
  323. dtype = dtype::Float32();
  324. }
  325. ret.dtype = dtype;
  326. ret.shape = {size};
  327. if (dtype == dtype::Int32()) {
  328. ret.storage = vec2storage<dt_int32>({items, size}, cn, dtype);
  329. } else if (dtype == dtype::Float32()) {
  330. ret.storage = vec2storage<dt_float32>({items, size}, cn, dtype);
  331. } else {
  332. mgb_assert(false);
  333. }
  334. return true;
  335. }
  336. template <typename seq_type>
  337. bool pyseq2hval(seq_type obj, CompNode cn, DType dtype, HostTensorArgs& ret) {
  338. if (dtype == dtype::Int32()) {
  339. return pyseq2hval<seq_type, dt_int32>(obj, cn, dtype, ret);
  340. } else if (dtype == dtype::Float32()) {
  341. return pyseq2hval<seq_type, dt_float32>(obj, cn, dtype, ret);
  342. } else if (!dtype.valid()) {
  343. return pyseq2hval<seq_type>(obj, cn, ret);
  344. } else {
  345. return false;
  346. }
  347. }
  348. bool pyarr2hval(py::array obj, CompNode cn, DType dtype, HostTensorArgs& ret) {
  349. auto data = obj.cast<py::array>();
  350. auto strides = data.strides();
  351. bool need_squeeze = false;
  352. for (size_t i = 0; i < data.ndim(); ++i) {
  353. if (strides[i] == 0) {
  354. need_squeeze = true;
  355. break;
  356. }
  357. }
  358. if (need_squeeze) {
  359. std::vector<size_t> shape;
  360. for (size_t i = 0; i < data.ndim(); ++i) {
  361. shape.push_back(data.shape(i));
  362. }
  363. data = data.squeeze();
  364. data.resize(shape);
  365. }
  366. HostTensorND retnd(cn);
  367. retnd = npy::np2tensor(data.ptr(), npy::Meth::copy_into(&retnd), dtype);
  368. if (!dtype.valid()) {
  369. dtype = retnd.dtype();
  370. }
  371. mgb_assert(
  372. retnd.layout().is_empty() || retnd.layout().is_contiguous(),
  373. "host value should be continuous");
  374. for (size_t i = 0; i < data.ndim(); ++i) {
  375. ret.shape[ret.shape.ndim++] = data.shape(i);
  376. }
  377. ret.dtype = dtype;
  378. ret.storage = HostStorage::make(retnd.storage());
  379. return true;
  380. }
  381. bool pyint2hval(py::int_ obj, CompNode cn, DType dtype, HostTensorArgs& ret) {
  382. if (!dtype.valid()) {
  383. dtype = dtype::Int32();
  384. }
  385. ret.dtype = dtype;
  386. ret.storage = scalar2storage((dt_int32)obj, cn, dtype);
  387. return true;
  388. }
  389. bool pyfloat2hval(py::float_ obj, CompNode cn, DType dtype, HostTensorArgs& ret) {
  390. if (!dtype.valid()) {
  391. dtype = dtype::Float32();
  392. }
  393. ret.dtype = dtype;
  394. ret.storage = scalar2storage((dt_float32)obj, cn, dtype);
  395. return true;
  396. }
  397. HostTensorArgs pyobj2hval(py::object obj, CompNode cn, DType dtype) {
  398. HostTensorArgs ret;
  399. bool success = false;
  400. // check order: float -> int -> tuple(int -> float) -> list(int -> float)
  401. // only handle `exact` pytype, isinstance also accepts subtype
  402. // for example, isinstance(True, int) == True
  403. if (obj.get_type().is(py_type<py::float_>())) {
  404. success = pyfloat2hval(py::float_(obj), cn, dtype, ret);
  405. } else if (obj.get_type().is(py_type<py::int_>())) { // py::bool_ is py::int_
  406. success = pyint2hval(py::int_(obj), cn, dtype, ret);
  407. } else if (obj.get_type().is(py_type<py::tuple>())) {
  408. success = pyseq2hval<py::tuple>(py::tuple(obj), cn, dtype, ret);
  409. } else if (obj.get_type().is(py_type<py::list>())) {
  410. success = pyseq2hval<py::list>(py::list(obj), cn, dtype, ret);
  411. } else if (obj.is_none()) {
  412. obj = py::list(0);
  413. }
  414. if (!success) {
  415. success = pyarr2hval(obj, cn, dtype, ret);
  416. }
  417. mgb_assert(success);
  418. return ret;
  419. }
  420. struct PyArgDesc {
  421. const char* name;
  422. py::object (*default_value)();
  423. };
  424. struct PyArgDescs {
  425. std::vector<PyArgDesc> items;
  426. ssize_t (*name2idx)(const char* name);
  427. };
  428. py::tuple parse_args(py::tuple args, const PyArgDescs& descs) {
  429. size_t nr_args = args.size();
  430. size_t nr_items = descs.items.size();
  431. mgb_assert(nr_args <= nr_items, "too many args");
  432. if (nr_args == nr_items) {
  433. return args;
  434. }
  435. py::tuple ret(nr_items);
  436. for (size_t i = 0; i < nr_args; ++i) {
  437. ret[i] = args[i];
  438. }
  439. for (size_t i = nr_args; i < nr_items; ++i) {
  440. ret[i] = descs.items[i].default_value();
  441. }
  442. return ret;
  443. }
  444. py::tuple parse_args_and_kwargs(
  445. py::tuple args, py::dict kwargs, const PyArgDescs& descs) {
  446. size_t nr_args = args.size();
  447. size_t nr_kwargs = kwargs.size();
  448. size_t nr_items = descs.items.size();
  449. mgb_assert(nr_args + nr_kwargs <= nr_items, "too many args");
  450. if (nr_args == nr_items) {
  451. return args;
  452. }
  453. py::tuple ret(nr_items);
  454. for (size_t i = 0; i < nr_args; ++i) {
  455. ret[i] = args[i];
  456. }
  457. bool has_value[nr_items - nr_args];
  458. for (size_t i = nr_args; i < nr_items; ++i) {
  459. has_value[i - nr_args] = false;
  460. }
  461. for (auto&& [k, v] : kwargs) {
  462. auto key = py::str(k).cast<std::string>();
  463. ssize_t index = descs.name2idx(key.c_str());
  464. mgb_assert(index >= nr_args);
  465. ret[index] = v;
  466. has_value[index - nr_args] = true;
  467. }
  468. for (size_t i = nr_args; i < nr_items; ++i) {
  469. if (!has_value[i - nr_args]) {
  470. ret[i] = descs.items[i].default_value();
  471. }
  472. }
  473. return ret;
  474. }
  475. CompNode as_comp_node(const std::string& name) {
  476. thread_local struct {
  477. std::string name;
  478. CompNode cn;
  479. } cached;
  480. if (cached.name != name) {
  481. cached.name = name;
  482. cached.cn = CompNode::load(name);
  483. }
  484. return cached.cn;
  485. }
  486. CompNode as_comp_node(py::object py_device) {
  487. std::optional<std::string> device_name;
  488. if (py_device.is_none() || py::str::check_(py_device)) {
  489. auto cls = py::handle(reinterpret_cast<PyObject*>(py_tensor_type));
  490. auto dmap_callback = cls.attr("dmap_callback");
  491. std::string name;
  492. if (dmap_callback.is_none() && py_device.is_none()) {
  493. name = get_default_device();
  494. } else {
  495. if (py_device.is_none()) {
  496. py_device = py::str(get_default_device());
  497. }
  498. if (!dmap_callback.is_none()) {
  499. py_device = dmap_callback(py_device);
  500. }
  501. name = py::str(py_device).cast<std::string>();
  502. }
  503. return as_comp_node(name);
  504. } else {
  505. if (py::isinstance(py_device, py_device_type)) {
  506. py_device = py_device.attr("_cn");
  507. }
  508. mgb_assert(py::isinstance(py_device, py_comp_node_type));
  509. return py_device.cast<CompNode>();
  510. }
  511. }
  512. template <char... Chars>
  513. bool compare_cstr(const char* cstr) {
  514. return (((*cstr++) == Chars) && ...) && *cstr == '\0';
  515. }
  516. ssize_t name2idx(const char* name) {
  517. const char* ch = name;
  518. // TODO: trie
  519. // clang-format off
  520. switch (*ch++) {
  521. case 'd':
  522. switch (*ch++) {
  523. // data
  524. case 'a': return compare_cstr<'t', 'a'>(ch) ? 0 : -1;
  525. // dtype
  526. case 't': return compare_cstr<'y', 'p', 'e'>(ch) ? 1 : -1;
  527. // device
  528. case 'e': return compare_cstr<'v', 'i', 'c', 'e'>(ch) ? 2 : -1;
  529. }
  530. case 'i':
  531. // is_const
  532. return compare_cstr<'s', '_', 'c', 'o', 'n', 's', 't'>(ch) ? 3 : -1;
  533. case 'n':
  534. switch (*ch++) {
  535. // no_cache
  536. case 'o': return compare_cstr<'_', 'c', 'a', 'c', 'h', 'e'>(ch) ? 4 : -1;
  537. // name
  538. case 'a': return compare_cstr<'m', 'e'>(ch) ? 5 : -1;
  539. }
  540. }
  541. // clang-format on
  542. return -1;
  543. }
  544. } // namespace
  545. TensorWrapper::TensorWrapper(PyObject* args, PyObject* kwargs) {
  546. static PyArgDescs descs = {
  547. {
  548. {"data", []() -> py::object { return py::none(); }},
  549. {"dtype", []() -> py::object { return py::none(); }},
  550. {"device", []() -> py::object { return py::none(); }},
  551. {"is_const", []() -> py::object { return py::bool_(false); }},
  552. {"no_cache", []() -> py::object { return py::bool_(false); }},
  553. {"name", []() -> py::object { return py::none(); }},
  554. },
  555. name2idx};
  556. py::detail::loader_life_support life_sup; // FIXME!!!required to cast DType
  557. auto tup = py::reinterpret_borrow<py::tuple>(args);
  558. if (kwargs) {
  559. tup = parse_args_and_kwargs(
  560. tup, py::reinterpret_borrow<py::dict>(kwargs), descs);
  561. } else {
  562. tup = parse_args(tup, descs);
  563. }
  564. mgb_assert(tup.size() == 6);
  565. if (auto* t = try_cast(tup[0].ptr())) {
  566. m_tensor = t->m_tensor->copy();
  567. } else {
  568. auto data = tup[0];
  569. DType dtype = tup[1].cast<DType>();
  570. bool is_const = tup[3].cast<bool>();
  571. bool no_cache = tup[4].cast<bool>();
  572. std::string name;
  573. if (!tup[5].is_none()) {
  574. name = tup[5].cast<std::string>();
  575. }
  576. CompNode cn = as_comp_node(tup[2]);
  577. {
  578. CreateTensor::Kind kind = is_const ? CreateTensor::Const
  579. : no_cache ? CreateTensor::Unique
  580. : CreateTensor::Common;
  581. auto&& hval = pyobj2hval(data, cn, dtype);
  582. auto val = imperative::apply(
  583. CreateTensor(kind, cn, hval.dtype, hval.shape), hval.storage)[0];
  584. m_tensor.emplace(val);
  585. }
  586. if (!name.empty()) {
  587. m_tensor->reset(imperative::apply(RenameValue(name), m_tensor->data())[0]);
  588. }
  589. }
  590. mgb_assert(m_tensor->data());
  591. }
  592. PyObject* TensorWrapper::module_trace_info() {
  593. if (auto module_trace_info = module_trace_info_map.try_get(m_tensor->data())) {
  594. if (module_trace_info->ptr()) {
  595. return module_trace_info->inc_ref().ptr();
  596. }
  597. }
  598. PyErr_SetString(
  599. PyExc_AttributeError,
  600. "Has no attribute named \'_NodeMixin__node\', please "
  601. "set it first");
  602. return nullptr;
  603. }
  604. void TensorWrapper::set_module_trace_info(PyObject* obj) {
  605. // TODO: erase when obj == nullptr
  606. module_trace_info_map[m_tensor->data()] = py::reinterpret_borrow<py::object>(obj);
  607. }
  608. void TensorWrapper::_set_name(PyObject* dest) {
  609. auto py_dest = py::reinterpret_borrow<py::object>(dest);
  610. auto name = py_dest.cast<std::string>();
  611. m_tensor->set_name(name);
  612. }
  613. PyObject* TensorWrapper::_detail() {
  614. return py::str(m_tensor->data().unwrap().to_string()).release().ptr();
  615. }
  616. void TensorWrapper::_watch() {
  617. m_tensor->data().watch();
  618. }
  619. PyObject* TensorWrapper::shape() {
  620. auto shape = m_tensor->shape();
  621. if (!shape) {
  622. Py_RETURN_NONE;
  623. }
  624. py::tuple ret(shape->ndim);
  625. for (size_t i = 0; i < shape->ndim; ++i) {
  626. ret[i] = shape->at(i);
  627. }
  628. return ret.release().ptr();
  629. }
  630. PyObject* TensorWrapper::dtype() {
  631. return py::cast(m_tensor->dtype()).release().ptr();
  632. }
  633. PyObject* TensorWrapper::device() {
  634. return py::cast(m_tensor->comp_node()).release().ptr();
  635. }
  636. PyObject* TensorWrapper::numpy() {
  637. auto hv = m_tensor->numpy();
  638. if (!hv) {
  639. PyErr_SetString(PyExc_ValueError, "tensor invalid");
  640. return nullptr;
  641. }
  642. auto arr = py::reinterpret_steal<py::array>(
  643. npy::ndarray_from_tensor(hv->as_nd(true), npy::ShareType::TRY_SHARE));
  644. if (hv->shape().is_scalar()) {
  645. mgb_assert(PyArray_Check(arr.ptr()));
  646. return PyArray_Squeeze(reinterpret_cast<PyArrayObject*>(arr.ptr()));
  647. }
  648. return arr.release().ptr();
  649. }
  650. void TensorWrapper::reset(PyObject* tensor) {
  651. TensorWrapper* t = TensorWrapper::try_cast(tensor);
  652. if (!t) {
  653. throw py::type_error("expect Tensor");
  654. }
  655. m_tensor->reset(t->m_tensor->data());
  656. }
  657. PyObject* TensorWrapper::detach() {
  658. auto detached = imperative::apply(DetachGrad(), m_tensor->data())[0];
  659. return TensorWrapper::make(py_tensor_type, detached).release().ptr();
  660. }
  661. PyObject* TensorWrapper::_dev_tensor() {
  662. auto dv = m_tensor->data().dev_tensor();
  663. // TODO: handle scalar
  664. return py::cast(dv->as_nd(true)).release().ptr();
  665. }
  666. void TensorWrapper::_drop() {
  667. imperative::apply(DTRCommand(DTRCommand::Drop), m_tensor->data());
  668. }
  669. PyObject* TensorWrapper::isscalar() {
  670. if (m_tensor->is_scalar()) {
  671. Py_RETURN_TRUE;
  672. } else {
  673. Py_RETURN_FALSE;
  674. }
  675. }
  676. struct TensorWeakRef {
  677. ValueWeakRef data;
  678. TensorWeakRef(const TensorWrapper& tw) : data(tw.m_tensor->data()) {}
  679. py::object operator()() {
  680. if (auto p = data.lock()) {
  681. return TensorWrapper::make(py_tensor_type, p);
  682. }
  683. return py::none();
  684. }
  685. };
  686. #ifdef METH_FASTCALL
  687. #define MGE_PY_INTERFACE(NAME, FUNC) \
  688. { #NAME, (PyCFunction)FUNC, METH_FASTCALL, nullptr }
  689. #else
  690. #define WRAP_FUNC_PY35(FUNC) \
  691. PyObject* py35_##FUNC(PyObject* self, PyObject* args) { \
  692. auto* arr = &PyTuple_GET_ITEM(args, 0); \
  693. auto size = PyTuple_GET_SIZE(args); \
  694. return FUNC(self, arr, size); \
  695. }
  696. WRAP_FUNC_PY35(py_apply);
  697. WRAP_FUNC_PY35(dtype_promotion);
  698. WRAP_FUNC_PY35(get_device);
  699. WRAP_FUNC_PY35(make_shape_tuple);
  700. WRAP_FUNC_PY35(getitem_cpp);
  701. WRAP_FUNC_PY35(setitem_cpp);
  702. WRAP_FUNC_PY35(split_cpp);
  703. WRAP_FUNC_PY35(expand_dims_cpp);
  704. WRAP_FUNC_PY35(squeeze_cpp);
  705. WRAP_FUNC_PY35(transpose_cpp);
  706. WRAP_FUNC_PY35(broadcast_cpp);
  707. WRAP_FUNC_PY35(reshape_cpp);
  708. WRAP_FUNC_PY35(adaptive_pool2d_cpp);
  709. WRAP_FUNC_PY35(Const);
  710. WRAP_FUNC_PY35(astype_cpp);
  711. WRAP_FUNC_PY35(matmul_cpp);
  712. WRAP_FUNC_PY35(batched_matmul_cpp);
  713. WRAP_FUNC_PY35(convert_single_value_cpp);
  714. WRAP_FUNC_PY35(convert_inputs_cpp);
  715. WRAP_FUNC_PY35(astensor1d_cpp);
  716. WRAP_FUNC_PY35(pixel_shuffle_cpp);
  717. #undef WRAP_FUNC_PY35
  718. #define MGE_PY_INTERFACE(NAME, FUNC) \
  719. { #NAME, (PyCFunction)py35_##FUNC, METH_VARARGS, nullptr }
  720. #endif
  721. void init_tensor(py::module m) {
  722. imperative::Tensor::static_initialize();
  723. static auto& transformations = TransformationManager::get_instance();
  724. using Segment = TransformationManager::Segment;
  725. using Channel = interpreter::Interpreter::Channel;
  726. auto* channel =
  727. imperative::ResourceManager::create_global<std::unique_ptr<Channel>>(
  728. interpreter::Interpreter::inst().create_channel())
  729. ->get();
  730. interpreter_for_py = channel;
  731. MGB_MARK_USED_VAR(
  732. transformations
  733. .register_at<Segment::Eval>(
  734. std::make_shared<InterpreterTransformation>(
  735. std::shared_ptr<Channel>(channel, [](Channel*) {})))
  736. .release());
  737. MGB_MARK_USED_VAR(transformations
  738. .register_at<Segment::Scalar>(
  739. std::make_shared<ScalarTransformation>())
  740. .release());
  741. MGB_MARK_USED_VAR(transformations
  742. .register_at<Segment::DTypePromote>(
  743. std::make_shared<DTypePromoteTransformation>())
  744. .release());
  745. MGB_MARK_USED_VAR(transformations
  746. .register_at<Segment::DimExpansion>(
  747. std::make_shared<DimExpansionTransformation>())
  748. .release());
  749. static py::exception<interpreter::AsyncError> py_async_error(
  750. m, "AsyncError", PyExc_RuntimeError);
  751. py::register_exception_translator([](std::exception_ptr p) {
  752. try {
  753. if (p)
  754. std::rethrow_exception(p);
  755. } catch (const interpreter::AsyncError& e) {
  756. pyext17::pybind11_translate_exception(e.nested_ptr());
  757. if (PyErr_Occurred()) {
  758. PyObject *exc, *val, *tb;
  759. PyErr_Fetch(&exc, &val, &tb);
  760. PyErr_NormalizeException(&exc, &val, &tb);
  761. if (tb) {
  762. PyException_SetTraceback(val, tb);
  763. }
  764. auto val2 = py_async_error.py::object::operator()(
  765. "An async error is reported. See above for the actual cause."
  766. " Hint: This is where it is reported, not where it happened."
  767. " You may call `megengine.config.async_level = 0 "
  768. "to get better error reporting.");
  769. PyException_SetCause(
  770. val2.ptr(), val); // PyException_SetCause steals reference
  771. Py_XDECREF(exc);
  772. Py_XDECREF(tb);
  773. PyErr_Restore(
  774. py_async_error.inc_ref().ptr(), val2.release().ptr(), nullptr);
  775. } else {
  776. py_async_error("Unkown async error");
  777. }
  778. }
  779. });
  780. auto* tensor_type =
  781. TensorWrapper::wrap_t::type()
  782. .def<&TensorWrapper::numpy>("numpy")
  783. .def_getset<&TensorWrapper::shape>("shape")
  784. .def_getset<&TensorWrapper::dtype>("dtype")
  785. .def_getset<&TensorWrapper::device>("device")
  786. .def<&TensorWrapper::reset>("_reset")
  787. .def<&TensorWrapper::isscalar>("_isscalar")
  788. .def<&TensorWrapper::detach>("detach")
  789. // TODO: remove this
  790. .def<&TensorWrapper::_dev_tensor>("_dev_tensor")
  791. .def<&TensorWrapper::_drop>("_drop")
  792. .def<&TensorWrapper::_detail>("_detail")
  793. .def<&TensorWrapper::_set_name>("_set_name")
  794. .def<&TensorWrapper::_watch>("_watch")
  795. .def_getset<
  796. &TensorWrapper::module_trace_info,
  797. &TensorWrapper::set_module_trace_info>("_NodeMixin__node")
  798. .finalize();
  799. if (!tensor_type)
  800. throw py::error_already_set();
  801. py::setattr(m, "Tensor", tensor_type);
  802. py::class_<TensorWeakRef>(m, "TensorWeakRef")
  803. .def(py::init<const TensorWrapper&>())
  804. .def("__call__", &TensorWeakRef::operator());
  805. py::class_<PySymbolVar, std::shared_ptr<PySymbolVar>>(m, "SymbolVar")
  806. .def_property_readonly(
  807. "dtype", [](PySymbolVar* v) { return v->m_node->dtype(); })
  808. .def_property(
  809. "var", [](PySymbolVar* v) { return v->m_node; },
  810. [](PySymbolVar* s, cg::VarNode* v) { s->m_node = v; })
  811. .def_property_readonly(
  812. "device", [](PySymbolVar* v) { return v->m_node->comp_node(); })
  813. .def_property_readonly(
  814. "graph", [](PySymbolVar* v) { return v->m_node->owner_graph(); })
  815. .def_property_readonly(
  816. "shape",
  817. [](PySymbolVar* v) -> const TensorShape* {
  818. auto&& mgr = v->m_node->owner_graph()->static_infer_manager();
  819. return mgr.infer_shape_fallible(v->m_node);
  820. })
  821. .def("numpy",
  822. [](PySymbolVar* v) {
  823. auto&& mgr = v->m_node->owner_graph()->static_infer_manager();
  824. auto&& type = mgr.get_infer_type(v->m_node);
  825. using InferType = cg::static_infer::InferType;
  826. if (!(type.value & (InferType::CONST | InferType::RT_STATIC))) {
  827. throw py::value_error("value invalid!");
  828. }
  829. auto* val = mgr.infer_value_fallible(v->m_node);
  830. if (!val) {
  831. throw py::value_error("value invalid!");
  832. }
  833. auto np_val = py::cast(*val).attr("numpy")();
  834. return np_val;
  835. })
  836. .def("_isscalar", [](PySymbolVar* v) { return v->is_scalar; })
  837. .def(py::init([](cg::VarNode* node) {
  838. return std::make_shared<PySymbolVar>(node);
  839. }),
  840. py::arg() = nullptr);
  841. static PyMethodDef method_defs[] = {
  842. MGE_PY_INTERFACE(apply, py_apply),
  843. MGE_PY_INTERFACE(dtype_promotion, dtype_promotion),
  844. MGE_PY_INTERFACE(get_device, get_device),
  845. MGE_PY_INTERFACE(make_shape_tuple, make_shape_tuple),
  846. MGE_PY_INTERFACE(getitem_cpp, getitem_cpp),
  847. MGE_PY_INTERFACE(setitem_cpp, setitem_cpp),
  848. MGE_PY_INTERFACE(split_cpp, split_cpp),
  849. MGE_PY_INTERFACE(expand_dims_cpp, expand_dims_cpp),
  850. MGE_PY_INTERFACE(squeeze_cpp, squeeze_cpp),
  851. MGE_PY_INTERFACE(transpose_cpp, transpose_cpp),
  852. MGE_PY_INTERFACE(broadcast_cpp, broadcast_cpp),
  853. MGE_PY_INTERFACE(reshape_cpp, reshape_cpp),
  854. MGE_PY_INTERFACE(adaptive_pool2d_cpp, adaptive_pool2d_cpp),
  855. MGE_PY_INTERFACE(Const, Const),
  856. MGE_PY_INTERFACE(astype_cpp, astype_cpp),
  857. MGE_PY_INTERFACE(matmul_cpp, matmul_cpp),
  858. MGE_PY_INTERFACE(batched_matmul_cpp, batched_matmul_cpp),
  859. MGE_PY_INTERFACE(convert_single_value_cpp, convert_single_value_cpp),
  860. MGE_PY_INTERFACE(convert_inputs_cpp, convert_inputs_cpp),
  861. MGE_PY_INTERFACE(astensor1d_cpp, astensor1d_cpp),
  862. MGE_PY_INTERFACE(pixel_shuffle_cpp, pixel_shuffle_cpp),
  863. {nullptr, nullptr, 0, nullptr}};
  864. for (auto&& def : method_defs) {
  865. if (def.ml_meth != nullptr) {
  866. auto* func = PyCFunction_NewEx(&def, nullptr, nullptr);
  867. if (!func)
  868. throw py::error_already_set();
  869. py::setattr(m, def.ml_name, func);
  870. }
  871. }
  872. static constexpr auto sync_py_task_q = [] {
  873. py::gil_scoped_release _;
  874. py_task_q.wait_all_task_finish();
  875. };
  876. m.def("clear_candidates", [channel]() { channel->clear_candidates(); });
  877. m.def("set_option", [channel](std::string name, size_t value) {
  878. channel->set_option(name, value);
  879. });
  880. m.def("get_option",
  881. [channel](std::string name) { return channel->get_option(name); });
  882. m.def("push_scope", [channel](std::string name) {
  883. Transformation::push_scope(name);
  884. channel->push_scope(name);
  885. });
  886. m.def("pop_scope", [channel](std::string name) {
  887. channel->pop_scope(name);
  888. Transformation::pop_scope(name);
  889. });
  890. m.def("start_profile", [channel](imperative::Profiler::options_t options) {
  891. channel->sync();
  892. imperative::Profiler::load_options(std::move(options));
  893. imperative::Profiler::start_profile();
  894. channel->start_profile();
  895. });
  896. m.def("stop_profile", [channel]() -> std::function<void(std::string, std::string)> {
  897. channel->stop_profile();
  898. channel->sync();
  899. imperative::Profiler::stop_profile();
  900. auto results = std::make_shared<imperative::Profiler::bundle_t>(
  901. imperative::Profiler::collect());
  902. return [results = results](std::string basename, std::string format) mutable {
  903. imperative::Profiler::dump_profile(basename, format, std::move(*results));
  904. results = nullptr;
  905. };
  906. });
  907. m.def("sync", [channel]() {
  908. if (channel->check_available()) {
  909. channel->sync();
  910. }
  911. sync_py_task_q();
  912. });
  913. m.def("full_sync", [channel]() {
  914. if (channel->check_available()) {
  915. channel->sync();
  916. }
  917. CompNode::sync_all();
  918. CompNode::foreach ([](CompNode cn) {
  919. auto err = cn.check_async_error();
  920. mgb_assert(!err, "%s", err->what());
  921. });
  922. sync_py_task_q();
  923. });
  924. m.def("close", [channel]() {
  925. channel->close();
  926. sync_py_task_q();
  927. });
  928. py::handle grad_key_type =
  929. GradKeyWrapper::wrap_t::type()
  930. .def<&GradKeyWrapper::attach>("attach")
  931. .def<&GradKeyWrapper::is_attached_to>("is_attached_to")
  932. .def_getset<&GradKeyWrapper::get_name, &GradKeyWrapper::set_name>(
  933. "name")
  934. .def<&GradKeyWrapper::enter>("enter")
  935. .def<&GradKeyWrapper::exit>("exit")
  936. .def<&GradKeyWrapper::suppress>("suppress")
  937. .def<&GradKeyWrapper::resume>("resume")
  938. .finalize();
  939. if (!grad_key_type)
  940. throw py::error_already_set();
  941. py::setattr(m, "GradKey", grad_key_type);
  942. m.def("backward", &GradKeyWrapper::backward);
  943. m.def("get_backward_closure", &GradKeyWrapper::get_backward_closure);
  944. m.def("set_py_tensor_type", [](py::object type_obj) {
  945. py_tensor_type = reinterpret_cast<PyTypeObject*>(type_obj.inc_ref().ptr());
  946. });
  947. m.def("set_py_device_type",
  948. [](py::object type_obj) { py_device_type = type_obj.inc_ref(); });
  949. /**
  950. * \brief trace proxy
  951. *
  952. */
  953. struct Trace {
  954. bool symbolic = false;
  955. bool no_exec = false;
  956. bool capture_as_const = false;
  957. bool profile = false;
  958. bool record_input_shapes = false;
  959. py::function options_visitor;
  960. std::shared_ptr<TracingTransformation> tracing;
  961. std::shared_ptr<CompiledTransformation> compiled;
  962. std::shared_ptr<LazyEvalTransformation> lazy_eval;
  963. std::pair<size_t, std::shared_ptr<GraphProfiler>> profiler;
  964. std::optional<TraceResult> trace_result;
  965. std::function<bool(py::object, py::object)> array_comparator;
  966. std::unique_ptr<CleanupGuard<>> tracing_guard;
  967. std::unique_ptr<CleanupGuard<>> compiled_guard;
  968. std::unique_ptr<CleanupGuard<>> lazy_eval_guard;
  969. bool compare_value(ValueRef lhs, ValueRef rhs) {
  970. auto lvalue = lhs.cast_ref<HostValue>();
  971. auto rvalue = rhs.cast_ref<HostValue>();
  972. if (lvalue->shape() != rvalue->shape()) {
  973. return false;
  974. }
  975. if (lvalue->shape().total_nr_elems() == 1) {
  976. return lvalue->item() == rvalue->item();
  977. }
  978. HostTensorND lnd = lvalue->as_nd(true);
  979. HostTensorND rnd = rvalue->as_nd(true);
  980. auto larr = py::reinterpret_steal<py::array>(
  981. npy::ndarray_from_tensor(lnd, npy::ShareType::TRY_SHARE));
  982. auto rarr = py::reinterpret_steal<py::array>(
  983. npy::ndarray_from_tensor(rnd, npy::ShareType::TRY_SHARE));
  984. return array_comparator(larr, rarr);
  985. }
  986. void enter() {
  987. auto& self = *this;
  988. if (!self.trace_result) { // untraced
  989. self.tracing = std::make_shared<TracingTransformation>(
  990. self.capture_as_const, self.record_input_shapes);
  991. if (self.symbolic) {
  992. self.lazy_eval =
  993. std::make_shared<LazyEvalTransformation>(self.no_exec);
  994. self.options_visitor(py::cast(&self.lazy_eval->options()));
  995. }
  996. } else if (!self.compiled) { // traced but not compiled
  997. using namespace std::placeholders;
  998. self.compiled = std::make_shared<CompiledTransformation>(
  999. *self.trace_result, self.record_input_shapes);
  1000. self.compiled->set_value_comparator(
  1001. std::bind(&Trace::compare_value, this, _1, _2));
  1002. self.options_visitor(py::cast(&self.compiled->options()));
  1003. self.compiled->compile();
  1004. }
  1005. // register transformations
  1006. if (self.compiled) {
  1007. if (self.profile) {
  1008. auto& current_graph = self.compiled->graph();
  1009. if (self.profiler.first != self.compiled->graph().id()) {
  1010. // graph changed
  1011. self.profiler = std::make_pair(
  1012. current_graph.id(),
  1013. std::make_shared<GraphProfiler>(&current_graph));
  1014. }
  1015. }
  1016. compiled_guard =
  1017. transformations.register_at<Segment::Trace>(self.compiled);
  1018. // start execute because InputCallback depends
  1019. self.compiled->execute();
  1020. } else if (self.tracing) {
  1021. tracing_guard =
  1022. transformations.register_at<Segment::Trace>(self.tracing);
  1023. if (self.lazy_eval) {
  1024. lazy_eval_guard =
  1025. transformations.register_at<Segment::Eval>(self.lazy_eval);
  1026. }
  1027. } else {
  1028. mgb_throw(MegBrainError, "invalid state: neither tracing nor compiled");
  1029. }
  1030. }
  1031. void exit() {
  1032. auto& self = *this;
  1033. if (self.tracing) {
  1034. tracing_guard.reset();
  1035. self.trace_result = self.tracing->get_result();
  1036. self.tracing.reset();
  1037. if (self.lazy_eval) {
  1038. auto lazy_eval = std::move(self.lazy_eval);
  1039. lazy_eval_guard.reset();
  1040. lazy_eval->check_exception();
  1041. }
  1042. } else if (self.compiled) {
  1043. compiled_guard.reset();
  1044. self.compiled->wait();
  1045. } else {
  1046. mgb_throw(MegBrainError, "invalid state: neither tracing nor compiled");
  1047. }
  1048. }
  1049. VarNodeArray dump(
  1050. std::shared_ptr<ComputingGraph> graph,
  1051. std::vector<std::tuple<std::string, std::string, TensorShape>> inputs,
  1052. std::vector<std::pair<std::string, std::string>> outputs,
  1053. bool prefer_input_names) {
  1054. auto& self = *this;
  1055. mgb_assert(self.trace_result);
  1056. // mark is like "arg_0", "kwarg_xxx", "output_0" ...
  1057. std::unordered_map<std::string, size_t> mark2var;
  1058. for (size_t i = 0; i < self.trace_result->vars.size(); ++i) {
  1059. auto& name = self.trace_result->vars[i].mark;
  1060. if (!name.empty()) {
  1061. mark2var[name] = i;
  1062. }
  1063. }
  1064. std::vector<std::tuple<size_t, std::string, TensorShape>> input_vars;
  1065. std::vector<std::pair<size_t, std::string>> output_vars;
  1066. for (auto&& [input_mark, input_name, input_shape] : inputs) {
  1067. mgb_assert(input_shape.ndim, "input shape invalid");
  1068. input_vars.push_back(
  1069. {mark2var.at(input_mark), input_name, input_shape});
  1070. }
  1071. for (auto&& [output_name, repr] : outputs) {
  1072. output_vars.push_back({mark2var.at(output_name), repr});
  1073. }
  1074. self.options_visitor(py::cast(&graph->options()));
  1075. auto vars = self.trace_result->dump(
  1076. *graph, input_vars, output_vars, prefer_input_names);
  1077. return vars;
  1078. }
  1079. };
  1080. py::class_<Trace>(m, "Trace")
  1081. .def(py::init<>())
  1082. .def_readwrite("record_input_shapes", &Trace::record_input_shapes)
  1083. .def_readwrite("array_comparator", &Trace::array_comparator)
  1084. .def_readwrite("profile", &Trace::profile)
  1085. .def_property_readonly(
  1086. "options",
  1087. [](Trace& self) {
  1088. if (self.compiled) {
  1089. return &self.compiled->options();
  1090. } else {
  1091. return (ComputingGraph::Options*)nullptr;
  1092. }
  1093. })
  1094. .def("get_profile",
  1095. [](Trace& self) -> py::object {
  1096. if (self.profiler.second && self.compiled) {
  1097. auto json = self.profiler.second->to_json_full(
  1098. self.compiled->graph().current_comp_seq());
  1099. return py::str(json->to_string());
  1100. } else {
  1101. return py::none();
  1102. }
  1103. })
  1104. .def_readwrite("symbolic", &Trace::symbolic)
  1105. .def_readwrite("capture_as_const", &Trace::capture_as_const)
  1106. .def_readwrite("no_exec", &Trace::no_exec)
  1107. .def_readwrite("options_visitor", &Trace::options_visitor)
  1108. .def("enter", &Trace::enter)
  1109. .def("exit", &Trace::exit)
  1110. .def("dump", &Trace::dump)
  1111. .def("begin_excluded_region",
  1112. [](Trace& self) {
  1113. mgb_assert(bool(self.tracing) ^ bool(self.compiled));
  1114. if (self.tracing) {
  1115. self.tracing_guard.reset();
  1116. } else if (self.compiled) {
  1117. self.compiled_guard.reset();
  1118. }
  1119. })
  1120. .def("end_excluded_region", [](Trace& self) {
  1121. mgb_assert(bool(self.tracing) ^ bool(self.compiled));
  1122. if (self.tracing) {
  1123. self.tracing_guard =
  1124. transformations.register_at<Segment::Trace>(self.tracing);
  1125. } else if (self.compiled) {
  1126. self.compiled_guard =
  1127. transformations.register_at<Segment::Trace>(self.compiled);
  1128. }
  1129. });
  1130. m.def("reduce_to_scalar", [](py::object op, py::object tensor) -> py::object {
  1131. auto reduce_to_scalar = [](const OpDef& op, const ValueRef& input) {
  1132. auto make_scalar_shape = [&](CompNode device) {
  1133. return imperative::apply(
  1134. CreateTensor(CreateTensor::Const, device, dtype::Int32(), {0}),
  1135. HostStorage::make(device))[0];
  1136. };
  1137. return imperative::apply(op, input, make_scalar_shape(*input.device()))[0];
  1138. };
  1139. if (py::isinstance<PySymbolVar>(tensor)) {
  1140. auto* graph = tensor.cast<PySymbolVar*>()->m_node->owner_graph();
  1141. SymbolVarContext context(graph);
  1142. context.init();
  1143. auto output = reduce_to_scalar(
  1144. *op.cast<std::shared_ptr<OpDef>>(), context.symvar2val(tensor));
  1145. auto typeobj = tensor.get_type();
  1146. return context.val2symvar(typeobj, output);
  1147. } else {
  1148. auto* tw = TensorWrapper::try_cast(tensor.ptr());
  1149. auto output = reduce_to_scalar(
  1150. *op.cast<std::shared_ptr<OpDef>>(), tw->m_tensor->data());
  1151. return TensorWrapper::make(py_tensor_type, output);
  1152. }
  1153. });
  1154. m.def("name_tensor", [](std::string name, py::object tensor) {
  1155. auto* tw = TensorWrapper::try_cast(tensor.ptr());
  1156. auto output = imperative::apply(TraceMarkVar(name), tw->m_tensor->data())[0];
  1157. tw->m_tensor->reset(output);
  1158. });
  1159. m.def("is_grad_attached", [](std::vector<py::object> tensors) -> bool {
  1160. SmallVector<ValueRef> values(tensors.size());
  1161. for (size_t i = 0; i < tensors.size(); ++i) {
  1162. values[i] = tensors[i].cast<TensorWrapper>().m_tensor->data();
  1163. }
  1164. auto outputs = imperative::apply(GetGradKey(), values);
  1165. if (outputs[0].is<GradKeyValue>()) {
  1166. return true;
  1167. } else {
  1168. return false;
  1169. }
  1170. });
  1171. m.def("get_grad_key", [](std::vector<py::object> tensors) -> py::object {
  1172. SmallVector<ValueRef> values(tensors.size());
  1173. for (size_t i = 0; i < tensors.size(); ++i) {
  1174. values[i] = tensors[i].cast<TensorWrapper>().m_tensor->data();
  1175. }
  1176. auto output = imperative::apply(GetGradKey(), values)[0];
  1177. if (!output) {
  1178. return py::none();
  1179. }
  1180. return py::reinterpret_borrow<py::object>(GradKeyWrapper::wrap_t::pycast(
  1181. GradKeyWrapper::get(output.cast<GradKeyValue>())));
  1182. });
  1183. m.def("set_grad", [](py::function backward_fn, std::vector<py::object> inputs,
  1184. std::vector<py::object> outputs) {
  1185. GenericFunction generic_backward_fn =
  1186. [backward_fn](Span<ValueRef> output_grads) -> ValueRefList {
  1187. py::list output_grad_tws;
  1188. for (auto&& output_grad : output_grads) {
  1189. if (output_grad) {
  1190. output_grad_tws.append(
  1191. TensorWrapper::make(py_tensor_type, output_grad));
  1192. } else {
  1193. output_grad_tws.append(py::none());
  1194. }
  1195. }
  1196. py::tuple input_grad_tws = backward_fn(*output_grad_tws);
  1197. ValueRefList input_grads(input_grad_tws.size());
  1198. for (size_t i = 0; i < input_grad_tws.size(); ++i) {
  1199. auto input_grad_tw = input_grad_tws[i];
  1200. if (!input_grad_tw.is_none()) {
  1201. input_grads[i] =
  1202. py::cast<TensorWrapper>(input_grad_tw).m_tensor->data();
  1203. } else {
  1204. input_grads[i] = {};
  1205. }
  1206. }
  1207. return input_grads;
  1208. };
  1209. SmallVector<ValueRef> values(inputs.size() + outputs.size());
  1210. for (size_t i = 0; i < inputs.size(); ++i) {
  1211. values[i] = inputs[i].cast<TensorWrapper>().m_tensor->data();
  1212. }
  1213. for (size_t i = 0; i < outputs.size(); ++i) {
  1214. values[i + inputs.size()] =
  1215. outputs[i].cast<TensorWrapper>().m_tensor->data();
  1216. }
  1217. auto wrapped_output_values =
  1218. imperative::apply(SetGrad(generic_backward_fn, inputs.size()), values);
  1219. std::vector<py::object> wrapped_outputs;
  1220. mgb_assert(wrapped_output_values.size() == outputs.size());
  1221. for (auto&& output_value : wrapped_output_values) {
  1222. wrapped_outputs.push_back(
  1223. TensorWrapper::make(py_tensor_type, output_value));
  1224. }
  1225. return wrapped_outputs;
  1226. });
  1227. static py::function module_trace_hook;
  1228. static auto get_module_trace = [] {
  1229. static std::shared_ptr<ModuleTraceTransformation> module_trace_transformation;
  1230. if (!module_trace_transformation) {
  1231. mgb_assert(module_trace_hook);
  1232. module_trace_transformation =
  1233. std::make_shared<ModuleTraceTransformation>(module_trace_hook);
  1234. MGB_MARK_USED_VAR(transformations
  1235. .register_at<Segment::ModuleTrace>(
  1236. module_trace_transformation)
  1237. .release());
  1238. }
  1239. return module_trace_transformation;
  1240. };
  1241. m.def("set_cpp_use_symbolic_shape", &set_cpp_use_symbolic_shape);
  1242. m.def("set_module_tracing", [=] { get_module_trace()->enable(); });
  1243. m.def("unset_module_tracing", [=] { get_module_trace()->disable(); });
  1244. m.def("is_tracing_module", [=] { return get_module_trace()->enabled(); });
  1245. m.def("set_module_trace_hook", [](py::function function) {
  1246. module_trace_hook = function;
  1247. module_trace_hook.inc_ref();
  1248. });
  1249. auto atexit = py::module::import("atexit");
  1250. atexit.attr("register")(py::cpp_function([]() { module_trace_hook = {}; }));
  1251. m.def("begin_record_values", [] { Value::begin_record_values(); });
  1252. m.def("end_record_values", [] {
  1253. std::vector<std::pair<size_t, std::string>> reprs;
  1254. auto values = Value::end_record_values();
  1255. for (auto&& value : values) {
  1256. reprs.push_back({value.id(), value.to_string()});
  1257. }
  1258. return reprs;
  1259. });
  1260. m.def("print_stats", [] { Stats::print(); });
  1261. m.def("reset_stats", [] { Stats::reset(); });
  1262. m.def("_get_convert_inputs",
  1263. []() -> bool { return DTypePromoteCfg::convert_input_enabled; });
  1264. m.def("_set_convert_inputs", [](bool flag) -> bool {
  1265. bool ret = DTypePromoteCfg::convert_input_enabled;
  1266. DTypePromoteCfg::convert_input_enabled = flag;
  1267. return ret;
  1268. });
  1269. m.def("_get_amp_dtype_autocast",
  1270. []() -> bool { return DTypePromoteCfg::amp_dtype_autocast_enabled; });
  1271. m.def("_set_amp_dtype_autocast", [](bool flag) -> bool {
  1272. bool ret = DTypePromoteCfg::amp_dtype_autocast_enabled;
  1273. DTypePromoteCfg::amp_dtype_autocast_enabled = flag;
  1274. return ret;
  1275. });
  1276. static auto get_amp_prec_dtype = [](bool is_high) -> std::string {
  1277. DType& target = is_high ? DTypePromoteCfg::amp_high_prec_dtype
  1278. : DTypePromoteCfg::amp_low_prec_dtype;
  1279. mgb_assert(target.category() == DTypeCategory::FLOAT);
  1280. std::string ret = target.name();
  1281. transform(ret.begin(), ret.end(), ret.begin(), ::tolower);
  1282. return ret;
  1283. };
  1284. static auto set_amp_prec_dtype = [](bool is_high,
  1285. std::string dtype_name) -> std::string {
  1286. DType& target = is_high ? DTypePromoteCfg::amp_high_prec_dtype
  1287. : DTypePromoteCfg::amp_low_prec_dtype;
  1288. std::string ret = target.name();
  1289. if (dtype_name == "float32") {
  1290. target = dtype::Float32();
  1291. } else if (dtype_name == "float16") {
  1292. target = dtype::Float16();
  1293. } else if (dtype_name == "bfloat16") {
  1294. target = dtype::BFloat16();
  1295. } else {
  1296. mgb_assert(
  1297. false, "casted type of amp should be float, but you give %s\n",
  1298. dtype_name.c_str());
  1299. }
  1300. transform(ret.begin(), ret.end(), ret.begin(), ::tolower);
  1301. return ret;
  1302. };
  1303. m.def("_get_amp_high_prec_dtype",
  1304. []() -> std::string { return get_amp_prec_dtype(true); });
  1305. m.def("_set_amp_high_prec_dtype", [](std::string dtype_name) -> std::string {
  1306. return set_amp_prec_dtype(true, dtype_name);
  1307. });
  1308. m.def("_get_amp_low_prec_dtype",
  1309. []() -> std::string { return get_amp_prec_dtype(false); });
  1310. m.def("_set_amp_low_prec_dtype", [](std::string dtype_name) -> std::string {
  1311. return set_amp_prec_dtype(false, dtype_name);
  1312. });
  1313. m.def("_clear_algorithm_cache", [] { megdnn::AlgorithmCache::instance().clear(); });
  1314. py::register_exception<TraceError>(m, "TraceError");
  1315. }
  1316. #undef MGE_PY_INTERFACE
  1317. } // namespace mgb::imperative::python