fix(imperative): fix refcount management on cpython opdef
refactor(mge/imperative): fix compilation for python3.6
GitOrigin-RevId: 332a516895
tags/v1.2.0
| @@ -48,7 +48,7 @@ def _(op: OpDef, inputs, outputs, input_requires_grad): | |||
| isinstance(op, Elemwise) and op.mode == Elemwise.Mode.ADD | |||
| ): | |||
| grad_fn = elemwise_add_grad_fn | |||
| elif isinstance(op, Reduce) and op.mode.name == "SUM": | |||
| elif isinstance(op, Reduce) and op.mode == Reduce.Mode.SUM: | |||
| grad_fn = reduce_sum_grad_fn | |||
| else: | |||
| grad_fn = default_grad_fn | |||
| @@ -447,8 +447,8 @@ def _(op: OpDef, *args: VarNode): | |||
| def _(op: BackwardGraph, *args: VarNode): | |||
| assert args | |||
| graph = args[0].graph | |||
| return op.interpret( | |||
| lambda op, args: apply(op, *args), graph._make_const_for_backward, args | |||
| return BackwardGraph.interpret( | |||
| op, lambda op, args: apply(op, *args), graph._make_const_for_backward, args | |||
| ) | |||
| @@ -13,6 +13,7 @@ | |||
| #include "megbrain/graph.h" | |||
| #include "megbrain/utils/persistent_cache.h" | |||
| #include "megbrain/imperative/op_def.h" | |||
| #include <Python.h> | |||
| #include <string> | |||
| @@ -376,6 +377,32 @@ namespace detail { | |||
| } | |||
| }; | |||
| template<> struct type_caster<mgb::imperative::OpDef> { | |||
| protected: | |||
| std::shared_ptr<mgb::imperative::OpDef> value; | |||
| public: | |||
| static constexpr auto name = _("OpDef"); | |||
| operator mgb::imperative::OpDef&() { return *value; } | |||
| operator const mgb::imperative::OpDef&() { return *value; } | |||
| operator std::shared_ptr<mgb::imperative::OpDef>&() { return value; } | |||
| operator std::shared_ptr<mgb::imperative::OpDef>&&() && { return std::move(value); } | |||
| template <typename T> using cast_op_type = T; | |||
| bool load(handle src, bool convert); | |||
| static handle cast(const mgb::imperative::OpDef& op, return_value_policy /* policy */, handle /* parent */); | |||
| static handle cast(std::shared_ptr<mgb::imperative::OpDef> op, return_value_policy policy, handle parent) { | |||
| return cast(*op, policy, parent); | |||
| } | |||
| }; | |||
| template <> struct type_caster<std::shared_ptr<mgb::imperative::OpDef>> : | |||
| public type_caster<mgb::imperative::OpDef> { | |||
| template <typename T> using cast_op_type = pybind11::detail::movable_cast_op_type<T>; | |||
| }; | |||
| } // detail | |||
| } // PYBIND11_NAMESPACE | |||
| @@ -106,13 +106,4 @@ void init_imperative_rt(py::module m) { | |||
| }); | |||
| m.def("make_backward_graph", &make_backward_graph); | |||
| py::class_<OpDef, std::shared_ptr<OpDef>>(m, "OpDef") | |||
| .def("ctype", [](const OpDef& opdef) { | |||
| return opdef.dyn_typeinfo()->name; | |||
| }) | |||
| .def("__eq__", [](const OpDef& lhs, const OpDef& rhs) { | |||
| return lhs.is_same(rhs); | |||
| }) | |||
| .def("__hash__", &OpDef::hash); | |||
| } | |||
| @@ -63,6 +63,7 @@ PYBIND11_MODULE(MODULE_NAME, m) { | |||
| from .utils import * | |||
| from .imperative import * | |||
| from .graph import * | |||
| from .ops import OpDef | |||
| )", | |||
| py::getattr(m, "__dict__")); | |||
| @@ -16,7 +16,11 @@ | |||
| #include "megbrain/imperative/ops/opr_attr.h" | |||
| #include "megbrain/imperative/ops/autogen.h" | |||
| #include <Python.h> | |||
| #include <unordered_map> | |||
| namespace py = pybind11; | |||
| using namespace mgb::imperative; | |||
| namespace { | |||
| auto normalize_enum(const std::string& in) { | |||
| @@ -28,20 +32,256 @@ auto normalize_enum(const std::string& in) { | |||
| } | |||
| } // anonymous namespace | |||
| namespace { | |||
| #define PyOp(name) Py##name | |||
| #define PyOpType(name) PyOp(name)::py_type | |||
| #define PyOpDefBegin(name) \ | |||
| struct PyOp(name) : PyOpDef { \ | |||
| using Ty = name; \ | |||
| Ty& inst() { return op->cast_final_safe<Ty>(); } \ | |||
| static PyTypeObject py_type; | |||
| #define PyOpDefEnd(name) \ | |||
| }; \ | |||
| PyTypeObject PyOpType(name); | |||
| #define RETURN_RICHCOMPARE(val1, val2, op) \ | |||
| do { \ | |||
| switch (op) { \ | |||
| case Py_EQ: if ((val1) == (val2)) Py_RETURN_TRUE; Py_RETURN_FALSE; \ | |||
| case Py_NE: if ((val1) != (val2)) Py_RETURN_TRUE; Py_RETURN_FALSE; \ | |||
| case Py_LT: if ((val1) < (val2)) Py_RETURN_TRUE; Py_RETURN_FALSE; \ | |||
| case Py_GT: if ((val1) > (val2)) Py_RETURN_TRUE; Py_RETURN_FALSE; \ | |||
| case Py_LE: if ((val1) <= (val2)) Py_RETURN_TRUE; Py_RETURN_FALSE; \ | |||
| case Py_GE: if ((val1) >= (val2)) Py_RETURN_TRUE; Py_RETURN_FALSE; \ | |||
| default: \ | |||
| Py_FatalError("Unreachable C code path reached"); \ | |||
| } \ | |||
| } while (0) | |||
| template<typename T, typename SFINAE=void> | |||
| struct pyobj_convert_generic { | |||
| static T from(PyObject* obj) { | |||
| // TODO: remove this guard which is used for pybind11 implicit conversion | |||
| py::detail::loader_life_support guard{}; | |||
| return py::cast<T>(py::handle(obj)); | |||
| } | |||
| template<typename U, | |||
| typename = std::enable_if_t<std::is_same_v<T, std::decay_t<U>>>> | |||
| static PyObject* to(U&& t) { | |||
| return py::cast(std::forward<U>(t)).release().ptr(); | |||
| } | |||
| }; | |||
| template<typename T> | |||
| PyObject* py_new_generic(PyTypeObject* type, PyObject*, PyObject*) { | |||
| PyObject* obj = type->tp_alloc(type, 0); | |||
| T* self = reinterpret_cast<T*>(obj); | |||
| if (self != NULL) { | |||
| self->op = T::Ty::make(); | |||
| } | |||
| return obj; | |||
| } | |||
| template<typename T> | |||
| void py_dealloc_generic(PyObject* obj) { | |||
| reinterpret_cast<T*>(obj)->op.reset(); | |||
| Py_TYPE(obj)->tp_free(obj); | |||
| } | |||
| template<typename T, typename U, U T::Ty::*attr> | |||
| PyObject* py_get_generic_impl(PyObject* obj, void* /* closure */) { | |||
| auto& op = reinterpret_cast<T*>(obj)->inst(); | |||
| return pyobj_convert_generic<U>::to(op.*attr); | |||
| } | |||
| #define py_get_generic(name, attr) \ | |||
| py_get_generic_impl<PyOp(name), decltype(std::declval<name>().attr), &name::attr> | |||
| template<typename T, typename U, U T::Ty::*attr> | |||
| int py_set_generic_impl(PyObject* obj, PyObject* value, void* /* closure */) { | |||
| if (value == NULL) { | |||
| PyErr_SetString(PyExc_TypeError, "Cannot delete the attribute"); | |||
| return -1; | |||
| } | |||
| auto& op = reinterpret_cast<T*>(obj)->inst(); | |||
| try { | |||
| op.*attr = pyobj_convert_generic<U>::from(value); | |||
| return 0; | |||
| } catch(py::error_already_set& e) { | |||
| e.restore(); | |||
| } catch(py::builtin_exception& e) { | |||
| e.set_error(); | |||
| } catch(...) { | |||
| PyErr_SetString(PyExc_RuntimeError, "Unknown Error"); | |||
| } | |||
| return -1; | |||
| } | |||
| #define py_set_generic(name, attr) \ | |||
| py_set_generic_impl<PyOp(name), decltype(std::declval<name>().attr), &name::attr> | |||
| struct PyOpDef { | |||
| PyObject_HEAD | |||
| std::shared_ptr<OpDef> op; | |||
| static PyTypeObject py_type; | |||
| static std::unordered_map<mgb::Typeinfo*, PyTypeObject*> ctype2pytype; | |||
| static Py_hash_t tp_hash(PyObject *obj); | |||
| static PyObject* tp_richcompare(PyObject *self, PyObject *other, int op); | |||
| }; | |||
| PyTypeObject PyOpType(OpDef); | |||
| std::unordered_map<mgb::Typeinfo*, PyTypeObject*> PyOp(OpDef)::ctype2pytype; | |||
| Py_hash_t PyOp(OpDef)::tp_hash(PyObject *obj) { | |||
| return static_cast<Py_hash_t>( | |||
| reinterpret_cast<PyOp(OpDef)*>(obj)->op->hash()); | |||
| } | |||
| PyObject* PyOp(OpDef)::tp_richcompare(PyObject *self, PyObject *other, int op) { | |||
| bool same = reinterpret_cast<PyOp(OpDef)*>(self)->op->is_same( | |||
| *reinterpret_cast<PyOp(OpDef)*>(other)->op); | |||
| if (op == Py_EQ || op == Py_NE) { | |||
| RETURN_RICHCOMPARE(same, true, op); | |||
| } | |||
| Py_RETURN_NOTIMPLEMENTED; | |||
| } | |||
| template<typename T> | |||
| struct EnumWrapper { | |||
| static_assert(std::is_enum_v<T>); | |||
| PyObject_HEAD | |||
| T value; | |||
| static const char* name; | |||
| static PyTypeObject type; | |||
| static std::unordered_map<T, std::string> type2str; | |||
| static std::unordered_map<std::string, T> str2type; | |||
| EnumWrapper() = default; | |||
| EnumWrapper(T v): value(v) {} | |||
| EnumWrapper(std::string&& str): EnumWrapper(str2type.at(normalize_enum(str))) {} | |||
| std::string to_string() const { | |||
| return type2str.at(value); | |||
| } | |||
| static PyObject* py_repr(PyObject* self) { | |||
| return pyobj_convert_generic<std::string>::to( | |||
| std::string(name) + "." + reinterpret_cast<EnumWrapper*>(self)->to_string()); | |||
| } | |||
| static PyObject* tp_richcompare(PyObject *self, PyObject *other, int op) { | |||
| T lhs = reinterpret_cast<EnumWrapper*>(self)->value, | |||
| rhs = reinterpret_cast<EnumWrapper*>(other)->value; | |||
| if (op == Py_EQ || op == Py_NE) { | |||
| RETURN_RICHCOMPARE(lhs, rhs, op); | |||
| } | |||
| Py_RETURN_NOTIMPLEMENTED; | |||
| } | |||
| }; | |||
| template<typename T> | |||
| struct pyobj_convert_generic<T, | |||
| std::enable_if_t<std::is_enum_v<std::decay_t<T>>>> { | |||
| using Wrapper = EnumWrapper<T>; | |||
| static T from(PyObject* obj) { | |||
| if (PyObject_TypeCheck(obj, &Wrapper::type)) { | |||
| return reinterpret_cast<Wrapper*>(obj)->value; | |||
| } | |||
| // try as string | |||
| // TODO: type checkcd | |||
| return Wrapper(pyobj_convert_generic<std::string>::from(obj)).value; | |||
| } | |||
| static PyObject* to(T t) { | |||
| PyTypeObject* pytype = &Wrapper::type; | |||
| PyObject* obj = pytype->tp_alloc(pytype, 0); | |||
| reinterpret_cast<Wrapper*>(obj)->value = t; | |||
| return obj; | |||
| } | |||
| }; | |||
| void _init_py_op_def(py::module m) { | |||
| auto& py_type = PyOpType(OpDef); | |||
| py_type = {PyVarObject_HEAD_INIT(NULL, 0)}; | |||
| py_type.tp_name = "megengine.core._imperative_rt.OpDef"; | |||
| py_type.tp_basicsize = sizeof(PyOp(OpDef)); | |||
| py_type.tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE; | |||
| py_type.tp_doc = "OpDef"; | |||
| py_type.tp_base = &PyBaseObject_Type; | |||
| py_type.tp_hash = PyOp(OpDef)::tp_hash; | |||
| py_type.tp_richcompare = PyOp(OpDef)::tp_richcompare; | |||
| mgb_assert(PyType_Ready(&py_type) >= 0); | |||
| m.add_object("OpDef", reinterpret_cast<PyObject*>(&py_type)); | |||
| } | |||
| /*********** begin of hand-write opdefs **************/ | |||
| PyOpDefBegin(BackwardGraph) // {{ | |||
| // }; | |||
| PyOpDefEnd(BackwardGraph) | |||
| void _init_py_backward_graph(py::module m) { | |||
| using py_op = PyOp(BackwardGraph); | |||
| auto& py_type = PyOpType(BackwardGraph); | |||
| py_type = {PyVarObject_HEAD_INIT(NULL, 0)}; | |||
| py_type.tp_name = "megengine.core._imperative_rt.ops.BackwardGraph"; | |||
| py_type.tp_basicsize = sizeof(PyOp(BackwardGraph)); | |||
| py_type.tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE; | |||
| py_type.tp_doc = "BackwardGraph"; | |||
| py_type.tp_base = &PyOpType(OpDef); | |||
| py_type.tp_dealloc = py_dealloc_generic<py_op>; | |||
| py_type.tp_new = py_new_generic<py_op>; | |||
| mgb_assert(PyType_Ready(&py_type) >= 0); | |||
| // FIXME: rewrite interpret function in cpython instead wrap directly by pybind11::cppfunction | |||
| auto interpret = py::cpp_function( | |||
| [](OpDef& self, py::object pyf, py::object pyc, | |||
| const mgb::SmallVector<py::object>& inputs) { | |||
| auto f = [pyf](OpDef& op, const mgb::SmallVector<py::object>& inputs) { | |||
| return py::cast<mgb::SmallVector<py::object>>(pyf(op.shared_from_this(), inputs)); | |||
| }; | |||
| auto c = [pyc](const TensorPtr& tensor) { | |||
| return pyc(tensor->dev_tensor()); | |||
| }; | |||
| return self.cast_final_safe<BackwardGraph>().graph().interpret<py::object>(f, c, inputs); | |||
| }); | |||
| mgb_assert(PyDict_SetItemString( | |||
| py_type.tp_dict, "interpret", interpret.release().ptr()) >= 0); | |||
| PyType_Modified(&py_type); | |||
| m.add_object("BackwardGraph", reinterpret_cast<PyObject*>(&py_type)); | |||
| mgb_assert(PyOp(OpDef)::ctype2pytype.emplace(BackwardGraph::typeinfo(), &py_type).second); | |||
| } | |||
| /*********** end of hand-write opdefs **************/ | |||
| // auto generated opdefs | |||
| #include "opdef.cpy.inl" | |||
| } // anonymous namespace | |||
| namespace PYBIND11_NAMESPACE { | |||
| namespace detail { | |||
| bool type_caster<OpDef>::load(handle src, bool convert) { | |||
| PyObject* obj = src.ptr(); | |||
| if (!PyObject_TypeCheck(obj, &PyOpType(OpDef))) { | |||
| return false; | |||
| } | |||
| value = reinterpret_cast<PyOp(OpDef)*>(obj)->op; | |||
| return true; | |||
| } | |||
| handle type_caster<OpDef>::cast(const OpDef& op, return_value_policy, handle) { | |||
| PyTypeObject* pytype; | |||
| auto& c2p = PyOp(OpDef)::ctype2pytype; | |||
| auto&& iter = c2p.find(op.dyn_typeinfo()); | |||
| if (iter != c2p.end()) { // FIXME: should always meet this condition | |||
| pytype = iter->second; | |||
| } else { // which means unregistered op type, jsut make it as an opaque op type | |||
| // currently, only OprAttr goes into this branch | |||
| pytype = &PyOpType(OpDef); | |||
| } | |||
| PyObject* obj = pytype->tp_alloc(pytype, 0); | |||
| mgb_assert(PyObject_TypeCheck(obj, &PyOpType(OpDef))); | |||
| reinterpret_cast<PyOp(OpDef)*>(obj)->op = const_cast<OpDef&>(op).shared_from_this(); | |||
| return py::handle(obj); | |||
| } | |||
| } // detail | |||
| } // PYBIND11_NAMESPACE | |||
| void init_ops(py::module m) { | |||
| using namespace mgb::imperative; | |||
| py::class_<BackwardGraph, std::shared_ptr<BackwardGraph>, OpDef>(m, "BackwardGraph") | |||
| .def("interpret", [](BackwardGraph& self, py::object pyf, py::object pyc, | |||
| const mgb::SmallVector<py::object>& inputs) { | |||
| auto f = [pyf](OpDef& op, const mgb::SmallVector<py::object>& inputs) { | |||
| return py::cast<mgb::SmallVector<py::object>>(pyf(op.shared_from_this(), inputs)); | |||
| }; | |||
| auto c = [pyc](const TensorPtr& tensor) { | |||
| return pyc(tensor->dev_tensor()); | |||
| }; | |||
| return self.graph().interpret<py::object>(f, c, inputs); | |||
| }); | |||
| #include "opdef.py.inl" | |||
| _init_py_op_def(m); | |||
| _init_py_backward_graph(m); | |||
| INIT_ALL_OP(m) | |||
| } | |||
| @@ -76,7 +76,7 @@ cg::OperatorNodeBase* apply_on_var_node( | |||
| SmallVector<TensorPtr> apply_on_physical_tensor( | |||
| const OpDef& def, | |||
| const SmallVector<TensorPtr>& inputs) { | |||
| auto opr = def.cast_final_safe<CondTake>(); | |||
| auto&& opr = def.cast_final_safe<CondTake>(); | |||
| mgb_assert(opr.same_type<CondTake>()); | |||
| mgb_assert(inputs.size() == 2, "CondTake take 2 inputs, got %lu", | |||
| inputs.size()); | |||
| @@ -111,7 +111,7 @@ cg::OperatorNodeBase* param_pack_split_apply_on_var_node( | |||
| SmallVector<TensorPtr> param_pack_split_apply_on_physical_tensor( | |||
| const OpDef& def, | |||
| const SmallVector<TensorPtr>& inputs) { | |||
| auto param = def.cast_final_safe<ParamPackSplit>(); | |||
| auto&& param = def.cast_final_safe<ParamPackSplit>(); | |||
| mgb_assert(inputs.size() == 1, "ParamPackSplit take 1 input, got %lu", inputs.size()); | |||
| auto&& inp = inputs[0]; | |||
| auto&& shp = inp->layout(); | |||
| @@ -27,6 +27,7 @@ struct BackwardGraphResult { | |||
| }; | |||
| class OpDef : public Hashable, | |||
| public NonCopyableObj, | |||
| public std::enable_shared_from_this<OpDef> { | |||
| mutable const OpTrait* m_trait = nullptr; | |||
| public: | |||
| @@ -64,7 +65,7 @@ template<typename T> | |||
| class OpDefImplBase : public OpDef { | |||
| public: | |||
| template<typename ...Args> | |||
| static std::shared_ptr<OpDef> make(Args&& ...args) { | |||
| static std::shared_ptr<T> make(Args&& ...args) { | |||
| return std::make_shared<T>(std::forward<Args>(args)...); | |||
| } | |||
| }; | |||
| @@ -10,5 +10,6 @@ set(LLVM_TARGET_DEFINITIONS ${MGE_IR_DIR}/ops.td) | |||
| tablegen(MGB opdef.h.inl ${MGE_IR_INCLUDE_DIRS} "--gen-cpp-header") | |||
| tablegen(MGB opdef.cpp.inl ${MGE_IR_INCLUDE_DIRS} "--gen-cpp-body") | |||
| tablegen(MGB opdef.py.inl ${MGE_IR_INCLUDE_DIRS} "--gen-python-binding") | |||
| add_custom_target(mgb_opdef ALL DEPENDS opdef.h.inl opdef.cpp.inl opdef.py.inl param_defs_tblgen) | |||
| tablegen(MGB opdef.cpy.inl ${MGE_IR_INCLUDE_DIRS} "--gen-python-c-extension") | |||
| add_custom_target(mgb_opdef ALL DEPENDS opdef.h.inl opdef.cpp.inl opdef.py.inl opdef.cpy.inl param_defs_tblgen) | |||
| set(MGB_OPDEF_OUT_DIR ${CMAKE_CURRENT_BINARY_DIR} PARENT_SCOPE) | |||
| @@ -11,7 +11,8 @@ enum ActionType { | |||
| None, | |||
| CppHeader, | |||
| CppBody, | |||
| Pybind | |||
| Pybind, | |||
| CPython | |||
| }; | |||
| // NOLINTNEXTLINE | |||
| @@ -22,7 +23,9 @@ llvm::cl::opt<ActionType> action( | |||
| clEnumValN(CppBody, "gen-cpp-body", | |||
| "Generate operator cpp body"), | |||
| clEnumValN(Pybind, "gen-python-binding", | |||
| "Generate pybind11 python bindings"))); | |||
| "Generate pybind11 python bindings"), | |||
| clEnumValN(CPython, "gen-python-c-extension", | |||
| "Generate python c extensions"))); | |||
| using MgbAttrWrapper = mlir::tblgen::MgbAttrWrapperBase; | |||
| using MgbEnumAttr = mlir::tblgen::MgbEnumAttrMixin; | |||
| @@ -196,7 +199,7 @@ static void gen_op_def_c_body_single(raw_ostream &os, MgbOp& op) { | |||
| formatMethImpl("hash") | |||
| ); | |||
| os << formatv( | |||
| " auto op_ = def_.cast_final_safe<{0}>();\n" | |||
| " auto&& op_ = def_.cast_final_safe<{0}>();\n" | |||
| " static_cast<void>(op_);\n", | |||
| className | |||
| ); | |||
| @@ -210,8 +213,8 @@ static void gen_op_def_c_body_single(raw_ostream &os, MgbOp& op) { | |||
| formatMethImpl("is_same_st") | |||
| ); | |||
| os << formatv( | |||
| " auto a_ = lhs_.cast_final_safe<{0}>(),\n" | |||
| " b_ = rhs_.cast_final_safe<{0}>();\n" | |||
| " auto &&a_ = lhs_.cast_final_safe<{0}>(),\n" | |||
| " &&b_ = rhs_.cast_final_safe<{0}>();\n" | |||
| " static_cast<void>(a_);\n" | |||
| " static_cast<void>(b_);\n", | |||
| className | |||
| @@ -237,15 +240,15 @@ static void gen_op_def_c_body_single(raw_ostream &os, MgbOp& op) { | |||
| } | |||
| } | |||
| struct PybindContext { | |||
| std::unordered_map<unsigned int, std::string> enumAlias; | |||
| struct EnumContext { | |||
| std::unordered_map<unsigned int, std::pair<llvm::StringRef, llvm::StringRef>> enumAlias; | |||
| }; | |||
| static void gen_op_def_pybind11_single(raw_ostream &os, MgbOp& op, PybindContext& ctx) { | |||
| auto class_name = op.getCppClassName(); | |||
| static void gen_op_def_pybind11_single(raw_ostream &os, MgbOp& op, EnumContext& ctx) { | |||
| auto className = op.getCppClassName(); | |||
| os << formatv( | |||
| "py::class_<{0}, std::shared_ptr<{0}>, OpDef> {0}Inst(m, \"{0}\");\n\n", | |||
| class_name | |||
| className | |||
| ); | |||
| for (auto&& i : op.getMgbAttributes()) { | |||
| if (auto attr = llvm::dyn_cast<MgbEnumAttr>(&i.attr)) { | |||
| @@ -263,17 +266,17 @@ static void gen_op_def_pybind11_single(raw_ostream &os, MgbOp& op, PybindContext | |||
| if (iter == enumAlias.end()) { | |||
| os << formatv( | |||
| "py::enum_<{0}::{1}>({0}Inst, \"{1}\")", | |||
| class_name, attr->getEnumName() | |||
| className, attr->getEnumName() | |||
| ); | |||
| std::vector<std::string> body; | |||
| for (auto&& i: attr->getEnumMembers()) { | |||
| os << formatv( | |||
| "\n .value(\"{2}\", {0}::{1}::{2})", | |||
| class_name, attr->getEnumName(), i | |||
| className, attr->getEnumName(), i | |||
| ); | |||
| body.push_back(formatv( | |||
| "if (str == \"{2}\") return {0}::{1}::{2};", | |||
| class_name, attr->getEnumName(), i | |||
| className, attr->getEnumName(), i | |||
| )); | |||
| } | |||
| os << formatv( | |||
| @@ -286,21 +289,21 @@ static void gen_op_def_pybind11_single(raw_ostream &os, MgbOp& op, PybindContext | |||
| ); | |||
| os << formatv( | |||
| "py::implicitly_convertible<std::string, {0}::{1}>();\n\n", | |||
| class_name, attr->getEnumName() | |||
| className, attr->getEnumName() | |||
| ); | |||
| enumAlias.emplace(enumID, formatv( | |||
| "{0}Inst.attr(\"{1}\")", class_name, attr->getEnumName() | |||
| )); | |||
| enumAlias.emplace(enumID, | |||
| std::make_pair(className, attr->getEnumName())); | |||
| } else { | |||
| os << formatv( | |||
| "{0}Inst.attr(\"{1}\") = {2};\n\n", | |||
| class_name, attr->getEnumName(), iter->second | |||
| "{0}Inst.attr(\"{1}\") = {2}Inst.attr(\"{3}\");\n\n", | |||
| className, attr->getEnumName(), | |||
| iter->second.first, iter->second.second | |||
| ); | |||
| } | |||
| } | |||
| } | |||
| // generate op class binding | |||
| os << formatv("{0}Inst", class_name); | |||
| os << formatv("{0}Inst", className); | |||
| bool hasDefaultCtor = op.getMgbAttributes().empty(); | |||
| if (!hasDefaultCtor) { | |||
| os << "\n .def(py::init<"; | |||
| @@ -327,12 +330,184 @@ static void gen_op_def_pybind11_single(raw_ostream &os, MgbOp& op, PybindContext | |||
| for (auto &&i : op.getMgbAttributes()) { | |||
| os << formatv( | |||
| "\n .def_readwrite(\"{0}\", &{1}::{0})", | |||
| i.name, class_name | |||
| i.name, className | |||
| ); | |||
| } | |||
| os << ";\n\n"; | |||
| } | |||
| static void gen_op_def_python_c_extension_single(raw_ostream &os, MgbOp& op, EnumContext& ctx) { | |||
| auto className = op.getCppClassName(); | |||
| std::string body; | |||
| // generate PyType for enum class member | |||
| for (auto&& i : op.getMgbAttributes()) { | |||
| if (auto attr = llvm::dyn_cast<MgbEnumAttr>(&i.attr)) { | |||
| unsigned int enumID; | |||
| if (auto alias = llvm::dyn_cast<MgbAliasAttr>(attr)) { | |||
| auto&& aliasBase = alias->getAliasBase(); | |||
| enumID = | |||
| llvm::cast<MgbEnumAttr>(aliasBase) | |||
| .getBaseRecord()->getID(); | |||
| } else { | |||
| enumID = attr->getBaseRecord()->getID(); | |||
| } | |||
| auto&& enumAlias = ctx.enumAlias; | |||
| auto&& iter = enumAlias.find(enumID); | |||
| auto enumName = attr->getEnumName(); | |||
| body += "{\n"; | |||
| body += formatv( | |||
| "auto& e_type = EnumWrapper<{0}::{1}>::type;", className, enumName | |||
| ); | |||
| if (iter == enumAlias.end()) { | |||
| os << formatv( | |||
| "template<> PyTypeObject EnumWrapper<{0}::{1}>::type={{};\n", | |||
| className, enumName); | |||
| os << formatv( | |||
| "template<> const char* EnumWrapper<{0}::{1}>::name = \"{0}.{1}\";\n", | |||
| className, enumName); | |||
| std::vector<std::string> pairStr; | |||
| for (auto&& i: attr->getEnumMembers()) { | |||
| pairStr.push_back(formatv( | |||
| "{{normalize_enum(\"{2}\"), {0}::{1}::{2}}", | |||
| className, enumName, i)); | |||
| } | |||
| os << formatv(R"( | |||
| template<> std::unordered_map<std::string, {0}::{1}> | |||
| EnumWrapper<{0}::{1}>::str2type = {{ | |||
| {2} | |||
| }; | |||
| )", className, enumName, llvm::join(pairStr, ", ")); | |||
| pairStr.clear(); | |||
| for (auto&& i: attr->getEnumMembers()) { | |||
| pairStr.push_back(formatv( | |||
| "{{{0}::{1}::{2}, normalize_enum(\"{2}\")}", | |||
| className, enumName, i)); | |||
| } | |||
| os << formatv(R"( | |||
| template<> std::unordered_map<{0}::{1}, std::string> | |||
| EnumWrapper<{0}::{1}>::type2str = {{ | |||
| {2} | |||
| }; | |||
| )", className, enumName, llvm::join(pairStr, ", ")); | |||
| body += formatv(R"( | |||
| e_type = {{PyVarObject_HEAD_INIT(NULL, 0)}; | |||
| e_type.tp_name = "megengine.core._imperative_rt.ops.{0}.{1}"; | |||
| e_type.tp_basicsize = sizeof(EnumWrapper<{0}::{1}>); | |||
| e_type.tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE; | |||
| e_type.tp_doc = "{0}.{1}"; | |||
| e_type.tp_base = &PyBaseObject_Type; | |||
| e_type.tp_repr = EnumWrapper<{0}::{1}>::py_repr; | |||
| e_type.tp_richcompare = EnumWrapper<{0}::{1}>::tp_richcompare; | |||
| mgb_assert(PyType_Ready(&e_type) >= 0); | |||
| )", className, enumName); | |||
| for (auto&& i: attr->getEnumMembers()) { | |||
| body += formatv(R"({{ | |||
| PyObject* inst = e_type.tp_alloc(&e_type, 0); | |||
| reinterpret_cast<EnumWrapper<{0}::{1}>*>(inst)->value = {0}::{1}::{2}; | |||
| mgb_assert(PyDict_SetItemString(e_type.tp_dict, "{2}", inst) >= 0); | |||
| })", className, enumName, i); | |||
| } | |||
| enumAlias.emplace(enumID, std::make_pair(className, enumName)); | |||
| } | |||
| body += formatv(R"( | |||
| PyType_Modified(&e_type); | |||
| mgb_assert(PyDict_SetItemString( | |||
| py_type.tp_dict, "{0}", reinterpret_cast<PyObject*>(&e_type)) >= 0); | |||
| )", enumName); | |||
| body += "}\n"; | |||
| } | |||
| } | |||
| // generate getsetters | |||
| std::vector<std::string> getsetters; | |||
| for (auto &&i : op.getMgbAttributes()) { | |||
| getsetters.push_back(formatv( | |||
| "{{\"{1}\", py_get_generic({0}, {1}), py_set_generic({0}, {1}), \"{1}\", NULL},", | |||
| className, i.name)); | |||
| } | |||
| // generate tp_init | |||
| std::string initBody; | |||
| if (!op.getMgbAttributes().empty()) { | |||
| initBody += "static const char* kwlist[] = {"; | |||
| llvm::for_each(op.getMgbAttributes(), [&](auto&& attr) { | |||
| initBody += formatv("\"{0}\", ", attr.name); | |||
| }); | |||
| initBody += "NULL};\n"; | |||
| initBody += " PyObject "; | |||
| std::vector<std::string> attrs; | |||
| llvm::for_each(op.getMgbAttributes(), [&](auto&& attr) { | |||
| attrs.push_back(formatv("*{0} = NULL", attr.name)); | |||
| }); | |||
| initBody += llvm::join(attrs, ", ") + ";\n"; | |||
| initBody += " if (!PyArg_ParseTupleAndKeywords(args, kwds, \"|"; | |||
| initBody += std::string(op.getMgbAttributes().size(), 'O'); | |||
| initBody += "\", const_cast<char**>(kwlist)"; | |||
| llvm::for_each(op.getMgbAttributes(), [&](auto&& attr) { | |||
| initBody += formatv(" ,&{0}", attr.name); | |||
| }); | |||
| initBody += "))\n"; | |||
| initBody += " return -1;\n"; | |||
| llvm::for_each(op.getMgbAttributes(), [&](auto&& attr) { | |||
| initBody += formatv(R"( | |||
| if ({1}) {{ | |||
| try {{ | |||
| reinterpret_cast<PyOp({0})*>(self)->inst().{1} = | |||
| pyobj_convert_generic<decltype({0}::{1})>::from({1}); | |||
| } catch(py::error_already_set& e) {{ | |||
| e.restore(); | |||
| return -1; | |||
| } catch(py::builtin_exception& e) {{ | |||
| e.set_error(); | |||
| return -1; | |||
| } catch(...) {{ | |||
| PyErr_SetString(PyExc_RuntimeError, "Unknown Error"); | |||
| return -1; | |||
| } | |||
| } | |||
| )", className, attr.name); | |||
| }); | |||
| } | |||
| initBody += "\n return 0;"; | |||
| os << formatv(R"( | |||
| PyOpDefBegin({0}) // {{ | |||
| static PyGetSetDef py_getsetters[]; | |||
| static int py_init(PyObject *self, PyObject *args, PyObject *kwds); | |||
| // }; | |||
| PyOpDefEnd({0}) | |||
| PyGetSetDef PyOp({0})::py_getsetters[] = {{ | |||
| {1} | |||
| {{NULL} /* Sentinel */ | |||
| }; | |||
| int PyOp({0})::py_init(PyObject *self, PyObject *args, PyObject *kwds) {{ | |||
| {2} | |||
| } | |||
| void _init_py_{0}(py::module m) {{ | |||
| using py_op = PyOp({0}); | |||
| auto& py_type = PyOpType({0}); | |||
| py_type = {{PyVarObject_HEAD_INIT(NULL, 0)}; | |||
| py_type.tp_name = "megengine.core._imperative_rt.ops.{0}"; | |||
| py_type.tp_basicsize = sizeof(PyOp({0})); | |||
| py_type.tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE; | |||
| py_type.tp_doc = "{0}"; | |||
| py_type.tp_base = &PyOpType(OpDef); | |||
| py_type.tp_dealloc = py_dealloc_generic<py_op>; | |||
| py_type.tp_new = py_new_generic<py_op>; | |||
| py_type.tp_init = py_op::py_init; | |||
| py_type.tp_getset = py_op::py_getsetters; | |||
| mgb_assert(PyType_Ready(&py_type) >= 0); | |||
| {3} | |||
| PyType_Modified(&py_type); | |||
| m.add_object("{0}", reinterpret_cast<PyObject*>(&py_type)); | |||
| mgb_assert(PyOp(OpDef)::ctype2pytype.emplace({0}::typeinfo(), &py_type).second); | |||
| } | |||
| )", | |||
| op.getCppClassName(), llvm::join(getsetters, "\n "), initBody, body); | |||
| } | |||
| static void for_each_operator(raw_ostream &os, RecordKeeper &keeper, | |||
| std::function<void(raw_ostream&, MgbOp&)> callback) { | |||
| auto op_base_class = keeper.getClass("Op"); | |||
| @@ -360,13 +535,26 @@ static bool gen_op_def_c_body(raw_ostream &os, RecordKeeper &keeper) { | |||
| } | |||
| static bool gen_op_def_pybind11(raw_ostream &os, RecordKeeper &keeper) { | |||
| PybindContext ctx; | |||
| EnumContext ctx; | |||
| using namespace std::placeholders; | |||
| for_each_operator(os, keeper, | |||
| std::bind(gen_op_def_pybind11_single, _1, _2, std::ref(ctx))); | |||
| return false; | |||
| } | |||
| static bool gen_op_def_python_c_extension(raw_ostream &os, RecordKeeper &keeper) { | |||
| EnumContext ctx; | |||
| using namespace std::placeholders; | |||
| for_each_operator(os, keeper, | |||
| std::bind(gen_op_def_python_c_extension_single, _1, _2, std::ref(ctx))); | |||
| os << "#define INIT_ALL_OP(m)"; | |||
| for_each_operator(os, keeper, [&](raw_ostream& os, MgbOp& op) { | |||
| os << formatv(" \\\n _init_py_{0}(m);", op.getCppClassName()); | |||
| }); | |||
| os << "\n"; | |||
| return false; | |||
| } | |||
| int main(int argc, char **argv) { | |||
| llvm::InitLLVM y(argc, argv); | |||
| llvm::cl::ParseCommandLineOptions(argc, argv); | |||
| @@ -379,5 +567,8 @@ int main(int argc, char **argv) { | |||
| if (action == ActionType::Pybind) { | |||
| return TableGenMain(argv[0], &gen_op_def_pybind11); | |||
| } | |||
| if (action == ActionType::CPython) { | |||
| return TableGenMain(argv[0], &gen_op_def_python_c_extension); | |||
| } | |||
| return -1; | |||
| } | |||
| } | |||