You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

python_c_extension.cpp 14 kB


  1. /**
  2. * \file imperative/tablegen/targets/python_c_extension.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "python_c_extension.h"
  12. #include "../emitter.h"
  13. namespace mlir::tblgen {
  14. namespace {
  15. struct Initproc {
  16. std::string func;
  17. Initproc(std::string&& s): func(std::move(s)) {}
  18. std::string operator()(std::string argument) {
  19. return formatv("{0}({1})", func, argument);
  20. }
  21. };
  22. class OpDefEmitter: public EmitterBase {
  23. public:
  24. OpDefEmitter(MgbOp& op_, raw_ostream& os_, Environment& env_):
  25. EmitterBase(os_, env_), op(op_) {
  26. ctx.withSelf(op.getCppClassName());
  27. }
  28. Initproc emit();
  29. private:
  30. void emit_class();
  31. void emit_py_init();
  32. void emit_py_getsetters();
  33. void emit_py_methods();
  34. Initproc emit_initproc();
  35. MgbOp& op;
  36. std::vector<Initproc> subclasses;
  37. mlir::tblgen::FmtContext ctx;
  38. };
  39. class EnumAttrEmitter: public EmitterBase {
  40. public:
  41. EnumAttrEmitter(llvm::StringRef parent, MgbEnumAttr* attr_, raw_ostream& os_, Environment& env_):
  42. EmitterBase(os_, env_), attr(attr_) {
  43. unsigned int enumID;
  44. if (auto alias = llvm::dyn_cast<MgbAliasAttr>(attr)) {
  45. auto&& aliasBase = alias->getAliasBase();
  46. enumID = llvm::cast<MgbEnumAttr>(aliasBase).getBaseRecord()->getID();
  47. } else {
  48. enumID = attr->getBaseRecord()->getID();
  49. }
  50. ctx.addSubst("enumTpl", attr->getEnumCombinedFlag() ? "BitCombinedEnumWrapper" : "EnumWrapper");
  51. ctx.addSubst("opClass", parent);
  52. ctx.addSubst("enumClass", attr->getEnumName());
  53. firstOccur = env().enumAlias.emplace(enumID, std::make_pair(parent, attr->getEnumName())).second;
  54. }
  55. Initproc emit();
  56. protected:
  57. void emit_trait();
  58. void emit_tpl_spl();
  59. Initproc emit_initproc();
  60. MgbEnumAttr* attr;
  61. bool firstOccur;
  62. mlir::tblgen::FmtContext ctx;
  63. };
  64. Initproc EnumAttrEmitter::emit() {
  65. emit_trait();
  66. emit_tpl_spl();
  67. return emit_initproc();
  68. }
  69. void EnumAttrEmitter::emit_trait() {
  70. if (!firstOccur) return;
  71. auto enumMax = [&] {
  72. if (attr->getEnumCombinedFlag()) {
  73. return formatv("(1llu << {0}) - 1", attr->getEnumMembers().size());
  74. } else {
  75. return formatv("{0} - 1", attr->getEnumMembers().size());
  76. }
  77. };
  78. os << tgfmt(R"(
  79. template<> struct EnumTrait<$opClass::$enumClass> {
  80. static constexpr const char *name = "$opClass.$enumClass";
  81. static constexpr std::underlying_type_t<$opClass::$enumClass> max = $0;
  82. };
  83. )", &ctx, enumMax());
  84. }
  85. void EnumAttrEmitter::emit_tpl_spl() {
  86. if (!firstOccur) return;
  87. os << tgfmt(
  88. "template<> PyTypeObject* $enumTpl<$opClass::$enumClass>::type = nullptr;\n",
  89. &ctx);
  90. auto quote = [&](auto&& i) -> std::string {
  91. return formatv("\"{0}\"", i);
  92. };
  93. os << tgfmt(R"(
  94. template<> const char*
  95. $enumTpl<$opClass::$enumClass>::members[] = {$0};
  96. )", &ctx, llvm::join(llvm::map_range(attr->getEnumMembers(), quote), ", "));
  97. auto mem2value = [&](auto&& i) -> std::string {
  98. return tgfmt("{normalize_enum(\"$0\"), $opClass::$enumClass::$0}", &ctx, i);
  99. };
  100. os << tgfmt(R"(
  101. template<> std::unordered_map<std::string, $opClass::$enumClass>
  102. $enumTpl<$opClass::$enumClass>::mem2value = {$0};
  103. )", &ctx, llvm::join(llvm::map_range(attr->getEnumMembers(), mem2value), ", "));
  104. os << tgfmt(
  105. "template<> PyObject* "
  106. "$enumTpl<$opClass::$enumClass>::pyobj_insts[$0] = {nullptr};\n",
  107. &ctx, attr->getEnumMembers().size());
  108. }
  109. Initproc EnumAttrEmitter::emit_initproc() {
  110. std::string initproc = formatv("_init_py_{0}_{1}",
  111. ctx.getSubstFor("opClass"), ctx.getSubstFor("enumClass"));
  112. os << tgfmt(R"(
  113. void $0(PyTypeObject& py_type) {
  114. auto& e_type = $enumTpl<$opClass::$enumClass>::type;
  115. )", &ctx, initproc);
  116. if (firstOccur) {
  117. os << tgfmt(R"(
  118. static PyMethodDef tp_methods[] = {
  119. {const_cast<char*>("dump"), (PyCFunction)$enumTpl<$opClass::$enumClass>::py_dump, METH_NOARGS, NULL},
  120. {NULL} /* Sentinel */
  121. };
  122. )", &ctx);
  123. os << tgfmt(R"(
  124. static PyType_Slot slots[] = {
  125. {Py_tp_repr, (void*)$enumTpl<$opClass::$enumClass>::py_repr},
  126. {Py_tp_richcompare, (void*)$enumTpl<$opClass::$enumClass>::tp_richcompare},
  127. {Py_tp_methods, tp_methods},
  128. )", &ctx);
  129. if (attr->getEnumCombinedFlag()) {
  130. // only bit combined enum could new instance because bitwise operation,
  131. // others should always use singleton
  132. os << tgfmt(R"(
  133. {Py_tp_new, (void*)$enumTpl<$opClass::$enumClass>::py_new_combined_enum},
  134. {Py_nb_or, (void*)$enumTpl<$opClass::$enumClass>::py_or},
  135. {Py_nb_and, (void*)$enumTpl<$opClass::$enumClass>::py_and},
  136. )", &ctx);
  137. }
  138. os << R"(
  139. {0, NULL}
  140. };)";
  141. os << tgfmt(R"(
  142. static PyType_Spec spec = {
  143. // name
  144. "megengine.core._imperative_rt.ops.$opClass.$enumClass",
  145. // basicsize
  146. sizeof($enumTpl<$opClass::$enumClass>),
  147. // itemsize
  148. 0,
  149. // flags
  150. Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HEAPTYPE,
  151. // slots
  152. slots
  153. };)", &ctx);
  154. os << tgfmt(R"(
  155. e_type = reinterpret_cast<PyTypeObject*>(PyType_FromSpec(&spec));
  156. )", &ctx);
  157. for (auto&& i : {
  158. std::pair<std::string, std::string>{"__name__", tgfmt("$enumClass", &ctx)},
  159. {"__module__", "megengine.core._imperative_rt.ops"},
  160. {"__qualname__", tgfmt("$opClass.$enumClass", &ctx)}}) {
  161. os << formatv(R"(
  162. mgb_assert(
  163. e_type->tp_setattro(
  164. reinterpret_cast<PyObject*>(e_type),
  165. py::cast("{0}").release().ptr(),
  166. py::cast("{1}").release().ptr()) >= 0);
  167. )", i.first, i.second);
  168. }
  169. auto&& members = attr->getEnumMembers();
  170. for (size_t idx = 0; idx < members.size(); ++ idx) {
  171. os << tgfmt(R"({
  172. PyObject* inst = e_type->tp_alloc(e_type, 0);
  173. reinterpret_cast<$enumTpl<$opClass::$enumClass>*>(inst)->value = $opClass::$enumClass::$0;
  174. mgb_assert(PyDict_SetItemString(e_type->tp_dict, "$0", inst) >= 0);
  175. $enumTpl<$opClass::$enumClass>::pyobj_insts[$1] = inst;
  176. })", &ctx, members[idx], idx);
  177. }
  178. }
  179. os << tgfmt(R"(
  180. Py_INCREF(e_type);
  181. mgb_assert(PyDict_SetItemString(
  182. py_type.tp_dict, "$enumClass", reinterpret_cast<PyObject*>(e_type)) >= 0);
  183. )", &ctx);
  184. os << "}\n";
  185. return initproc;
  186. }
  187. Initproc OpDefEmitter::emit() {
  188. for (auto&& i : op.getMgbAttributes()) {
  189. if (auto attr = llvm::dyn_cast<MgbEnumAttr>(&i.attr)) {
  190. subclasses.push_back(EnumAttrEmitter(op.getCppClassName(), attr, os, env()).emit());
  191. }
  192. }
  193. emit_class();
  194. emit_py_init();
  195. emit_py_getsetters();
  196. emit_py_methods();
  197. return emit_initproc();
  198. }
  199. void OpDefEmitter::emit_class() {
  200. auto&& className = op.getCppClassName();
  201. std::string method_defs;
  202. std::vector<std::string> body;
  203. llvm::for_each(op.getMgbAttributes(), [&](auto&& attr) {
  204. body.push_back(formatv(R"(
  205. {{"{0}", serialization<decltype(opdef.{0})>::dump(opdef.{0})})"
  206. , attr.name));
  207. });
  208. method_defs += formatv(R"(
  209. static PyObject* getstate(PyObject* self, PyObject*) {{
  210. auto& opdef = reinterpret_cast<PyOp({0})*>(self)->inst();
  211. static_cast<void>(opdef);
  212. std::unordered_map<std::string, py::object> state {{
  213. {1}
  214. };
  215. return py::cast(state).release().ptr();
  216. })", className, llvm::join(body, ","));
  217. body.clear();
  218. llvm::for_each(op.getMgbAttributes(), [&](auto&& attr) {
  219. body.push_back(formatv(R"(
  220. {{
  221. auto&& iter = state.find("{0}");
  222. if (iter != state.end()) {
  223. opdef.{0} = serialization<decltype(opdef.{0})>::load(iter->second);
  224. }
  225. })", attr.name));
  226. });
  227. method_defs += formatv(R"(
  228. static PyObject* setstate(PyObject* self, PyObject* args) {{
  229. PyObject* dict = PyTuple_GetItem(args, 0);
  230. if (!dict) return NULL;
  231. auto state = py::cast<std::unordered_map<std::string, py::object>>(dict);
  232. auto& opdef = reinterpret_cast<PyOp({0})*>(self)->inst();
  233. static_cast<void>(opdef);
  234. {1}
  235. Py_RETURN_NONE;
  236. })", className, llvm::join(body, "\n"));
  237. os << tgfmt(R"(
  238. PyOpDefBegin($_self) // {
  239. static PyGetSetDef py_getsetters[];
  240. static PyMethodDef tp_methods[];
  241. $0
  242. static int py_init(PyObject *self, PyObject *args, PyObject *kwds);
  243. // };
  244. PyOpDefEnd($_self)
  245. )", &ctx, method_defs);
  246. }
  247. void OpDefEmitter::emit_py_init() {
  248. std::string initBody;
  249. if (!op.getMgbAttributes().empty()) {
  250. initBody += "static const char* kwlist[] = {";
  251. std::vector<llvm::StringRef> attr_name_list;
  252. llvm::for_each(op.getMgbAttributes(), [&](auto&& attr) {
  253. attr_name_list.push_back(attr.name);
  254. });
  255. attr_name_list.push_back("scope");
  256. llvm::for_each(attr_name_list, [&](auto&& attr) {
  257. initBody += formatv("\"{0}\", ", attr);
  258. });
  259. initBody += "NULL};\n";
  260. initBody += " PyObject ";
  261. auto initializer = [&](auto&& attr) -> std::string {
  262. return formatv("*{0} = NULL", attr);
  263. };
  264. initBody += llvm::join(llvm::map_range(attr_name_list, initializer), ", ") + ";\n";
  265. initBody += " if (!PyArg_ParseTupleAndKeywords(args, kwds, \"|";
  266. // an extra slot created for name
  267. initBody += std::string(attr_name_list.size(), 'O');
  268. initBody += "\", const_cast<char**>(kwlist)";
  269. llvm::for_each(attr_name_list, [&](auto&& attr) {
  270. initBody += formatv(", &{0}", attr);
  271. });
  272. initBody += "))\n";
  273. initBody += " return -1;\n";
  274. llvm::for_each(op.getMgbAttributes(), [&](auto&& attr) {
  275. initBody += tgfmt(R"(
  276. if ($0) {
  277. try {
  278. // TODO: remove this guard which is used for pybind11 implicit conversion
  279. py::detail::loader_life_support guard{};
  280. reinterpret_cast<PyOp($_self)*>(self)->inst().$0 =
  281. py::cast<decltype($_self::$0)>(py::handle($0));
  282. } CATCH_ALL(-1)
  283. }
  284. )", &ctx, attr.name);
  285. });
  286. initBody += tgfmt(R"(
  287. if (scope) {
  288. try {
  289. reinterpret_cast<PyOp(OpDef)*>(self)->op
  290. ->set_scope(py::cast<std::string>(py::handle(scope)));
  291. } CATCH_ALL(-1)
  292. }
  293. )", &ctx);
  294. }
  295. initBody += "\n return 0;";
  296. os << tgfmt(R"(
  297. int PyOp($_self)::py_init(PyObject *self, PyObject *args, PyObject *kwds) {
  298. $0
  299. }
  300. )", &ctx, initBody);
  301. }
  302. void OpDefEmitter::emit_py_getsetters() {
  303. auto f = [&](auto&& attr) -> std::string {
  304. return tgfmt(
  305. "{const_cast<char*>(\"$0\"), py_get_generic($_self, $0), py_set_generic($_self, $0), const_cast<char*>(\"$0\"), NULL},",
  306. &ctx, attr.name);
  307. };
  308. os << tgfmt(R"(
  309. PyGetSetDef PyOp($_self)::py_getsetters[] = {
  310. $0
  311. {NULL} /* Sentinel */
  312. };
  313. )", &ctx, llvm::join(llvm::map_range(op.getMgbAttributes(), f), "\n "));
  314. }
  315. void OpDefEmitter::emit_py_methods(){
  316. // generate methods
  317. std::string method_defs;
  318. std::vector<std::string> method_items;
  319. {
  320. auto&& className = op.getCppClassName();
  321. // generate getstate
  322. method_items.push_back(formatv(
  323. "{{const_cast<char*>(\"__getstate__\"), PyOp({0})::getstate, METH_NOARGS, \"{0} getstate\"},",
  324. className));
  325. // generate setstate
  326. method_items.push_back(formatv(
  327. "{{const_cast<char*>(\"__setstate__\"), PyOp({0})::setstate, METH_VARARGS, \"{0} setstate\"},",
  328. className));
  329. }
  330. os << tgfmt(R"(
  331. PyMethodDef PyOp($_self)::tp_methods[] = {
  332. $0
  333. {NULL} /* Sentinel */
  334. };
  335. )", &ctx, llvm::join(method_items, "\n "));
  336. }
  337. Initproc OpDefEmitter::emit_initproc() {
  338. std::string initproc = formatv("_init_py_{0}", op.getCppClassName());
  339. std::string subclass_init_call;
  340. for (auto&& i : subclasses) {
  341. subclass_init_call += formatv(" {0};\n", i("py_type"));
  342. }
  343. os << tgfmt(R"(
  344. void $0(py::module m) {
  345. using py_op = PyOp($_self);
  346. auto& py_type = PyOpType($_self);
  347. py_type = {PyVarObject_HEAD_INIT(NULL, 0)};
  348. py_type.tp_name = "megengine.core._imperative_rt.ops.$_self";
  349. py_type.tp_basicsize = sizeof(PyOp($_self));
  350. py_type.tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE;
  351. py_type.tp_doc = "$_self";
  352. py_type.tp_base = &PyOpType(OpDef);
  353. py_type.tp_dealloc = py_dealloc_generic<py_op>;
  354. py_type.tp_new = py_new_generic<py_op>;
  355. py_type.tp_init = py_op::py_init;
  356. py_type.tp_methods = py_op::tp_methods;
  357. py_type.tp_getset = py_op::py_getsetters;
  358. mgb_assert(PyType_Ready(&py_type) >= 0);
  359. $1
  360. PyType_Modified(&py_type);
  361. m.add_object("$_self", reinterpret_cast<PyObject*>(&py_type));
  362. mgb_assert(PyOp(OpDef)::ctype2pytype.emplace($_self::typeinfo(), &py_type).second);
  363. }
  364. )", &ctx, initproc, subclass_init_call);
  365. return initproc;
  366. }
  367. } // namespace
  368. bool gen_op_def_python_c_extension(raw_ostream &os, llvm::RecordKeeper &keeper) {
  369. Environment env;
  370. using namespace std::placeholders;
  371. std::vector<Initproc> initprocs;
  372. foreach_operator(keeper, [&](MgbOp& op) {
  373. initprocs.emplace_back(OpDefEmitter(op, os, env).emit());
  374. });
  375. os << "#define INIT_ALL_OP(m)";
  376. for(auto&& init : initprocs) {
  377. os << formatv(" \\\n {0};", init("m"));
  378. }
  379. os << "\n";
  380. return false;
  381. }
  382. } // namespace mlir::tblgen

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台