You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

ops.cpp 30 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806
  1. #include "./ops.h"
  2. #include "./helper.h"
  3. #include "./tensor.h"
  4. #include "megbrain/common.h"
  5. #include "megbrain/imperative.h"
  6. #include "megbrain/imperative/graph_builder.h"
  7. #include "megbrain/imperative/ops/autogen.h"
  8. #include "megbrain/imperative/ops/backward_graph.h"
  9. #include "megbrain/imperative/ops/opr_attr.h"
  10. #include "megbrain/imperative/ops/rng.h"
  11. #include "megbrain/imperative/ops/utility.h"
  12. #include <Python.h>
  13. #include <unordered_map>
  14. namespace py = pybind11;
  15. using namespace mgb::imperative;
  16. namespace {
  17. auto normalize_enum(const std::string& in) {
  18. std::string ret;
  19. for (auto&& c : in) {
  20. ret += toupper(c);
  21. }
  22. return ret;
  23. }
  24. } // anonymous namespace
  25. #define CATCH_ALL(RETVAL) \
  26. catch (py::error_already_set & e) { \
  27. e.restore(); \
  28. return RETVAL; \
  29. } \
  30. catch (py::builtin_exception & e) { \
  31. e.set_error(); \
  32. return RETVAL; \
  33. } \
  34. catch (std::exception & e) { \
  35. PyErr_SetString(PyExc_RuntimeError, e.what()); \
  36. return RETVAL; \
  37. }
  38. namespace {
  39. #define PyOp(name) Py##name
  40. #define PyOpType(name) PyOp(name)::py_type
  41. #define PyOpDefBegin(name) \
  42. struct PyOp(name) : PyOpDef { \
  43. using Ty = name; \
  44. Ty& inst() { return op->cast_final_safe<Ty>(); } \
  45. static PyTypeObject py_type;
  46. #define PyOpDefEnd(name) \
  47. } \
  48. ; \
  49. PyTypeObject PyOpType(name);
  50. #define RETURN_RICHCOMPARE(val1, val2, op) \
  51. do { \
  52. switch (op) { \
  53. case Py_EQ: \
  54. if ((val1) == (val2)) \
  55. Py_RETURN_TRUE; \
  56. Py_RETURN_FALSE; \
  57. case Py_NE: \
  58. if ((val1) != (val2)) \
  59. Py_RETURN_TRUE; \
  60. Py_RETURN_FALSE; \
  61. case Py_LT: \
  62. if ((val1) < (val2)) \
  63. Py_RETURN_TRUE; \
  64. Py_RETURN_FALSE; \
  65. case Py_GT: \
  66. if ((val1) > (val2)) \
  67. Py_RETURN_TRUE; \
  68. Py_RETURN_FALSE; \
  69. case Py_LE: \
  70. if ((val1) <= (val2)) \
  71. Py_RETURN_TRUE; \
  72. Py_RETURN_FALSE; \
  73. case Py_GE: \
  74. if ((val1) >= (val2)) \
  75. Py_RETURN_TRUE; \
  76. Py_RETURN_FALSE; \
  77. default: \
  78. Py_FatalError("Unreachable C code path reached"); \
  79. } \
  80. } while (0)
  81. template <typename T>
  82. PyObject* py_new_generic(PyTypeObject* type, PyObject*, PyObject*) {
  83. PyObject* obj = type->tp_alloc(type, 0);
  84. T* self = reinterpret_cast<T*>(obj);
  85. if (self != NULL) {
  86. self->op = T::Ty::make();
  87. }
  88. return obj;
  89. }
  90. template <typename T, typename SNIFAE = void>
  91. struct serialization {
  92. static T load(py::object obj) { return py::cast<T>(obj); }
  93. template <
  94. typename U, typename = std::enable_if_t<std::is_same_v<T, std::decay_t<U>>>>
  95. static py::object dump(U&& t) {
  96. return py::cast(std::forward<U>(t));
  97. }
  98. };
  99. template <typename T>
  100. void py_dealloc_generic(PyObject* obj) {
  101. reinterpret_cast<T*>(obj)->op.reset();
  102. Py_TYPE(obj)->tp_free(obj);
  103. }
  104. template <typename T, typename U, U T::Ty::*attr>
  105. PyObject* py_get_generic_impl(PyObject* obj, void* /* closure */) {
  106. auto& op = reinterpret_cast<T*>(obj)->inst();
  107. return py::cast(op.*attr).release().ptr();
  108. }
  109. #define py_get_generic(name, attr) \
  110. py_get_generic_impl<PyOp(name), decltype(std::declval<name>().attr), &name::attr>
  111. template <typename T, typename U, U T::Ty::*attr>
  112. int py_set_generic_impl(PyObject* obj, PyObject* value, void* /* closure */) {
  113. if (value == NULL) {
  114. PyErr_SetString(PyExc_TypeError, "Cannot delete the attribute");
  115. return -1;
  116. }
  117. auto& op = reinterpret_cast<T*>(obj)->inst();
  118. try {
  119. // TODO: remove this guard which is used for pybind11 implicit conversion
  120. py::detail::loader_life_support guard{};
  121. op.*attr = py::cast<U>(py::handle(value));
  122. }
  123. CATCH_ALL(-1)
  124. return 0;
  125. }
  126. #define py_set_generic(name, attr) \
  127. py_set_generic_impl<PyOp(name), decltype(std::declval<name>().attr), &name::attr>
  128. struct PyOpDef {
  129. PyObject_HEAD std::shared_ptr<OpDef> op;
  130. static PyTypeObject py_type;
  131. static std::unordered_map<mgb::Typeinfo*, PyTypeObject*> ctype2pytype;
  132. static PyGetSetDef py_getsetters[];
  133. static Py_hash_t tp_hash(PyObject* obj);
  134. static PyObject* tp_richcompare(PyObject* self, PyObject* other, int op);
  135. static PyObject* py_repr(PyObject* self) {
  136. return py::cast(reinterpret_cast<PyOpDef*>(self)->op->make_name())
  137. .release()
  138. .ptr();
  139. }
  140. };
  141. PyTypeObject PyOpType(OpDef);
  142. std::unordered_map<mgb::Typeinfo*, PyTypeObject*> PyOp(OpDef)::ctype2pytype;
  143. PyObject* py_get_scope(PyObject* obj, void* /* closure */) {
  144. return py::cast(reinterpret_cast<PyOp(OpDef)*>(obj)->op->scope()).release().ptr();
  145. }
  146. int py_set_scope(PyObject* obj, PyObject* value, void* /* closure */) {
  147. if (value == NULL) {
  148. PyErr_SetString(PyExc_TypeError, "Cannot delete the attribute");
  149. return -1;
  150. }
  151. try {
  152. reinterpret_cast<PyOp(OpDef)*>(obj)->op->set_scope(
  153. py::cast<std::string>(py::handle(value)));
  154. }
  155. CATCH_ALL(-1)
  156. return 0;
  157. }
  158. PyGetSetDef PyOp(OpDef)::py_getsetters[] = {
  159. {const_cast<char*>("scope"), py_get_scope, py_set_scope,
  160. const_cast<char*>("scope"), NULL},
  161. {NULL}};
  162. Py_hash_t PyOp(OpDef)::tp_hash(PyObject* obj) {
  163. return static_cast<Py_hash_t>(reinterpret_cast<PyOp(OpDef)*>(obj)->op->hash());
  164. }
  165. PyObject* PyOp(OpDef)::tp_richcompare(PyObject* self, PyObject* other, int op) {
  166. bool same = reinterpret_cast<PyOp(OpDef)*>(self)->op->is_same(
  167. *reinterpret_cast<PyOp(OpDef)*>(other)->op);
  168. if (op == Py_EQ || op == Py_NE) {
  169. RETURN_RICHCOMPARE(same, true, op);
  170. }
  171. Py_RETURN_NOTIMPLEMENTED;
  172. }
  173. template <typename T>
  174. struct EnumTrait;
  175. #define PyEnumHead \
  176. static_assert(std::is_enum_v<T>); \
  177. PyObject_HEAD T value; \
  178. constexpr static const char* name = EnumTrait<T>::name; \
  179. static PyTypeObject* type; \
  180. static const char* members[]; \
  181. static std::unordered_map<std::string, T> mem2value; \
  182. static PyObject* pyobj_insts[];
  183. template <typename T>
  184. struct EnumWrapper {
  185. PyEnumHead std::string to_string() const {
  186. return members[static_cast<size_t>(value)];
  187. }
  188. static PyObject* py_repr(PyObject* self) {
  189. return py::cast(
  190. std::string(name) + "." +
  191. reinterpret_cast<EnumWrapper*>(self)->to_string())
  192. .release()
  193. .ptr();
  194. }
  195. static PyObject* py_dump(PyObject* self) {
  196. return py::cast(reinterpret_cast<EnumWrapper*>(self)->to_string())
  197. .release()
  198. .ptr();
  199. }
  200. static PyObject* tp_richcompare(PyObject* self, PyObject* other, int op) {
  201. if (op == Py_EQ || op == Py_NE) {
  202. T lhs, rhs;
  203. if (load(other, rhs) && load(self, lhs)) {
  204. RETURN_RICHCOMPARE(lhs, rhs, op);
  205. } else {
  206. RETURN_RICHCOMPARE(0, 1, op);
  207. }
  208. }
  209. Py_RETURN_NOTIMPLEMENTED;
  210. }
  211. static bool load(py::handle src, T& value) {
  212. PyObject* obj = src.ptr();
  213. if (PyObject_TypeCheck(obj, type)) {
  214. value = reinterpret_cast<EnumWrapper*>(obj)->value;
  215. return true;
  216. }
  217. if (py::isinstance<py::str>(src)) {
  218. auto&& iter = mem2value.find(normalize_enum(py::cast<std::string>(src)));
  219. if (iter != mem2value.end()) {
  220. value = iter->second;
  221. return true;
  222. } else {
  223. return false;
  224. }
  225. }
  226. return false;
  227. }
  228. static PyObject* cast(const T& value) {
  229. auto v = static_cast<std::underlying_type_t<T>>(value);
  230. mgb_assert(v <= EnumTrait<T>::max);
  231. PyObject* obj = pyobj_insts[v];
  232. Py_INCREF(obj);
  233. return obj;
  234. }
  235. };
  236. template <typename T>
  237. struct BitCombinedEnumWrapper {
  238. PyEnumHead std::string to_string() const {
  239. uint32_t value_int = static_cast<uint32_t>(value);
  240. if (value_int == 0) {
  241. return "None";
  242. } else {
  243. std::string ret;
  244. bool first = true;
  245. for (uint32_t i = 0; i < 32; i++) {
  246. if (value_int >> i & 1) {
  247. if (!first) {
  248. ret += " + ";
  249. } else {
  250. first = false;
  251. }
  252. ret += (std::string(name) + "." + members[i]);
  253. }
  254. }
  255. return ret;
  256. }
  257. }
  258. static PyObject* py_new_combined_enum(
  259. PyTypeObject* type, PyObject* args, PyObject*) {
  260. if (!PyTuple_Size(args)) {
  261. PyObject* obj = type->tp_alloc(type, 0);
  262. reinterpret_cast<BitCombinedEnumWrapper*>(obj)->value = T();
  263. return obj;
  264. } else {
  265. PyObject* input;
  266. if (!PyArg_ParseTuple(args, "|O", &input)) {
  267. return nullptr;
  268. }
  269. T value;
  270. if (load(input, value)) {
  271. return cast(value);
  272. } else {
  273. PyErr_SetString(
  274. PyExc_RuntimeError,
  275. mgb::ssprintf(
  276. "Cannot convert type %s to type %s\n",
  277. input->ob_type->tp_name, name)
  278. .c_str());
  279. return nullptr;
  280. }
  281. }
  282. }
  283. static PyObject* py_repr(PyObject* self) {
  284. return py::cast(reinterpret_cast<BitCombinedEnumWrapper*>(self)->to_string())
  285. .release()
  286. .ptr();
  287. }
  288. static PyObject* py_dump(PyObject* self) {
  289. std::vector<std::string> result;
  290. auto value = reinterpret_cast<BitCombinedEnumWrapper*>(self)->value;
  291. uint32_t value_int = static_cast<uint32_t>(value);
  292. for (uint32_t i = 0; i < 32; i++) {
  293. if (value_int >> i & 1) {
  294. result.push_back(members[i]);
  295. }
  296. }
  297. return py::tuple(py::cast(result)).release().ptr();
  298. }
  299. static PyObject* py_or(PyObject* self, PyObject* other) {
  300. if (!(self->ob_type == other->ob_type)) {
  301. return PyErr_Format(
  302. PyExc_RuntimeError,
  303. "Operand in or operator must be the same type.");
  304. }
  305. T lhs = reinterpret_cast<BitCombinedEnumWrapper*>(self)->value,
  306. rhs = reinterpret_cast<BitCombinedEnumWrapper*>(other)->value;
  307. return cast(lhs | rhs);
  308. }
  309. static PyObject* py_and(PyObject* self, PyObject* other) {
  310. if (!(self->ob_type == other->ob_type)) {
  311. return PyErr_Format(
  312. PyExc_RuntimeError,
  313. "Operand in and operator must be the same type.");
  314. }
  315. T lhs = reinterpret_cast<BitCombinedEnumWrapper*>(self)->value,
  316. rhs = reinterpret_cast<BitCombinedEnumWrapper*>(other)->value;
  317. return cast(lhs & rhs);
  318. }
  319. static PyObject* tp_richcompare(PyObject* self, PyObject* other, int op) {
  320. if (op == Py_EQ || op == Py_NE) {
  321. T lhs, rhs;
  322. if (load(other, rhs) && load(self, lhs)) {
  323. RETURN_RICHCOMPARE(lhs, rhs, op);
  324. } else {
  325. RETURN_RICHCOMPARE(0, 1, op);
  326. }
  327. }
  328. Py_RETURN_NOTIMPLEMENTED;
  329. }
  330. static bool load(py::handle src, T& value) {
  331. PyObject* obj = src.ptr();
  332. if (PyObject_TypeCheck(obj, type)) {
  333. value = reinterpret_cast<BitCombinedEnumWrapper*>(obj)->value;
  334. return true;
  335. }
  336. if (py::isinstance<py::str>(src)) {
  337. auto&& iter = mem2value.find(normalize_enum(py::cast<std::string>(src)));
  338. if (iter != mem2value.end()) {
  339. value = iter->second;
  340. return true;
  341. } else {
  342. return false;
  343. }
  344. }
  345. if (py::isinstance<py::tuple>(src)) {
  346. auto params = py::cast<std::vector<std::string>>(src);
  347. bool first = true;
  348. for (auto s : params) {
  349. auto&& iter = mem2value.find(normalize_enum(s));
  350. if (iter != mem2value.end()) {
  351. if (first) {
  352. value = iter->second;
  353. first = false;
  354. } else {
  355. value |= iter->second;
  356. }
  357. } else {
  358. return false;
  359. }
  360. }
  361. return true;
  362. }
  363. if (py::isinstance<py::int_>(obj)) {
  364. auto v = py::cast<std::underlying_type_t<T>>(src);
  365. if (v > EnumTrait<T>::max) {
  366. return false;
  367. }
  368. value = static_cast<T>(v);
  369. return true;
  370. }
  371. return false;
  372. }
  373. static PyObject* cast(const T& value) {
  374. auto v = static_cast<std::underlying_type_t<T>>(value);
  375. mgb_assert(v <= EnumTrait<T>::max);
  376. if ((!v) || (v & (v - 1))) {
  377. PyObject* obj = type->tp_alloc(type, 0);
  378. reinterpret_cast<BitCombinedEnumWrapper*>(obj)->value = value;
  379. return obj;
  380. } else {
  381. PyObject* obj = pyobj_insts[__builtin_ctz(v)];
  382. Py_INCREF(obj);
  383. return obj;
  384. }
  385. }
  386. };
  387. template <typename T>
  388. struct serialization<T, std::enable_if_t<std::is_enum_v<std::decay_t<T>>>> {
  389. static T load(py::object obj) {
  390. auto caster = pybind11::detail::type_caster<T>();
  391. if (caster.load(obj, true)) {
  392. return caster;
  393. } else {
  394. PyErr_SetString(PyExc_RuntimeError, "load faild \n");
  395. return caster;
  396. }
  397. }
  398. static py::object dump(T t) { return py::cast(t).attr("dump")(); }
  399. };
  400. void _init_py_op_def(py::module m) {
  401. using py_op = PyOp(OpDef);
  402. auto& py_type = PyOpType(OpDef);
  403. py_type = {PyVarObject_HEAD_INIT(NULL, 0)};
  404. py_type.tp_name = "megengine.core._imperative_rt.OpDef";
  405. py_type.tp_basicsize = sizeof(PyOp(OpDef));
  406. py_type.tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE;
  407. py_type.tp_doc = "OpDef";
  408. py_type.tp_base = &PyBaseObject_Type;
  409. py_type.tp_hash = PyOp(OpDef)::tp_hash;
  410. py_type.tp_richcompare = PyOp(OpDef)::tp_richcompare;
  411. py_type.tp_getset = py_op::py_getsetters;
  412. py_type.tp_repr = py_op::py_repr;
  413. py_type.tp_dealloc = py_dealloc_generic<PyOp(OpDef)>;
  414. mgb_assert(PyType_Ready(&py_type) >= 0);
  415. m.add_object("OpDef", reinterpret_cast<PyObject*>(&py_type));
  416. }
  417. /*********** begin of hand-write opdefs **************/
  418. struct PyOpBase : PyOpDef {
  419. static PyTypeObject py_type;
  420. static PyObject* tp_new(PyTypeObject* type, PyObject*, PyObject*) {
  421. auto* obj = type->tp_alloc(type, 0);
  422. if (obj) {
  423. auto* self = reinterpret_cast<PyOpBase*>(obj);
  424. new (&self->op) decltype(self->op);
  425. }
  426. return obj;
  427. }
  428. };
  429. PyTypeObject PyOpBase::py_type;
  430. void _init_py_op_base(py::module m) {
  431. using py_op = PyOpBase;
  432. auto& py_type = PyOpBase::py_type;
  433. py_type = {PyVarObject_HEAD_INIT(NULL, 0)};
  434. py_type.tp_name = "megengine.core._imperative_rt.ops.PyOpBase";
  435. py_type.tp_basicsize = sizeof(py_op);
  436. py_type.tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE;
  437. py_type.tp_doc = "PyOpBase";
  438. py_type.tp_base = &PyOpType(OpDef);
  439. py_type.tp_dealloc = py_dealloc_generic<py_op>;
  440. py_type.tp_new = py_op::tp_new;
  441. mgb_assert(PyType_Ready(&py_type) >= 0);
  442. m.add_object("PyOpBase", reinterpret_cast<PyObject*>(&py_type));
  443. }
  444. /*********** end of hand-write opdefs **************/
  445. // auto generated opdefs
  446. #include "opdef.cpy.inl"
  447. #undef CATCH_ALL
  448. } // anonymous namespace
  449. namespace PYBIND11_NAMESPACE {
  450. namespace detail {
  451. bool type_caster<OpDef>::load(handle src, bool convert) {
  452. PyObject* obj = src.ptr();
  453. if (!PyObject_TypeCheck(obj, &PyOpType(OpDef))) {
  454. return false;
  455. }
  456. value = reinterpret_cast<PyOp(OpDef)*>(obj)->op;
  457. if (!value) {
  458. // opdef only defined in Python
  459. value = std::make_shared<GenericPyOp>(reinterpret_borrow<object>(src));
  460. }
  461. return true;
  462. }
  463. handle type_caster<OpDef>::cast(const OpDef& op, return_value_policy, handle) {
  464. if (auto* pyop = op.try_cast_final<GenericPyOp>()) {
  465. return object(pyop->obj).release();
  466. }
  467. PyTypeObject* pytype;
  468. auto& c2p = PyOp(OpDef)::ctype2pytype;
  469. auto&& iter = c2p.find(op.dyn_typeinfo());
  470. if (iter != c2p.end()) { // FIXME: should always meet this condition
  471. pytype = iter->second;
  472. } else { // which means unregistered op type, jsut make it as an opaque op type
  473. // currently, only OprAttr goes into this branch
  474. pytype = &PyOpType(OpDef);
  475. }
  476. PyObject* obj = pytype->tp_alloc(pytype, 0);
  477. mgb_assert(PyObject_TypeCheck(obj, &PyOpType(OpDef)));
  478. reinterpret_cast<PyOp(OpDef)*>(obj)->op = const_cast<OpDef&>(op).shared_from_this();
  479. return py::handle(obj);
  480. }
  481. #define ENUM_CASTER_IMPL(T) \
  482. bool type_caster<T>::load(handle src, bool) { \
  483. return EnumWrapper<T>::load(src, value); \
  484. } \
  485. handle type_caster<T>::cast(const T& value, return_value_policy, handle) { \
  486. return EnumWrapper<T>::cast(value); \
  487. }
  488. FOR_EACH_ENUM_PARAM(ENUM_CASTER_IMPL)
  489. #define BIT_COMBINED_ENUM_CASTER_IMPL(T) \
  490. bool type_caster<T>::load(handle src, bool) { \
  491. return BitCombinedEnumWrapper<T>::load(src, value); \
  492. } \
  493. handle type_caster<T>::cast(const T& value, return_value_policy, handle) { \
  494. return BitCombinedEnumWrapper<T>::cast(value); \
  495. }
  496. FOR_EACH_BIT_COMBINED_ENUM_PARAM(BIT_COMBINED_ENUM_CASTER_IMPL)
  497. } // namespace detail
  498. } // namespace PYBIND11_NAMESPACE
  499. void init_ops(py::module m) {
  500. _init_py_op_def(m);
  501. _init_py_op_base(m);
  502. INIT_ALL_OP(m)
  503. m.def("new_rng_handle", &rng::new_handle);
  504. m.def(
  505. "delete_rng_handle",
  506. [](size_t handle) {
  507. if (mgb::imperative::python::interpreter_for_py->check_available()) {
  508. mgb::imperative::python::interpreter_for_py->sync();
  509. }
  510. mgb::CompNode::sync_all();
  511. mgb::CompNode::foreach ([](mgb::CompNode cn) {
  512. auto err = cn.check_async_error();
  513. mgb_assert(!err, "%s", err->what());
  514. });
  515. py_task_q.wait_all_task_finish();
  516. rng::delete_handle(handle);
  517. },
  518. py::call_guard<py::gil_scoped_release>());
  519. m.def("set_global_rng_seed", [](uint64_t seed) -> void {
  520. mgb_assert(
  521. python::interpreter_for_py->check_available(),
  522. "set global random seed failed since imperative interpreter has been "
  523. "destroyed");
  524. python::interpreter_for_py->sync();
  525. mgb::CompNode::sync_all();
  526. rng::set_global_rng_seed(seed);
  527. });
  528. m.def("get_global_rng_seed", &rng::get_global_rng_seed);
  529. m.def("get_rng_handle_compnode", &rng::get_rng_handle_compnode);
  530. struct PySubgraphBuilder {
  531. explicit PySubgraphBuilder(std::string name) : name{name} {}
  532. std::string name;
  533. Subgraph graph;
  534. mgb::SmallVector<bool> output_grad_mask;
  535. Subgraph::var_t next_var = 1;
  536. std::shared_ptr<mgb::Hashable> key = nullptr;
  537. std::shared_ptr<OpDef> build() {
  538. if (key == nullptr) {
  539. key = std::make_shared<UniqueKey>();
  540. }
  541. return SubgraphOp::make(
  542. name, std::make_shared<Subgraph>(graph), output_grad_mask, key);
  543. }
  544. };
  545. py::class_<PySubgraphBuilder>(m, "SubgraphBuilder")
  546. .def(py::init<std::string>())
  547. .def(py::init<PySubgraphBuilder>())
  548. .def("input",
  549. [](PySubgraphBuilder& self) {
  550. mgb_assert(self.key == nullptr);
  551. auto var = self.next_var++;
  552. self.graph.inputs.push_back(var);
  553. return var;
  554. })
  555. .def("apply",
  556. [](PySubgraphBuilder& self, std::shared_ptr<OpDef> op,
  557. Subgraph::vars_t inputs, size_t nr_outputs) {
  558. mgb_assert(self.key == nullptr);
  559. Subgraph::vars_t outputs;
  560. for (size_t i = 0; i < nr_outputs; ++i) {
  561. outputs.push_back(self.next_var++);
  562. }
  563. self.graph.exprs.push_back({op, inputs, outputs});
  564. return outputs;
  565. })
  566. .def("apply_const",
  567. [](PySubgraphBuilder& self, py::object value, mgb::DType dtype,
  568. mgb::CompNode cn) {
  569. mgb_assert(self.key == nullptr);
  570. auto var = self.next_var++;
  571. mgb::HostTensorND hvalue(cn);
  572. npy::np2tensor(
  573. value.cast<py::array>().ptr(),
  574. npy::Meth::copy_into(&hvalue), dtype);
  575. self.graph.constants.push_back({var, Tensor::make(hvalue)});
  576. return var;
  577. })
  578. .def("outputs",
  579. [](PySubgraphBuilder& self, Subgraph::vars_t outputs) {
  580. mgb_assert(self.key == nullptr);
  581. self.graph.outputs = outputs;
  582. self.output_grad_mask.resize(outputs.size(), true);
  583. })
  584. .def("outputs_has_grad",
  585. [](PySubgraphBuilder& self, mgb::SmallVector<bool> outputs_has_grad) {
  586. mgb_assert(self.key == nullptr);
  587. mgb_assert(
  588. self.graph.outputs.size() == self.output_grad_mask.size());
  589. self.output_grad_mask = outputs_has_grad;
  590. })
  591. .def("get",
  592. [](PySubgraphBuilder& self) {
  593. return (std::shared_ptr<OpDef>)self.build();
  594. })
  595. .def("compile",
  596. [](PySubgraphBuilder& self, int gopt_level) {
  597. return (std::shared_ptr<OpDef>)CompiledOp::make(
  598. self.build(), gopt_level);
  599. })
  600. .def("jit_fuse", [](PySubgraphBuilder& self) {
  601. return (std::shared_ptr<OpDef>)CompiledOp::make(
  602. JITFusionOp::make(self.build()));
  603. });
  604. m.def("set_jit_enabled", &JITFusionOp::set_enabled);
  605. bool jit_supported = false;
  606. #if MGB_JIT
  607. jit_supported = true;
  608. #endif
  609. m.attr("jit_supported") = jit_supported;
  610. auto custom = submodule(m, "_custom");
  611. init_custom(custom);
  612. }
  613. #define CUSTOM_CASE_TO_PARSE_NON_LIST(dyn_type, static_type) \
  614. case custom::ParamDynType::dyn_type: { \
  615. param_val = py::handle(kv.second).cast<static_type>(); \
  616. break; \
  617. }
  618. #define CUSTOM_CASE_TO_PARSE_LIST(dyn_type, static_type) \
  619. case custom::ParamDynType::dyn_type: { \
  620. auto pyvals = py::handle(kv.second).cast<py::list>(); \
  621. static_type vals; \
  622. using basic_type = custom::get_vector_template_arg_type<static_type>::type; \
  623. for (auto& pyval : pyvals) { \
  624. vals.push_back(py::handle(pyval).cast<basic_type>()); \
  625. } \
  626. param_val = vals; \
  627. break; \
  628. }
  629. PyObject* make_custom_op(PyObject* self, PyObject** args, Py_ssize_t nargs) {
  630. #if MGB_CUSTOM_OP
  631. auto op_name = py::handle(args[0]).cast<std::string>();
  632. auto kwargs = py::handle(args[1]).cast<py::dict>();
  633. std::shared_ptr<OpDef> opdef = CustomOpDefFactory::inst()->create_opdef(op_name);
  634. auto& custom_opdef = static_cast<mgb::imperative::CustomOpDef&>(*opdef);
  635. auto& param = custom_opdef.param();
  636. for (auto&& kv : kwargs) {
  637. std::string param_name = py::handle(kv.first).cast<std::string>();
  638. std::string type_name = py::handle(kv.second).ptr()->ob_type->tp_name;
  639. if (!param.exist(param_name)) {
  640. mgb_log_warn(
  641. "op %s have no param named %s, ignore this param parsed from "
  642. "python",
  643. op_name.c_str(), param_name.c_str());
  644. continue;
  645. }
  646. auto& param_val = param[param_name];
  647. switch (param_val.type()) {
  648. CUSTOM_FOR_EACH_BASIC_PARAMTYPE(CUSTOM_CASE_TO_PARSE_NON_LIST)
  649. CUSTOM_FOR_STRING_PARAMTYPE(CUSTOM_CASE_TO_PARSE_NON_LIST)
  650. CUSTOM_FOR_EACH_BASIC_LIST_PARAMTYPE(CUSTOM_CASE_TO_PARSE_LIST)
  651. CUSTOM_FOR_BOOL_LIST_PARAMTYPE(CUSTOM_CASE_TO_PARSE_LIST)
  652. CUSTOM_FOR_STRING_LIST_PARAMTYPE(CUSTOM_CASE_TO_PARSE_LIST)
  653. default: {
  654. mgb_assert(
  655. false, "param dtype of %s:%s is invalid", op_name.c_str(),
  656. param_name.c_str());
  657. }
  658. }
  659. }
  660. PyTypeObject* pytype;
  661. pytype = &PyOpType(OpDef);
  662. PyObject* obj = pytype->tp_alloc(pytype, 0);
  663. reinterpret_cast<PyOp(OpDef)*>(obj)->op = opdef;
  664. return obj;
  665. #else
  666. mgb_assert(
  667. false,
  668. "Custom Op is disabled now, please build megengine with Custom Op open");
  669. return nullptr;
  670. #endif
  671. }
  672. #undef CUSTOM_CASE_TO_PARSE_LIST
  673. #undef CUSTOM_CASE_TO_PARSE_NON_LIST
  674. py::list install_custom(const std::string& name, const std::string& path) {
  675. #if MGB_CUSTOM_OP
  676. py::list ret;
  677. const auto& ops_in_lib = custom::LibManager::inst()->install(name, path);
  678. for (const auto& op : ops_in_lib) {
  679. ret.append(op);
  680. }
  681. return ret;
  682. #else
  683. mgb_assert(
  684. false,
  685. "Custom Op is disabled now, please build megengine with Custom Op open");
  686. py::list ret;
  687. return ret;
  688. #endif
  689. }
  690. bool uninstall_custom(const std::string& name) {
  691. #if MGB_CUSTOM_OP
  692. return custom::LibManager::inst()->uninstall(name);
  693. #else
  694. mgb_assert(
  695. false,
  696. "Custom Op is disabled now, please build megengine with Custom Op open");
  697. return false;
  698. #endif
  699. }
  700. py::list get_custom_op_list(void) {
  701. #if MGB_CUSTOM_OP
  702. std::vector<std::string> all_ops = CustomOpDefFactory::inst()->op_list();
  703. py::list ret;
  704. for (auto& op : all_ops) {
  705. ret.append(op);
  706. }
  707. return ret;
  708. #else
  709. mgb_assert(
  710. false,
  711. "Custom Op is disabled now, please build megengine with Custom Op open");
  712. py::list ret;
  713. return ret;
  714. #endif
  715. }
  716. #ifndef METH_FASTCALL
  717. PyObject* py35_make_custom_op(PyObject* self, PyObject* args) {
  718. auto* arr = &PyTuple_GET_ITEM(args, 0);
  719. auto size = PyTuple_GET_SIZE(args);
  720. return make_custom_op(self, arr, size);
  721. };
  722. #endif
  723. void init_custom(pybind11::module m) {
  724. m.def("_install", &install_custom);
  725. m.def("_uninstall", &uninstall_custom);
  726. m.def("_get_custom_op_list", &get_custom_op_list);
  727. m.def("get_custom_op_abi_tag", [](void) -> int {
  728. int ret = 0;
  729. #ifdef _GLIBCXX_USE_CXX11_ABI
  730. ret = _GLIBCXX_USE_CXX11_ABI;
  731. #endif
  732. return ret;
  733. });
  734. static PyMethodDef method_def = {
  735. #ifdef METH_FASTCALL
  736. "_make_custom_op", (PyCFunction)make_custom_op, METH_FASTCALL, ""
  737. #else
  738. "_make_custom_op", (PyCFunction)py35_make_custom_op, METH_VARARGS, ""
  739. #endif
  740. };
  741. auto* func = PyCFunction_NewEx(&method_def, nullptr, nullptr);
  742. pybind11::setattr(m, method_def.ml_name, func);
  743. }