You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

ops.cpp 19 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532
  1. /**
  2. * \file imperative/python/src/ops.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "./ops.h"
  12. #include "./helper.h"
  13. #include "./tensor.h"
  14. #include "megbrain/common.h"
  15. #include "megbrain/imperative.h"
  16. #include "megbrain/imperative/graph_builder.h"
  17. #include "megbrain/imperative/ops/backward_graph.h"
  18. #include "megbrain/imperative/ops/opr_attr.h"
  19. #include "megbrain/imperative/ops/utility.h"
  20. #include "megbrain/imperative/ops/autogen.h"
  21. #include "megbrain/imperative/ops/rng.h"
  22. #include <Python.h>
  23. #include <unordered_map>
  24. namespace py = pybind11;
  25. using namespace mgb::imperative;
  26. namespace {
  27. auto normalize_enum(const std::string& in) {
  28. std::string ret;
  29. for (auto&& c : in) {
  30. ret += toupper(c);
  31. }
  32. return ret;
  33. }
  34. } // anonymous namespace
  35. #define CATCH_ALL(RETVAL) \
  36. catch(py::error_already_set& e) { \
  37. e.restore(); \
  38. return RETVAL; \
  39. } catch(py::builtin_exception& e) { \
  40. e.set_error(); \
  41. return RETVAL; \
  42. } catch(std::exception& e) { \
  43. PyErr_SetString(PyExc_RuntimeError, e.what()); \
  44. return RETVAL; \
  45. } \
  46. namespace {
  47. #define PyOp(name) Py##name
  48. #define PyOpType(name) PyOp(name)::py_type
  49. #define PyOpDefBegin(name) \
  50. struct PyOp(name) : PyOpDef { \
  51. using Ty = name; \
  52. Ty& inst() { return op->cast_final_safe<Ty>(); } \
  53. static PyTypeObject py_type;
  54. #define PyOpDefEnd(name) \
  55. }; \
  56. PyTypeObject PyOpType(name);
  57. #define RETURN_RICHCOMPARE(val1, val2, op) \
  58. do { \
  59. switch (op) { \
  60. case Py_EQ: if ((val1) == (val2)) Py_RETURN_TRUE; Py_RETURN_FALSE; \
  61. case Py_NE: if ((val1) != (val2)) Py_RETURN_TRUE; Py_RETURN_FALSE; \
  62. case Py_LT: if ((val1) < (val2)) Py_RETURN_TRUE; Py_RETURN_FALSE; \
  63. case Py_GT: if ((val1) > (val2)) Py_RETURN_TRUE; Py_RETURN_FALSE; \
  64. case Py_LE: if ((val1) <= (val2)) Py_RETURN_TRUE; Py_RETURN_FALSE; \
  65. case Py_GE: if ((val1) >= (val2)) Py_RETURN_TRUE; Py_RETURN_FALSE; \
  66. default: \
  67. Py_FatalError("Unreachable C code path reached"); \
  68. } \
  69. } while (0)
  70. template <typename T>
  71. PyObject* py_new_generic(PyTypeObject* type, PyObject*, PyObject*) {
  72. PyObject* obj = type->tp_alloc(type, 0);
  73. T* self = reinterpret_cast<T*>(obj);
  74. if (self != NULL) {
  75. self->op = T::Ty::make();
  76. }
  77. return obj;
  78. }
  79. template<typename T>
  80. void py_dealloc_generic(PyObject* obj) {
  81. reinterpret_cast<T*>(obj)->op.reset();
  82. Py_TYPE(obj)->tp_free(obj);
  83. }
  84. template<typename T, typename U, U T::Ty::*attr>
  85. PyObject* py_get_generic_impl(PyObject* obj, void* /* closure */) {
  86. auto& op = reinterpret_cast<T*>(obj)->inst();
  87. return py::cast(op.*attr).release().ptr();
  88. }
  89. #define py_get_generic(name, attr) \
  90. py_get_generic_impl<PyOp(name), decltype(std::declval<name>().attr), &name::attr>
  91. template<typename T, typename U, U T::Ty::*attr>
  92. int py_set_generic_impl(PyObject* obj, PyObject* value, void* /* closure */) {
  93. if (value == NULL) {
  94. PyErr_SetString(PyExc_TypeError, "Cannot delete the attribute");
  95. return -1;
  96. }
  97. auto& op = reinterpret_cast<T*>(obj)->inst();
  98. try {
  99. // TODO: remove this guard which is used for pybind11 implicit conversion
  100. py::detail::loader_life_support guard{};
  101. op.*attr = py::cast<U>(py::handle(value));
  102. } CATCH_ALL(-1)
  103. return 0;
  104. }
  105. #define py_set_generic(name, attr) \
  106. py_set_generic_impl<PyOp(name), decltype(std::declval<name>().attr), &name::attr>
  107. struct PyOpDef {
  108. PyObject_HEAD
  109. std::shared_ptr<OpDef> op;
  110. static PyTypeObject py_type;
  111. static std::unordered_map<mgb::Typeinfo*, PyTypeObject*> ctype2pytype;
  112. static PyGetSetDef py_getsetters[];
  113. static Py_hash_t tp_hash(PyObject *obj);
  114. static PyObject* tp_richcompare(PyObject *self, PyObject *other, int op);
  115. };
  116. PyTypeObject PyOpType(OpDef);
  117. std::unordered_map<mgb::Typeinfo*, PyTypeObject*> PyOp(OpDef)::ctype2pytype;
  118. PyObject* py_get_scope(PyObject* obj, void* /* closure */) {
  119. return py::cast(
  120. reinterpret_cast<PyOp(OpDef)*>(obj)->op->scope()).release().ptr();
  121. }
  122. int py_set_scope(PyObject* obj, PyObject* value, void* /* closure */) {
  123. if (value == NULL) {
  124. PyErr_SetString(PyExc_TypeError, "Cannot delete the attribute");
  125. return -1;
  126. }
  127. try {
  128. reinterpret_cast<PyOp(OpDef)*>(obj)->op
  129. ->set_scope(py::cast<std::string>(py::handle(value)));
  130. } CATCH_ALL(-1)
  131. return 0;
  132. }
  133. PyGetSetDef PyOp(OpDef)::py_getsetters[] = {
  134. {const_cast<char*>("scope"), py_get_scope, py_set_scope, "scope", NULL},
  135. {NULL}
  136. };
  137. Py_hash_t PyOp(OpDef)::tp_hash(PyObject *obj) {
  138. return static_cast<Py_hash_t>(
  139. reinterpret_cast<PyOp(OpDef)*>(obj)->op->hash());
  140. }
  141. PyObject* PyOp(OpDef)::tp_richcompare(PyObject *self, PyObject *other, int op) {
  142. bool same = reinterpret_cast<PyOp(OpDef)*>(self)->op->is_same(
  143. *reinterpret_cast<PyOp(OpDef)*>(other)->op);
  144. if (op == Py_EQ || op == Py_NE) {
  145. RETURN_RICHCOMPARE(same, true, op);
  146. }
  147. Py_RETURN_NOTIMPLEMENTED;
  148. }
  149. template<typename T>
  150. struct EnumTrait;
  151. #define PyEnumHead \
  152. static_assert(std::is_enum_v<T>); \
  153. PyObject_HEAD \
  154. T value; \
  155. constexpr static const char *name = EnumTrait<T>::name; \
  156. static PyTypeObject* type; \
  157. static const char* members[]; \
  158. static std::unordered_map<std::string, T> mem2value; \
  159. static PyObject* pyobj_insts[];
  160. template<typename T>
  161. struct EnumWrapper {
  162. PyEnumHead
  163. std::string to_string() const {
  164. return members[static_cast<size_t>(value)];
  165. }
  166. static PyObject* py_repr(PyObject* self) {
  167. return py::cast(
  168. std::string(name) + "." + reinterpret_cast<EnumWrapper*>(self)->to_string())
  169. .release().ptr();
  170. }
  171. static PyObject* tp_richcompare(PyObject *self, PyObject *other, int op) {
  172. if (op == Py_EQ || op == Py_NE) {
  173. T lhs, rhs;
  174. if (load(other, rhs) && load(self, lhs)) {
  175. RETURN_RICHCOMPARE(lhs, rhs, op);
  176. } else {
  177. RETURN_RICHCOMPARE(0, 1, op);
  178. }
  179. }
  180. Py_RETURN_NOTIMPLEMENTED;
  181. }
  182. static bool load(py::handle src, T& value) {
  183. PyObject* obj = src.ptr();
  184. if (PyObject_TypeCheck(obj, type)) {
  185. value = reinterpret_cast<EnumWrapper*>(obj)->value;
  186. return true;
  187. }
  188. if (py::isinstance<py::str>(src)) {
  189. auto&& iter = mem2value.find(
  190. normalize_enum(py::cast<std::string>(src)));
  191. if (iter != mem2value.end()) {
  192. value = iter->second;
  193. return true;
  194. } else {
  195. return false;
  196. }
  197. }
  198. return false;
  199. }
  200. static PyObject* cast(const T& value) {
  201. auto v = static_cast<std::underlying_type_t<T>>(value);
  202. mgb_assert(v <= EnumTrait<T>::max);
  203. PyObject* obj = pyobj_insts[v];
  204. Py_INCREF(obj);
  205. return obj;
  206. }
  207. };
  208. template<typename T>
  209. struct BitCombinedEnumWrapper {
  210. PyEnumHead
  211. std::string to_string() const {
  212. uint32_t value_int = static_cast<uint32_t>(value);
  213. if (value_int == 0) {
  214. return "None";
  215. } else {
  216. std::string ret;
  217. bool first = true;
  218. for (uint32_t i = 0; i < 32; i++) {
  219. if (value_int >> i & 1) {
  220. if (!first) {
  221. ret += " + ";
  222. } else {
  223. first = false;
  224. }
  225. ret += (std::string(name) + "." + members[i]);
  226. }
  227. }
  228. return ret;
  229. }
  230. }
  231. static PyObject* py_new_combined_enum(PyTypeObject* type, PyObject* args, PyObject*) {
  232. if (!PyTuple_Size(args)) {
  233. PyObject* obj = type->tp_alloc(type, 0);
  234. reinterpret_cast<BitCombinedEnumWrapper*>(obj)->value = T();
  235. return obj;
  236. }
  237. else {
  238. PyObject* input;
  239. if (!PyArg_ParseTuple(args, "|O", &input)) {
  240. return nullptr;
  241. }
  242. T value;
  243. if (load(input, value)) {
  244. return cast(value);
  245. } else {
  246. PyErr_SetString(PyExc_RuntimeError,
  247. mgb::ssprintf("Cannot convert type %s to type %s\n",
  248. input->ob_type->tp_name, name).c_str());
  249. return nullptr;
  250. }
  251. }
  252. }
  253. static PyObject* py_repr(PyObject* self) {
  254. return py::cast(
  255. reinterpret_cast<BitCombinedEnumWrapper*>(self)->to_string())
  256. .release().ptr();
  257. }
  258. static PyObject* py_or(PyObject* self, PyObject* other) {
  259. if(!(self->ob_type == other->ob_type)){
  260. return PyErr_Format(
  261. PyExc_RuntimeError,
  262. "Operand in or operator must be the same type.");
  263. }
  264. T lhs = reinterpret_cast<BitCombinedEnumWrapper*>(self)->value,
  265. rhs = reinterpret_cast<BitCombinedEnumWrapper*>(other)->value;
  266. return cast(lhs | rhs);
  267. }
  268. static PyObject* py_and(PyObject* self, PyObject* other) {
  269. if (!(self->ob_type == other->ob_type)) {
  270. return PyErr_Format(
  271. PyExc_RuntimeError,
  272. "Operand in and operator must be the same type.");
  273. }
  274. T lhs = reinterpret_cast<BitCombinedEnumWrapper*>(self)->value,
  275. rhs = reinterpret_cast<BitCombinedEnumWrapper*>(other)->value;
  276. return cast(lhs & rhs);
  277. }
  278. static PyObject* tp_richcompare(PyObject* self, PyObject* other, int op) {
  279. if (op == Py_EQ || op == Py_NE) {
  280. T lhs, rhs;
  281. if (load(other, rhs) && load(self, lhs)) {
  282. RETURN_RICHCOMPARE(lhs, rhs, op);
  283. } else {
  284. RETURN_RICHCOMPARE(0, 1, op);
  285. }
  286. }
  287. Py_RETURN_NOTIMPLEMENTED;
  288. }
  289. static bool load(py::handle src, T& value) {
  290. PyObject* obj = src.ptr();
  291. if (PyObject_TypeCheck(obj, type)) {
  292. value = reinterpret_cast<BitCombinedEnumWrapper*>(obj)->value;
  293. return true;
  294. }
  295. if (py::isinstance<py::str>(src)) {
  296. auto&& iter = mem2value.find(
  297. normalize_enum(py::cast<std::string>(src)));
  298. if (iter != mem2value.end()) {
  299. value = iter->second;
  300. return true;
  301. } else {
  302. return false;
  303. }
  304. }
  305. if (py::isinstance<py::int_>(obj)) {
  306. auto v = py::cast<std::underlying_type_t<T>>(src);
  307. if(v > EnumTrait<T>::max) {
  308. return false;
  309. }
  310. value = static_cast<T>(v);
  311. return true;
  312. }
  313. return false;
  314. }
  315. static PyObject* cast(const T& value) {
  316. auto v = static_cast<std::underlying_type_t<T>>(value);
  317. mgb_assert(v <= EnumTrait<T>::max);
  318. if ((!v) || (v & (v - 1))) {
  319. PyObject* obj = type->tp_alloc(type, 0);
  320. reinterpret_cast<BitCombinedEnumWrapper*>(obj)->value = value;
  321. return obj;
  322. } else {
  323. PyObject* obj = pyobj_insts[__builtin_ctz(v)];
  324. Py_INCREF(obj);
  325. return obj;
  326. }
  327. }
  328. };
  329. void _init_py_op_def(py::module m) {
  330. using py_op = PyOp(OpDef);
  331. auto& py_type = PyOpType(OpDef);
  332. py_type = {PyVarObject_HEAD_INIT(NULL, 0)};
  333. py_type.tp_name = "megengine.core._imperative_rt.OpDef";
  334. py_type.tp_basicsize = sizeof(PyOp(OpDef));
  335. py_type.tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE;
  336. py_type.tp_doc = "OpDef";
  337. py_type.tp_base = &PyBaseObject_Type;
  338. py_type.tp_hash = PyOp(OpDef)::tp_hash;
  339. py_type.tp_richcompare = PyOp(OpDef)::tp_richcompare;
  340. py_type.tp_getset = py_op::py_getsetters;
  341. mgb_assert(PyType_Ready(&py_type) >= 0);
  342. m.add_object("OpDef", reinterpret_cast<PyObject*>(&py_type));
  343. }
  344. /*********** begin of hand-write opdefs **************/
  345. struct PyOpBase : PyOpDef {
  346. static PyTypeObject py_type;
  347. static PyObject* tp_new(PyTypeObject* type, PyObject*, PyObject*) {
  348. auto* obj = type->tp_alloc(type, 0);
  349. if (obj) {
  350. auto* self = reinterpret_cast<PyOpBase*>(obj);
  351. new(&self->op) decltype(self->op);
  352. }
  353. return obj;
  354. }
  355. };
  356. PyTypeObject PyOpBase::py_type;
  357. void _init_py_op_base(py::module m) {
  358. using py_op = PyOpBase;
  359. auto& py_type = PyOpBase::py_type;
  360. py_type = {PyVarObject_HEAD_INIT(NULL, 0)};
  361. py_type.tp_name = "megengine.core._imperative_rt.ops.PyOpBase";
  362. py_type.tp_basicsize = sizeof(py_op);
  363. py_type.tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE;
  364. py_type.tp_doc = "PyOpBase";
  365. py_type.tp_base = &PyOpType(OpDef);
  366. py_type.tp_dealloc = py_dealloc_generic<py_op>;
  367. py_type.tp_new = py_op::tp_new;
  368. mgb_assert(PyType_Ready(&py_type) >= 0);
  369. m.add_object("PyOpBase", reinterpret_cast<PyObject*>(&py_type));
  370. }
  371. /*********** end of hand-write opdefs **************/
  372. // auto generated opdefs
  373. #include "opdef.cpy.inl"
  374. #undef CATCH_ALL
  375. } // anonymous namespace
  376. namespace PYBIND11_NAMESPACE {
  377. namespace detail {
  378. bool type_caster<OpDef>::load(handle src, bool convert) {
  379. PyObject* obj = src.ptr();
  380. if (!PyObject_TypeCheck(obj, &PyOpType(OpDef))) {
  381. return false;
  382. }
  383. value = reinterpret_cast<PyOp(OpDef)*>(obj)->op;
  384. if (!value) {
  385. // opdef only defined in Python
  386. value = std::make_shared<GenericPyOp>(reinterpret_borrow<object>(src));
  387. }
  388. return true;
  389. }
  390. handle type_caster<OpDef>::cast(const OpDef& op, return_value_policy, handle) {
  391. if (auto* pyop = op.try_cast_final<GenericPyOp>()) {
  392. return object(pyop->obj).release();
  393. }
  394. PyTypeObject* pytype;
  395. auto& c2p = PyOp(OpDef)::ctype2pytype;
  396. auto&& iter = c2p.find(op.dyn_typeinfo());
  397. if (iter != c2p.end()) { // FIXME: should always meet this condition
  398. pytype = iter->second;
  399. } else { // which means unregistered op type, jsut make it as an opaque op type
  400. // currently, only OprAttr goes into this branch
  401. pytype = &PyOpType(OpDef);
  402. }
  403. PyObject* obj = pytype->tp_alloc(pytype, 0);
  404. mgb_assert(PyObject_TypeCheck(obj, &PyOpType(OpDef)));
  405. reinterpret_cast<PyOp(OpDef)*>(obj)->op = const_cast<OpDef&>(op).shared_from_this();
  406. return py::handle(obj);
  407. }
  408. #define ENUM_CASTER_IMPL(T) \
  409. bool type_caster<T>::load(handle src, bool) { \
  410. return EnumWrapper<T>::load(src, value); \
  411. } \
  412. handle type_caster<T>::cast(const T& value, return_value_policy, handle) { \
  413. return EnumWrapper<T>::cast(value); \
  414. }
  415. FOR_EACH_ENUM_PARAM(ENUM_CASTER_IMPL)
  416. #define BIT_COMBINED_ENUM_CASTER_IMPL(T) \
  417. bool type_caster<T>::load(handle src, bool) { \
  418. return BitCombinedEnumWrapper<T>::load(src, value); \
  419. } \
  420. handle type_caster<T>::cast(const T& value, return_value_policy, handle) { \
  421. return BitCombinedEnumWrapper<T>::cast(value); \
  422. }
  423. FOR_EACH_BIT_COMBINED_ENUM_PARAM(BIT_COMBINED_ENUM_CASTER_IMPL)
  424. } // detail
  425. } // PYBIND11_NAMESPACE
  426. void init_ops(py::module m) {
  427. _init_py_op_def(m);
  428. _init_py_op_base(m);
  429. INIT_ALL_OP(m)
  430. m.def("new_rng_handle", &rng::new_handle);
  431. m.def("delete_rng_handle", [](size_t handle){
  432. // RNG op might execute after handle released due to async dispatch, so
  433. // we need sync before delete a handle to avoid memory leak or use-after-free
  434. if(python::interpreter_for_py->check_available()){
  435. python::interpreter_for_py->sync();
  436. }
  437. mgb::CompNode::sync_all();
  438. py_task_q.wait_all_task_finish();
  439. rng::delete_handle(handle);
  440. }, py::call_guard<py::gil_scoped_release>());
  441. m.def("set_global_rng_seed", &rng::set_global_rng_seed);
  442. m.def("get_global_rng_seed", &rng::get_global_rng_seed);
  443. m.def("get_rng_handle_compnode", &rng::get_rng_handle_compnode);
  444. struct PySubgraphBuilder {
  445. explicit PySubgraphBuilder(std::string name) : name{name}{}
  446. std::string name;
  447. std::shared_ptr<Subgraph> graph_storage = std::make_shared<Subgraph>();
  448. std::shared_ptr<UniqueKey> graph_key = std::make_shared<UniqueKey>();
  449. Subgraph& graph = *graph_storage;
  450. mgb::SmallVector<bool> output_grad_mask;
  451. Subgraph::var_t next_var = 1;
  452. std::shared_ptr<OpDef> build() const {
  453. return SubgraphOp::make(name, graph_storage, output_grad_mask, graph_key);
  454. }
  455. };
  456. py::class_<PySubgraphBuilder>(m, "SubgraphBuilder")
  457. .def(py::init<std::string>())
  458. .def("input", [](PySubgraphBuilder& self){
  459. auto var = self.next_var++;
  460. self.graph.inputs.push_back(var);
  461. return var;
  462. })
  463. .def("apply", [](PySubgraphBuilder& self, std::shared_ptr<OpDef> op, Subgraph::vars_t inputs, size_t nr_outputs){
  464. Subgraph::vars_t outputs;
  465. for (size_t i = 0; i < nr_outputs; ++i) {
  466. outputs.push_back(self.next_var++);
  467. }
  468. self.graph.exprs.push_back({op, inputs, outputs});
  469. return outputs;
  470. })
  471. .def("apply_const", [](PySubgraphBuilder& self, py::object value, mgb::DType dtype, mgb::CompNode cn){
  472. auto var = self.next_var++;
  473. mgb::HostTensorND hvalue(cn);
  474. npy::np2tensor(value.cast<py::array>().ptr(), npy::Meth::copy_into(&hvalue), dtype);
  475. self.graph.constants.push_back({var, Tensor::make(hvalue)});
  476. return var;
  477. })
  478. .def("outputs", [](PySubgraphBuilder& self, Subgraph::vars_t outputs){
  479. self.graph.outputs = outputs;
  480. self.output_grad_mask.resize(outputs.size(), true);
  481. })
  482. .def("outputs_has_grad", [](PySubgraphBuilder& self, mgb::SmallVector<bool> outputs_has_grad){
  483. mgb_assert(self.graph.outputs.size() == self.output_grad_mask.size());
  484. self.output_grad_mask = outputs_has_grad;
  485. })
  486. .def("get", [](PySubgraphBuilder& self){
  487. return (std::shared_ptr<OpDef>)self.build();
  488. })
  489. .def("compile", [](PySubgraphBuilder& self, int gopt_level){
  490. return (std::shared_ptr<OpDef>)CompiledOp::make(self.build(), gopt_level);
  491. });
  492. }

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台