You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

ops.cpp 30 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805
  1. /**
  2. * \file imperative/python/src/ops.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "./ops.h"
  12. #include "./helper.h"
  13. #include "./tensor.h"
  14. #include "megbrain/common.h"
  15. #include "megbrain/imperative.h"
  16. #include "megbrain/imperative/graph_builder.h"
  17. #include "megbrain/imperative/ops/autogen.h"
  18. #include "megbrain/imperative/ops/backward_graph.h"
  19. #include "megbrain/imperative/ops/opr_attr.h"
  20. #include "megbrain/imperative/ops/rng.h"
  21. #include "megbrain/imperative/ops/utility.h"
  22. #include <Python.h>
  23. #include <unordered_map>
  24. namespace py = pybind11;
  25. using namespace mgb::imperative;
  26. namespace {
  27. auto normalize_enum(const std::string& in) {
  28. std::string ret;
  29. for (auto&& c : in) {
  30. ret += toupper(c);
  31. }
  32. return ret;
  33. }
  34. } // anonymous namespace
  35. #define CATCH_ALL(RETVAL) \
  36. catch (py::error_already_set & e) { \
  37. e.restore(); \
  38. return RETVAL; \
  39. } \
  40. catch (py::builtin_exception & e) { \
  41. e.set_error(); \
  42. return RETVAL; \
  43. } \
  44. catch (std::exception & e) { \
  45. PyErr_SetString(PyExc_RuntimeError, e.what()); \
  46. return RETVAL; \
  47. }
  48. namespace {
  49. #define PyOp(name) Py##name
  50. #define PyOpType(name) PyOp(name)::py_type
  51. #define PyOpDefBegin(name) \
  52. struct PyOp(name) : PyOpDef { \
  53. using Ty = name; \
  54. Ty& inst() { return op->cast_final_safe<Ty>(); } \
  55. static PyTypeObject py_type;
  56. #define PyOpDefEnd(name) \
  57. } \
  58. ; \
  59. PyTypeObject PyOpType(name);
  60. #define RETURN_RICHCOMPARE(val1, val2, op) \
  61. do { \
  62. switch (op) { \
  63. case Py_EQ: \
  64. if ((val1) == (val2)) \
  65. Py_RETURN_TRUE; \
  66. Py_RETURN_FALSE; \
  67. case Py_NE: \
  68. if ((val1) != (val2)) \
  69. Py_RETURN_TRUE; \
  70. Py_RETURN_FALSE; \
  71. case Py_LT: \
  72. if ((val1) < (val2)) \
  73. Py_RETURN_TRUE; \
  74. Py_RETURN_FALSE; \
  75. case Py_GT: \
  76. if ((val1) > (val2)) \
  77. Py_RETURN_TRUE; \
  78. Py_RETURN_FALSE; \
  79. case Py_LE: \
  80. if ((val1) <= (val2)) \
  81. Py_RETURN_TRUE; \
  82. Py_RETURN_FALSE; \
  83. case Py_GE: \
  84. if ((val1) >= (val2)) \
  85. Py_RETURN_TRUE; \
  86. Py_RETURN_FALSE; \
  87. default: \
  88. Py_FatalError("Unreachable C code path reached"); \
  89. } \
  90. } while (0)
  91. template <typename T>
  92. PyObject* py_new_generic(PyTypeObject* type, PyObject*, PyObject*) {
  93. PyObject* obj = type->tp_alloc(type, 0);
  94. T* self = reinterpret_cast<T*>(obj);
  95. if (self != NULL) {
  96. self->op = T::Ty::make();
  97. }
  98. return obj;
  99. }
  100. template <typename T, typename SNIFAE = void>
  101. struct serialization {
  102. static T load(py::object obj) { return py::cast<T>(obj); }
  103. template <
  104. typename U, typename = std::enable_if_t<std::is_same_v<T, std::decay_t<U>>>>
  105. static py::object dump(U&& t) {
  106. return py::cast(std::forward<U>(t));
  107. }
  108. };
  109. template <typename T>
  110. void py_dealloc_generic(PyObject* obj) {
  111. reinterpret_cast<T*>(obj)->op.reset();
  112. Py_TYPE(obj)->tp_free(obj);
  113. }
  114. template <typename T, typename U, U T::Ty::*attr>
  115. PyObject* py_get_generic_impl(PyObject* obj, void* /* closure */) {
  116. auto& op = reinterpret_cast<T*>(obj)->inst();
  117. return py::cast(op.*attr).release().ptr();
  118. }
  119. #define py_get_generic(name, attr) \
  120. py_get_generic_impl<PyOp(name), decltype(std::declval<name>().attr), &name::attr>
  121. template <typename T, typename U, U T::Ty::*attr>
  122. int py_set_generic_impl(PyObject* obj, PyObject* value, void* /* closure */) {
  123. if (value == NULL) {
  124. PyErr_SetString(PyExc_TypeError, "Cannot delete the attribute");
  125. return -1;
  126. }
  127. auto& op = reinterpret_cast<T*>(obj)->inst();
  128. try {
  129. // TODO: remove this guard which is used for pybind11 implicit conversion
  130. py::detail::loader_life_support guard{};
  131. op.*attr = py::cast<U>(py::handle(value));
  132. }
  133. CATCH_ALL(-1)
  134. return 0;
  135. }
  136. #define py_set_generic(name, attr) \
  137. py_set_generic_impl<PyOp(name), decltype(std::declval<name>().attr), &name::attr>
  138. struct PyOpDef {
  139. PyObject_HEAD std::shared_ptr<OpDef> op;
  140. static PyTypeObject py_type;
  141. static std::unordered_map<mgb::Typeinfo*, PyTypeObject*> ctype2pytype;
  142. static PyGetSetDef py_getsetters[];
  143. static Py_hash_t tp_hash(PyObject* obj);
  144. static PyObject* tp_richcompare(PyObject* self, PyObject* other, int op);
  145. static PyObject* py_repr(PyObject* self) {
  146. return py::cast(reinterpret_cast<PyOpDef*>(self)->op->make_name())
  147. .release()
  148. .ptr();
  149. }
  150. };
  151. PyTypeObject PyOpType(OpDef);
  152. std::unordered_map<mgb::Typeinfo*, PyTypeObject*> PyOp(OpDef)::ctype2pytype;
  153. PyObject* py_get_scope(PyObject* obj, void* /* closure */) {
  154. return py::cast(reinterpret_cast<PyOp(OpDef)*>(obj)->op->scope()).release().ptr();
  155. }
  156. int py_set_scope(PyObject* obj, PyObject* value, void* /* closure */) {
  157. if (value == NULL) {
  158. PyErr_SetString(PyExc_TypeError, "Cannot delete the attribute");
  159. return -1;
  160. }
  161. try {
  162. reinterpret_cast<PyOp(OpDef)*>(obj)->op->set_scope(
  163. py::cast<std::string>(py::handle(value)));
  164. }
  165. CATCH_ALL(-1)
  166. return 0;
  167. }
  168. PyGetSetDef PyOp(OpDef)::py_getsetters[] = {
  169. {const_cast<char*>("scope"), py_get_scope, py_set_scope,
  170. const_cast<char*>("scope"), NULL},
  171. {NULL}};
  172. Py_hash_t PyOp(OpDef)::tp_hash(PyObject* obj) {
  173. return static_cast<Py_hash_t>(reinterpret_cast<PyOp(OpDef)*>(obj)->op->hash());
  174. }
  175. PyObject* PyOp(OpDef)::tp_richcompare(PyObject* self, PyObject* other, int op) {
  176. bool same = reinterpret_cast<PyOp(OpDef)*>(self)->op->is_same(
  177. *reinterpret_cast<PyOp(OpDef)*>(other)->op);
  178. if (op == Py_EQ || op == Py_NE) {
  179. RETURN_RICHCOMPARE(same, true, op);
  180. }
  181. Py_RETURN_NOTIMPLEMENTED;
  182. }
  183. template <typename T>
  184. struct EnumTrait;
  185. #define PyEnumHead \
  186. static_assert(std::is_enum_v<T>); \
  187. PyObject_HEAD T value; \
  188. constexpr static const char* name = EnumTrait<T>::name; \
  189. static PyTypeObject* type; \
  190. static const char* members[]; \
  191. static std::unordered_map<std::string, T> mem2value; \
  192. static PyObject* pyobj_insts[];
  193. template <typename T>
  194. struct EnumWrapper {
  195. PyEnumHead std::string to_string() const {
  196. return members[static_cast<size_t>(value)];
  197. }
  198. static PyObject* py_repr(PyObject* self) {
  199. return py::cast(
  200. std::string(name) + "." +
  201. reinterpret_cast<EnumWrapper*>(self)->to_string())
  202. .release()
  203. .ptr();
  204. }
  205. static PyObject* py_dump(PyObject* self) {
  206. return py::cast(reinterpret_cast<EnumWrapper*>(self)->to_string())
  207. .release()
  208. .ptr();
  209. }
  210. static PyObject* tp_richcompare(PyObject* self, PyObject* other, int op) {
  211. if (op == Py_EQ || op == Py_NE) {
  212. T lhs, rhs;
  213. if (load(other, rhs) && load(self, lhs)) {
  214. RETURN_RICHCOMPARE(lhs, rhs, op);
  215. } else {
  216. RETURN_RICHCOMPARE(0, 1, op);
  217. }
  218. }
  219. Py_RETURN_NOTIMPLEMENTED;
  220. }
  221. static bool load(py::handle src, T& value) {
  222. PyObject* obj = src.ptr();
  223. if (PyObject_TypeCheck(obj, type)) {
  224. value = reinterpret_cast<EnumWrapper*>(obj)->value;
  225. return true;
  226. }
  227. if (py::isinstance<py::str>(src)) {
  228. auto&& iter = mem2value.find(normalize_enum(py::cast<std::string>(src)));
  229. if (iter != mem2value.end()) {
  230. value = iter->second;
  231. return true;
  232. } else {
  233. return false;
  234. }
  235. }
  236. return false;
  237. }
  238. static PyObject* cast(const T& value) {
  239. auto v = static_cast<std::underlying_type_t<T>>(value);
  240. mgb_assert(v <= EnumTrait<T>::max);
  241. PyObject* obj = pyobj_insts[v];
  242. Py_INCREF(obj);
  243. return obj;
  244. }
  245. };
  246. template <typename T>
  247. struct BitCombinedEnumWrapper {
  248. PyEnumHead std::string to_string() const {
  249. uint32_t value_int = static_cast<uint32_t>(value);
  250. if (value_int == 0) {
  251. return "None";
  252. } else {
  253. std::string ret;
  254. bool first = true;
  255. for (uint32_t i = 0; i < 32; i++) {
  256. if (value_int >> i & 1) {
  257. if (!first) {
  258. ret += " + ";
  259. } else {
  260. first = false;
  261. }
  262. ret += (std::string(name) + "." + members[i]);
  263. }
  264. }
  265. return ret;
  266. }
  267. }
  268. static PyObject* py_new_combined_enum(
  269. PyTypeObject* type, PyObject* args, PyObject*) {
  270. if (!PyTuple_Size(args)) {
  271. PyObject* obj = type->tp_alloc(type, 0);
  272. reinterpret_cast<BitCombinedEnumWrapper*>(obj)->value = T();
  273. return obj;
  274. } else {
  275. PyObject* input;
  276. if (!PyArg_ParseTuple(args, "|O", &input)) {
  277. return nullptr;
  278. }
  279. T value;
  280. if (load(input, value)) {
  281. return cast(value);
  282. } else {
  283. PyErr_SetString(
  284. PyExc_RuntimeError,
  285. mgb::ssprintf(
  286. "Cannot convert type %s to type %s\n",
  287. input->ob_type->tp_name, name)
  288. .c_str());
  289. return nullptr;
  290. }
  291. }
  292. }
  293. static PyObject* py_repr(PyObject* self) {
  294. return py::cast(reinterpret_cast<BitCombinedEnumWrapper*>(self)->to_string())
  295. .release()
  296. .ptr();
  297. }
  298. static PyObject* py_dump(PyObject* self) {
  299. std::vector<std::string> result;
  300. auto value = reinterpret_cast<BitCombinedEnumWrapper*>(self)->value;
  301. uint32_t value_int = static_cast<uint32_t>(value);
  302. for (uint32_t i = 0; i < 32; i++) {
  303. if (value_int >> i & 1) {
  304. result.push_back(members[i]);
  305. }
  306. }
  307. return py::tuple(py::cast(result)).release().ptr();
  308. }
  309. static PyObject* py_or(PyObject* self, PyObject* other) {
  310. if (!(self->ob_type == other->ob_type)) {
  311. return PyErr_Format(
  312. PyExc_RuntimeError,
  313. "Operand in or operator must be the same type.");
  314. }
  315. T lhs = reinterpret_cast<BitCombinedEnumWrapper*>(self)->value,
  316. rhs = reinterpret_cast<BitCombinedEnumWrapper*>(other)->value;
  317. return cast(lhs | rhs);
  318. }
  319. static PyObject* py_and(PyObject* self, PyObject* other) {
  320. if (!(self->ob_type == other->ob_type)) {
  321. return PyErr_Format(
  322. PyExc_RuntimeError,
  323. "Operand in and operator must be the same type.");
  324. }
  325. T lhs = reinterpret_cast<BitCombinedEnumWrapper*>(self)->value,
  326. rhs = reinterpret_cast<BitCombinedEnumWrapper*>(other)->value;
  327. return cast(lhs & rhs);
  328. }
  329. static PyObject* tp_richcompare(PyObject* self, PyObject* other, int op) {
  330. if (op == Py_EQ || op == Py_NE) {
  331. T lhs, rhs;
  332. if (load(other, rhs) && load(self, lhs)) {
  333. RETURN_RICHCOMPARE(lhs, rhs, op);
  334. } else {
  335. RETURN_RICHCOMPARE(0, 1, op);
  336. }
  337. }
  338. Py_RETURN_NOTIMPLEMENTED;
  339. }
  340. static bool load(py::handle src, T& value) {
  341. PyObject* obj = src.ptr();
  342. if (PyObject_TypeCheck(obj, type)) {
  343. value = reinterpret_cast<BitCombinedEnumWrapper*>(obj)->value;
  344. return true;
  345. }
  346. if (py::isinstance<py::str>(src)) {
  347. auto&& iter = mem2value.find(normalize_enum(py::cast<std::string>(src)));
  348. if (iter != mem2value.end()) {
  349. value = iter->second;
  350. return true;
  351. } else {
  352. return false;
  353. }
  354. }
  355. if (py::isinstance<py::tuple>(src)) {
  356. auto params = py::cast<std::vector<std::string>>(src);
  357. bool first = true;
  358. for (auto s : params) {
  359. auto&& iter = mem2value.find(normalize_enum(s));
  360. if (iter != mem2value.end()) {
  361. if (first) {
  362. value = iter->second;
  363. first = false;
  364. } else {
  365. value |= iter->second;
  366. }
  367. } else {
  368. return false;
  369. }
  370. }
  371. return true;
  372. }
  373. if (py::isinstance<py::int_>(obj)) {
  374. auto v = py::cast<std::underlying_type_t<T>>(src);
  375. if (v > EnumTrait<T>::max) {
  376. return false;
  377. }
  378. value = static_cast<T>(v);
  379. return true;
  380. }
  381. return false;
  382. }
  383. static PyObject* cast(const T& value) {
  384. auto v = static_cast<std::underlying_type_t<T>>(value);
  385. mgb_assert(v <= EnumTrait<T>::max);
  386. if ((!v) || (v & (v - 1))) {
  387. PyObject* obj = type->tp_alloc(type, 0);
  388. reinterpret_cast<BitCombinedEnumWrapper*>(obj)->value = value;
  389. return obj;
  390. } else {
  391. PyObject* obj = pyobj_insts[__builtin_ctz(v)];
  392. Py_INCREF(obj);
  393. return obj;
  394. }
  395. }
  396. };
  397. template <typename T>
  398. struct serialization<T, std::enable_if_t<std::is_enum_v<std::decay_t<T>>>> {
  399. static T load(py::object obj) {
  400. auto caster = pybind11::detail::type_caster<T>();
  401. if (caster.load(obj, true)) {
  402. return caster;
  403. } else {
  404. PyErr_SetString(PyExc_RuntimeError, "load faild \n");
  405. return caster;
  406. }
  407. }
  408. static py::object dump(T t) { return py::cast(t).attr("dump")(); }
  409. };
  410. void _init_py_op_def(py::module m) {
  411. using py_op = PyOp(OpDef);
  412. auto& py_type = PyOpType(OpDef);
  413. py_type = {PyVarObject_HEAD_INIT(NULL, 0)};
  414. py_type.tp_name = "megengine.core._imperative_rt.OpDef";
  415. py_type.tp_basicsize = sizeof(PyOp(OpDef));
  416. py_type.tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE;
  417. py_type.tp_doc = "OpDef";
  418. py_type.tp_base = &PyBaseObject_Type;
  419. py_type.tp_hash = PyOp(OpDef)::tp_hash;
  420. py_type.tp_richcompare = PyOp(OpDef)::tp_richcompare;
  421. py_type.tp_getset = py_op::py_getsetters;
  422. py_type.tp_repr = py_op::py_repr;
  423. py_type.tp_dealloc = py_dealloc_generic<PyOp(OpDef)>;
  424. mgb_assert(PyType_Ready(&py_type) >= 0);
  425. m.add_object("OpDef", reinterpret_cast<PyObject*>(&py_type));
  426. }
  427. /*********** begin of hand-write opdefs **************/
  428. struct PyOpBase : PyOpDef {
  429. static PyTypeObject py_type;
  430. static PyObject* tp_new(PyTypeObject* type, PyObject*, PyObject*) {
  431. auto* obj = type->tp_alloc(type, 0);
  432. if (obj) {
  433. auto* self = reinterpret_cast<PyOpBase*>(obj);
  434. new (&self->op) decltype(self->op);
  435. }
  436. return obj;
  437. }
  438. };
  439. PyTypeObject PyOpBase::py_type;
  440. void _init_py_op_base(py::module m) {
  441. using py_op = PyOpBase;
  442. auto& py_type = PyOpBase::py_type;
  443. py_type = {PyVarObject_HEAD_INIT(NULL, 0)};
  444. py_type.tp_name = "megengine.core._imperative_rt.ops.PyOpBase";
  445. py_type.tp_basicsize = sizeof(py_op);
  446. py_type.tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE;
  447. py_type.tp_doc = "PyOpBase";
  448. py_type.tp_base = &PyOpType(OpDef);
  449. py_type.tp_dealloc = py_dealloc_generic<py_op>;
  450. py_type.tp_new = py_op::tp_new;
  451. mgb_assert(PyType_Ready(&py_type) >= 0);
  452. m.add_object("PyOpBase", reinterpret_cast<PyObject*>(&py_type));
  453. }
  454. /*********** end of hand-write opdefs **************/
  455. // auto generated opdefs
  456. #include "opdef.cpy.inl"
  457. #undef CATCH_ALL
  458. } // anonymous namespace
  459. namespace PYBIND11_NAMESPACE {
  460. namespace detail {
  461. bool type_caster<OpDef>::load(handle src, bool convert) {
  462. PyObject* obj = src.ptr();
  463. if (!PyObject_TypeCheck(obj, &PyOpType(OpDef))) {
  464. return false;
  465. }
  466. value = reinterpret_cast<PyOp(OpDef)*>(obj)->op;
  467. if (!value) {
  468. // opdef only defined in Python
  469. value = std::make_shared<GenericPyOp>(reinterpret_borrow<object>(src));
  470. }
  471. return true;
  472. }
  473. handle type_caster<OpDef>::cast(const OpDef& op, return_value_policy, handle) {
  474. if (auto* pyop = op.try_cast_final<GenericPyOp>()) {
  475. return object(pyop->obj).release();
  476. }
  477. PyTypeObject* pytype;
  478. auto& c2p = PyOp(OpDef)::ctype2pytype;
  479. auto&& iter = c2p.find(op.dyn_typeinfo());
  480. if (iter != c2p.end()) { // FIXME: should always meet this condition
  481. pytype = iter->second;
  482. } else { // which means unregistered op type, jsut make it as an opaque op type
  483. // currently, only OprAttr goes into this branch
  484. pytype = &PyOpType(OpDef);
  485. }
  486. PyObject* obj = pytype->tp_alloc(pytype, 0);
  487. mgb_assert(PyObject_TypeCheck(obj, &PyOpType(OpDef)));
  488. reinterpret_cast<PyOp(OpDef)*>(obj)->op = const_cast<OpDef&>(op).shared_from_this();
  489. return py::handle(obj);
  490. }
  491. #define ENUM_CASTER_IMPL(T) \
  492. bool type_caster<T>::load(handle src, bool) { \
  493. return EnumWrapper<T>::load(src, value); \
  494. } \
  495. handle type_caster<T>::cast(const T& value, return_value_policy, handle) { \
  496. return EnumWrapper<T>::cast(value); \
  497. }
  498. FOR_EACH_ENUM_PARAM(ENUM_CASTER_IMPL)
  499. #define BIT_COMBINED_ENUM_CASTER_IMPL(T) \
  500. bool type_caster<T>::load(handle src, bool) { \
  501. return BitCombinedEnumWrapper<T>::load(src, value); \
  502. } \
  503. handle type_caster<T>::cast(const T& value, return_value_policy, handle) { \
  504. return BitCombinedEnumWrapper<T>::cast(value); \
  505. }
  506. FOR_EACH_BIT_COMBINED_ENUM_PARAM(BIT_COMBINED_ENUM_CASTER_IMPL)
  507. } // namespace detail
  508. } // namespace PYBIND11_NAMESPACE
  509. void init_ops(py::module m) {
  510. _init_py_op_def(m);
  511. _init_py_op_base(m);
  512. INIT_ALL_OP(m)
  513. m.def("new_rng_handle", &rng::new_handle);
  514. m.def(
  515. "delete_rng_handle",
  516. [](size_t handle) {
  517. mgb::CompNode::sync_all();
  518. py_task_q.wait_all_task_finish();
  519. rng::delete_handle(handle);
  520. },
  521. py::call_guard<py::gil_scoped_release>());
  522. m.def("set_global_rng_seed", [](uint64_t seed) -> void {
  523. mgb_assert(
  524. python::interpreter_for_py->check_available(),
  525. "set global random seed failed since imperative interpreter has been "
  526. "destroyed");
  527. python::interpreter_for_py->sync();
  528. mgb::CompNode::sync_all();
  529. rng::set_global_rng_seed(seed);
  530. });
  531. m.def("get_global_rng_seed", &rng::get_global_rng_seed);
  532. m.def("get_rng_handle_compnode", &rng::get_rng_handle_compnode);
  533. struct PySubgraphBuilder {
  534. explicit PySubgraphBuilder(std::string name) : name{name} {}
  535. std::string name;
  536. Subgraph graph;
  537. mgb::SmallVector<bool> output_grad_mask;
  538. Subgraph::var_t next_var = 1;
  539. std::shared_ptr<mgb::Hashable> key = nullptr;
  540. std::shared_ptr<OpDef> build() {
  541. if (key == nullptr) {
  542. key = std::make_shared<UniqueKey>();
  543. }
  544. return SubgraphOp::make(
  545. name, std::make_shared<Subgraph>(graph), output_grad_mask, key);
  546. }
  547. };
  548. py::class_<PySubgraphBuilder>(m, "SubgraphBuilder")
  549. .def(py::init<std::string>())
  550. .def(py::init<PySubgraphBuilder>())
  551. .def("input",
  552. [](PySubgraphBuilder& self) {
  553. mgb_assert(self.key == nullptr);
  554. auto var = self.next_var++;
  555. self.graph.inputs.push_back(var);
  556. return var;
  557. })
  558. .def("apply",
  559. [](PySubgraphBuilder& self, std::shared_ptr<OpDef> op,
  560. Subgraph::vars_t inputs, size_t nr_outputs) {
  561. mgb_assert(self.key == nullptr);
  562. Subgraph::vars_t outputs;
  563. for (size_t i = 0; i < nr_outputs; ++i) {
  564. outputs.push_back(self.next_var++);
  565. }
  566. self.graph.exprs.push_back({op, inputs, outputs});
  567. return outputs;
  568. })
  569. .def("apply_const",
  570. [](PySubgraphBuilder& self, py::object value, mgb::DType dtype,
  571. mgb::CompNode cn) {
  572. mgb_assert(self.key == nullptr);
  573. auto var = self.next_var++;
  574. mgb::HostTensorND hvalue(cn);
  575. npy::np2tensor(
  576. value.cast<py::array>().ptr(),
  577. npy::Meth::copy_into(&hvalue), dtype);
  578. self.graph.constants.push_back({var, Tensor::make(hvalue)});
  579. return var;
  580. })
  581. .def("outputs",
  582. [](PySubgraphBuilder& self, Subgraph::vars_t outputs) {
  583. mgb_assert(self.key == nullptr);
  584. self.graph.outputs = outputs;
  585. self.output_grad_mask.resize(outputs.size(), true);
  586. })
  587. .def("outputs_has_grad",
  588. [](PySubgraphBuilder& self, mgb::SmallVector<bool> outputs_has_grad) {
  589. mgb_assert(self.key == nullptr);
  590. mgb_assert(
  591. self.graph.outputs.size() == self.output_grad_mask.size());
  592. self.output_grad_mask = outputs_has_grad;
  593. })
  594. .def("get",
  595. [](PySubgraphBuilder& self) {
  596. return (std::shared_ptr<OpDef>)self.build();
  597. })
  598. .def("compile",
  599. [](PySubgraphBuilder& self, int gopt_level) {
  600. return (std::shared_ptr<OpDef>)CompiledOp::make(
  601. self.build(), gopt_level);
  602. })
  603. .def("jit_fuse", [](PySubgraphBuilder& self) {
  604. return (std::shared_ptr<OpDef>)CompiledOp::make(
  605. JITFusionOp::make(self.build()));
  606. });
  607. m.def("set_jit_enabled", &JITFusionOp::set_enabled);
  608. auto custom = submodule(m, "_custom");
  609. init_custom(custom);
  610. }
  611. #define CUSTOM_CASE_TO_PARSE_NON_LIST(dyn_type, static_type) \
  612. case custom::ParamDynType::dyn_type: { \
  613. param_val = py::handle(kv.second).cast<static_type>(); \
  614. break; \
  615. }
  616. #define CUSTOM_CASE_TO_PARSE_LIST(dyn_type, static_type) \
  617. case custom::ParamDynType::dyn_type: { \
  618. auto pyvals = py::handle(kv.second).cast<py::list>(); \
  619. static_type vals; \
  620. using basic_type = custom::get_vector_template_arg_type<static_type>::type; \
  621. for (auto& pyval : pyvals) { \
  622. vals.push_back(py::handle(pyval).cast<basic_type>()); \
  623. } \
  624. param_val = vals; \
  625. break; \
  626. }
  627. PyObject* make_custom_op(PyObject* self, PyObject** args, Py_ssize_t nargs) {
  628. #if MGB_CUSTOM_OP
  629. auto op_name = py::handle(args[0]).cast<std::string>();
  630. auto kwargs = py::handle(args[1]).cast<py::dict>();
  631. std::shared_ptr<OpDef> opdef = CustomOpDefFactory::inst()->create_opdef(op_name);
  632. auto& custom_opdef = static_cast<mgb::imperative::CustomOpDef&>(*opdef);
  633. auto& param = custom_opdef.param();
  634. for (auto&& kv : kwargs) {
  635. std::string param_name = py::handle(kv.first).cast<std::string>();
  636. std::string type_name = py::handle(kv.second).ptr()->ob_type->tp_name;
  637. if (!param.exist(param_name)) {
  638. mgb_log_warn(
  639. "op %s have no param named %s, ignore this param parsed from "
  640. "python",
  641. op_name.c_str(), param_name.c_str());
  642. continue;
  643. }
  644. auto& param_val = param[param_name];
  645. switch (param_val.type()) {
  646. CUSTOM_FOR_EACH_BASIC_PARAMTYPE(CUSTOM_CASE_TO_PARSE_NON_LIST)
  647. CUSTOM_FOR_STRING_PARAMTYPE(CUSTOM_CASE_TO_PARSE_NON_LIST)
  648. CUSTOM_FOR_EACH_BASIC_LIST_PARAMTYPE(CUSTOM_CASE_TO_PARSE_LIST)
  649. CUSTOM_FOR_BOOL_LIST_PARAMTYPE(CUSTOM_CASE_TO_PARSE_LIST)
  650. CUSTOM_FOR_STRING_LIST_PARAMTYPE(CUSTOM_CASE_TO_PARSE_LIST)
  651. default: {
  652. mgb_assert(
  653. false, "param dtype of %s:%s is invalid", op_name.c_str(),
  654. param_name.c_str());
  655. }
  656. }
  657. }
  658. PyTypeObject* pytype;
  659. pytype = &PyOpType(OpDef);
  660. PyObject* obj = pytype->tp_alloc(pytype, 0);
  661. reinterpret_cast<PyOp(OpDef)*>(obj)->op = opdef;
  662. return obj;
  663. #else
  664. mgb_assert(
  665. false,
  666. "Custom Op is disabled now, please build megengine with Custom Op open");
  667. return nullptr;
  668. #endif
  669. }
  670. #undef CUSTOM_CASE_TO_PARSE_LIST
  671. #undef CUSTOM_CASE_TO_PARSE_NON_LIST
  672. py::list install_custom(const std::string& name, const std::string& path) {
  673. #if MGB_CUSTOM_OP
  674. py::list ret;
  675. const auto& ops_in_lib = custom::LibManager::inst()->install(name, path);
  676. for (const auto& op : ops_in_lib) {
  677. ret.append(op);
  678. }
  679. return ret;
  680. #else
  681. mgb_assert(
  682. false,
  683. "Custom Op is disabled now, please build megengine with Custom Op open");
  684. py::list ret;
  685. return ret;
  686. #endif
  687. }
  688. bool uninstall_custom(const std::string& name) {
  689. #if MGB_CUSTOM_OP
  690. return custom::LibManager::inst()->uninstall(name);
  691. #else
  692. mgb_assert(
  693. false,
  694. "Custom Op is disabled now, please build megengine with Custom Op open");
  695. return false;
  696. #endif
  697. }
  698. py::list get_custom_op_list(void) {
  699. #if MGB_CUSTOM_OP
  700. std::vector<std::string> all_ops = CustomOpDefFactory::inst()->op_list();
  701. py::list ret;
  702. for (auto& op : all_ops) {
  703. ret.append(op);
  704. }
  705. return ret;
  706. #else
  707. mgb_assert(
  708. false,
  709. "Custom Op is disabled now, please build megengine with Custom Op open");
  710. py::list ret;
  711. return ret;
  712. #endif
  713. }
  714. #ifndef METH_FASTCALL
  715. PyObject* py35_make_custom_op(PyObject* self, PyObject* args) {
  716. auto* arr = &PyTuple_GET_ITEM(args, 0);
  717. auto size = PyTuple_GET_SIZE(args);
  718. return make_custom_op(self, arr, size);
  719. };
  720. #endif
  721. void init_custom(pybind11::module m) {
  722. m.def("_install", &install_custom);
  723. m.def("_uninstall", &uninstall_custom);
  724. m.def("_get_custom_op_list", &get_custom_op_list);
  725. m.def("get_custom_op_abi_tag", [](void) -> int {
  726. int ret = 0;
  727. #ifdef _GLIBCXX_USE_CXX11_ABI
  728. ret = _GLIBCXX_USE_CXX11_ABI;
  729. #endif
  730. return ret;
  731. });
  732. static PyMethodDef method_def = {
  733. #ifdef METH_FASTCALL
  734. "_make_custom_op", (PyCFunction)make_custom_op, METH_FASTCALL, ""
  735. #else
  736. "_make_custom_op", (PyCFunction)py35_make_custom_op, METH_VARARGS, ""
  737. #endif
  738. };
  739. auto* func = PyCFunction_NewEx(&method_def, nullptr, nullptr);
  740. pybind11::setattr(m, method_def.ml_name, func);
  741. }