GitOrigin-RevId: b81b085762
tags/v1.3.1
| @@ -1,6 +1,7 @@ | |||
| # mgb tablegen executable | |||
| set(TABLE_TARGET mgb-mlir-autogen) | |||
| add_executable(${TABLE_TARGET} autogen.cpp) | |||
| file(GLOB_RECURSE SRCS ${CMAKE_CURRENT_SOURCE_DIR}/*.h ${CMAKE_CURRENT_SOURCE_DIR}/*.cpp) | |||
| add_executable(${TABLE_TARGET} ${SRCS}) | |||
| target_include_directories(${TABLE_TARGET} PRIVATE ${MLIR_LLVM_INCLUDE_DIR}) | |||
| target_link_libraries(${TABLE_TARGET} PRIVATE LLVMTableGen MLIRTableGen LLVMSupport) | |||
| set(MGB_TABLEGEN_EXE ${TABLE_TARGET}) | |||
| @@ -1,8 +1,17 @@ | |||
| #include <iostream> | |||
| #include <unordered_map> | |||
| #include <functional> | |||
| #include "./helper.h" | |||
| /** | |||
| * \file imperative/tablegen/autogen.cpp | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| */ | |||
| #include "./targets/cpp_class.h" | |||
| #include "./targets/pybind11.h" | |||
| #include "./targets/python_c_extension.h" | |||
| using llvm::raw_ostream; | |||
| using llvm::RecordKeeper; | |||
| @@ -27,731 +36,7 @@ llvm::cl::opt<ActionType> action( | |||
| clEnumValN(CPython, "gen-python-c-extension", | |||
| "Generate python c extensions"))); | |||
| using MgbAttrWrapper = mlir::tblgen::MgbAttrWrapperBase; | |||
| using MgbEnumAttr = mlir::tblgen::MgbEnumAttrMixin; | |||
| using MgbHashableAttr = mlir::tblgen::MgbHashableAttrMixin; | |||
| using MgbAliasAttr = mlir::tblgen::MgbAliasAttrMixin; | |||
| using MgbOp = mlir::tblgen::MgbOpBase; | |||
| using MgbHashableOp = mlir::tblgen::MgbHashableOpMixin; | |||
| llvm::StringRef attr_to_ctype(const mlir::tblgen::Attribute& attr_) { | |||
| // Note: we have already registered the corresponding attr wrappers | |||
| // for following basic ctypes so we needn't handle them here | |||
| /* auto&& attr_type_name = attr.getAttrDefName(); | |||
| if (attr_type_name == "UI32Attr") { | |||
| return "uint32_t"; | |||
| } | |||
| if (attr_type_name == "UI64Attr") { | |||
| return "uint64_t"; | |||
| } | |||
| if (attr_type_name == "I32Attr") { | |||
| return "int32_t"; | |||
| } | |||
| if (attr_type_name == "F32Attr") { | |||
| return "float"; | |||
| } | |||
| if (attr_type_name == "F64Attr") { | |||
| return "double"; | |||
| } | |||
| if (attr_type_name == "StrAttr") { | |||
| return "std::string"; | |||
| } | |||
| if (attr_type_name == "BoolAttr") { | |||
| return "bool"; | |||
| }*/ | |||
| auto&& attr = llvm::cast<MgbAttrWrapper>(attr_); | |||
| if (auto e = llvm::dyn_cast<MgbEnumAttr>(&attr)) { | |||
| return e->getEnumName(); | |||
| } | |||
| return attr.getUnderlyingType(); | |||
| } | |||
| static void gen_op_def_c_header_single(raw_ostream &os, MgbOp& op) { | |||
| os << formatv( | |||
| "class {0} : public OpDefImplBase<{0}> {{\n" | |||
| " MGB_DYN_TYPE_OBJ_FINAL_DECL;\n\n" | |||
| "public:\n", | |||
| op.getCppClassName() | |||
| ); | |||
| // handle enum alias | |||
| for (auto &&i : op.getMgbAttributes()) { | |||
| if (auto attr = llvm::dyn_cast<MgbEnumAttr>(&i.attr)) { | |||
| os << formatv( | |||
| " using {0} = {1};\n", | |||
| attr->getEnumName(), attr->getUnderlyingType() | |||
| ); | |||
| } | |||
| } | |||
| for (auto &&i : op.getMgbAttributes()) { | |||
| auto defaultValue = i.attr.getDefaultValue().str(); | |||
| if (!defaultValue.empty()) { | |||
| defaultValue = formatv(" = {0}", defaultValue); | |||
| } | |||
| os << formatv( | |||
| " {0} {1}{2};\n", | |||
| attr_to_ctype(i.attr), i.name, defaultValue | |||
| ); | |||
| } | |||
| auto gen_ctor = [&](auto&& paramList, auto&& memInitList, auto&& body) { | |||
| os << formatv( | |||
| " {0}({1}){2}{3}\n", | |||
| op.getCppClassName(), paramList, memInitList, body | |||
| ); | |||
| }; | |||
| gen_ctor("", "", " = default;"); | |||
| if (!op.getMgbAttributes().empty()) { | |||
| std::vector<std::string> paramList, initList; | |||
| for (auto &&i : op.getMgbAttributes()) { | |||
| paramList.push_back(formatv( | |||
| "{0} {1}_", attr_to_ctype(i.attr), i.name | |||
| )); | |||
| initList.push_back(formatv( | |||
| "{0}({0}_)", i.name | |||
| )); | |||
| } | |||
| paramList.push_back("std::string scope_ = {}"); | |||
| gen_ctor(llvm::join(paramList, ", "), | |||
| ": " + llvm::join(initList, ", "), | |||
| " { set_scope(scope_); }"); | |||
| } | |||
| auto packedParams = op.getPackedParams(); | |||
| if (!packedParams.empty()) { | |||
| std::vector<std::string> paramList, initList; | |||
| for (auto &&p : packedParams) { | |||
| auto&& paramFields = p.getFields(); | |||
| auto&& paramType = p.getFullName(); | |||
| auto&& paramName = formatv("packed_param_{0}", paramList.size()); | |||
| paramList.push_back( | |||
| paramFields.empty() ? paramType.str() | |||
| : formatv("{0} {1}", paramType, paramName) | |||
| ); | |||
| for (auto&& i : paramFields) { | |||
| initList.push_back(formatv( | |||
| "{0}({1}.{0})", i.name, paramName | |||
| )); | |||
| } | |||
| } | |||
| for (auto&& i : op.getExtraArguments()) { | |||
| paramList.push_back(formatv( | |||
| "{0} {1}_", attr_to_ctype(i.attr), i.name | |||
| )); | |||
| initList.push_back(formatv( | |||
| "{0}({0}_)", i.name | |||
| )); | |||
| } | |||
| gen_ctor(llvm::join(paramList, ", "), | |||
| initList.empty() ? "" : ": " + llvm::join(initList, ", "), | |||
| " {}"); | |||
| } | |||
| if (!packedParams.empty()) { | |||
| for (auto&& p : packedParams) { | |||
| auto accessor = p.getAccessor(); | |||
| if (!accessor.empty()) { | |||
| os << formatv( | |||
| " {0} {1}() const {{\n", | |||
| p.getFullName(), accessor | |||
| ); | |||
| std::vector<llvm::StringRef> fields; | |||
| for (auto&& i : p.getFields()) { | |||
| fields.push_back(i.name); | |||
| } | |||
| os << formatv( | |||
| " return {{{0}};\n", | |||
| llvm::join(fields, ", ") | |||
| ); | |||
| os << " }\n"; | |||
| } | |||
| } | |||
| } | |||
| if (auto decl = op.getExtraOpdefDecl()) { | |||
| os << decl.getValue(); | |||
| } | |||
| os << formatv( | |||
| "};\n\n" | |||
| ); | |||
| } | |||
| static void gen_to_string_trait_for_enum(raw_ostream &os, MgbOp& op) { | |||
| for (auto &&i : op.getMgbAttributes()) { | |||
| if (auto attr = llvm::dyn_cast<MgbEnumAttr>(&i.attr)) { | |||
| if (attr->supportToString()) { | |||
| std::vector<std::string> case_body; | |||
| std::string ename = formatv("{0}::{1}", | |||
| op.getCppClassName(), attr->getEnumName()); | |||
| llvm::for_each(attr->getEnumMembers(), [&](auto&& v){ | |||
| case_body.push_back(formatv( | |||
| "case {0}::{1}: return \"{1}\";", ename, v)); | |||
| }); | |||
| os << formatv(R"( | |||
| template <> | |||
| struct ToStringTrait<{0}> { | |||
| std::string operator()({0} e) const { | |||
| switch (e) { | |||
| {1} | |||
| default: | |||
| return "{0}::Unknown"; | |||
| } | |||
| } | |||
| }; | |||
| )", ename, llvm::join(case_body, "\n")); | |||
| } | |||
| } | |||
| } | |||
| } | |||
| static void gen_op_def_c_body_single(raw_ostream &os, MgbOp& op) { | |||
| auto&& className = op.getCppClassName(); | |||
| os << formatv( | |||
| "MGB_DYN_TYPE_OBJ_FINAL_IMPL({0});\n\n", className | |||
| ); | |||
| auto formatMethImpl = [&](auto&& meth) { | |||
| return formatv( | |||
| "{0}_{1}_impl", className, meth | |||
| ); | |||
| }; | |||
| std::vector<std::string> methods; | |||
| if (auto hashable = llvm::dyn_cast<MgbHashableOp>(&op)) { | |||
| os << "namespace {\n"; | |||
| // generate hash() | |||
| mlir::tblgen::FmtContext ctx; | |||
| os << formatv( | |||
| "size_t {0}(const OpDef& def_) {{\n", | |||
| formatMethImpl("hash") | |||
| ); | |||
| os << formatv( | |||
| " auto&& op_ = def_.cast_final_safe<{0}>();\n" | |||
| " static_cast<void>(op_);\n", | |||
| className | |||
| ); | |||
| ctx.withSelf("op_"); | |||
| os << mlir::tblgen::tgfmt(hashable->getHashFunctionTemplate(), &ctx); | |||
| os << "}\n"; | |||
| // generate is_same_st() | |||
| os << formatv( | |||
| "bool {0}(const OpDef& lhs_, const OpDef& rhs_) {{\n", | |||
| formatMethImpl("is_same_st") | |||
| ); | |||
| os << formatv( | |||
| " auto &&a_ = lhs_.cast_final_safe<{0}>(),\n" | |||
| " &&b_ = rhs_.cast_final_safe<{0}>();\n" | |||
| " static_cast<void>(a_);\n" | |||
| " static_cast<void>(b_);\n", | |||
| className | |||
| ); | |||
| os << mlir::tblgen::tgfmt(hashable->getCmpFunctionTemplate(), &ctx, "a_", "b_"); | |||
| os << "}\n"; | |||
| // generate props() | |||
| os << formatv( | |||
| "std::vector<std::pair<const char*, std::string>> {0}(const OpDef& def_) {{\n", | |||
| formatMethImpl("props") | |||
| ); | |||
| os << formatv( | |||
| " auto&& op_ = def_.cast_final_safe<{0}>();\n" | |||
| " static_cast<void>(op_);\n", | |||
| className | |||
| ); | |||
| ctx.withSelf("op_"); | |||
| os << mlir::tblgen::tgfmt(hashable->getPropsFunctionTemplate(), &ctx); | |||
| os << "}\n"; | |||
| // generate make_name() | |||
| os << formatv( | |||
| "std::string {0}(const OpDef& def_) {{\n", formatMethImpl("make_name") | |||
| ); | |||
| os << formatv( | |||
| " auto&& op_ = def_.cast_final_safe<{0}>();\n" | |||
| " static_cast<void>(op_);\n", | |||
| className | |||
| ); | |||
| ctx.withSelf("op_"); | |||
| os << mlir::tblgen::tgfmt(op.getNameFunctionTemplate(), &ctx); | |||
| os << "}\n"; | |||
| os << "} // anonymous namespace\n"; | |||
| methods.push_back("hash"); | |||
| methods.push_back("is_same_st"); | |||
| methods.push_back("props"); | |||
| methods.push_back("make_name"); | |||
| } | |||
| if (!methods.empty()) { | |||
| os << formatv( | |||
| "OP_TRAIT_REG({0}, {0})", op.getCppClassName() | |||
| ); | |||
| for (auto&& i : methods) { | |||
| os << formatv( | |||
| "\n .{0}({1})", i, formatMethImpl(i) | |||
| ); | |||
| } | |||
| os << ";\n\n"; | |||
| } | |||
| } | |||
| struct EnumContext { | |||
| std::unordered_map<unsigned int, std::pair<llvm::StringRef, llvm::StringRef>> enumAlias; | |||
| }; | |||
| static void gen_op_def_pybind11_single(raw_ostream &os, MgbOp& op, EnumContext& ctx) { | |||
| auto className = op.getCppClassName(); | |||
| os << formatv( | |||
| "py::class_<{0}, std::shared_ptr<{0}>, OpDef> {0}Inst(m, \"{0}\");\n\n", | |||
| className | |||
| ); | |||
| for (auto&& i : op.getMgbAttributes()) { | |||
| if (auto attr = llvm::dyn_cast<MgbEnumAttr>(&i.attr)) { | |||
| unsigned int enumID; | |||
| if (auto alias = llvm::dyn_cast<MgbAliasAttr>(attr)) { | |||
| auto&& aliasBase = alias->getAliasBase(); | |||
| enumID = | |||
| llvm::cast<MgbEnumAttr>(aliasBase) | |||
| .getBaseRecord()->getID(); | |||
| } else { | |||
| enumID = attr->getBaseRecord()->getID(); | |||
| } | |||
| auto&& enumAlias = ctx.enumAlias; | |||
| auto&& iter = enumAlias.find(enumID); | |||
| if (iter == enumAlias.end()) { | |||
| os << formatv( | |||
| "py::enum_<{0}::{1}>({0}Inst, \"{1}\")", | |||
| className, attr->getEnumName() | |||
| ); | |||
| std::vector<std::string> body; | |||
| for (auto&& i: attr->getEnumMembers()) { | |||
| os << formatv( | |||
| "\n .value(\"{2}\", {0}::{1}::{2})", | |||
| className, attr->getEnumName(), i | |||
| ); | |||
| body.push_back(formatv( | |||
| "if (str == \"{2}\") return {0}::{1}::{2};", | |||
| className, attr->getEnumName(), i | |||
| )); | |||
| } | |||
| if (attr->getEnumCombinedFlag()) { | |||
| //! define operator | | |||
| os << formatv( | |||
| "\n .def(\"__or__\", []({0}::{1} s0, {0}::{1} s1) {{ " | |||
| "\n return static_cast<{0}::{1}>(uint32_t(s0) | uint32_t(s1));" | |||
| "\n })", | |||
| className, attr->getEnumName()); | |||
| //! define operator & | |||
| os << formatv( | |||
| "\n .def(\"__and__\", []({0}::{1} s0, {0}::{1} s1) {{" | |||
| "\n return static_cast<{0}::{1}>(uint32_t(s0) & uint32_t(s1));" | |||
| "\n })", | |||
| className, attr->getEnumName()); | |||
| } | |||
| os << formatv( | |||
| "\n .def(py::init([](const std::string& in) {" | |||
| "\n auto&& str = normalize_enum(in);" | |||
| "\n {0}" | |||
| "\n throw py::cast_error(\"invalid enum value \" + in);" | |||
| "\n }));\n", | |||
| llvm::join(body, "\n ") | |||
| ); | |||
| os << formatv( | |||
| "py::implicitly_convertible<std::string, {0}::{1}>();\n\n", | |||
| className, attr->getEnumName() | |||
| ); | |||
| enumAlias.emplace(enumID, | |||
| std::make_pair(className, attr->getEnumName())); | |||
| } else { | |||
| os << formatv( | |||
| "{0}Inst.attr(\"{1}\") = {2}Inst.attr(\"{3}\");\n\n", | |||
| className, attr->getEnumName(), | |||
| iter->second.first, iter->second.second | |||
| ); | |||
| } | |||
| } | |||
| } | |||
| // generate op class binding | |||
| os << formatv("{0}Inst", className); | |||
| bool hasDefaultCtor = op.getMgbAttributes().empty(); | |||
| if (!hasDefaultCtor) { | |||
| os << "\n .def(py::init<"; | |||
| std::vector<llvm::StringRef> targs; | |||
| for (auto &&i : op.getMgbAttributes()) { | |||
| targs.push_back(i.attr.getReturnType()); | |||
| } | |||
| os << llvm::join(targs, ", "); | |||
| os << ", std::string>()"; | |||
| for (auto &&i : op.getMgbAttributes()) { | |||
| os << formatv(", py::arg(\"{0}\")", i.name); | |||
| auto defaultValue = i.attr.getDefaultValue(); | |||
| if (!defaultValue.empty()) { | |||
| os << formatv(" = {0}", defaultValue); | |||
| } else { | |||
| hasDefaultCtor = true; | |||
| } | |||
| } | |||
| os << ", py::arg(\"scope\") = {})"; | |||
| } | |||
| if (hasDefaultCtor) { | |||
| os << "\n .def(py::init<>())"; | |||
| } | |||
| for (auto &&i : op.getMgbAttributes()) { | |||
| os << formatv( | |||
| "\n .def_readwrite(\"{0}\", &{1}::{0})", | |||
| i.name, className | |||
| ); | |||
| } | |||
| os << ";\n\n"; | |||
| } | |||
| static std::string gen_op_def_python_c_extension_enum( | |||
| raw_ostream& os, EnumContext& ctx, MgbEnumAttr* attr, | |||
| llvm::StringRef className) { | |||
| std::string body; | |||
| unsigned int enumID; | |||
| if (auto alias = llvm::dyn_cast<MgbAliasAttr>(attr)) { | |||
| auto&& aliasBase = alias->getAliasBase(); | |||
| enumID = llvm::cast<MgbEnumAttr>(aliasBase).getBaseRecord()->getID(); | |||
| } else { | |||
| enumID = attr->getBaseRecord()->getID(); | |||
| } | |||
| auto&& enumAlias = ctx.enumAlias; | |||
| auto&& iter = enumAlias.find(enumID); | |||
| auto enumName = attr->getEnumName(); | |||
| body += "{\n"; | |||
| body += formatv("auto& e_type = EnumWrapper<{0}::{1}>::type;", className, | |||
| enumName); | |||
| if (iter == enumAlias.end()) { | |||
| os << formatv( | |||
| "template<> PyTypeObject EnumWrapper<{0}::{1}>::type={{};\n", | |||
| className, enumName); | |||
| os << formatv( | |||
| "template<> const char* EnumWrapper<{0}::{1}>::name = " | |||
| "\"{0}.{1}\";\n", | |||
| className, enumName); | |||
| std::vector<std::string> pairStr; | |||
| for (auto&& i : attr->getEnumMembers()) { | |||
| pairStr.push_back( | |||
| formatv("{{normalize_enum(\"{2}\"), {0}::{1}::{2}}", | |||
| className, enumName, i)); | |||
| } | |||
| os << formatv(R"( | |||
| template<> std::unordered_map<std::string, {0}::{1}> | |||
| EnumWrapper<{0}::{1}>::str2type = {{ | |||
| {2} | |||
| }; | |||
| )", | |||
| className, enumName, llvm::join(pairStr, ", ")); | |||
| pairStr.clear(); | |||
| for (auto&& i : attr->getEnumMembers()) { | |||
| pairStr.push_back( | |||
| formatv("{{{0}::{1}::{2}, normalize_enum(\"{2}\")}", | |||
| className, enumName, i)); | |||
| } | |||
| os << formatv(R"( | |||
| template<> std::unordered_map<{0}::{1}, std::string> | |||
| EnumWrapper<{0}::{1}>::type2str = {{ | |||
| {2} | |||
| }; | |||
| )", | |||
| className, enumName, llvm::join(pairStr, ", ")); | |||
| body += formatv(R"( | |||
| e_type = {{PyVarObject_HEAD_INIT(NULL, 0)}; | |||
| e_type.tp_name = "megengine.core._imperative_rt.ops.{0}.{1}"; | |||
| e_type.tp_basicsize = sizeof(EnumWrapper<{0}::{1}>); | |||
| e_type.tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE; | |||
| e_type.tp_doc = "{0}.{1}"; | |||
| e_type.tp_base = &PyBaseObject_Type; | |||
| e_type.tp_repr = EnumWrapper<{0}::{1}>::py_repr; | |||
| e_type.tp_richcompare = EnumWrapper<{0}::{1}>::tp_richcompare; | |||
| mgb_assert(PyType_Ready(&e_type) >= 0); | |||
| )", | |||
| className, enumName); | |||
| for (auto&& i : attr->getEnumMembers()) { | |||
| body += formatv(R"({{ | |||
| PyObject* inst = e_type.tp_alloc(&e_type, 0); | |||
| reinterpret_cast<EnumWrapper<{0}::{1}>*>(inst)->value = {0}::{1}::{2}; | |||
| mgb_assert(PyDict_SetItemString(e_type.tp_dict, "{2}", inst) >= 0); | |||
| })", | |||
| className, enumName, i); | |||
| } | |||
| enumAlias.emplace(enumID, std::make_pair(className, enumName)); | |||
| } | |||
| body += formatv(R"( | |||
| PyType_Modified(&e_type); | |||
| mgb_assert(PyDict_SetItemString( | |||
| py_type.tp_dict, "{0}", reinterpret_cast<PyObject*>(&e_type)) >= 0); | |||
| )", | |||
| enumName); | |||
| body += "}\n"; | |||
| return body; | |||
| } | |||
| static std::string gen_op_def_python_c_extension_bit_combined_enum( | |||
| raw_ostream& os, EnumContext& ctx, MgbEnumAttr* attr, | |||
| llvm::StringRef className) { | |||
| std::string body; | |||
| unsigned int enumID; | |||
| if (auto alias = llvm::dyn_cast<MgbAliasAttr>(attr)) { | |||
| auto&& aliasBase = alias->getAliasBase(); | |||
| enumID = llvm::cast<MgbEnumAttr>(aliasBase).getBaseRecord()->getID(); | |||
| } else { | |||
| enumID = attr->getBaseRecord()->getID(); | |||
| } | |||
| auto&& enumAlias = ctx.enumAlias; | |||
| auto&& iter = enumAlias.find(enumID); | |||
| auto enumName = attr->getEnumName(); | |||
| body += "{\n"; | |||
| body += formatv("auto& e_type = BitCombinedEnumWrapper<{0}::{1}>::type;", | |||
| className, enumName); | |||
| if (iter == enumAlias.end()) { | |||
| os << formatv( | |||
| "template<> PyTypeObject " | |||
| "BitCombinedEnumWrapper<{0}::{1}>::type={{};\n", | |||
| className, enumName); | |||
| os << formatv( | |||
| "template<> PyNumberMethods " | |||
| "BitCombinedEnumWrapper<{0}::{1}>::number_methods={{};\n", | |||
| className, enumName); | |||
| os << formatv( | |||
| "template<> const char* BitCombinedEnumWrapper<{0}::{1}>::name " | |||
| "= \"{0}.{1}\";\n", | |||
| className, enumName); | |||
| os << formatv( | |||
| "template<> struct EnumTrait<{0}::{1}> {{ static constexpr " | |||
| "bool is_bit_combined = true;};\n", | |||
| className, enumName); | |||
| std::vector<std::string> pairStr; | |||
| for (auto&& i : attr->getEnumMembers()) { | |||
| pairStr.push_back( | |||
| formatv("{{normalize_enum(\"{2}\"), {0}::{1}::{2}}", | |||
| className, enumName, i)); | |||
| } | |||
| os << formatv(R"( | |||
| template<> std::unordered_map<std::string, {0}::{1}> | |||
| BitCombinedEnumWrapper<{0}::{1}>::str2type = {{ | |||
| {2} | |||
| }; | |||
| )", | |||
| className, enumName, llvm::join(pairStr, ", ")); | |||
| pairStr.clear(); | |||
| for (auto&& i : attr->getEnumMembers()) { | |||
| pairStr.push_back( | |||
| formatv("{{{0}::{1}::{2}, normalize_enum(\"{2}\")}", | |||
| className, enumName, i)); | |||
| } | |||
| os << formatv(R"( | |||
| template<> std::unordered_map<{0}::{1}, std::string> | |||
| BitCombinedEnumWrapper<{0}::{1}>::type2str = {{ | |||
| {2} | |||
| }; | |||
| )", | |||
| className, enumName, llvm::join(pairStr, ", ")); | |||
| body += formatv(R"( | |||
| e_type = {{PyVarObject_HEAD_INIT(NULL, 0)}; | |||
| e_type.tp_name = "megengine.core._imperative_rt.ops.{0}.{1}"; | |||
| e_type.tp_basicsize = sizeof(BitCombinedEnumWrapper<{0}::{1}>); | |||
| e_type.tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE; | |||
| e_type.tp_doc = "{0}.{1}"; | |||
| e_type.tp_base = &PyBaseObject_Type; | |||
| e_type.tp_new = BitCombinedEnumWrapper<{0}::{1}>::py_new_combined_enum; | |||
| e_type.tp_init = BitCombinedEnumWrapper<{0}::{1}>::py_init; | |||
| e_type.tp_repr = BitCombinedEnumWrapper<{0}::{1}>::py_repr; | |||
| e_type.tp_richcompare = BitCombinedEnumWrapper<{0}::{1}>::tp_richcompare; | |||
| auto& number_method = BitCombinedEnumWrapper<{0}::{1}>::number_methods; | |||
| number_method.nb_or = BitCombinedEnumWrapper<{0}::{1}>::py_or; | |||
| number_method.nb_and = BitCombinedEnumWrapper<{0}::{1}>::py_and; | |||
| e_type.tp_as_number = &number_method; | |||
| mgb_assert(PyType_Ready(&e_type) >= 0); | |||
| )", | |||
| className, enumName); | |||
| for (auto&& i : attr->getEnumMembers()) { | |||
| body += formatv(R"({{ | |||
| PyObject* inst = e_type.tp_alloc(&e_type, 0); | |||
| reinterpret_cast<BitCombinedEnumWrapper<{0}::{1}>*>(inst)->value = {0}::{1}::{2}; | |||
| mgb_assert(PyDict_SetItemString(e_type.tp_dict, "{2}", inst) >= 0); | |||
| })", | |||
| className, enumName, i); | |||
| } | |||
| enumAlias.emplace(enumID, std::make_pair(className, enumName)); | |||
| } | |||
| body += formatv(R"( | |||
| PyType_Modified(&e_type); | |||
| mgb_assert(PyDict_SetItemString( | |||
| py_type.tp_dict, "{0}", reinterpret_cast<PyObject*>(&e_type)) >= 0); | |||
| )", | |||
| enumName); | |||
| body += "}\n"; | |||
| return body; | |||
| } | |||
| static void gen_op_def_python_c_extension_single(raw_ostream &os, MgbOp& op, EnumContext& ctx) { | |||
| auto className = op.getCppClassName(); | |||
| std::string body; | |||
| // generate PyType for enum class member | |||
| for (auto&& i : op.getMgbAttributes()) { | |||
| if (auto attr = llvm::dyn_cast<MgbEnumAttr>(&i.attr)) { | |||
| if (attr->getEnumCombinedFlag()) { | |||
| body += gen_op_def_python_c_extension_bit_combined_enum( | |||
| os, ctx, attr, className); | |||
| } else { | |||
| body += gen_op_def_python_c_extension_enum(os, ctx, attr, | |||
| className); | |||
| } | |||
| } | |||
| } | |||
| // generate getsetters | |||
| std::vector<std::string> getsetters; | |||
| for (auto &&i : op.getMgbAttributes()) { | |||
| getsetters.push_back(formatv( | |||
| "{{const_cast<char*>(\"{1}\"), py_get_generic({0}, {1}), py_set_generic({0}, {1}), const_cast<char*>(\"{1}\"), NULL},", | |||
| className, i.name)); | |||
| } | |||
| // generate tp_init | |||
| std::string initBody; | |||
| if (!op.getMgbAttributes().empty()) { | |||
| initBody += "static const char* kwlist[] = {"; | |||
| std::vector<llvm::StringRef> attr_name_list; | |||
| llvm::for_each(op.getMgbAttributes(), [&](auto&& attr) { | |||
| attr_name_list.push_back(attr.name); | |||
| }); | |||
| attr_name_list.push_back("scope"); | |||
| llvm::for_each(attr_name_list, [&](auto&& attr) { | |||
| initBody += formatv("\"{0}\", ", attr); | |||
| }); | |||
| initBody += "NULL};\n"; | |||
| initBody += " PyObject "; | |||
| std::vector<std::string> attr_init; | |||
| llvm::for_each(attr_name_list, [&](auto&& attr) { | |||
| attr_init.push_back(formatv("*{0} = NULL", attr)); | |||
| }); | |||
| initBody += llvm::join(attr_init, ", ") + ";\n"; | |||
| initBody += " if (!PyArg_ParseTupleAndKeywords(args, kwds, \"|"; | |||
| // an extra slot created for name | |||
| initBody += std::string(attr_name_list.size(), 'O'); | |||
| initBody += "\", const_cast<char**>(kwlist)"; | |||
| llvm::for_each(attr_name_list, [&](auto&& attr) { | |||
| initBody += formatv(", &{0}", attr); | |||
| }); | |||
| initBody += "))\n"; | |||
| initBody += " return -1;\n"; | |||
| llvm::for_each(op.getMgbAttributes(), [&](auto&& attr) { | |||
| initBody += formatv(R"( | |||
| if ({1}) {{ | |||
| try {{ | |||
| reinterpret_cast<PyOp({0})*>(self)->inst().{1} = | |||
| pyobj_convert_generic<decltype({0}::{1})>::from({1}); | |||
| } CATCH_ALL(-1) | |||
| } | |||
| )", className, attr.name); | |||
| }); | |||
| initBody += formatv(R"( | |||
| if (scope) {{ | |||
| try {{ | |||
| reinterpret_cast<PyOp(OpDef)*>(self)->op | |||
| ->set_scope(pyobj_convert_generic<std::string>::from(scope)); | |||
| } CATCH_ALL(-1) | |||
| } | |||
| )", className); | |||
| } | |||
| initBody += "\n return 0;"; | |||
| os << formatv(R"( | |||
| PyOpDefBegin({0}) // {{ | |||
| static PyGetSetDef py_getsetters[]; | |||
| static int py_init(PyObject *self, PyObject *args, PyObject *kwds); | |||
| // }; | |||
| PyOpDefEnd({0}) | |||
| PyGetSetDef PyOp({0})::py_getsetters[] = {{ | |||
| {1} | |||
| {{NULL} /* Sentinel */ | |||
| }; | |||
| int PyOp({0})::py_init(PyObject *self, PyObject *args, PyObject *kwds) {{ | |||
| {2} | |||
| } | |||
| void _init_py_{0}(py::module m) {{ | |||
| using py_op = PyOp({0}); | |||
| auto& py_type = PyOpType({0}); | |||
| py_type = {{PyVarObject_HEAD_INIT(NULL, 0)}; | |||
| py_type.tp_name = "megengine.core._imperative_rt.ops.{0}"; | |||
| py_type.tp_basicsize = sizeof(PyOp({0})); | |||
| py_type.tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE; | |||
| py_type.tp_doc = "{0}"; | |||
| py_type.tp_base = &PyOpType(OpDef); | |||
| py_type.tp_dealloc = py_dealloc_generic<py_op>; | |||
| py_type.tp_new = py_new_generic<py_op>; | |||
| py_type.tp_init = py_op::py_init; | |||
| py_type.tp_getset = py_op::py_getsetters; | |||
| mgb_assert(PyType_Ready(&py_type) >= 0); | |||
| {3} | |||
| PyType_Modified(&py_type); | |||
| m.add_object("{0}", reinterpret_cast<PyObject*>(&py_type)); | |||
| mgb_assert(PyOp(OpDef)::ctype2pytype.emplace({0}::typeinfo(), &py_type).second); | |||
| } | |||
| )", | |||
| op.getCppClassName(), llvm::join(getsetters, "\n "), initBody, body); | |||
| } | |||
| static void for_each_operator(raw_ostream &os, RecordKeeper &keeper, | |||
| std::function<void(raw_ostream&, MgbOp&)> callback) { | |||
| auto op_base_class = keeper.getClass("Op"); | |||
| ASSERT(op_base_class, "could not find base class Op"); | |||
| for (auto&& i: keeper.getDefs()) { | |||
| auto&& r = i.second; | |||
| if (r->isSubClassOf(op_base_class)) { | |||
| auto op = mlir::tblgen::Operator(r.get()); | |||
| if (op.getDialectName().str() == "mgb") { | |||
| std::cerr << "\033[34;15m" << "Generating " << r->getName().str() << "\033[0m" << std::endl; | |||
| callback(os, llvm::cast<MgbOp>(op)); | |||
| } | |||
| } | |||
| } | |||
| } | |||
| static bool gen_op_def_c_header(raw_ostream &os, RecordKeeper &keeper) { | |||
| for_each_operator(os, keeper, gen_op_def_c_header_single); | |||
| for_each_operator(os, keeper, gen_to_string_trait_for_enum); | |||
| return false; | |||
| } | |||
| static bool gen_op_def_c_body(raw_ostream &os, RecordKeeper &keeper) { | |||
| for_each_operator(os, keeper, gen_op_def_c_body_single); | |||
| return false; | |||
| } | |||
| static bool gen_op_def_pybind11(raw_ostream &os, RecordKeeper &keeper) { | |||
| EnumContext ctx; | |||
| using namespace std::placeholders; | |||
| for_each_operator(os, keeper, | |||
| std::bind(gen_op_def_pybind11_single, _1, _2, std::ref(ctx))); | |||
| return false; | |||
| } | |||
| static bool gen_op_def_python_c_extension(raw_ostream &os, RecordKeeper &keeper) { | |||
| EnumContext ctx; | |||
| using namespace std::placeholders; | |||
| for_each_operator(os, keeper, | |||
| std::bind(gen_op_def_python_c_extension_single, _1, _2, std::ref(ctx))); | |||
| os << "#define INIT_ALL_OP(m)"; | |||
| for_each_operator(os, keeper, [&](raw_ostream& os, MgbOp& op) { | |||
| os << formatv(" \\\n _init_py_{0}(m);", op.getCppClassName()); | |||
| }); | |||
| os << "\n"; | |||
| return false; | |||
| } | |||
| using namespace mlir::tblgen; | |||
| int main(int argc, char **argv) { | |||
| llvm::InitLLVM y(argc, argv); | |||
| @@ -0,0 +1,40 @@ | |||
| /** | |||
| * \file imperative/tablegen/emitter.h | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| */ | |||
| #pragma once | |||
| #include <unordered_map> | |||
| #include <stdexcept> | |||
| #include "llvm/ADT/StringRef.h" | |||
| #include "llvm/Support/raw_ostream.h" | |||
| namespace mlir::tblgen { | |||
| struct Environment { | |||
| std::unordered_map<unsigned int, std::pair<llvm::StringRef, llvm::StringRef>> enumAlias; | |||
| }; | |||
| struct EmitterBase { | |||
| EmitterBase(raw_ostream& os_): os(os_) {} | |||
| EmitterBase(raw_ostream& os_, Environment& env): os(os_), env_p(&env) {} | |||
| protected: | |||
| void newline() { os << "\n"; } | |||
| Environment& env() { | |||
| if (env_p) { | |||
| return *env_p; | |||
| } | |||
| throw std::runtime_error("access global environment via non-environment emitter"); | |||
| } | |||
| raw_ostream& os; | |||
| Environment* env_p = nullptr; | |||
| }; | |||
| } // namespace mlir::tblgen | |||
| @@ -1,3 +1,16 @@ | |||
| /** | |||
| * \file imperative/tablegen/helper.h | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| */ | |||
| #pragma once | |||
| #include <iostream> | |||
| #include <string> | |||
| #include <vector> | |||
| @@ -278,5 +291,28 @@ public: | |||
| } | |||
| }; | |||
| using MgbAttrWrapper = mlir::tblgen::MgbAttrWrapperBase; | |||
| using MgbEnumAttr = mlir::tblgen::MgbEnumAttrMixin; | |||
| using MgbHashableAttr = mlir::tblgen::MgbHashableAttrMixin; | |||
| using MgbAliasAttr = mlir::tblgen::MgbAliasAttrMixin; | |||
| using MgbOp = mlir::tblgen::MgbOpBase; | |||
| using MgbHashableOp = mlir::tblgen::MgbHashableOpMixin; | |||
| static inline void foreach_operator(llvm::RecordKeeper &keeper, | |||
| std::function<void(MgbOp&)> callback) { | |||
| auto op_base_class = keeper.getClass("Op"); | |||
| ASSERT(op_base_class, "could not find base class Op"); | |||
| for (auto&& i: keeper.getDefs()) { | |||
| auto&& r = i.second; | |||
| if (r->isSubClassOf(op_base_class)) { | |||
| auto op = mlir::tblgen::Operator(r.get()); | |||
| if (op.getDialectName().str() == "mgb") { | |||
| std::cerr << "\033[34;15m" << "Generating " << r->getName().str() << "\033[0m" << std::endl; | |||
| callback(llvm::cast<MgbOp>(op)); | |||
| } | |||
| } | |||
| } | |||
| } | |||
| } // namespace tblgen | |||
| } // namespace mlir | |||
| @@ -0,0 +1,309 @@ | |||
| /** | |||
| * \file imperative/tablegen/targets/cpp_class.cpp | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| */ | |||
| #include "./cpp_class.h" | |||
| #include "../emitter.h" | |||
| namespace mlir::tblgen { | |||
| namespace { | |||
| llvm::StringRef attr_to_ctype(const mlir::tblgen::Attribute& attr_) { | |||
| // Note: we have already registered the corresponding attr wrappers | |||
| // for following basic ctypes so we needn't handle them here | |||
| /* auto&& attr_type_name = attr.getAttrDefName(); | |||
| if (attr_type_name == "UI32Attr") { | |||
| return "uint32_t"; | |||
| } | |||
| if (attr_type_name == "UI64Attr") { | |||
| return "uint64_t"; | |||
| } | |||
| if (attr_type_name == "I32Attr") { | |||
| return "int32_t"; | |||
| } | |||
| if (attr_type_name == "F32Attr") { | |||
| return "float"; | |||
| } | |||
| if (attr_type_name == "F64Attr") { | |||
| return "double"; | |||
| } | |||
| if (attr_type_name == "StrAttr") { | |||
| return "std::string"; | |||
| } | |||
| if (attr_type_name == "BoolAttr") { | |||
| return "bool"; | |||
| }*/ | |||
| auto&& attr = llvm::cast<MgbAttrWrapper>(attr_); | |||
| if (auto e = llvm::dyn_cast<MgbEnumAttr>(&attr)) { | |||
| return e->getEnumName(); | |||
| } | |||
| return attr.getUnderlyingType(); | |||
| } | |||
| class OpDefEmitter final: public EmitterBase { | |||
| public: | |||
| OpDefEmitter(MgbOp& op_, raw_ostream& os_): | |||
| EmitterBase(os_), op(op_) {} | |||
| void emit_header(); | |||
| void emit_tpl_spl(); | |||
| void emit_body(); | |||
| private: | |||
| MgbOp& op; | |||
| }; | |||
| void OpDefEmitter::emit_header() { | |||
| os << formatv( | |||
| "class {0} : public OpDefImplBase<{0}> {{\n" | |||
| " MGB_DYN_TYPE_OBJ_FINAL_DECL;\n\n" | |||
| "public:\n", | |||
| op.getCppClassName() | |||
| ); | |||
| // handle enum alias | |||
| for (auto &&i : op.getMgbAttributes()) { | |||
| if (auto attr = llvm::dyn_cast<MgbEnumAttr>(&i.attr)) { | |||
| os << formatv( | |||
| " using {0} = {1};\n", | |||
| attr->getEnumName(), attr->getUnderlyingType() | |||
| ); | |||
| } | |||
| } | |||
| for (auto &&i : op.getMgbAttributes()) { | |||
| auto defaultValue = i.attr.getDefaultValue().str(); | |||
| if (!defaultValue.empty()) { | |||
| defaultValue = formatv(" = {0}", defaultValue); | |||
| } | |||
| os << formatv( | |||
| " {0} {1}{2};\n", | |||
| attr_to_ctype(i.attr), i.name, defaultValue | |||
| ); | |||
| } | |||
| auto gen_ctor = [&](auto&& paramList, auto&& memInitList, auto&& body) { | |||
| os << formatv( | |||
| " {0}({1}){2}{3}\n", | |||
| op.getCppClassName(), paramList, memInitList, body | |||
| ); | |||
| }; | |||
| gen_ctor("", "", " = default;"); | |||
| if (!op.getMgbAttributes().empty()) { | |||
| std::vector<std::string> paramList, initList; | |||
| for (auto &&i : op.getMgbAttributes()) { | |||
| paramList.push_back(formatv( | |||
| "{0} {1}_", attr_to_ctype(i.attr), i.name | |||
| )); | |||
| initList.push_back(formatv( | |||
| "{0}({0}_)", i.name | |||
| )); | |||
| } | |||
| paramList.push_back("std::string scope_ = {}"); | |||
| gen_ctor(llvm::join(paramList, ", "), | |||
| ": " + llvm::join(initList, ", "), | |||
| " { set_scope(scope_); }"); | |||
| } | |||
| auto packedParams = op.getPackedParams(); | |||
| if (!packedParams.empty()) { | |||
| std::vector<std::string> paramList, initList; | |||
| for (auto &&p : packedParams) { | |||
| auto&& paramFields = p.getFields(); | |||
| auto&& paramType = p.getFullName(); | |||
| auto&& paramName = formatv("packed_param_{0}", paramList.size()); | |||
| paramList.push_back( | |||
| paramFields.empty() ? paramType.str() | |||
| : formatv("{0} {1}", paramType, paramName) | |||
| ); | |||
| for (auto&& i : paramFields) { | |||
| initList.push_back(formatv( | |||
| "{0}({1}.{0})", i.name, paramName | |||
| )); | |||
| } | |||
| } | |||
| for (auto&& i : op.getExtraArguments()) { | |||
| paramList.push_back(formatv( | |||
| "{0} {1}_", attr_to_ctype(i.attr), i.name | |||
| )); | |||
| initList.push_back(formatv( | |||
| "{0}({0}_)", i.name | |||
| )); | |||
| } | |||
| gen_ctor(llvm::join(paramList, ", "), | |||
| initList.empty() ? "" : ": " + llvm::join(initList, ", "), | |||
| " {}"); | |||
| } | |||
| if (!packedParams.empty()) { | |||
| for (auto&& p : packedParams) { | |||
| auto accessor = p.getAccessor(); | |||
| if (!accessor.empty()) { | |||
| os << formatv( | |||
| " {0} {1}() const {{\n", | |||
| p.getFullName(), accessor | |||
| ); | |||
| std::vector<llvm::StringRef> fields; | |||
| for (auto&& i : p.getFields()) { | |||
| fields.push_back(i.name); | |||
| } | |||
| os << formatv( | |||
| " return {{{0}};\n", | |||
| llvm::join(fields, ", ") | |||
| ); | |||
| os << " }\n"; | |||
| } | |||
| } | |||
| } | |||
| if (auto decl = op.getExtraOpdefDecl()) { | |||
| os << decl.getValue(); | |||
| } | |||
| os << formatv( | |||
| "};\n\n" | |||
| ); | |||
| } | |||
| void OpDefEmitter::emit_tpl_spl() { | |||
| for (auto &&i : op.getMgbAttributes()) { | |||
| if (auto attr = llvm::dyn_cast<MgbEnumAttr>(&i.attr)) { | |||
| if (attr->supportToString()) { | |||
| std::vector<std::string> case_body; | |||
| std::string ename = formatv("{0}::{1}", | |||
| op.getCppClassName(), attr->getEnumName()); | |||
| llvm::for_each(attr->getEnumMembers(), [&](auto&& v){ | |||
| case_body.push_back(formatv( | |||
| "case {0}::{1}: return \"{1}\";", ename, v)); | |||
| }); | |||
| os << formatv(R"( | |||
| template <> | |||
| struct ToStringTrait<{0}> { | |||
| std::string operator()({0} e) const { | |||
| switch (e) { | |||
| {1} | |||
| default: | |||
| return "{0}::Unknown"; | |||
| } | |||
| } | |||
| }; | |||
| )", ename, llvm::join(case_body, "\n")); | |||
| } | |||
| } | |||
| } | |||
| } | |||
| void OpDefEmitter::emit_body() { | |||
| auto&& className = op.getCppClassName(); | |||
| os << formatv( | |||
| "MGB_DYN_TYPE_OBJ_FINAL_IMPL({0});\n\n", className | |||
| ); | |||
| auto formatMethImpl = [&](auto&& meth) { | |||
| return formatv( | |||
| "{0}_{1}_impl", className, meth | |||
| ); | |||
| }; | |||
| std::vector<std::string> methods; | |||
| if (auto hashable = llvm::dyn_cast<MgbHashableOp>(&op)) { | |||
| os << "namespace {\n"; | |||
| // generate hash() | |||
| mlir::tblgen::FmtContext ctx; | |||
| os << formatv( | |||
| "size_t {0}(const OpDef& def_) {{\n", | |||
| formatMethImpl("hash") | |||
| ); | |||
| os << formatv( | |||
| " auto&& op_ = def_.cast_final_safe<{0}>();\n" | |||
| " static_cast<void>(op_);\n", | |||
| className | |||
| ); | |||
| ctx.withSelf("op_"); | |||
| os << mlir::tblgen::tgfmt(hashable->getHashFunctionTemplate(), &ctx); | |||
| os << "}\n"; | |||
| // generate is_same_st() | |||
| os << formatv( | |||
| "bool {0}(const OpDef& lhs_, const OpDef& rhs_) {{\n", | |||
| formatMethImpl("is_same_st") | |||
| ); | |||
| os << formatv( | |||
| " auto &&a_ = lhs_.cast_final_safe<{0}>(),\n" | |||
| " &&b_ = rhs_.cast_final_safe<{0}>();\n" | |||
| " static_cast<void>(a_);\n" | |||
| " static_cast<void>(b_);\n", | |||
| className | |||
| ); | |||
| os << mlir::tblgen::tgfmt(hashable->getCmpFunctionTemplate(), &ctx, "a_", "b_"); | |||
| os << "}\n"; | |||
| // generate props() | |||
| os << formatv( | |||
| "std::vector<std::pair<const char*, std::string>> {0}(const OpDef& def_) {{\n", | |||
| formatMethImpl("props") | |||
| ); | |||
| os << formatv( | |||
| " auto&& op_ = def_.cast_final_safe<{0}>();\n" | |||
| " static_cast<void>(op_);\n", | |||
| className | |||
| ); | |||
| ctx.withSelf("op_"); | |||
| os << mlir::tblgen::tgfmt(hashable->getPropsFunctionTemplate(), &ctx); | |||
| os << "}\n"; | |||
| // generate make_name() | |||
| os << formatv( | |||
| "std::string {0}(const OpDef& def_) {{\n", formatMethImpl("make_name") | |||
| ); | |||
| os << formatv( | |||
| " auto&& op_ = def_.cast_final_safe<{0}>();\n" | |||
| " static_cast<void>(op_);\n", | |||
| className | |||
| ); | |||
| ctx.withSelf("op_"); | |||
| os << mlir::tblgen::tgfmt(op.getNameFunctionTemplate(), &ctx); | |||
| os << "}\n"; | |||
| os << "} // anonymous namespace\n"; | |||
| methods.push_back("hash"); | |||
| methods.push_back("is_same_st"); | |||
| methods.push_back("props"); | |||
| methods.push_back("make_name"); | |||
| } | |||
| if (!methods.empty()) { | |||
| os << formatv( | |||
| "OP_TRAIT_REG({0}, {0})", op.getCppClassName() | |||
| ); | |||
| for (auto&& i : methods) { | |||
| os << formatv( | |||
| "\n .{0}({1})", i, formatMethImpl(i) | |||
| ); | |||
| } | |||
| os << ";\n\n"; | |||
| } | |||
| } | |||
| } // namespace | |||
| bool gen_op_def_c_header(raw_ostream &os, llvm::RecordKeeper &keeper) { | |||
| foreach_operator(keeper, [&](MgbOp& op) { | |||
| OpDefEmitter emitter(op, os); | |||
| emitter.emit_header(); | |||
| emitter.emit_tpl_spl(); | |||
| }); | |||
| return false; | |||
| } | |||
| bool gen_op_def_c_body(raw_ostream &os, llvm::RecordKeeper &keeper) { | |||
| foreach_operator(keeper, [&](MgbOp& op) { | |||
| OpDefEmitter emitter(op, os); | |||
| emitter.emit_body(); | |||
| }); | |||
| return false; | |||
| } | |||
| } // namespace mlir::tblgen | |||
| @@ -0,0 +1,21 @@ | |||
| /** | |||
| * \file imperative/tablegen/targets/cpp_class.h | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| */ | |||
| #pragma once | |||
| #include "../helper.h" | |||
| namespace mlir::tblgen { | |||
| bool gen_op_def_c_header(raw_ostream &os, llvm::RecordKeeper &keeper); | |||
| bool gen_op_def_c_body(raw_ostream &os, llvm::RecordKeeper &keeper); | |||
| } // namespace mlir::tblgen | |||
| @@ -0,0 +1,142 @@ | |||
| /** | |||
| * \file imperative/tablegen/targets/pybind11.cpp | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| */ | |||
| #include "./pybind11.h" | |||
| #include "../emitter.h" | |||
| namespace mlir::tblgen { | |||
| namespace { | |||
| class OpDefEmitter final: public EmitterBase { | |||
| public: | |||
| OpDefEmitter(MgbOp& op_, raw_ostream& os_, Environment& env_): | |||
| EmitterBase(os_, env_), op(op_) {} | |||
| void emit(); | |||
| private: | |||
| MgbOp& op; | |||
| }; | |||
| void OpDefEmitter::emit() { | |||
| auto className = op.getCppClassName(); | |||
| os << formatv( | |||
| "py::class_<{0}, std::shared_ptr<{0}>, OpDef> {0}Inst(m, \"{0}\");\n\n", | |||
| className | |||
| ); | |||
| for (auto&& i : op.getMgbAttributes()) { | |||
| if (auto attr = llvm::dyn_cast<MgbEnumAttr>(&i.attr)) { | |||
| unsigned int enumID; | |||
| if (auto alias = llvm::dyn_cast<MgbAliasAttr>(attr)) { | |||
| auto&& aliasBase = alias->getAliasBase(); | |||
| enumID = | |||
| llvm::cast<MgbEnumAttr>(aliasBase) | |||
| .getBaseRecord()->getID(); | |||
| } else { | |||
| enumID = attr->getBaseRecord()->getID(); | |||
| } | |||
| auto&& enumAlias = env().enumAlias; | |||
| auto&& iter = enumAlias.find(enumID); | |||
| if (iter == enumAlias.end()) { | |||
| os << formatv( | |||
| "py::enum_<{0}::{1}>({0}Inst, \"{1}\")", | |||
| className, attr->getEnumName() | |||
| ); | |||
| std::vector<std::string> body; | |||
| for (auto&& i: attr->getEnumMembers()) { | |||
| os << formatv( | |||
| "\n .value(\"{2}\", {0}::{1}::{2})", | |||
| className, attr->getEnumName(), i | |||
| ); | |||
| body.push_back(formatv( | |||
| "if (str == \"{2}\") return {0}::{1}::{2};", | |||
| className, attr->getEnumName(), i | |||
| )); | |||
| } | |||
| if (attr->getEnumCombinedFlag()) { | |||
| //! define operator | | |||
| os << formatv( | |||
| "\n .def(\"__or__\", []({0}::{1} s0, {0}::{1} s1) {{ " | |||
| "\n return static_cast<{0}::{1}>(uint32_t(s0) | uint32_t(s1));" | |||
| "\n })", | |||
| className, attr->getEnumName()); | |||
| //! define operator & | |||
| os << formatv( | |||
| "\n .def(\"__and__\", []({0}::{1} s0, {0}::{1} s1) {{" | |||
| "\n return static_cast<{0}::{1}>(uint32_t(s0) & uint32_t(s1));" | |||
| "\n })", | |||
| className, attr->getEnumName()); | |||
| } | |||
| os << formatv( | |||
| "\n .def(py::init([](const std::string& in) {" | |||
| "\n auto&& str = normalize_enum(in);" | |||
| "\n {0}" | |||
| "\n throw py::cast_error(\"invalid enum value \" + in);" | |||
| "\n }));\n", | |||
| llvm::join(body, "\n ") | |||
| ); | |||
| os << formatv( | |||
| "py::implicitly_convertible<std::string, {0}::{1}>();\n\n", | |||
| className, attr->getEnumName() | |||
| ); | |||
| enumAlias.emplace(enumID, | |||
| std::make_pair(className, attr->getEnumName())); | |||
| } else { | |||
| os << formatv( | |||
| "{0}Inst.attr(\"{1}\") = {2}Inst.attr(\"{3}\");\n\n", | |||
| className, attr->getEnumName(), | |||
| iter->second.first, iter->second.second | |||
| ); | |||
| } | |||
| } | |||
| } | |||
| // generate op class binding | |||
| os << formatv("{0}Inst", className); | |||
| bool hasDefaultCtor = op.getMgbAttributes().empty(); | |||
| if (!hasDefaultCtor) { | |||
| os << "\n .def(py::init<"; | |||
| std::vector<llvm::StringRef> targs; | |||
| for (auto &&i : op.getMgbAttributes()) { | |||
| targs.push_back(i.attr.getReturnType()); | |||
| } | |||
| os << llvm::join(targs, ", "); | |||
| os << ", std::string>()"; | |||
| for (auto &&i : op.getMgbAttributes()) { | |||
| os << formatv(", py::arg(\"{0}\")", i.name); | |||
| auto defaultValue = i.attr.getDefaultValue(); | |||
| if (!defaultValue.empty()) { | |||
| os << formatv(" = {0}", defaultValue); | |||
| } else { | |||
| hasDefaultCtor = true; | |||
| } | |||
| } | |||
| os << ", py::arg(\"scope\") = {})"; | |||
| } | |||
| if (hasDefaultCtor) { | |||
| os << "\n .def(py::init<>())"; | |||
| } | |||
| for (auto &&i : op.getMgbAttributes()) { | |||
| os << formatv( | |||
| "\n .def_readwrite(\"{0}\", &{1}::{0})", | |||
| i.name, className | |||
| ); | |||
| } | |||
| os << ";\n\n"; | |||
| } | |||
| } // namespace | |||
| bool gen_op_def_pybind11(raw_ostream &os, llvm::RecordKeeper &keeper) { | |||
| Environment env; | |||
| using namespace std::placeholders; | |||
| foreach_operator(keeper, [&](MgbOp& op) { | |||
| OpDefEmitter(op, os, env).emit(); | |||
| }); | |||
| return false; | |||
| } | |||
| } // namespace mlir::tblgen | |||
| @@ -0,0 +1,19 @@ | |||
| /** | |||
| * \file imperative/tablegen/targets/pybind11.h | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| */ | |||
| #pragma once | |||
| #include "../helper.h" | |||
| namespace mlir::tblgen { | |||
| bool gen_op_def_pybind11(raw_ostream &os, llvm::RecordKeeper &keeper); | |||
| } // namespace mlir::tblgen | |||
| @@ -0,0 +1,313 @@ | |||
| /** | |||
| * \file imperative/tablegen/targets/python_c_extension.cpp | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| */ | |||
| #include "python_c_extension.h" | |||
| #include "../emitter.h" | |||
| namespace mlir::tblgen { | |||
| namespace { | |||
| struct Initproc { | |||
| std::string func; | |||
| Initproc(std::string&& s): func(std::move(s)) {} | |||
| std::string operator()(std::string argument) { | |||
| return formatv("{0}({1})", func, argument); | |||
| } | |||
| }; | |||
| class OpDefEmitter: public EmitterBase { | |||
| public: | |||
| OpDefEmitter(MgbOp& op_, raw_ostream& os_, Environment& env_): | |||
| EmitterBase(os_, env_), op(op_) { | |||
| ctx.withSelf(op.getCppClassName()); | |||
| } | |||
| Initproc emit(); | |||
| private: | |||
| void emit_class(); | |||
| void emit_py_init(); | |||
| void emit_py_getsetters(); | |||
| Initproc emit_initproc(); | |||
| MgbOp& op; | |||
| std::vector<Initproc> subclasses; | |||
| mlir::tblgen::FmtContext ctx; | |||
| }; | |||
| class EnumAttrEmitter: public EmitterBase { | |||
| public: | |||
| EnumAttrEmitter(llvm::StringRef parent, MgbEnumAttr* attr_, raw_ostream& os_, Environment& env_): | |||
| EmitterBase(os_, env_), attr(attr_) { | |||
| unsigned int enumID; | |||
| if (auto alias = llvm::dyn_cast<MgbAliasAttr>(attr)) { | |||
| auto&& aliasBase = alias->getAliasBase(); | |||
| enumID = llvm::cast<MgbEnumAttr>(aliasBase).getBaseRecord()->getID(); | |||
| } else { | |||
| enumID = attr->getBaseRecord()->getID(); | |||
| } | |||
| ctx.addSubst("enumTpl", attr->getEnumCombinedFlag() ? "BitCombinedEnumWrapper" : "EnumWrapper"); | |||
| ctx.addSubst("opClass", parent); | |||
| ctx.addSubst("enumClass", attr->getEnumName()); | |||
| firstOccur = env().enumAlias.emplace(enumID, std::make_pair(parent, attr->getEnumName())).second; | |||
| } | |||
| Initproc emit(); | |||
| protected: | |||
| void emit_tpl_spl(); | |||
| Initproc emit_initproc(); | |||
| MgbEnumAttr* attr; | |||
| bool firstOccur; | |||
| mlir::tblgen::FmtContext ctx; | |||
| }; | |||
| Initproc EnumAttrEmitter::emit() { | |||
| emit_tpl_spl(); | |||
| return emit_initproc(); | |||
| } | |||
| void EnumAttrEmitter::emit_tpl_spl() { | |||
| if (!firstOccur) return; | |||
| os << tgfmt( | |||
| "template<> PyTypeObject $enumTpl<$opClass::$enumClass>::type={};\n", | |||
| &ctx); | |||
| os << tgfmt( | |||
| "template<> const char* $enumTpl<$opClass::$enumClass>::name = " | |||
| "\"$opClass.$enumClass\";\n", | |||
| &ctx); | |||
| if (attr->getEnumCombinedFlag()) { | |||
| os << tgfmt( | |||
| "template<> PyNumberMethods " | |||
| "$enumTpl<$opClass::$enumClass>::number_methods={};\n", | |||
| &ctx); | |||
| os << tgfmt( | |||
| "template<> struct EnumTrait<$opClass::$enumClass> { static constexpr " | |||
| "bool is_bit_combined = true;};\n", | |||
| &ctx); | |||
| } | |||
| auto str2type = [&](auto&& i) -> std::string { | |||
| return tgfmt("{normalize_enum(\"$0\"), $opClass::$enumClass::$0}", &ctx, i); | |||
| }; | |||
| os << tgfmt(R"( | |||
| template<> std::unordered_map<std::string, $opClass::$enumClass> | |||
| $enumTpl<$opClass::$enumClass>::str2type = {$0}; | |||
| )", &ctx, llvm::join(llvm::map_range(attr->getEnumMembers(), str2type), ", ")); | |||
| auto type2str = [&](auto&& i) -> std::string { | |||
| return tgfmt("{$opClass::$enumClass::$0, normalize_enum(\"$0\")}", &ctx, i); | |||
| }; | |||
| os << tgfmt(R"( | |||
| template<> std::unordered_map<$opClass::$enumClass, std::string> | |||
| $enumTpl<$opClass::$enumClass>::type2str = {$0}; | |||
| )", &ctx, llvm::join(llvm::map_range(attr->getEnumMembers(), type2str), ", ")); | |||
| } | |||
| Initproc EnumAttrEmitter::emit_initproc() { | |||
| std::string initproc = formatv("_init_py_{0}_{1}", | |||
| ctx.getSubstFor("opClass"), ctx.getSubstFor("enumClass")); | |||
| os << tgfmt(R"( | |||
| void $0(PyTypeObject& py_type) { | |||
| auto& e_type = $enumTpl<$opClass::$enumClass>::type; | |||
| )", &ctx, initproc); | |||
| if (firstOccur) { | |||
| os << tgfmt(R"( | |||
| e_type = {PyVarObject_HEAD_INIT(NULL, 0)}; | |||
| e_type.tp_name = "megengine.core._imperative_rt.ops.$opClass.$enumClass"; | |||
| e_type.tp_basicsize = sizeof($enumTpl<$opClass::$enumClass>); | |||
| e_type.tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE; | |||
| e_type.tp_doc = "$opClass.$enumClass"; | |||
| e_type.tp_base = &PyBaseObject_Type; | |||
| e_type.tp_repr = $enumTpl<$opClass::$enumClass>::py_repr; | |||
| e_type.tp_richcompare = $enumTpl<$opClass::$enumClass>::tp_richcompare; | |||
| )", &ctx); | |||
| if (attr->getEnumCombinedFlag()) { | |||
| // only bit combined enum could new instance because bitwise operation, | |||
| // others should always use singleton | |||
| os << tgfmt(R"( | |||
| e_type.tp_new = $enumTpl<$opClass::$enumClass>::py_new_combined_enum; | |||
| e_type.tp_init = $enumTpl<$opClass::$enumClass>::py_init; | |||
| auto& number_method = $enumTpl<$opClass::$enumClass>::number_methods; | |||
| number_method.nb_or = $enumTpl<$opClass::$enumClass>::py_or; | |||
| number_method.nb_and = $enumTpl<$opClass::$enumClass>::py_and; | |||
| e_type.tp_as_number = &number_method; | |||
| )", &ctx); | |||
| } | |||
| os << " mgb_assert(PyType_Ready(&e_type) >= 0);\n"; | |||
| for (auto&& i : attr->getEnumMembers()) { | |||
| os << tgfmt(R"({ | |||
| PyObject* inst = e_type.tp_alloc(&e_type, 0); | |||
| reinterpret_cast<$enumTpl<$opClass::$enumClass>*>(inst)->value = $opClass::$enumClass::$0; | |||
| mgb_assert(PyDict_SetItemString(e_type.tp_dict, "$0", inst) >= 0); | |||
| PyType_Modified(&e_type); | |||
| })", &ctx, i); | |||
| } | |||
| } | |||
| os << tgfmt(R"( | |||
| mgb_assert(PyDict_SetItemString( | |||
| py_type.tp_dict, "$enumClass", reinterpret_cast<PyObject*>(&e_type)) >= 0); | |||
| )", &ctx); | |||
| os << "}\n"; | |||
| return initproc; | |||
| } | |||
| Initproc OpDefEmitter::emit() { | |||
| for (auto&& i : op.getMgbAttributes()) { | |||
| if (auto attr = llvm::dyn_cast<MgbEnumAttr>(&i.attr)) { | |||
| subclasses.push_back(EnumAttrEmitter(op.getCppClassName(), attr, os, env()).emit()); | |||
| } | |||
| } | |||
| emit_class(); | |||
| emit_py_init(); | |||
| emit_py_getsetters(); | |||
| return emit_initproc(); | |||
| } | |||
| void OpDefEmitter::emit_class() { | |||
| os << tgfmt(R"( | |||
| PyOpDefBegin($_self) // { | |||
| static PyGetSetDef py_getsetters[]; | |||
| static int py_init(PyObject *self, PyObject *args, PyObject *kwds); | |||
| // }; | |||
| PyOpDefEnd($_self) | |||
| )", &ctx); | |||
| } | |||
| void OpDefEmitter::emit_py_init() { | |||
| std::string initBody; | |||
| if (!op.getMgbAttributes().empty()) { | |||
| initBody += "static const char* kwlist[] = {"; | |||
| std::vector<llvm::StringRef> attr_name_list; | |||
| llvm::for_each(op.getMgbAttributes(), [&](auto&& attr) { | |||
| attr_name_list.push_back(attr.name); | |||
| }); | |||
| attr_name_list.push_back("scope"); | |||
| llvm::for_each(attr_name_list, [&](auto&& attr) { | |||
| initBody += formatv("\"{0}\", ", attr); | |||
| }); | |||
| initBody += "NULL};\n"; | |||
| initBody += " PyObject "; | |||
| auto initializer = [&](auto&& attr) -> std::string { | |||
| return formatv("*{0} = NULL", attr); | |||
| }; | |||
| initBody += llvm::join(llvm::map_range(attr_name_list, initializer), ", ") + ";\n"; | |||
| initBody += " if (!PyArg_ParseTupleAndKeywords(args, kwds, \"|"; | |||
| // an extra slot created for name | |||
| initBody += std::string(attr_name_list.size(), 'O'); | |||
| initBody += "\", const_cast<char**>(kwlist)"; | |||
| llvm::for_each(attr_name_list, [&](auto&& attr) { | |||
| initBody += formatv(", &{0}", attr); | |||
| }); | |||
| initBody += "))\n"; | |||
| initBody += " return -1;\n"; | |||
| llvm::for_each(op.getMgbAttributes(), [&](auto&& attr) { | |||
| initBody += tgfmt(R"( | |||
| if ($0) { | |||
| try { | |||
| reinterpret_cast<PyOp($_self)*>(self)->inst().$0 = | |||
| pyobj_convert_generic<decltype($_self::$0)>::from($0); | |||
| } CATCH_ALL(-1) | |||
| } | |||
| )", &ctx, attr.name); | |||
| }); | |||
| initBody += tgfmt(R"( | |||
| if (scope) { | |||
| try { | |||
| reinterpret_cast<PyOp(OpDef)*>(self)->op | |||
| ->set_scope(pyobj_convert_generic<std::string>::from(scope)); | |||
| } CATCH_ALL(-1) | |||
| } | |||
| )", &ctx); | |||
| } | |||
| initBody += "\n return 0;"; | |||
| os << tgfmt(R"( | |||
| int PyOp($_self)::py_init(PyObject *self, PyObject *args, PyObject *kwds) { | |||
| $0 | |||
| } | |||
| )", &ctx, initBody); | |||
| } | |||
| void OpDefEmitter::emit_py_getsetters() { | |||
| auto f = [&](auto&& attr) -> std::string { | |||
| return tgfmt( | |||
| "{const_cast<char*>(\"$0\"), py_get_generic($_self, $0), py_set_generic($_self, $0), const_cast<char*>(\"$0\"), NULL},", | |||
| &ctx, attr.name); | |||
| }; | |||
| os << tgfmt(R"( | |||
| PyGetSetDef PyOp($_self)::py_getsetters[] = { | |||
| $0 | |||
| {NULL} /* Sentinel */ | |||
| }; | |||
| )", &ctx, llvm::join(llvm::map_range(op.getMgbAttributes(), f), "\n ")); | |||
| } | |||
| Initproc OpDefEmitter::emit_initproc() { | |||
| std::string initproc = formatv("_init_py_{0}", op.getCppClassName()); | |||
| std::string subclass_init_call; | |||
| for (auto&& i : subclasses) { | |||
| subclass_init_call += formatv(" {0};\n", i("py_type")); | |||
| } | |||
| os << tgfmt(R"( | |||
| void $0(py::module m) { | |||
| using py_op = PyOp($_self); | |||
| auto& py_type = PyOpType($_self); | |||
| py_type = {PyVarObject_HEAD_INIT(NULL, 0)}; | |||
| py_type.tp_name = "megengine.core._imperative_rt.ops.$_self"; | |||
| py_type.tp_basicsize = sizeof(PyOp($_self)); | |||
| py_type.tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE; | |||
| py_type.tp_doc = "$_self"; | |||
| py_type.tp_base = &PyOpType(OpDef); | |||
| py_type.tp_dealloc = py_dealloc_generic<py_op>; | |||
| py_type.tp_new = py_new_generic<py_op>; | |||
| py_type.tp_init = py_op::py_init; | |||
| py_type.tp_getset = py_op::py_getsetters; | |||
| mgb_assert(PyType_Ready(&py_type) >= 0); | |||
| $1 | |||
| PyType_Modified(&py_type); | |||
| m.add_object("$_self", reinterpret_cast<PyObject*>(&py_type)); | |||
| mgb_assert(PyOp(OpDef)::ctype2pytype.emplace($_self::typeinfo(), &py_type).second); | |||
| } | |||
| )", &ctx, initproc, subclass_init_call); | |||
| return initproc; | |||
| } | |||
| } // namespace | |||
| bool gen_op_def_python_c_extension(raw_ostream &os, llvm::RecordKeeper &keeper) { | |||
| Environment env; | |||
| using namespace std::placeholders; | |||
| std::vector<Initproc> initprocs; | |||
| foreach_operator(keeper, [&](MgbOp& op) { | |||
| initprocs.emplace_back(OpDefEmitter(op, os, env).emit()); | |||
| }); | |||
| os << "#define INIT_ALL_OP(m)"; | |||
| for(auto&& init : initprocs) { | |||
| os << formatv(" \\\n {0};", init("m")); | |||
| } | |||
| os << "\n"; | |||
| return false; | |||
| } | |||
| } // namespace mlir::tblgen | |||
| @@ -0,0 +1,19 @@ | |||
| /** | |||
| * \file imperative/tablegen/targets/python_c_extension.h | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| */ | |||
| #pragma once | |||
| #include "../helper.h" | |||
| namespace mlir::tblgen { | |||
| bool gen_op_def_python_c_extension(raw_ostream &os, llvm::RecordKeeper &keeper); | |||
| } // namespace mlir::tblgen | |||