You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

ops.cpp 18 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489
  1. /**
  2. * \file imperative/python/src/ops.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "./ops.h"
  12. #include "megbrain/imperative.h"
  13. #include "megbrain/imperative/ops/backward_graph.h"
  14. #include "megbrain/imperative/ops/opr_attr.h"
  15. #include "megbrain/imperative/ops/utility.h"
  16. #include "megbrain/imperative/ops/autogen.h"
  17. #include <Python.h>
  18. #include <unordered_map>
  19. namespace py = pybind11;
  20. using namespace mgb::imperative;
  21. namespace {
  22. auto normalize_enum(const std::string& in) {
  23. std::string ret;
  24. for (auto&& c : in) {
  25. ret += toupper(c);
  26. }
  27. return ret;
  28. }
  29. } // anonymous namespace
  30. #define CATCH_ALL(RETVAL) \
  31. catch(py::error_already_set& e) { \
  32. e.restore(); \
  33. return RETVAL; \
  34. } catch(py::builtin_exception& e) { \
  35. e.set_error(); \
  36. return RETVAL; \
  37. } catch(std::exception& e) { \
  38. PyErr_SetString(PyExc_RuntimeError, e.what()); \
  39. return RETVAL; \
  40. } \
  41. namespace {
  42. #define PyOp(name) Py##name
  43. #define PyOpType(name) PyOp(name)::py_type
  44. #define PyOpDefBegin(name) \
  45. struct PyOp(name) : PyOpDef { \
  46. using Ty = name; \
  47. Ty& inst() { return op->cast_final_safe<Ty>(); } \
  48. static PyTypeObject py_type;
  49. #define PyOpDefEnd(name) \
  50. }; \
  51. PyTypeObject PyOpType(name);
  52. #define RETURN_RICHCOMPARE(val1, val2, op) \
  53. do { \
  54. switch (op) { \
  55. case Py_EQ: if ((val1) == (val2)) Py_RETURN_TRUE; Py_RETURN_FALSE; \
  56. case Py_NE: if ((val1) != (val2)) Py_RETURN_TRUE; Py_RETURN_FALSE; \
  57. case Py_LT: if ((val1) < (val2)) Py_RETURN_TRUE; Py_RETURN_FALSE; \
  58. case Py_GT: if ((val1) > (val2)) Py_RETURN_TRUE; Py_RETURN_FALSE; \
  59. case Py_LE: if ((val1) <= (val2)) Py_RETURN_TRUE; Py_RETURN_FALSE; \
  60. case Py_GE: if ((val1) >= (val2)) Py_RETURN_TRUE; Py_RETURN_FALSE; \
  61. default: \
  62. Py_FatalError("Unreachable C code path reached"); \
  63. } \
  64. } while (0)
  65. template <typename T, typename SFINAE = void>
  66. struct pyobj_convert_generic {
  67. static T from(PyObject* obj) {
  68. // TODO: remove this guard which is used for pybind11 implicit conversion
  69. py::detail::loader_life_support guard{};
  70. return py::cast<T>(py::handle(obj));
  71. }
  72. template<typename U,
  73. typename = std::enable_if_t<std::is_same_v<T, std::decay_t<U>>>>
  74. static PyObject* to(U&& t) {
  75. return py::cast(std::forward<U>(t)).release().ptr();
  76. }
  77. };
  78. template<typename T, typename SFINAE=void>
  79. struct EnumTrait;
  80. template <typename T>
  81. struct EnumTrait<T, std::enable_if_t<std::is_enum_v<T>>> {
  82. static constexpr bool is_bit_combined = false;
  83. static constexpr std::underlying_type_t<T> max = 0;
  84. };
  85. template <typename T>
  86. PyObject* py_new_generic(PyTypeObject* type, PyObject*, PyObject*) {
  87. PyObject* obj = type->tp_alloc(type, 0);
  88. T* self = reinterpret_cast<T*>(obj);
  89. if (self != NULL) {
  90. self->op = T::Ty::make();
  91. }
  92. return obj;
  93. }
  94. template<typename T>
  95. void py_dealloc_generic(PyObject* obj) {
  96. reinterpret_cast<T*>(obj)->op.reset();
  97. Py_TYPE(obj)->tp_free(obj);
  98. }
  99. template<typename T, typename U, U T::Ty::*attr>
  100. PyObject* py_get_generic_impl(PyObject* obj, void* /* closure */) {
  101. auto& op = reinterpret_cast<T*>(obj)->inst();
  102. return pyobj_convert_generic<U>::to(op.*attr);
  103. }
  104. #define py_get_generic(name, attr) \
  105. py_get_generic_impl<PyOp(name), decltype(std::declval<name>().attr), &name::attr>
  106. template<typename T, typename U, U T::Ty::*attr>
  107. int py_set_generic_impl(PyObject* obj, PyObject* value, void* /* closure */) {
  108. if (value == NULL) {
  109. PyErr_SetString(PyExc_TypeError, "Cannot delete the attribute");
  110. return -1;
  111. }
  112. auto& op = reinterpret_cast<T*>(obj)->inst();
  113. try {
  114. op.*attr = pyobj_convert_generic<U>::from(value);
  115. } CATCH_ALL(-1)
  116. return 0;
  117. }
  118. #define py_set_generic(name, attr) \
  119. py_set_generic_impl<PyOp(name), decltype(std::declval<name>().attr), &name::attr>
  120. struct PyOpDef {
  121. PyObject_HEAD
  122. std::shared_ptr<OpDef> op;
  123. static PyTypeObject py_type;
  124. static std::unordered_map<mgb::Typeinfo*, PyTypeObject*> ctype2pytype;
  125. static PyGetSetDef py_getsetters[];
  126. static Py_hash_t tp_hash(PyObject *obj);
  127. static PyObject* tp_richcompare(PyObject *self, PyObject *other, int op);
  128. };
  129. PyTypeObject PyOpType(OpDef);
  130. std::unordered_map<mgb::Typeinfo*, PyTypeObject*> PyOp(OpDef)::ctype2pytype;
  131. PyObject* py_get_scope(PyObject* obj, void* /* closure */) {
  132. return pyobj_convert_generic<std::string>::to(
  133. reinterpret_cast<PyOp(OpDef)*>(obj)->op->scope());
  134. }
  135. int py_set_scope(PyObject* obj, PyObject* value, void* /* closure */) {
  136. if (value == NULL) {
  137. PyErr_SetString(PyExc_TypeError, "Cannot delete the attribute");
  138. return -1;
  139. }
  140. try {
  141. reinterpret_cast<PyOp(OpDef)*>(obj)->op
  142. ->set_scope(pyobj_convert_generic<std::string>::from(value));
  143. } CATCH_ALL(-1)
  144. return 0;
  145. }
  146. PyGetSetDef PyOp(OpDef)::py_getsetters[] = {
  147. {const_cast<char*>("scope"), py_get_scope, py_set_scope, "scope", NULL},
  148. {NULL}
  149. };
  150. Py_hash_t PyOp(OpDef)::tp_hash(PyObject *obj) {
  151. return static_cast<Py_hash_t>(
  152. reinterpret_cast<PyOp(OpDef)*>(obj)->op->hash());
  153. }
  154. PyObject* PyOp(OpDef)::tp_richcompare(PyObject *self, PyObject *other, int op) {
  155. bool same = reinterpret_cast<PyOp(OpDef)*>(self)->op->is_same(
  156. *reinterpret_cast<PyOp(OpDef)*>(other)->op);
  157. if (op == Py_EQ || op == Py_NE) {
  158. RETURN_RICHCOMPARE(same, true, op);
  159. }
  160. Py_RETURN_NOTIMPLEMENTED;
  161. }
  162. template<typename T>
  163. struct EnumWrapper {
  164. static_assert(std::is_enum_v<T>);
  165. PyObject_HEAD
  166. T value;
  167. static const char* name;
  168. static PyTypeObject type;
  169. static std::unordered_map<T, std::string> type2str;
  170. static std::unordered_map<std::string, T> str2type;
  171. EnumWrapper() = default;
  172. EnumWrapper(T v): value(v) {}
  173. EnumWrapper(std::string&& str): EnumWrapper(str2type.at(normalize_enum(str))) {}
  174. std::string to_string() const {
  175. return type2str.at(value);
  176. }
  177. static PyObject* py_repr(PyObject* self) {
  178. return pyobj_convert_generic<std::string>::to(
  179. std::string(name) + "." + reinterpret_cast<EnumWrapper*>(self)->to_string());
  180. }
  181. static PyObject* tp_richcompare(PyObject *self, PyObject *other, int op) {
  182. T lhs = reinterpret_cast<EnumWrapper*>(self)->value,
  183. rhs = reinterpret_cast<EnumWrapper*>(other)->value;
  184. if (op == Py_EQ || op == Py_NE) {
  185. RETURN_RICHCOMPARE(lhs, rhs, op);
  186. }
  187. Py_RETURN_NOTIMPLEMENTED;
  188. }
  189. };
  190. template <typename T>
  191. struct pyobj_convert_generic<T,
  192. std::enable_if_t<std::is_enum_v<std::decay_t<T>> &&
  193. !EnumTrait<T>::is_bit_combined>> {
  194. using Wrapper = EnumWrapper<T>;
  195. static T from(PyObject* obj) {
  196. if (PyObject_TypeCheck(obj, &Wrapper::type)) {
  197. return reinterpret_cast<Wrapper*>(obj)->value;
  198. }
  199. // try as string
  200. // TODO: type checkcd
  201. return Wrapper(pyobj_convert_generic<std::string>::from(obj)).value;
  202. }
  203. static PyObject* to(T t) {
  204. PyTypeObject* pytype = &Wrapper::type;
  205. PyObject* obj = pytype->tp_alloc(pytype, 0);
  206. reinterpret_cast<Wrapper*>(obj)->value = t;
  207. return obj;
  208. }
  209. };
  210. template<typename T>
  211. struct BitCombinedEnumWrapper {
  212. static_assert(std::is_enum_v<T>);
  213. PyObject_HEAD
  214. T value;
  215. static const char* name;
  216. static PyTypeObject type;
  217. static std::unordered_map<T, std::string> type2str;
  218. static std::unordered_map<std::string, T> str2type;
  219. static PyNumberMethods number_methods;
  220. BitCombinedEnumWrapper() = default;
  221. BitCombinedEnumWrapper(T v): value(v) {}
  222. BitCombinedEnumWrapper(std::string&& str)
  223. : BitCombinedEnumWrapper(str2type.at(normalize_enum(str))) {}
  224. std::string to_string() const {
  225. if (static_cast<uint32_t>(value) == 0) {
  226. return "None";
  227. } else {
  228. auto ret = std::string();
  229. bool first = true;
  230. for (uint32_t i = 0; i < 32; i++) {
  231. uint32_t value_int = static_cast<uint32_t>(value);
  232. auto it = type2str.find(static_cast<T>((1 << i) & value_int));
  233. if (it != type2str.end()) {
  234. if (!first) {
  235. ret += " + ";
  236. } else {
  237. first = false;
  238. }
  239. ret += (std::string(name) + "." + it->second);
  240. }
  241. }
  242. return ret;
  243. }
  244. }
  245. static PyObject* py_new_combined_enum(PyTypeObject* type, PyObject* args, PyObject*) {
  246. if (!PyTuple_Size(args)) {
  247. PyObject* obj = type->tp_alloc(type, 0);
  248. reinterpret_cast<BitCombinedEnumWrapper*>(obj)->value = T();
  249. return obj;
  250. }
  251. else {
  252. PyObject* input;
  253. if (!PyArg_ParseTuple(args, "|O", &input)) {
  254. return nullptr;
  255. }
  256. T value;
  257. try {
  258. value = pyobj_convert_generic<T>::from(input);
  259. } CATCH_ALL(nullptr);
  260. PyObject* obj = type->tp_alloc(type, 0);
  261. reinterpret_cast<BitCombinedEnumWrapper*>(obj)->value = value;
  262. return obj;
  263. }
  264. }
  265. static PyObject* py_repr(PyObject* self) {
  266. return pyobj_convert_generic<std::string>::to(
  267. reinterpret_cast<BitCombinedEnumWrapper*>(self)->to_string());
  268. }
  269. static PyObject* py_or(PyObject* self, PyObject* other) {
  270. if(!(self->ob_type == other->ob_type)){
  271. return PyErr_Format(
  272. PyExc_RuntimeError,
  273. "Operand in or operator must be the same type.");
  274. }
  275. PyObject* obj = type.tp_alloc(&type, 0);
  276. T lhs = reinterpret_cast<BitCombinedEnumWrapper*>(self)->value,
  277. rhs = reinterpret_cast<BitCombinedEnumWrapper*>(other)->value;
  278. reinterpret_cast<BitCombinedEnumWrapper*>(obj)->value = static_cast<T>(
  279. static_cast<uint32_t>(lhs) | static_cast<uint32_t>(rhs));
  280. return obj;
  281. }
  282. static PyObject* py_and(PyObject* self, PyObject* other) {
  283. if (!(self->ob_type == other->ob_type)) {
  284. return PyErr_Format(
  285. PyExc_RuntimeError,
  286. "Operand in and operator must be the same type.");
  287. }
  288. PyObject* obj = type.tp_alloc(&type, 0);
  289. T lhs = reinterpret_cast<BitCombinedEnumWrapper*>(self)->value,
  290. rhs = reinterpret_cast<BitCombinedEnumWrapper*>(other)->value;
  291. reinterpret_cast<BitCombinedEnumWrapper*>(obj)->value = static_cast<T>(
  292. static_cast<uint32_t>(lhs) & static_cast<uint32_t>(rhs));
  293. return obj;
  294. }
  295. static PyObject* tp_richcompare(PyObject* self, PyObject* other, int op) {
  296. T lhs = reinterpret_cast<BitCombinedEnumWrapper*>(self)->value,
  297. rhs = reinterpret_cast<BitCombinedEnumWrapper*>(other)->value;
  298. if (op == Py_EQ || op == Py_NE) {
  299. RETURN_RICHCOMPARE(lhs, rhs, op);
  300. }
  301. Py_RETURN_NOTIMPLEMENTED;
  302. }
  303. };
  304. template <typename T>
  305. struct pyobj_convert_generic<T,
  306. std::enable_if_t<std::is_enum_v<std::decay_t<T>> &&
  307. EnumTrait<T>::is_bit_combined>> {
  308. using Wrapper = BitCombinedEnumWrapper<T>;
  309. static T from(PyObject* obj) {
  310. if (PyObject_TypeCheck(obj, &Wrapper::type)) {
  311. return reinterpret_cast<Wrapper*>(obj)->value;
  312. } else if(PyLong_Check(obj)) {
  313. auto value = pyobj_convert_generic<std::underlying_type_t<T>>::from(obj);
  314. mgb_throw_if(value > EnumTrait<T>::max, mgb::MegBrainError,
  315. "out of range, cannot convert %zu to %s",
  316. static_cast<uint32_t>(value), Wrapper::name);
  317. return static_cast<T>(value);
  318. }
  319. // try as string
  320. // TODO: type checkcd
  321. return Wrapper(pyobj_convert_generic<std::string>::from(obj)).value;
  322. }
  323. static PyObject* to(T t) {
  324. PyTypeObject* pytype = &Wrapper::type;
  325. PyObject* obj = pytype->tp_alloc(pytype, 0);
  326. reinterpret_cast<Wrapper*>(obj)->value = t;
  327. return obj;
  328. }
  329. };
  330. void _init_py_op_def(py::module m) {
  331. using py_op = PyOp(OpDef);
  332. auto& py_type = PyOpType(OpDef);
  333. py_type = {PyVarObject_HEAD_INIT(NULL, 0)};
  334. py_type.tp_name = "megengine.core._imperative_rt.OpDef";
  335. py_type.tp_basicsize = sizeof(PyOp(OpDef));
  336. py_type.tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE;
  337. py_type.tp_doc = "OpDef";
  338. py_type.tp_base = &PyBaseObject_Type;
  339. py_type.tp_hash = PyOp(OpDef)::tp_hash;
  340. py_type.tp_richcompare = PyOp(OpDef)::tp_richcompare;
  341. py_type.tp_getset = py_op::py_getsetters;
  342. mgb_assert(PyType_Ready(&py_type) >= 0);
  343. m.add_object("OpDef", reinterpret_cast<PyObject*>(&py_type));
  344. }
  345. /*********** begin of hand-write opdefs **************/
  346. PyOpDefBegin(BackwardGraph) // {{
  347. // };
  348. PyOpDefEnd(BackwardGraph)
  349. void _init_py_backward_graph(py::module m) {
  350. using py_op = PyOp(BackwardGraph);
  351. auto& py_type = PyOpType(BackwardGraph);
  352. py_type = {PyVarObject_HEAD_INIT(NULL, 0)};
  353. py_type.tp_name = "megengine.core._imperative_rt.ops.BackwardGraph";
  354. py_type.tp_basicsize = sizeof(PyOp(BackwardGraph));
  355. py_type.tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE;
  356. py_type.tp_doc = "BackwardGraph";
  357. py_type.tp_base = &PyOpType(OpDef);
  358. py_type.tp_dealloc = py_dealloc_generic<py_op>;
  359. py_type.tp_new = py_new_generic<py_op>;
  360. mgb_assert(PyType_Ready(&py_type) >= 0);
  361. // FIXME: rewrite interpret function in cpython instead wrap directly by pybind11::cppfunction
  362. auto interpret = py::cpp_function(
  363. [](OpDef& self, py::object pyf, py::object pyc,
  364. const mgb::SmallVector<py::object>& inputs) {
  365. auto f = [pyf](OpDef& op, const mgb::SmallVector<py::object>& inputs) {
  366. return py::cast<mgb::SmallVector<py::object>>(pyf(op.shared_from_this(), inputs));
  367. };
  368. auto c = [pyc](const TensorPtr& tensor) {
  369. return pyc(tensor->dev_tensor());
  370. };
  371. return self.cast_final_safe<BackwardGraph>().graph().interpret<py::object>(f, c, inputs);
  372. });
  373. mgb_assert(PyDict_SetItemString(
  374. py_type.tp_dict, "interpret", interpret.release().ptr()) >= 0);
  375. PyType_Modified(&py_type);
  376. m.add_object("BackwardGraph", reinterpret_cast<PyObject*>(&py_type));
  377. mgb_assert(PyOp(OpDef)::ctype2pytype.emplace(BackwardGraph::typeinfo(), &py_type).second);
  378. }
  379. struct PyOpBase : PyOpDef {
  380. static PyTypeObject py_type;
  381. static PyObject* tp_new(PyTypeObject* type, PyObject*, PyObject*) {
  382. auto* obj = type->tp_alloc(type, 0);
  383. if (obj) {
  384. auto* self = reinterpret_cast<PyOpBase*>(obj);
  385. new(&self->op) decltype(self->op);
  386. }
  387. return obj;
  388. }
  389. };
  390. PyTypeObject PyOpBase::py_type;
  391. void _init_py_op_base(py::module m) {
  392. using py_op = PyOpBase;
  393. auto& py_type = PyOpBase::py_type;
  394. py_type = {PyVarObject_HEAD_INIT(NULL, 0)};
  395. py_type.tp_name = "megengine.core._imperative_rt.ops.PyOpBase";
  396. py_type.tp_basicsize = sizeof(py_op);
  397. py_type.tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE;
  398. py_type.tp_doc = "PyOpBase";
  399. py_type.tp_base = &PyOpType(OpDef);
  400. py_type.tp_dealloc = py_dealloc_generic<py_op>;
  401. py_type.tp_new = py_op::tp_new;
  402. mgb_assert(PyType_Ready(&py_type) >= 0);
  403. m.add_object("PyOpBase", reinterpret_cast<PyObject*>(&py_type));
  404. }
  405. /*********** end of hand-write opdefs **************/
  406. // auto generated opdefs
  407. #include "opdef.cpy.inl"
  408. #undef CATCH_ALL
  409. } // anonymous namespace
  410. namespace PYBIND11_NAMESPACE {
  411. namespace detail {
  412. bool type_caster<OpDef>::load(handle src, bool convert) {
  413. PyObject* obj = src.ptr();
  414. if (!PyObject_TypeCheck(obj, &PyOpType(OpDef))) {
  415. return false;
  416. }
  417. value = reinterpret_cast<PyOp(OpDef)*>(obj)->op;
  418. if (!value) {
  419. // opdef only defined in Python
  420. value = std::make_shared<GenericPyOp>(reinterpret_borrow<object>(src));
  421. }
  422. return true;
  423. }
  424. handle type_caster<OpDef>::cast(const OpDef& op, return_value_policy, handle) {
  425. if (auto* pyop = op.try_cast_final<GenericPyOp>()) {
  426. return object(pyop->obj).release();
  427. }
  428. PyTypeObject* pytype;
  429. auto& c2p = PyOp(OpDef)::ctype2pytype;
  430. auto&& iter = c2p.find(op.dyn_typeinfo());
  431. if (iter != c2p.end()) { // FIXME: should always meet this condition
  432. pytype = iter->second;
  433. } else { // which means unregistered op type, jsut make it as an opaque op type
  434. // currently, only OprAttr goes into this branch
  435. pytype = &PyOpType(OpDef);
  436. }
  437. PyObject* obj = pytype->tp_alloc(pytype, 0);
  438. mgb_assert(PyObject_TypeCheck(obj, &PyOpType(OpDef)));
  439. reinterpret_cast<PyOp(OpDef)*>(obj)->op = const_cast<OpDef&>(op).shared_from_this();
  440. return py::handle(obj);
  441. }
  442. } // detail
  443. } // PYBIND11_NAMESPACE
  444. void init_ops(py::module m) {
  445. _init_py_op_def(m);
  446. _init_py_backward_graph(m);
  447. _init_py_op_base(m);
  448. INIT_ALL_OP(m)
  449. }

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台