You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

ops.cpp 18 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514
  1. /**
  2. * \file imperative/python/src/ops.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "./ops.h"
  12. #include "./helper.h"
  13. #include "./tensor.h"
  14. #include "megbrain/common.h"
  15. #include "megbrain/imperative.h"
  16. #include "megbrain/imperative/ops/backward_graph.h"
  17. #include "megbrain/imperative/ops/opr_attr.h"
  18. #include "megbrain/imperative/ops/utility.h"
  19. #include "megbrain/imperative/ops/autogen.h"
  20. #include "megbrain/imperative/ops/rng.h"
  21. #include <Python.h>
  22. #include <unordered_map>
  23. namespace py = pybind11;
  24. using namespace mgb::imperative;
  25. namespace {
  26. auto normalize_enum(const std::string& in) {
  27. std::string ret;
  28. for (auto&& c : in) {
  29. ret += toupper(c);
  30. }
  31. return ret;
  32. }
  33. } // anonymous namespace
  34. #define CATCH_ALL(RETVAL) \
  35. catch(py::error_already_set& e) { \
  36. e.restore(); \
  37. return RETVAL; \
  38. } catch(py::builtin_exception& e) { \
  39. e.set_error(); \
  40. return RETVAL; \
  41. } catch(std::exception& e) { \
  42. PyErr_SetString(PyExc_RuntimeError, e.what()); \
  43. return RETVAL; \
  44. } \
  45. namespace {
  46. #define PyOp(name) Py##name
  47. #define PyOpType(name) PyOp(name)::py_type
  48. #define PyOpDefBegin(name) \
  49. struct PyOp(name) : PyOpDef { \
  50. using Ty = name; \
  51. Ty& inst() { return op->cast_final_safe<Ty>(); } \
  52. static PyTypeObject py_type;
  53. #define PyOpDefEnd(name) \
  54. }; \
  55. PyTypeObject PyOpType(name);
  56. #define RETURN_RICHCOMPARE(val1, val2, op) \
  57. do { \
  58. switch (op) { \
  59. case Py_EQ: if ((val1) == (val2)) Py_RETURN_TRUE; Py_RETURN_FALSE; \
  60. case Py_NE: if ((val1) != (val2)) Py_RETURN_TRUE; Py_RETURN_FALSE; \
  61. case Py_LT: if ((val1) < (val2)) Py_RETURN_TRUE; Py_RETURN_FALSE; \
  62. case Py_GT: if ((val1) > (val2)) Py_RETURN_TRUE; Py_RETURN_FALSE; \
  63. case Py_LE: if ((val1) <= (val2)) Py_RETURN_TRUE; Py_RETURN_FALSE; \
  64. case Py_GE: if ((val1) >= (val2)) Py_RETURN_TRUE; Py_RETURN_FALSE; \
  65. default: \
  66. Py_FatalError("Unreachable C code path reached"); \
  67. } \
  68. } while (0)
  69. template <typename T>
  70. PyObject* py_new_generic(PyTypeObject* type, PyObject*, PyObject*) {
  71. PyObject* obj = type->tp_alloc(type, 0);
  72. T* self = reinterpret_cast<T*>(obj);
  73. if (self != NULL) {
  74. self->op = T::Ty::make();
  75. }
  76. return obj;
  77. }
  78. template<typename T>
  79. void py_dealloc_generic(PyObject* obj) {
  80. reinterpret_cast<T*>(obj)->op.reset();
  81. Py_TYPE(obj)->tp_free(obj);
  82. }
  83. template<typename T, typename U, U T::Ty::*attr>
  84. PyObject* py_get_generic_impl(PyObject* obj, void* /* closure */) {
  85. auto& op = reinterpret_cast<T*>(obj)->inst();
  86. return py::cast(op.*attr).release().ptr();
  87. }
  88. #define py_get_generic(name, attr) \
  89. py_get_generic_impl<PyOp(name), decltype(std::declval<name>().attr), &name::attr>
  90. template<typename T, typename U, U T::Ty::*attr>
  91. int py_set_generic_impl(PyObject* obj, PyObject* value, void* /* closure */) {
  92. if (value == NULL) {
  93. PyErr_SetString(PyExc_TypeError, "Cannot delete the attribute");
  94. return -1;
  95. }
  96. auto& op = reinterpret_cast<T*>(obj)->inst();
  97. try {
  98. // TODO: remove this guard which is used for pybind11 implicit conversion
  99. py::detail::loader_life_support guard{};
  100. op.*attr = py::cast<U>(py::handle(value));
  101. } CATCH_ALL(-1)
  102. return 0;
  103. }
  104. #define py_set_generic(name, attr) \
  105. py_set_generic_impl<PyOp(name), decltype(std::declval<name>().attr), &name::attr>
  106. struct PyOpDef {
  107. PyObject_HEAD
  108. std::shared_ptr<OpDef> op;
  109. static PyTypeObject py_type;
  110. static std::unordered_map<mgb::Typeinfo*, PyTypeObject*> ctype2pytype;
  111. static PyGetSetDef py_getsetters[];
  112. static Py_hash_t tp_hash(PyObject *obj);
  113. static PyObject* tp_richcompare(PyObject *self, PyObject *other, int op);
  114. };
  115. PyTypeObject PyOpType(OpDef);
  116. std::unordered_map<mgb::Typeinfo*, PyTypeObject*> PyOp(OpDef)::ctype2pytype;
  117. PyObject* py_get_scope(PyObject* obj, void* /* closure */) {
  118. return py::cast(
  119. reinterpret_cast<PyOp(OpDef)*>(obj)->op->scope()).release().ptr();
  120. }
  121. int py_set_scope(PyObject* obj, PyObject* value, void* /* closure */) {
  122. if (value == NULL) {
  123. PyErr_SetString(PyExc_TypeError, "Cannot delete the attribute");
  124. return -1;
  125. }
  126. try {
  127. reinterpret_cast<PyOp(OpDef)*>(obj)->op
  128. ->set_scope(py::cast<std::string>(py::handle(value)));
  129. } CATCH_ALL(-1)
  130. return 0;
  131. }
  132. PyGetSetDef PyOp(OpDef)::py_getsetters[] = {
  133. {const_cast<char*>("scope"), py_get_scope, py_set_scope, "scope", NULL},
  134. {NULL}
  135. };
  136. Py_hash_t PyOp(OpDef)::tp_hash(PyObject *obj) {
  137. return static_cast<Py_hash_t>(
  138. reinterpret_cast<PyOp(OpDef)*>(obj)->op->hash());
  139. }
  140. PyObject* PyOp(OpDef)::tp_richcompare(PyObject *self, PyObject *other, int op) {
  141. bool same = reinterpret_cast<PyOp(OpDef)*>(self)->op->is_same(
  142. *reinterpret_cast<PyOp(OpDef)*>(other)->op);
  143. if (op == Py_EQ || op == Py_NE) {
  144. RETURN_RICHCOMPARE(same, true, op);
  145. }
  146. Py_RETURN_NOTIMPLEMENTED;
  147. }
  148. template<typename T>
  149. struct EnumTrait;
  150. #define PyEnumHead \
  151. static_assert(std::is_enum_v<T>); \
  152. PyObject_HEAD \
  153. T value; \
  154. constexpr static const char *name = EnumTrait<T>::name; \
  155. static PyTypeObject* type; \
  156. static const char* members[]; \
  157. static std::unordered_map<std::string, T> mem2value; \
  158. static PyObject* pyobj_insts[];
  159. template<typename T>
  160. struct EnumWrapper {
  161. PyEnumHead
  162. std::string to_string() const {
  163. return members[static_cast<size_t>(value)];
  164. }
  165. static PyObject* py_repr(PyObject* self) {
  166. return py::cast(
  167. std::string(name) + "." + reinterpret_cast<EnumWrapper*>(self)->to_string())
  168. .release().ptr();
  169. }
  170. static PyObject* tp_richcompare(PyObject *self, PyObject *other, int op) {
  171. if (op == Py_EQ || op == Py_NE) {
  172. T lhs, rhs;
  173. if (load(other, rhs) && load(self, lhs)) {
  174. RETURN_RICHCOMPARE(lhs, rhs, op);
  175. } else {
  176. RETURN_RICHCOMPARE(0, 1, op);
  177. }
  178. }
  179. Py_RETURN_NOTIMPLEMENTED;
  180. }
  181. static bool load(py::handle src, T& value) {
  182. PyObject* obj = src.ptr();
  183. if (PyObject_TypeCheck(obj, type)) {
  184. value = reinterpret_cast<EnumWrapper*>(obj)->value;
  185. return true;
  186. }
  187. if (py::isinstance<py::str>(src)) {
  188. auto&& iter = mem2value.find(
  189. normalize_enum(py::cast<std::string>(src)));
  190. if (iter != mem2value.end()) {
  191. value = iter->second;
  192. return true;
  193. } else {
  194. return false;
  195. }
  196. }
  197. return false;
  198. }
  199. static PyObject* cast(const T& value) {
  200. auto v = static_cast<std::underlying_type_t<T>>(value);
  201. mgb_assert(v <= EnumTrait<T>::max);
  202. PyObject* obj = pyobj_insts[v];
  203. Py_INCREF(obj);
  204. return obj;
  205. }
  206. };
  207. template<typename T>
  208. struct BitCombinedEnumWrapper {
  209. PyEnumHead
  210. std::string to_string() const {
  211. uint32_t value_int = static_cast<uint32_t>(value);
  212. if (value_int == 0) {
  213. return "None";
  214. } else {
  215. std::string ret;
  216. bool first = true;
  217. for (uint32_t i = 0; i < 32; i++) {
  218. if (value_int >> i & 1) {
  219. if (!first) {
  220. ret += " + ";
  221. } else {
  222. first = false;
  223. }
  224. ret += (std::string(name) + "." + members[i]);
  225. }
  226. }
  227. return ret;
  228. }
  229. }
  230. static PyObject* py_new_combined_enum(PyTypeObject* type, PyObject* args, PyObject*) {
  231. if (!PyTuple_Size(args)) {
  232. PyObject* obj = type->tp_alloc(type, 0);
  233. reinterpret_cast<BitCombinedEnumWrapper*>(obj)->value = T();
  234. return obj;
  235. }
  236. else {
  237. PyObject* input;
  238. if (!PyArg_ParseTuple(args, "|O", &input)) {
  239. return nullptr;
  240. }
  241. T value;
  242. if (load(input, value)) {
  243. return cast(value);
  244. } else {
  245. PyErr_SetString(PyExc_RuntimeError,
  246. mgb::ssprintf("Cannot convert type %s to type %s\n",
  247. input->ob_type->tp_name, name).c_str());
  248. return nullptr;
  249. }
  250. }
  251. }
  252. static PyObject* py_repr(PyObject* self) {
  253. return py::cast(
  254. reinterpret_cast<BitCombinedEnumWrapper*>(self)->to_string())
  255. .release().ptr();
  256. }
  257. static PyObject* py_or(PyObject* self, PyObject* other) {
  258. if(!(self->ob_type == other->ob_type)){
  259. return PyErr_Format(
  260. PyExc_RuntimeError,
  261. "Operand in or operator must be the same type.");
  262. }
  263. T lhs = reinterpret_cast<BitCombinedEnumWrapper*>(self)->value,
  264. rhs = reinterpret_cast<BitCombinedEnumWrapper*>(other)->value;
  265. return cast(lhs | rhs);
  266. }
  267. static PyObject* py_and(PyObject* self, PyObject* other) {
  268. if (!(self->ob_type == other->ob_type)) {
  269. return PyErr_Format(
  270. PyExc_RuntimeError,
  271. "Operand in and operator must be the same type.");
  272. }
  273. T lhs = reinterpret_cast<BitCombinedEnumWrapper*>(self)->value,
  274. rhs = reinterpret_cast<BitCombinedEnumWrapper*>(other)->value;
  275. return cast(lhs & rhs);
  276. }
  277. static PyObject* tp_richcompare(PyObject* self, PyObject* other, int op) {
  278. if (op == Py_EQ || op == Py_NE) {
  279. T lhs, rhs;
  280. if (load(other, rhs) && load(self, lhs)) {
  281. RETURN_RICHCOMPARE(lhs, rhs, op);
  282. } else {
  283. RETURN_RICHCOMPARE(0, 1, op);
  284. }
  285. }
  286. Py_RETURN_NOTIMPLEMENTED;
  287. }
  288. static bool load(py::handle src, T& value) {
  289. PyObject* obj = src.ptr();
  290. if (PyObject_TypeCheck(obj, type)) {
  291. value = reinterpret_cast<BitCombinedEnumWrapper*>(obj)->value;
  292. return true;
  293. }
  294. if (py::isinstance<py::str>(src)) {
  295. auto&& iter = mem2value.find(
  296. normalize_enum(py::cast<std::string>(src)));
  297. if (iter != mem2value.end()) {
  298. value = iter->second;
  299. return true;
  300. } else {
  301. return false;
  302. }
  303. }
  304. if (py::isinstance<py::int_>(obj)) {
  305. auto v = py::cast<std::underlying_type_t<T>>(src);
  306. if(v > EnumTrait<T>::max) {
  307. return false;
  308. }
  309. value = static_cast<T>(v);
  310. return true;
  311. }
  312. return false;
  313. }
  314. static PyObject* cast(const T& value) {
  315. auto v = static_cast<std::underlying_type_t<T>>(value);
  316. mgb_assert(v <= EnumTrait<T>::max);
  317. if ((!v) || (v & (v - 1))) {
  318. PyObject* obj = type->tp_alloc(type, 0);
  319. reinterpret_cast<BitCombinedEnumWrapper*>(obj)->value = value;
  320. return obj;
  321. } else {
  322. PyObject* obj = pyobj_insts[__builtin_ctz(v)];
  323. Py_INCREF(obj);
  324. return obj;
  325. }
  326. }
  327. };
  328. void _init_py_op_def(py::module m) {
  329. using py_op = PyOp(OpDef);
  330. auto& py_type = PyOpType(OpDef);
  331. py_type = {PyVarObject_HEAD_INIT(NULL, 0)};
  332. py_type.tp_name = "megengine.core._imperative_rt.OpDef";
  333. py_type.tp_basicsize = sizeof(PyOp(OpDef));
  334. py_type.tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE;
  335. py_type.tp_doc = "OpDef";
  336. py_type.tp_base = &PyBaseObject_Type;
  337. py_type.tp_hash = PyOp(OpDef)::tp_hash;
  338. py_type.tp_richcompare = PyOp(OpDef)::tp_richcompare;
  339. py_type.tp_getset = py_op::py_getsetters;
  340. mgb_assert(PyType_Ready(&py_type) >= 0);
  341. m.add_object("OpDef", reinterpret_cast<PyObject*>(&py_type));
  342. }
  343. /*********** begin of hand-write opdefs **************/
  344. PyOpDefBegin(BackwardGraph) // {{
  345. // };
  346. PyOpDefEnd(BackwardGraph)
  347. void _init_py_backward_graph(py::module m) {
  348. using py_op = PyOp(BackwardGraph);
  349. auto& py_type = PyOpType(BackwardGraph);
  350. py_type = {PyVarObject_HEAD_INIT(NULL, 0)};
  351. py_type.tp_name = "megengine.core._imperative_rt.ops.BackwardGraph";
  352. py_type.tp_basicsize = sizeof(PyOp(BackwardGraph));
  353. py_type.tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE;
  354. py_type.tp_doc = "BackwardGraph";
  355. py_type.tp_base = &PyOpType(OpDef);
  356. py_type.tp_dealloc = py_dealloc_generic<py_op>;
  357. py_type.tp_new = py_new_generic<py_op>;
  358. mgb_assert(PyType_Ready(&py_type) >= 0);
  359. // FIXME: rewrite interpret function in cpython instead wrap directly by pybind11::cppfunction
  360. auto interpret = py::cpp_function(
  361. [](OpDef& self, py::object pyf, py::object pyc,
  362. const mgb::SmallVector<py::object>& inputs) {
  363. auto f = [pyf](OpDef& op, const mgb::SmallVector<py::object>& inputs) {
  364. return py::cast<mgb::SmallVector<py::object>>(pyf(op.shared_from_this(), inputs));
  365. };
  366. auto c = [pyc](const TensorPtr& tensor) {
  367. return pyc(tensor->dev_tensor());
  368. };
  369. return self.cast_final_safe<BackwardGraph>().graph().interpret<py::object>(f, c, inputs);
  370. });
  371. mgb_assert(PyDict_SetItemString(
  372. py_type.tp_dict, "interpret", interpret.release().ptr()) >= 0);
  373. PyType_Modified(&py_type);
  374. m.add_object("BackwardGraph", reinterpret_cast<PyObject*>(&py_type));
  375. mgb_assert(PyOp(OpDef)::ctype2pytype.emplace(BackwardGraph::typeinfo(), &py_type).second);
  376. }
  377. struct PyOpBase : PyOpDef {
  378. static PyTypeObject py_type;
  379. static PyObject* tp_new(PyTypeObject* type, PyObject*, PyObject*) {
  380. auto* obj = type->tp_alloc(type, 0);
  381. if (obj) {
  382. auto* self = reinterpret_cast<PyOpBase*>(obj);
  383. new(&self->op) decltype(self->op);
  384. }
  385. return obj;
  386. }
  387. };
  388. PyTypeObject PyOpBase::py_type;
  389. void _init_py_op_base(py::module m) {
  390. using py_op = PyOpBase;
  391. auto& py_type = PyOpBase::py_type;
  392. py_type = {PyVarObject_HEAD_INIT(NULL, 0)};
  393. py_type.tp_name = "megengine.core._imperative_rt.ops.PyOpBase";
  394. py_type.tp_basicsize = sizeof(py_op);
  395. py_type.tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE;
  396. py_type.tp_doc = "PyOpBase";
  397. py_type.tp_base = &PyOpType(OpDef);
  398. py_type.tp_dealloc = py_dealloc_generic<py_op>;
  399. py_type.tp_new = py_op::tp_new;
  400. mgb_assert(PyType_Ready(&py_type) >= 0);
  401. m.add_object("PyOpBase", reinterpret_cast<PyObject*>(&py_type));
  402. }
  403. /*********** end of hand-write opdefs **************/
  404. // auto generated opdefs
  405. #include "opdef.cpy.inl"
  406. #undef CATCH_ALL
  407. } // anonymous namespace
  408. namespace PYBIND11_NAMESPACE {
  409. namespace detail {
  410. bool type_caster<OpDef>::load(handle src, bool convert) {
  411. PyObject* obj = src.ptr();
  412. if (!PyObject_TypeCheck(obj, &PyOpType(OpDef))) {
  413. return false;
  414. }
  415. value = reinterpret_cast<PyOp(OpDef)*>(obj)->op;
  416. if (!value) {
  417. // opdef only defined in Python
  418. value = std::make_shared<GenericPyOp>(reinterpret_borrow<object>(src));
  419. }
  420. return true;
  421. }
  422. handle type_caster<OpDef>::cast(const OpDef& op, return_value_policy, handle) {
  423. if (auto* pyop = op.try_cast_final<GenericPyOp>()) {
  424. return object(pyop->obj).release();
  425. }
  426. PyTypeObject* pytype;
  427. auto& c2p = PyOp(OpDef)::ctype2pytype;
  428. auto&& iter = c2p.find(op.dyn_typeinfo());
  429. if (iter != c2p.end()) { // FIXME: should always meet this condition
  430. pytype = iter->second;
  431. } else { // which means unregistered op type, jsut make it as an opaque op type
  432. // currently, only OprAttr goes into this branch
  433. pytype = &PyOpType(OpDef);
  434. }
  435. PyObject* obj = pytype->tp_alloc(pytype, 0);
  436. mgb_assert(PyObject_TypeCheck(obj, &PyOpType(OpDef)));
  437. reinterpret_cast<PyOp(OpDef)*>(obj)->op = const_cast<OpDef&>(op).shared_from_this();
  438. return py::handle(obj);
  439. }
  440. #define ENUM_CASTER_IMPL(T) \
  441. bool type_caster<T>::load(handle src, bool) { \
  442. return EnumWrapper<T>::load(src, value); \
  443. } \
  444. handle type_caster<T>::cast(const T& value, return_value_policy, handle) { \
  445. return EnumWrapper<T>::cast(value); \
  446. }
  447. FOR_EACH_ENUM_PARAM(ENUM_CASTER_IMPL)
  448. #define BIT_COMBINED_ENUM_CASTER_IMPL(T) \
  449. bool type_caster<T>::load(handle src, bool) { \
  450. return BitCombinedEnumWrapper<T>::load(src, value); \
  451. } \
  452. handle type_caster<T>::cast(const T& value, return_value_policy, handle) { \
  453. return BitCombinedEnumWrapper<T>::cast(value); \
  454. }
  455. FOR_EACH_BIT_COMBINED_ENUM_PARAM(BIT_COMBINED_ENUM_CASTER_IMPL)
  456. } // detail
  457. } // PYBIND11_NAMESPACE
  458. void init_ops(py::module m) {
  459. _init_py_op_def(m);
  460. _init_py_backward_graph(m);
  461. _init_py_op_base(m);
  462. INIT_ALL_OP(m)
  463. m.def("new_rng_handle", &rng::new_handle);
  464. m.def("delete_rng_handle", [](size_t handle){
  465. // RNG op might execute after handle released due to async dispatch, so
  466. // we need sync before delete a handle to avoid memory leak or use-after-free
  467. python::interpreter_for_py->sync();
  468. mgb::CompNode::sync_all();
  469. py_task_q.wait_all_task_finish();
  470. rng::delete_handle(handle);
  471. }, py::call_guard<py::gil_scoped_release>());
  472. m.def("set_global_rng_seed", &rng::set_global_rng_seed);
  473. m.def("get_global_rng_seed", &rng::get_global_rng_seed);
  474. }

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台