You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

tensor.cpp 53 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395
  1. #include "megbrain/common.h"
  2. #include "megbrain/dtype.h"
  3. #include "megbrain/imperative/cpp_cupti.h"
  4. #include "megbrain/imperative/ops/autogen.h"
  5. #include "megbrain/imperative/ops/backward_graph.h"
  6. #include "megbrain/imperative/ops/utility.h"
  7. #include "megbrain/imperative/profiler.h"
  8. #include "megbrain/imperative/transformations/dim_expansion.h"
  9. #include "megbrain/imperative/transformations/dtype_promote.h"
  10. #include "megbrain/imperative/transformations/eval.h"
  11. #include "megbrain/imperative/transformations/format.h"
  12. #include "megbrain/imperative/transformations/group_comm.h"
  13. #include "megbrain/imperative/transformations/lazy.h"
  14. #include "megbrain/imperative/transformations/scalar.h"
  15. #include "megbrain/imperative/transformations/symbol.h"
  16. #include "megbrain/imperative/transformations/trace.h"
  17. #include "megbrain/imperative/utils/map.h"
  18. #include "megbrain/opr/io.h"
  19. #include "megbrain/plugin/profiler.h"
  20. #include "megbrain/utils/stats.h"
  21. #include "megdnn/algorithm_cache.h"
  22. #include "./common.h"
  23. #include "./grad.h"
  24. #include "./graph_rt.h"
  25. #include "./helper.h"
  26. #include "./module_trace.h"
  27. #include "./numpy_dtypes.h"
  28. #include "./tensor.h"
  29. #include "./tensor_utils.h"
  30. #include "./transformation.h"
  31. #include <object.h>
  32. #include <pybind11/numpy.h>
  33. #include <pybind11/operators.h>
  34. #include <pybind11/pytypes.h>
  35. #include <pyerrors.h>
  36. #include <iterator>
  37. #include <range/v3/all.hpp>
  38. #include <string>
  39. #include <unordered_map>
  40. #include "../../src/impl/mgb_cg_impl.h"
  41. namespace py = pybind11;
  42. namespace views = ranges::views;
  43. namespace mgb::imperative::python {
  44. interpreter::Interpreter::Channel* interpreter_for_py = nullptr;
  45. PyTypeObject* py_tensor_type = nullptr;
  46. PyTypeObject* py_varnode_type = nullptr;
  47. pybind11::handle py_device_type = nullptr;
  48. PyObject* cpp_use_symbolic_shape;
  49. #define REGISTE_APPLY_FUNC(mode) \
  50. void set_##mode(py::object pyf) { mode = pyf.ptr(); }
  51. REGISTE_APPLY_FUNC(cpp_use_symbolic_shape)
  52. #undef REGISTE_APPLY_FUNC
  53. PyArray_Descr* _dtype_promotion(PyObject* const* args, size_t nargs);
  54. CompNode _get_device(PyObject* const* args, size_t nargs);
  55. PyObject* py_apply(
  56. PyObject* self, PyObject* const* args, size_t nargs /* , PyObject* kwnames */) {
  57. try {
  58. // if (kwnames && PyTuple_GET_SIZE(kwnames)) {
  59. // PyErr_SetString(PyExc_TypeError, "keyword argument not allowed");
  60. // return nullptr;
  61. // }
  62. if (nargs < 2) {
  63. PyErr_SetString(
  64. PyExc_TypeError,
  65. "py_apply expects one Op and at least one tensor "
  66. "as argument");
  67. return nullptr;
  68. }
  69. auto* py_op = args[0];
  70. ++args;
  71. --nargs;
  72. auto op = py::handle(py_op).cast<std::shared_ptr<OpDef>>();
  73. SmallVector<ValueRef, 8> tensors(nargs);
  74. mgb::CompNode target_cn;
  75. mgb::DType target_dtype;
  76. auto convert_pyinput_to_tensor = [&](size_t i) -> ValueRef {
  77. if (!target_dtype.valid()) {
  78. target_dtype = npy::dtype_np2mgb_descr(_dtype_promotion(args, nargs));
  79. target_cn = _get_device(args, nargs);
  80. }
  81. HostTensorND ht(target_cn);
  82. ht = npy::np2tensor(args[i], npy::Meth::copy_into(&ht), target_dtype);
  83. if (PyArray_Check(args[i]) || PyList_Check(args[i])) { // non scaler
  84. // py_tuple is not allowed here because of tracing
  85. return imperative::apply(
  86. CreateTensor(CreateTensor::Const, target_cn, ht.layout()),
  87. HostStorage::make(ht.storage()))[0];
  88. } else { // scaler
  89. return imperative::apply(
  90. CreateTensor(CreateTensor::Const, target_cn, target_dtype, {}),
  91. HostStorage::make(ht.storage()))[0];
  92. }
  93. };
  94. bool is_varnode_apply = false;
  95. for (size_t i = 0; i < nargs; ++i) {
  96. if (PyObject_TypeCheck(args[i], py_varnode_type)) {
  97. is_varnode_apply = true;
  98. }
  99. if (TensorWrapper* tw = TensorWrapper::try_cast(args[i])) {
  100. tensors[i] = tw->m_tensor->data();
  101. } else if (
  102. DTypePromoteCfg::convert_input_enabled &&
  103. (op->same_type<Elemwise>() || op->same_type<ElemwiseMultiType>())) {
  104. tensors[i] = convert_pyinput_to_tensor(i);
  105. } else {
  106. PyErr_SetString(PyExc_TypeError, "py_apply expects tensor as inputs");
  107. return nullptr;
  108. }
  109. }
  110. auto outputs = [&] { return imperative::apply(*op, tensors); }();
  111. size_t nout = outputs.size();
  112. auto ret = py::tuple(nout);
  113. PyTypeObject* py_type = is_varnode_apply ? py_varnode_type : py_tensor_type;
  114. for (size_t i = 0; i < nout; ++i) {
  115. ret[i] = TensorWrapper::make(py_type, std::move(outputs[i]));
  116. }
  117. return ret.release().ptr();
  118. }
  119. PYEXT17_TRANSLATE_EXC_RET(nullptr)
  120. }
  121. namespace {
  122. template <typename T>
  123. py::handle py_type() {
  124. if constexpr (std::is_same_v<T, py::int_>) {
  125. return (PyObject*)&PyLong_Type;
  126. } else if constexpr (std::is_same_v<T, py::float_>) {
  127. return (PyObject*)&PyFloat_Type;
  128. } else if constexpr (std::is_same_v<T, py::tuple>) {
  129. return (PyObject*)&PyTuple_Type;
  130. } else if constexpr (std::is_same_v<T, py::list>) {
  131. return (PyObject*)&PyList_Type;
  132. } else {
  133. static_assert(std::is_same_v<T, T>);
  134. }
  135. }
  136. template <typename T>
  137. auto scalar2storage(T val, CompNode cn, DType dtype) {
  138. using max_ctype_t = DTypeScalar::max_ctype;
  139. DTypeScalar scalar(dtype);
  140. scalar.set_retain_dtype(val);
  141. HostTensorStorage storage(cn);
  142. auto* raw_ptr = reinterpret_cast<dt_byte*>(new max_ctype_t());
  143. std::shared_ptr<dt_byte> raw_storage = {
  144. raw_ptr, [](dt_byte* ptr) { delete reinterpret_cast<max_ctype_t*>(ptr); }};
  145. storage.only_reset_raw_storage(cn, dtype.size(), raw_storage, 0);
  146. std::memcpy(storage.ptr(), scalar.storage(), dtype.size());
  147. return HostStorage::make(std::move(storage));
  148. }
  149. template <typename ctype>
  150. auto vec2storage(Span<DTypeScalar> vec, CompNode cn, DType dtype) {
  151. mgb_assert(vec.size() <= MEGDNN_MAX_NDIM);
  152. // TODO: use storage cache and modify ConstTensorCache to return (Host, Device)
  153. auto* raw_ptr = new ctype[MEGDNN_MAX_NDIM];
  154. for (size_t i = 0; i < vec.size(); ++i) {
  155. raw_ptr[i] = vec[i].get_cast<ctype>();
  156. }
  157. mgb_assert(sizeof(ctype) == dtype.size());
  158. std::shared_ptr<dt_byte> raw_storage = {
  159. reinterpret_cast<dt_byte*>(raw_ptr),
  160. [](dt_byte* ptr) { delete[] reinterpret_cast<ctype*>(ptr); }};
  161. HostTensorStorage storage(cn);
  162. storage.only_reset_raw_storage(cn, sizeof(ctype) * vec.size(), raw_storage, 0);
  163. return HostStorage::make(std::move(storage));
  164. }
  165. struct HostTensorArgs {
  166. ValueShape shape;
  167. DType dtype;
  168. HostStorage::ref_t storage;
  169. HostTensorND as_tensor_nd() const {
  170. HostTensorND ret(CompNode::default_cpu(), shape.as_tensor_shape(), dtype);
  171. ret.only_reset_raw_storage(*storage);
  172. return ret;
  173. }
  174. };
  175. template <typename seq_type, typename ctype>
  176. bool pyseq2hval(seq_type obj, CompNode cn, DType dtype, HostTensorArgs& ret) {
  177. auto size = obj.size();
  178. if (size > MEGDNN_MAX_NDIM) {
  179. return false;
  180. }
  181. ctype items[size];
  182. for (size_t i = 0; i < size; ++i) {
  183. py::handle item = obj[i];
  184. if (item.get_type().is(py_type<py::int_>())) {
  185. items[i] = (ctype)(dt_int32)item.template cast<py::int_>();
  186. } else if (item.get_type().is(py_type<py::float_>())) {
  187. items[i] = (ctype)(dt_float32)item.template cast<py::float_>();
  188. } else {
  189. return false;
  190. }
  191. }
  192. mgb_assert(sizeof(ctype) == dtype.size());
  193. auto* raw_ptr = new ctype[size];
  194. std::shared_ptr<dt_byte> raw_storage = {
  195. reinterpret_cast<dt_byte*>(raw_ptr),
  196. [](dt_byte* ptr) { delete[] reinterpret_cast<ctype*>(ptr); }};
  197. HostTensorStorage storage(cn);
  198. storage.only_reset_raw_storage(cn, sizeof(ctype) * size, raw_storage, 0);
  199. std::memcpy(storage.ptr(), items, sizeof(ctype) * size);
  200. ret.dtype = dtype;
  201. ret.shape = {size};
  202. ret.storage = HostStorage::make(std::move(storage));
  203. return true;
  204. }
  205. template <typename seq_type>
  206. bool pyseq2hval(seq_type obj, CompNode cn, HostTensorArgs& ret) {
  207. auto size = obj.size();
  208. if (size > MEGDNN_MAX_NDIM) {
  209. return false;
  210. }
  211. DTypeScalar items[size];
  212. DType dtype;
  213. for (size_t i = 0; i < size; ++i) {
  214. auto&& item = obj[i];
  215. if (item.get_type().is(py_type<py::int_>())) {
  216. items[i] = (dt_int32)item.template cast<py::int_>();
  217. if (!dtype.valid()) {
  218. dtype = dtype::Int32();
  219. } else if (dtype != dtype::Int32() && dtype != dtype::Float32()) {
  220. return false;
  221. }
  222. } else if (item.get_type().is(py_type<py::float_>())) {
  223. items[i] = (dt_float32)item.template cast<py::float_>();
  224. if (!dtype.valid()) {
  225. dtype = dtype::Float32();
  226. } else if (dtype == dtype::Int32()) {
  227. dtype = dtype::Float32();
  228. } else if (dtype != dtype::Float32()) {
  229. return false;
  230. }
  231. } else {
  232. return false;
  233. }
  234. }
  235. if (!dtype.valid()) {
  236. dtype = dtype::Float32();
  237. }
  238. ret.dtype = dtype;
  239. ret.shape = {size};
  240. if (dtype == dtype::Int32()) {
  241. ret.storage = vec2storage<dt_int32>({items, size}, cn, dtype);
  242. } else if (dtype == dtype::Float32()) {
  243. ret.storage = vec2storage<dt_float32>({items, size}, cn, dtype);
  244. } else {
  245. mgb_assert(false);
  246. }
  247. return true;
  248. }
  249. template <typename seq_type>
  250. bool pyseq2hval(seq_type obj, CompNode cn, DType dtype, HostTensorArgs& ret) {
  251. if (dtype == dtype::Int32()) {
  252. return pyseq2hval<seq_type, dt_int32>(obj, cn, dtype, ret);
  253. } else if (dtype == dtype::Float32()) {
  254. return pyseq2hval<seq_type, dt_float32>(obj, cn, dtype, ret);
  255. } else if (!dtype.valid()) {
  256. return pyseq2hval<seq_type>(obj, cn, ret);
  257. } else {
  258. return false;
  259. }
  260. }
  261. bool pyarr2hval(py::array obj, CompNode cn, DType dtype, HostTensorArgs& ret) {
  262. auto data = obj.cast<py::array>();
  263. auto strides = data.strides();
  264. bool need_squeeze = false;
  265. for (size_t i = 0; i < data.ndim(); ++i) {
  266. if (strides[i] == 0) {
  267. need_squeeze = true;
  268. break;
  269. }
  270. }
  271. if (need_squeeze) {
  272. std::vector<size_t> shape;
  273. for (size_t i = 0; i < data.ndim(); ++i) {
  274. shape.push_back(data.shape(i));
  275. }
  276. data = data.squeeze();
  277. data.resize(shape);
  278. }
  279. HostTensorND retnd(cn);
  280. retnd = npy::np2tensor(data.ptr(), npy::Meth::copy_into(&retnd), dtype);
  281. if (!dtype.valid()) {
  282. dtype = retnd.dtype();
  283. }
  284. mgb_assert(
  285. retnd.layout().is_empty() || retnd.layout().is_contiguous(),
  286. "host value should be continuous");
  287. for (size_t i = 0; i < data.ndim(); ++i) {
  288. ret.shape[ret.shape.ndim++] = data.shape(i);
  289. }
  290. ret.dtype = dtype;
  291. ret.storage = HostStorage::make(retnd.storage());
  292. return true;
  293. }
  294. bool pyint2hval(py::int_ obj, CompNode cn, DType dtype, HostTensorArgs& ret) {
  295. if (!dtype.valid()) {
  296. dtype = dtype::Int32();
  297. }
  298. ret.dtype = dtype;
  299. ret.storage = scalar2storage((dt_int32)obj, cn, dtype);
  300. return true;
  301. }
  302. bool pyfloat2hval(py::float_ obj, CompNode cn, DType dtype, HostTensorArgs& ret) {
  303. if (!dtype.valid()) {
  304. dtype = dtype::Float32();
  305. }
  306. ret.dtype = dtype;
  307. ret.storage = scalar2storage((dt_float32)obj, cn, dtype);
  308. return true;
  309. }
  310. HostTensorArgs pyobj2hval(py::object obj, CompNode cn, DType dtype) {
  311. HostTensorArgs ret;
  312. bool success = false;
  313. // check order: float -> int -> tuple(int -> float) -> list(int -> float)
  314. // only handle `exact` pytype, isinstance also accepts subtype
  315. // for example, isinstance(True, int) == True
  316. if (obj.get_type().is(py_type<py::float_>())) {
  317. success = pyfloat2hval(py::float_(obj), cn, dtype, ret);
  318. } else if (obj.get_type().is(py_type<py::int_>())) { // py::bool_ is py::int_
  319. success = pyint2hval(py::int_(obj), cn, dtype, ret);
  320. } else if (obj.get_type().is(py_type<py::tuple>())) {
  321. success = pyseq2hval<py::tuple>(py::tuple(obj), cn, dtype, ret);
  322. } else if (obj.get_type().is(py_type<py::list>())) {
  323. success = pyseq2hval<py::list>(py::list(obj), cn, dtype, ret);
  324. } else if (obj.is_none()) {
  325. obj = py::list(0);
  326. }
  327. if (!success) {
  328. success = pyarr2hval(obj, cn, dtype, ret);
  329. }
  330. mgb_assert(success);
  331. return ret;
  332. }
  333. struct PyArgDesc {
  334. const char* name;
  335. py::object (*default_value)();
  336. };
  337. struct PyArgDescs {
  338. std::vector<PyArgDesc> items;
  339. ssize_t (*name2idx)(const char* name);
  340. };
  341. py::tuple parse_args(py::tuple args, const PyArgDescs& descs) {
  342. size_t nr_args = args.size();
  343. size_t nr_items = descs.items.size();
  344. mgb_assert(nr_args <= nr_items, "too many args");
  345. if (nr_args == nr_items) {
  346. return args;
  347. }
  348. py::tuple ret(nr_items);
  349. for (size_t i = 0; i < nr_args; ++i) {
  350. ret[i] = args[i];
  351. }
  352. for (size_t i = nr_args; i < nr_items; ++i) {
  353. ret[i] = descs.items[i].default_value();
  354. }
  355. return ret;
  356. }
  357. py::tuple parse_args_and_kwargs(
  358. py::tuple args, py::dict kwargs, const PyArgDescs& descs) {
  359. size_t nr_args = args.size();
  360. size_t nr_kwargs = kwargs.size();
  361. size_t nr_items = descs.items.size();
  362. mgb_assert(nr_args + nr_kwargs <= nr_items, "too many args");
  363. if (nr_args == nr_items) {
  364. return args;
  365. }
  366. py::tuple ret(nr_items);
  367. for (size_t i = 0; i < nr_args; ++i) {
  368. ret[i] = args[i];
  369. }
  370. bool has_value[nr_items - nr_args];
  371. for (size_t i = nr_args; i < nr_items; ++i) {
  372. has_value[i - nr_args] = false;
  373. }
  374. for (auto&& [k, v] : kwargs) {
  375. auto key = py::str(k).cast<std::string>();
  376. ssize_t index = descs.name2idx(key.c_str());
  377. mgb_assert(index >= nr_args);
  378. ret[index] = v;
  379. has_value[index - nr_args] = true;
  380. }
  381. for (size_t i = nr_args; i < nr_items; ++i) {
  382. if (!has_value[i - nr_args]) {
  383. ret[i] = descs.items[i].default_value();
  384. }
  385. }
  386. return ret;
  387. }
  388. CompNode as_comp_node(const std::string& name) {
  389. thread_local struct {
  390. std::string name;
  391. CompNode cn;
  392. } cached;
  393. if (cached.name != name) {
  394. cached.name = name;
  395. cached.cn = CompNode::load(name);
  396. }
  397. return cached.cn;
  398. }
  399. CompNode as_comp_node(py::object py_device) {
  400. std::optional<std::string> device_name;
  401. if (py_device.is_none() || py::str::check_(py_device)) {
  402. auto cls = py::handle(reinterpret_cast<PyObject*>(py_tensor_type));
  403. auto dmap_callback = cls.attr("dmap_callback");
  404. std::string name;
  405. if (dmap_callback.is_none() && py_device.is_none()) {
  406. name = get_default_device();
  407. } else {
  408. if (py_device.is_none()) {
  409. py_device = py::str(get_default_device());
  410. }
  411. if (!dmap_callback.is_none()) {
  412. py_device = dmap_callback(py_device);
  413. }
  414. name = py::str(py_device).cast<std::string>();
  415. }
  416. return as_comp_node(name);
  417. } else {
  418. if (py::isinstance(py_device, py_device_type)) {
  419. py_device = py_device.attr("_cn");
  420. }
  421. mgb_assert(py::isinstance(py_device, py_comp_node_type));
  422. return py_device.cast<CompNode>();
  423. }
  424. }
  425. template <char... Chars>
  426. bool compare_cstr(const char* cstr) {
  427. return (((*cstr++) == Chars) && ...) && *cstr == '\0';
  428. }
  429. ssize_t name2idx(const char* name) {
  430. const char* ch = name;
  431. // TODO: trie
  432. // clang-format off
  433. switch (*ch++) {
  434. case 'd':
  435. switch (*ch++) {
  436. // data
  437. case 'a': return compare_cstr<'t', 'a'>(ch) ? 0 : -1;
  438. // dtype
  439. case 't': return compare_cstr<'y', 'p', 'e'>(ch) ? 1 : -1;
  440. // device
  441. case 'e': return compare_cstr<'v', 'i', 'c', 'e'>(ch) ? 2 : -1;
  442. }
  443. case 'i':
  444. // is_const
  445. return compare_cstr<'s', '_', 'c', 'o', 'n', 's', 't'>(ch) ? 3 : -1;
  446. case 'n':
  447. switch (*ch++) {
  448. // no_cache
  449. case 'o': return compare_cstr<'_', 'c', 'a', 'c', 'h', 'e'>(ch) ? 4 : -1;
  450. // name
  451. case 'a': return compare_cstr<'m', 'e'>(ch) ? 5 : -1;
  452. }
  453. case 'f':
  454. // format
  455. return compare_cstr<'o', 'r', 'm', 'a', 't'>(ch) ? 6 : -1;
  456. }
  457. // clang-format on
  458. return -1;
  459. }
  460. } // namespace
  461. TensorWrapper::TensorWrapper(PyObject* args, PyObject* kwargs) {
  462. static PyArgDescs descs = {
  463. {
  464. {"data", []() -> py::object { return py::none(); }},
  465. {"dtype", []() -> py::object { return py::none(); }},
  466. {"device", []() -> py::object { return py::none(); }},
  467. {"is_const", []() -> py::object { return py::bool_(false); }},
  468. {"no_cache", []() -> py::object { return py::bool_(false); }},
  469. {"name", []() -> py::object { return py::none(); }},
  470. {"format", []() -> py::object { return py::none(); }},
  471. },
  472. name2idx};
  473. py::detail::loader_life_support life_sup; // FIXME!!!required to cast DType
  474. auto tup = py::reinterpret_borrow<py::tuple>(args);
  475. if (kwargs) {
  476. tup = parse_args_and_kwargs(
  477. tup, py::reinterpret_borrow<py::dict>(kwargs), descs);
  478. } else {
  479. tup = parse_args(tup, descs);
  480. }
  481. mgb_assert(tup.size() == 7);
  482. if (auto* t = try_cast(tup[0].ptr())) {
  483. m_tensor = t->m_tensor;
  484. // TODO: merge two path in arg parse
  485. if (!tup[1].is_none()) {
  486. auto dtype = tup[1].cast<DType>();
  487. mgb_assert(
  488. dtype == m_tensor->dtype(), "dtype mismatch: %s vs %s",
  489. dtype.name(), m_tensor->dtype().name());
  490. }
  491. if (!tup[2].is_none()) {
  492. auto device = as_comp_node(tup[2]);
  493. mgb_assert(
  494. device == m_tensor->comp_node(), "device mismatch: %s vs %s",
  495. device.to_string().c_str(),
  496. m_tensor->comp_node().to_string().c_str());
  497. }
  498. mgb_assert(!tup[3].cast<bool>(), "expect is_const == False, got True");
  499. bool no_cache = tup[4].cast<bool>();
  500. if (no_cache) {
  501. // always copy because it's hard to tell whether this tensor is cached
  502. m_tensor = m_tensor->copy();
  503. }
  504. // ignore name
  505. if (!tup[6].is_none()) {
  506. Format format = tup[6].cast<std::string>();
  507. mgb_assert(
  508. format == m_tensor->format(), "format mismatch: %s vs %s",
  509. format.to_string().c_str(), m_tensor->format().to_string().c_str());
  510. }
  511. } else {
  512. auto data = tup[0];
  513. DType dtype = tup[1].cast<DType>();
  514. CompNode cn = as_comp_node(tup[2]);
  515. bool is_const = tup[3].cast<bool>();
  516. bool no_cache = tup[4].cast<bool>();
  517. std::string name;
  518. if (!tup[5].is_none()) {
  519. name = tup[5].cast<std::string>();
  520. }
  521. Format format;
  522. if (!tup[6].is_none()) {
  523. format = tup[6].cast<std::string>();
  524. }
  525. {
  526. CreateTensor::Kind kind = is_const ? CreateTensor::Const
  527. : no_cache ? CreateTensor::Unique
  528. : CreateTensor::Common;
  529. ValueRef val;
  530. if (py::isinstance(data, Py_Varnode)) {
  531. cg::VarNode* m_node = py::handle(data).cast<cg::VarNode*>();
  532. val = imperative::apply(
  533. CreateNode(m_node), Span<ValueRef>(nullptr, nullptr))[0];
  534. } else {
  535. auto&& hval = pyobj2hval(data, cn, dtype);
  536. val = imperative::apply(
  537. CreateTensor(kind, cn, hval.dtype, hval.shape, format),
  538. hval.storage)[0];
  539. }
  540. m_tensor.emplace(val);
  541. }
  542. if (!name.empty()) {
  543. m_tensor->reset(imperative::apply(RenameValue(name), m_tensor->data())[0]);
  544. }
  545. }
  546. mgb_assert(m_tensor->data());
  547. }
  548. PyObject* TensorWrapper::module_trace_info() {
  549. if (auto module_trace_info =
  550. ModuleTraceTransformation::module_trace_info_map.try_get(
  551. m_tensor->data())) {
  552. if (module_trace_info->ptr()) {
  553. return module_trace_info->inc_ref().ptr();
  554. }
  555. }
  556. PyErr_SetString(
  557. PyExc_AttributeError,
  558. "Has no attribute named \'_NodeMixin__node\', please "
  559. "set it first");
  560. return nullptr;
  561. }
  562. void TensorWrapper::set_module_trace_info(PyObject* obj) {
  563. // TODO: erase when obj == nullptr
  564. ModuleTraceTransformation::module_trace_info_map[m_tensor->data()] =
  565. py::reinterpret_borrow<py::object>(obj);
  566. }
  567. void TensorWrapper::_set_format(PyObject* dest) {
  568. auto py_dest = py::reinterpret_borrow<py::object>(dest);
  569. auto format = py_dest.cast<std::string>();
  570. m_tensor->set_format(format);
  571. }
  572. void TensorWrapper::_set_name(PyObject* dest) {
  573. auto py_dest = py::reinterpret_borrow<py::object>(dest);
  574. auto name = py_dest.cast<std::string>();
  575. m_tensor->set_name(name);
  576. }
  577. PyObject* TensorWrapper::_detail() {
  578. return py::str(m_tensor->data().unwrap().to_string()).release().ptr();
  579. }
  580. void TensorWrapper::_watch() {
  581. m_tensor->data().watch();
  582. }
  583. PyObject* TensorWrapper::shape() {
  584. auto shape = m_tensor->shape();
  585. if (!shape) {
  586. Py_RETURN_NONE;
  587. }
  588. py::tuple ret(shape->ndim);
  589. for (size_t i = 0; i < shape->ndim; ++i) {
  590. ret[i] = shape->at(i);
  591. }
  592. return ret.release().ptr();
  593. }
  594. PyObject* TensorWrapper::dtype() {
  595. return py::cast(m_tensor->dtype()).release().ptr();
  596. }
  597. PyObject* TensorWrapper::device() {
  598. return py::cast(m_tensor->comp_node()).release().ptr();
  599. }
  600. PyObject* TensorWrapper::format() {
  601. return py::cast(m_tensor->format().to_string()).release().ptr();
  602. }
  603. PyObject* TensorWrapper::numpy() {
  604. auto hv = m_tensor->numpy();
  605. if (!hv) {
  606. PyErr_SetString(PyExc_ValueError, "tensor invalid");
  607. return nullptr;
  608. }
  609. auto arr = py::reinterpret_steal<py::array>(
  610. npy::ndarray_from_tensor(hv->as_nd(true), npy::ShareType::TRY_SHARE));
  611. if (hv->shape().is_scalar()) {
  612. mgb_assert(PyArray_Check(arr.ptr()));
  613. return PyArray_Squeeze(reinterpret_cast<PyArrayObject*>(arr.ptr()));
  614. }
  615. return arr.release().ptr();
  616. }
  617. void TensorWrapper::reset(PyObject* tensor) {
  618. TensorWrapper* t = TensorWrapper::try_cast(tensor);
  619. if (!t) {
  620. throw py::type_error("expect Tensor");
  621. }
  622. m_tensor->reset(t->m_tensor->data());
  623. }
  624. PyObject* TensorWrapper::detach() {
  625. auto detached = imperative::apply(DetachGrad(), m_tensor->data())[0];
  626. return TensorWrapper::make(py_tensor_type, detached).release().ptr();
  627. }
  628. PyObject* TensorWrapper::_dev_tensor() {
  629. auto dv = m_tensor->data().dev_tensor();
  630. // TODO: handle scalar
  631. return py::cast(dv->as_nd(true)).release().ptr();
  632. }
  633. void TensorWrapper::_drop() {
  634. imperative::apply(DTRCommand(DTRCommand::Drop), m_tensor->data());
  635. }
  636. PyObject* TensorWrapper::isscalar() {
  637. if (m_tensor->is_scalar()) {
  638. Py_RETURN_TRUE;
  639. } else {
  640. Py_RETURN_FALSE;
  641. }
  642. }
  643. PyObject* TensorWrapper::_var() {
  644. TypedValueRef<NodeValue> value =
  645. imperative::apply(GetVarVal(), m_tensor->data())[0].as_ref<NodeValue>();
  646. auto* node = value->node();
  647. return py::cast(node).release().ptr();
  648. }
  649. PyObject* TensorWrapper::_graph() {
  650. TypedValueRef<NodeValue> value =
  651. imperative::apply(GetVarVal(), m_tensor->data())[0].as_ref<NodeValue>();
  652. auto* graph = value->graph();
  653. return py::cast(graph).release().ptr();
  654. }
  655. struct TensorWeakRef {
  656. ValueWeakRef data;
  657. TensorWeakRef(const TensorWrapper& tw) : data(tw.m_tensor->data()) {}
  658. py::object operator()() {
  659. if (auto p = data.lock()) {
  660. return TensorWrapper::make(py_tensor_type, p);
  661. }
  662. return py::none();
  663. }
  664. };
  665. #ifdef METH_FASTCALL
  666. #define MGE_PY_INTERFACE(NAME, FUNC) \
  667. { #NAME, (PyCFunction)FUNC, METH_FASTCALL, nullptr }
  668. #else
  669. #define WRAP_FUNC_PY35(FUNC) \
  670. PyObject* py35_##FUNC(PyObject* self, PyObject* args) { \
  671. auto* arr = &PyTuple_GET_ITEM(args, 0); \
  672. auto size = PyTuple_GET_SIZE(args); \
  673. return FUNC(self, arr, size); \
  674. }
  675. WRAP_FUNC_PY35(py_apply);
  676. WRAP_FUNC_PY35(dtype_promotion);
  677. WRAP_FUNC_PY35(get_device);
  678. WRAP_FUNC_PY35(make_shape_tuple);
  679. WRAP_FUNC_PY35(getitem_cpp);
  680. WRAP_FUNC_PY35(setitem_cpp);
  681. WRAP_FUNC_PY35(split_cpp);
  682. WRAP_FUNC_PY35(expand_dims_cpp);
  683. WRAP_FUNC_PY35(squeeze_cpp);
  684. WRAP_FUNC_PY35(transpose_cpp);
  685. WRAP_FUNC_PY35(broadcast_cpp);
  686. WRAP_FUNC_PY35(reshape_cpp);
  687. WRAP_FUNC_PY35(adaptive_pool2d_cpp);
  688. WRAP_FUNC_PY35(Const);
  689. WRAP_FUNC_PY35(astype_cpp);
  690. WRAP_FUNC_PY35(matmul_cpp);
  691. WRAP_FUNC_PY35(batched_matmul_cpp);
  692. WRAP_FUNC_PY35(convert_single_value_cpp);
  693. WRAP_FUNC_PY35(convert_inputs_cpp);
  694. WRAP_FUNC_PY35(astensor1d_cpp);
  695. WRAP_FUNC_PY35(pixel_shuffle_cpp);
  696. #undef WRAP_FUNC_PY35
  697. #define MGE_PY_INTERFACE(NAME, FUNC) \
  698. { #NAME, (PyCFunction)py35_##FUNC, METH_VARARGS, nullptr }
  699. #endif
  700. void init_tensor(py::module m) {
  701. imperative::Tensor::static_initialize();
  702. // Transformations
  703. static auto& transformations = TransformationManager::get_instance();
  704. using Segment = TransformationManager::Segment;
  705. using Channel = interpreter::Interpreter::Channel;
  706. auto* channel =
  707. imperative::ResourceManager::create_global<std::unique_ptr<Channel>>(
  708. interpreter::Interpreter::inst().create_channel())
  709. ->get();
  710. interpreter_for_py = channel;
  711. MGB_MARK_USED_VAR(
  712. transformations
  713. .register_at<Segment::Eval>(
  714. std::make_shared<InterpreterTransformation>(
  715. std::shared_ptr<Channel>(channel, [](Channel*) {})))
  716. .release());
  717. MGB_MARK_USED_VAR(transformations
  718. .register_at<Segment::Scalar>(
  719. std::make_shared<ScalarTransformation>())
  720. .release());
  721. MGB_MARK_USED_VAR(transformations
  722. .register_at<Segment::Symbol>(
  723. std::make_shared<SymbolTransformation>())
  724. .release());
  725. MGB_MARK_USED_VAR(transformations
  726. .register_at<Segment::DTypePromote>(
  727. std::make_shared<DTypePromoteTransformation>())
  728. .release());
  729. MGB_MARK_USED_VAR(transformations
  730. .register_at<Segment::DimExpansion>(
  731. std::make_shared<DimExpansionTransformation>())
  732. .release());
  733. auto format_trans = std::make_shared<FormatTransformation>();
  734. MGB_MARK_USED_VAR(
  735. transformations.register_at<Segment::Format>(format_trans).release());
  736. static py::exception<interpreter::AsyncError> py_async_error(
  737. m, "AsyncError", PyExc_RuntimeError);
  738. py::register_exception_translator([](std::exception_ptr p) {
  739. try {
  740. if (p)
  741. std::rethrow_exception(p);
  742. } catch (const interpreter::AsyncError& e) {
  743. pyext17::pybind11_translate_exception(e.nested_ptr());
  744. if (PyErr_Occurred()) {
  745. PyObject *exc, *val, *tb;
  746. PyErr_Fetch(&exc, &val, &tb);
  747. PyErr_NormalizeException(&exc, &val, &tb);
  748. if (tb) {
  749. PyException_SetTraceback(val, tb);
  750. }
  751. auto val2 = py_async_error.py::object::operator()(
  752. "An async error is reported. See above for the actual cause."
  753. " Hint: This is where it is reported, not where it happened."
  754. " You may call `megengine.config.async_level = 0 "
  755. "to get better error reporting.");
  756. PyException_SetCause(
  757. val2.ptr(), val); // PyException_SetCause steals reference
  758. Py_XDECREF(exc);
  759. Py_XDECREF(tb);
  760. PyErr_Restore(
  761. py_async_error.inc_ref().ptr(), val2.release().ptr(), nullptr);
  762. } else {
  763. py_async_error("Unkown async error");
  764. }
  765. }
  766. });
  767. // Tensor
  768. auto* tensor_type =
  769. TensorWrapper::wrap_t::type()
  770. .def<&TensorWrapper::numpy>("numpy")
  771. .def_getset<&TensorWrapper::shape>("shape")
  772. .def_getset<&TensorWrapper::dtype>("dtype")
  773. .def_getset<&TensorWrapper::device>("device")
  774. .def<&TensorWrapper::format>("format")
  775. .def<&TensorWrapper::reset>("_reset")
  776. .def<&TensorWrapper::isscalar>("_isscalar")
  777. .def<&TensorWrapper::detach>("detach")
  778. // TODO: remove this
  779. .def<&TensorWrapper::_dev_tensor>("_dev_tensor")
  780. .def<&TensorWrapper::_drop>("_drop")
  781. .def<&TensorWrapper::_detail>("_detail")
  782. .def<&TensorWrapper::_set_format>("_set_format")
  783. .def<&TensorWrapper::_set_name>("_set_name")
  784. .def<&TensorWrapper::_watch>("_watch")
  785. .def<&TensorWrapper::_var>("var")
  786. .def<&TensorWrapper::_graph>("graph")
  787. .def_getset<
  788. &TensorWrapper::module_trace_info,
  789. &TensorWrapper::set_module_trace_info>("_NodeMixin__node")
  790. .finalize();
  791. if (!tensor_type)
  792. throw py::error_already_set();
  793. py::setattr(m, "Tensor", tensor_type);
  794. py::enum_<Format::Type>(m, "FormatType")
  795. .value("DEFAULT", Format::Type::DEFAULT)
  796. .value("NCHW", Format::Type::NCHW)
  797. .value("NHWC", Format::Type::NHWC)
  798. .export_values();
  799. py::class_<TensorWeakRef>(m, "TensorWeakRef")
  800. .def(py::init<const TensorWrapper&>())
  801. .def("__call__", &TensorWeakRef::operator());
  802. static PyMethodDef method_defs[] = {
  803. MGE_PY_INTERFACE(apply, py_apply),
  804. MGE_PY_INTERFACE(dtype_promotion, dtype_promotion),
  805. MGE_PY_INTERFACE(get_device, get_device),
  806. MGE_PY_INTERFACE(make_shape_tuple, make_shape_tuple),
  807. MGE_PY_INTERFACE(getitem_cpp, getitem_cpp),
  808. MGE_PY_INTERFACE(setitem_cpp, setitem_cpp),
  809. MGE_PY_INTERFACE(split_cpp, split_cpp),
  810. MGE_PY_INTERFACE(expand_dims_cpp, expand_dims_cpp),
  811. MGE_PY_INTERFACE(squeeze_cpp, squeeze_cpp),
  812. MGE_PY_INTERFACE(transpose_cpp, transpose_cpp),
  813. MGE_PY_INTERFACE(broadcast_cpp, broadcast_cpp),
  814. MGE_PY_INTERFACE(reshape_cpp, reshape_cpp),
  815. MGE_PY_INTERFACE(adaptive_pool2d_cpp, adaptive_pool2d_cpp),
  816. MGE_PY_INTERFACE(Const, Const),
  817. MGE_PY_INTERFACE(astype_cpp, astype_cpp),
  818. MGE_PY_INTERFACE(matmul_cpp, matmul_cpp),
  819. MGE_PY_INTERFACE(batched_matmul_cpp, batched_matmul_cpp),
  820. MGE_PY_INTERFACE(convert_single_value_cpp, convert_single_value_cpp),
  821. MGE_PY_INTERFACE(convert_inputs_cpp, convert_inputs_cpp),
  822. MGE_PY_INTERFACE(astensor1d_cpp, astensor1d_cpp),
  823. MGE_PY_INTERFACE(pixel_shuffle_cpp, pixel_shuffle_cpp),
  824. {nullptr, nullptr, 0, nullptr}};
  825. for (auto&& def : method_defs) {
  826. if (def.ml_meth != nullptr) {
  827. auto* func = PyCFunction_NewEx(&def, nullptr, nullptr);
  828. if (!func)
  829. throw py::error_already_set();
  830. py::setattr(m, def.ml_name, func);
  831. }
  832. }
  833. static constexpr auto sync_py_task_q = [] {
  834. py::gil_scoped_release _;
  835. py_task_q.wait_all_task_finish();
  836. };
  837. m.def("clear_candidates", [channel]() { channel->clear_candidates(); });
  838. m.def("set_option", [channel](std::string name, size_t value) {
  839. channel->set_option(name, value);
  840. });
  841. m.def("get_option",
  842. [channel](std::string name) { return channel->get_option(name); });
  843. m.def("push_scope", [channel](std::string name) {
  844. Transformation::push_scope(name);
  845. channel->push_scope(name);
  846. });
  847. m.def("pop_scope", [channel](std::string name) {
  848. channel->pop_scope(name);
  849. Transformation::pop_scope(name);
  850. });
  851. m.def("start_profile", [channel](imperative::Profiler::options_t options) {
  852. channel->sync();
  853. imperative::Profiler::load_options(std::move(options));
  854. imperative::Profiler::start_profile();
  855. channel->start_profile();
  856. });
  857. m.def("stop_profile", [channel]() -> std::function<void(std::string, std::string)> {
  858. channel->stop_profile();
  859. channel->sync();
  860. CompNode::sync_all();
  861. imperative::Profiler::stop_profile();
  862. auto results = std::make_shared<imperative::Profiler::bundle_t>(
  863. imperative::Profiler::collect());
  864. return [results = results](std::string basename, std::string format) mutable {
  865. imperative::Profiler::dump_profile(basename, format, std::move(*results));
  866. results = nullptr;
  867. };
  868. });
  869. m.def("enable_cupti", &cupti::enable);
  870. m.def("disable_cupti", &cupti::disable);
  871. m.def("cupti_available", &cupti::available);
  872. static std::unique_ptr<CleanupGuard<>> group_comm_guard;
  873. m.def("group_start", []() {
  874. auto commtrans = std::make_shared<GroupCommTransformation>();
  875. group_comm_guard = transformations.register_at<Segment::GroupComm>(commtrans);
  876. });
  877. m.def("group_end", []() { group_comm_guard.reset(); });
  878. m.def("sync", [channel]() {
  879. if (channel->check_available()) {
  880. channel->sync();
  881. }
  882. sync_py_task_q();
  883. });
  884. m.def("full_sync", [channel]() {
  885. if (channel->check_available()) {
  886. channel->sync();
  887. }
  888. CompNode::sync_all();
  889. CompNode::foreach ([](CompNode cn) {
  890. auto err = cn.check_async_error();
  891. mgb_assert(!err, "%s", err->what());
  892. });
  893. sync_py_task_q();
  894. });
  895. m.def("close", [channel]() {
  896. channel->close();
  897. sync_py_task_q();
  898. });
  899. // GradTransformation
  900. py::handle grad_key_type =
  901. GradKeyWrapper::wrap_t::type()
  902. .def<&GradKeyWrapper::attach>("attach")
  903. .def<&GradKeyWrapper::is_attached_to>("is_attached_to")
  904. .def_getset<&GradKeyWrapper::get_name, &GradKeyWrapper::set_name>(
  905. "name")
  906. .def<&GradKeyWrapper::enter>("enter")
  907. .def<&GradKeyWrapper::exit>("exit")
  908. .def<&GradKeyWrapper::suppress>("suppress")
  909. .def<&GradKeyWrapper::resume>("resume")
  910. .finalize();
  911. if (!grad_key_type)
  912. throw py::error_already_set();
  913. py::setattr(m, "GradKey", grad_key_type);
  914. m.def("backward", &GradKeyWrapper::backward);
  915. m.def("get_backward_closure", &GradKeyWrapper::get_backward_closure);
  916. m.def("set_py_tensor_type", [](py::object type_obj) {
  917. py_tensor_type = reinterpret_cast<PyTypeObject*>(type_obj.inc_ref().ptr());
  918. });
  919. m.def("set_py_varnode_type", [](py::object type_obj) {
  920. py_varnode_type = reinterpret_cast<PyTypeObject*>(type_obj.inc_ref().ptr());
  921. });
  922. m.def("set_py_device_type",
  923. [](py::object type_obj) { py_device_type = type_obj.inc_ref(); });
  924. /**
  925. * \brief trace proxy
  926. *
  927. */
  928. struct Trace {
  929. bool symbolic = false;
  930. bool no_exec = false;
  931. bool capture_as_const = false;
  932. bool profile = false;
  933. bool record_input_shapes = false;
  934. py::function options_visitor;
  935. std::shared_ptr<TracingTransformation> tracing;
  936. std::shared_ptr<CompiledTransformation> compiled;
  937. std::shared_ptr<LazyEvalTransformation> lazy_eval;
  938. std::pair<size_t, std::shared_ptr<GraphProfiler>> profiler;
  939. std::optional<TraceResult> trace_result;
  940. std::function<bool(py::object, py::object)> array_comparator;
  941. std::unique_ptr<CleanupGuard<>> tracing_guard;
  942. std::unique_ptr<CleanupGuard<>> compiled_guard;
  943. std::unique_ptr<CleanupGuard<>> lazy_eval_guard;
  944. bool compare_value(ValueRef lhs, ValueRef rhs) {
  945. auto lvalue = lhs.cast_ref<HostValue>();
  946. auto rvalue = rhs.cast_ref<HostValue>();
  947. if (lvalue->shape() != rvalue->shape()) {
  948. return false;
  949. }
  950. if (lvalue->shape().total_nr_elems() == 1) {
  951. return lvalue->item() == rvalue->item();
  952. }
  953. HostTensorND lnd = lvalue->as_nd(true);
  954. HostTensorND rnd = rvalue->as_nd(true);
  955. auto larr = py::reinterpret_steal<py::array>(
  956. npy::ndarray_from_tensor(lnd, npy::ShareType::TRY_SHARE));
  957. auto rarr = py::reinterpret_steal<py::array>(
  958. npy::ndarray_from_tensor(rnd, npy::ShareType::TRY_SHARE));
  959. return array_comparator(larr, rarr);
  960. }
  961. void enter() {
  962. auto& self = *this;
  963. if (!self.trace_result) { // untraced
  964. self.tracing = std::make_shared<TracingTransformation>(
  965. self.capture_as_const, self.record_input_shapes);
  966. if (self.symbolic) {
  967. self.lazy_eval =
  968. std::make_shared<LazyEvalTransformation>(self.no_exec);
  969. self.options_visitor(py::cast(&self.lazy_eval->options()));
  970. }
  971. } else if (!self.compiled) { // traced but not compiled
  972. using namespace std::placeholders;
  973. self.compiled = std::make_shared<CompiledTransformation>(
  974. *self.trace_result, self.record_input_shapes);
  975. self.compiled->set_value_comparator(
  976. std::bind(&Trace::compare_value, this, _1, _2));
  977. self.options_visitor(py::cast(&self.compiled->options()));
  978. try {
  979. self.compiled->compile();
  980. } catch (const std::exception& e) {
  981. mgb_log_error("error in trace: %s", e.what());
  982. }
  983. }
  984. // register transformations
  985. if (self.compiled) {
  986. if (self.profile) {
  987. auto& current_graph = self.compiled->graph();
  988. if (self.profiler.first != self.compiled->graph().id()) {
  989. // graph changed
  990. self.profiler = std::make_pair(
  991. current_graph.id(),
  992. std::make_shared<GraphProfiler>(&current_graph));
  993. }
  994. }
  995. compiled_guard =
  996. transformations.register_at<Segment::Trace>(self.compiled);
  997. // start execute because InputCallback depends
  998. self.compiled->execute();
  999. } else if (self.tracing) {
  1000. tracing_guard =
  1001. transformations.register_at<Segment::Trace>(self.tracing);
  1002. if (self.lazy_eval) {
  1003. lazy_eval_guard =
  1004. transformations.register_at<Segment::Eval>(self.lazy_eval);
  1005. }
  1006. } else {
  1007. mgb_throw(MegBrainError, "invalid state: neither tracing nor compiled");
  1008. }
  1009. }
  1010. void exit() {
  1011. auto& self = *this;
  1012. if (self.tracing) {
  1013. tracing_guard.reset();
  1014. self.trace_result = self.tracing->get_result();
  1015. self.tracing.reset();
  1016. if (self.lazy_eval) {
  1017. auto lazy_eval = std::move(self.lazy_eval);
  1018. lazy_eval_guard.reset();
  1019. lazy_eval->check_exception();
  1020. }
  1021. } else if (self.compiled) {
  1022. compiled_guard.reset();
  1023. self.compiled->wait();
  1024. } else {
  1025. mgb_throw(MegBrainError, "invalid state: neither tracing nor compiled");
  1026. }
  1027. }
  1028. VarNodeArray dump(
  1029. std::shared_ptr<ComputingGraph> graph,
  1030. std::vector<std::tuple<std::string, std::string, TensorShape>> inputs,
  1031. std::vector<std::pair<std::string, std::string>> outputs,
  1032. bool prefer_input_names) {
  1033. auto& self = *this;
  1034. mgb_assert(self.trace_result);
  1035. // mark is like "arg_0", "kwarg_xxx", "output_0" ...
  1036. std::unordered_map<std::string, size_t> mark2var;
  1037. for (size_t i = 0; i < self.trace_result->vars.size(); ++i) {
  1038. auto& name = self.trace_result->vars[i].mark;
  1039. if (!name.empty()) {
  1040. mark2var[name] = i;
  1041. }
  1042. }
  1043. std::vector<std::tuple<size_t, std::string, TensorShape>> input_vars;
  1044. std::vector<std::pair<size_t, std::string>> output_vars;
  1045. for (auto&& [input_mark, input_name, input_shape] : inputs) {
  1046. mgb_assert(input_shape.ndim, "input shape invalid");
  1047. input_vars.push_back(
  1048. {mark2var.at(input_mark), input_name, input_shape});
  1049. }
  1050. for (auto&& [output_name, repr] : outputs) {
  1051. output_vars.push_back({mark2var.at(output_name), repr});
  1052. }
  1053. self.options_visitor(py::cast(&graph->options()));
  1054. auto vars = self.trace_result->dump(
  1055. *graph, input_vars, output_vars, prefer_input_names);
  1056. return vars;
  1057. }
  1058. };
  1059. py::class_<Trace>(m, "Trace")
  1060. .def(py::init<>())
  1061. .def_readwrite("record_input_shapes", &Trace::record_input_shapes)
  1062. .def_readwrite("array_comparator", &Trace::array_comparator)
  1063. .def_readwrite("profile", &Trace::profile)
  1064. .def_property_readonly(
  1065. "options",
  1066. [](Trace& self) {
  1067. if (self.compiled) {
  1068. return &self.compiled->options();
  1069. } else {
  1070. return (ComputingGraph::Options*)nullptr;
  1071. }
  1072. })
  1073. .def("get_profile",
  1074. [](Trace& self) -> py::object {
  1075. if (self.profiler.second && self.compiled) {
  1076. auto json = self.profiler.second->to_json_full(
  1077. self.compiled->graph().current_comp_seq());
  1078. return py::str(json->to_string());
  1079. } else {
  1080. return py::none();
  1081. }
  1082. })
  1083. .def_readwrite("symbolic", &Trace::symbolic)
  1084. .def_readwrite("capture_as_const", &Trace::capture_as_const)
  1085. .def_readwrite("no_exec", &Trace::no_exec)
  1086. .def_readwrite("options_visitor", &Trace::options_visitor)
  1087. .def("enter", &Trace::enter)
  1088. .def("exit", &Trace::exit)
  1089. .def("dump", &Trace::dump)
  1090. .def("begin_excluded_region",
  1091. [](Trace& self) {
  1092. mgb_assert(bool(self.tracing) ^ bool(self.compiled));
  1093. if (self.tracing) {
  1094. self.tracing_guard.reset();
  1095. } else if (self.compiled) {
  1096. self.compiled_guard.reset();
  1097. }
  1098. })
  1099. .def("end_excluded_region", [](Trace& self) {
  1100. mgb_assert(bool(self.tracing) ^ bool(self.compiled));
  1101. if (self.tracing) {
  1102. self.tracing_guard =
  1103. transformations.register_at<Segment::Trace>(self.tracing);
  1104. } else if (self.compiled) {
  1105. self.compiled_guard =
  1106. transformations.register_at<Segment::Trace>(self.compiled);
  1107. }
  1108. });
  1109. m.def("name_tensor", [](std::string name, py::object tensor) {
  1110. auto* tw = TensorWrapper::try_cast(tensor.ptr());
  1111. auto output = imperative::apply(TraceMarkVar(name), tw->m_tensor->data())[0];
  1112. tw->m_tensor->reset(output);
  1113. });
  1114. m.def("is_grad_attached", [](std::vector<py::object> tensors) -> bool {
  1115. SmallVector<ValueRef> values(tensors.size());
  1116. for (size_t i = 0; i < tensors.size(); ++i) {
  1117. values[i] = tensors[i].cast<TensorWrapper>().m_tensor->data();
  1118. }
  1119. auto outputs = imperative::apply(GetGradKey(), values);
  1120. if (outputs[0].is<GradKeyValue>()) {
  1121. return true;
  1122. } else {
  1123. return false;
  1124. }
  1125. });
  1126. m.def("get_grad_key", [](std::vector<py::object> tensors) -> py::object {
  1127. SmallVector<ValueRef> values(tensors.size());
  1128. for (size_t i = 0; i < tensors.size(); ++i) {
  1129. values[i] = tensors[i].cast<TensorWrapper>().m_tensor->data();
  1130. }
  1131. auto output = imperative::apply(GetGradKey(), values)[0];
  1132. if (!output) {
  1133. return py::none();
  1134. }
  1135. return py::reinterpret_borrow<py::object>(GradKeyWrapper::wrap_t::pycast(
  1136. GradKeyWrapper::get(output.cast<GradKeyValue>())));
  1137. });
  1138. m.def("set_grad", [](py::function backward_fn, std::vector<py::object> inputs,
  1139. std::vector<py::object> outputs) {
  1140. GenericFunction generic_backward_fn =
  1141. [backward_fn](Span<ValueRef> output_grads) -> ValueRefList {
  1142. py::list output_grad_tws;
  1143. for (auto&& output_grad : output_grads) {
  1144. if (output_grad) {
  1145. output_grad_tws.append(
  1146. TensorWrapper::make(py_tensor_type, output_grad));
  1147. } else {
  1148. output_grad_tws.append(py::none());
  1149. }
  1150. }
  1151. py::tuple input_grad_tws = backward_fn(*output_grad_tws);
  1152. ValueRefList input_grads(input_grad_tws.size());
  1153. for (size_t i = 0; i < input_grad_tws.size(); ++i) {
  1154. auto input_grad_tw = input_grad_tws[i];
  1155. if (!input_grad_tw.is_none()) {
  1156. input_grads[i] =
  1157. py::cast<TensorWrapper>(input_grad_tw).m_tensor->data();
  1158. } else {
  1159. input_grads[i] = {};
  1160. }
  1161. }
  1162. return input_grads;
  1163. };
  1164. SmallVector<ValueRef> values(inputs.size() + outputs.size());
  1165. for (size_t i = 0; i < inputs.size(); ++i) {
  1166. values[i] = inputs[i].cast<TensorWrapper>().m_tensor->data();
  1167. }
  1168. for (size_t i = 0; i < outputs.size(); ++i) {
  1169. values[i + inputs.size()] =
  1170. outputs[i].cast<TensorWrapper>().m_tensor->data();
  1171. }
  1172. auto wrapped_output_values =
  1173. imperative::apply(SetGrad(generic_backward_fn, inputs.size()), values);
  1174. std::vector<py::object> wrapped_outputs;
  1175. mgb_assert(wrapped_output_values.size() == outputs.size());
  1176. for (auto&& output_value : wrapped_output_values) {
  1177. wrapped_outputs.push_back(
  1178. TensorWrapper::make(py_tensor_type, output_value));
  1179. }
  1180. return wrapped_outputs;
  1181. });
  1182. // ModuleTraceTransformation
  1183. static py::function module_trace_hook;
  1184. static auto get_module_trace = [] {
  1185. static std::shared_ptr<ModuleTraceTransformation> module_trace_transformation;
  1186. if (!module_trace_transformation) {
  1187. mgb_assert(module_trace_hook);
  1188. module_trace_transformation =
  1189. std::make_shared<ModuleTraceTransformation>(module_trace_hook);
  1190. MGB_MARK_USED_VAR(transformations
  1191. .register_at<Segment::ModuleTrace>(
  1192. module_trace_transformation)
  1193. .release());
  1194. }
  1195. return module_trace_transformation;
  1196. };
  1197. m.def("set_cpp_use_symbolic_shape", &set_cpp_use_symbolic_shape);
  1198. m.def("set_module_tracing", [=] { get_module_trace()->enable(); });
  1199. m.def("unset_module_tracing", [=] { get_module_trace()->disable(); });
  1200. m.def("is_tracing_module", [=] { return get_module_trace()->enabled(); });
  1201. m.def("set_module_trace_hook", [](py::function function) {
  1202. module_trace_hook = function;
  1203. module_trace_hook.inc_ref();
  1204. });
  1205. auto atexit = py::module::import("atexit");
  1206. atexit.attr("register")(py::cpp_function([]() { module_trace_hook = {}; }));
  1207. m.def("begin_record_values", [] { Value::begin_record_values(); });
  1208. m.def("end_record_values", [] {
  1209. std::vector<std::pair<size_t, std::string>> reprs;
  1210. auto values = Value::end_record_values();
  1211. for (auto&& value : values) {
  1212. reprs.push_back({value.id(), value.to_string()});
  1213. }
  1214. return reprs;
  1215. });
  1216. m.def("print_stats", [] { Stats::print(); });
  1217. m.def("reset_stats", [] { Stats::reset(); });
  1218. m.def("_get_convert_inputs",
  1219. []() -> bool { return DTypePromoteCfg::convert_input_enabled; });
  1220. m.def("_set_convert_inputs", [](bool flag) -> bool {
  1221. bool ret = DTypePromoteCfg::convert_input_enabled;
  1222. DTypePromoteCfg::convert_input_enabled = flag;
  1223. return ret;
  1224. });
  1225. m.def("_get_amp_dtype_autocast",
  1226. []() -> bool { return DTypePromoteCfg::amp_dtype_autocast_enabled; });
  1227. m.def("_set_amp_dtype_autocast", [](bool flag) -> bool {
  1228. bool ret = DTypePromoteCfg::amp_dtype_autocast_enabled;
  1229. DTypePromoteCfg::amp_dtype_autocast_enabled = flag;
  1230. return ret;
  1231. });
  1232. static auto get_amp_prec_dtype = [](bool is_high) -> std::string {
  1233. DType& target = is_high ? DTypePromoteCfg::amp_high_prec_dtype
  1234. : DTypePromoteCfg::amp_low_prec_dtype;
  1235. mgb_assert(target.category() == DTypeCategory::FLOAT);
  1236. std::string ret = target.name();
  1237. transform(ret.begin(), ret.end(), ret.begin(), ::tolower);
  1238. return ret;
  1239. };
  1240. static auto set_amp_prec_dtype = [](bool is_high,
  1241. std::string dtype_name) -> std::string {
  1242. DType& target = is_high ? DTypePromoteCfg::amp_high_prec_dtype
  1243. : DTypePromoteCfg::amp_low_prec_dtype;
  1244. std::string ret = target.name();
  1245. if (dtype_name == "float32") {
  1246. target = dtype::Float32();
  1247. } else if (dtype_name == "float16") {
  1248. target = dtype::Float16();
  1249. } else if (dtype_name == "bfloat16") {
  1250. target = dtype::BFloat16();
  1251. } else {
  1252. mgb_assert(
  1253. false, "casted type of amp should be float, but you give %s\n",
  1254. dtype_name.c_str());
  1255. }
  1256. transform(ret.begin(), ret.end(), ret.begin(), ::tolower);
  1257. return ret;
  1258. };
  1259. m.def("_get_amp_high_prec_dtype",
  1260. []() -> std::string { return get_amp_prec_dtype(true); });
  1261. m.def("_set_amp_high_prec_dtype", [](std::string dtype_name) -> std::string {
  1262. return set_amp_prec_dtype(true, dtype_name);
  1263. });
  1264. m.def("_get_amp_low_prec_dtype",
  1265. []() -> std::string { return get_amp_prec_dtype(false); });
  1266. m.def("_set_amp_low_prec_dtype", [](std::string dtype_name) -> std::string {
  1267. return set_amp_prec_dtype(false, dtype_name);
  1268. });
  1269. m.def("_clear_algorithm_cache", [] { megdnn::AlgorithmCache::instance().clear(); });
  1270. // FormatTransformation
  1271. m.def("set_auto_format_convert",
  1272. [format_trans](bool enabled) { format_trans->set_auto_convert(enabled); });
  1273. m.def("get_auto_format_convert",
  1274. [format_trans]() { return format_trans->get_auto_convert(); });
  1275. py::register_exception<TraceError>(m, "TraceError");
  1276. }
  1277. #undef MGE_PY_INTERFACE
  1278. } // namespace mgb::imperative::python