|
|
|
@@ -1,8 +1,242 @@ |
|
|
|
#include "python_c_extension.h" |
|
|
|
#include <cctype> |
|
|
|
#include <functional> |
|
|
|
#include <iostream> |
|
|
|
#include <sstream> |
|
|
|
#include <string> |
|
|
|
#include <tuple> |
|
|
|
#include <unordered_map> |
|
|
|
#include <vector> |
|
|
|
|
|
|
|
#include "../emitter.h" |
|
|
|
#include "python_c_extension.h" |
|
|
|
|
|
|
|
namespace mlir::tblgen { |
|
|
|
namespace { |
|
|
|
|
|
|
|
class TypeInfo; |
|
|
|
std::pair<TypeInfo, int> parse_type(const std::string&, const int); |
|
|
|
std::pair<std::vector<std::string>, int> parse_namespace(const std::string&, const int); |
|
|
|
|
|
|
|
struct Unit {}; |
|
|
|
Unit unit; |
|
|
|
|
|
|
|
struct ParseError {}; |
|
|
|
|
|
|
|
class TypeInfo { |
|
|
|
public: |
|
|
|
TypeInfo(std::string name) : name(name) {} |
|
|
|
|
|
|
|
std::string to_python_type_string() { |
|
|
|
std::stringstream ss; |
|
|
|
ss << translate_type_name(name); |
|
|
|
if (params.size() > 0) { |
|
|
|
ss << "[" << params[0].to_python_type_string(); |
|
|
|
for (auto i = 1; i < params.size(); i++) { |
|
|
|
ss << ", " << params[i].to_python_type_string(); |
|
|
|
} |
|
|
|
ss << "]"; |
|
|
|
} |
|
|
|
return ss.str(); |
|
|
|
} |
|
|
|
|
|
|
|
std::string translate_type_name(const std::string& cppTypeName) { |
|
|
|
auto res = translation.find(cppTypeName); |
|
|
|
if (res != translation.end()) |
|
|
|
return res->second; |
|
|
|
try { |
|
|
|
auto segments = parse_namespace(cppTypeName, 0).first; |
|
|
|
// special rules |
|
|
|
if (segments.size() > 3 && segments[0] == "megdnn" && |
|
|
|
segments[1] == "param") { |
|
|
|
segments.erase(segments.begin(), segments.begin() + 3); |
|
|
|
} else if ( |
|
|
|
segments.size() == 2 && segments[0] == "megdnn" && |
|
|
|
segments[1] == "DType") { |
|
|
|
segments.erase(segments.begin(), segments.begin() + 1); |
|
|
|
segments[0] = "str"; |
|
|
|
} else if ( |
|
|
|
segments.size() == 2 && segments[0] == "mgb" && |
|
|
|
segments[1] == "CompNode") { |
|
|
|
segments.erase(segments.begin(), segments.begin() + 1); |
|
|
|
segments[0] = "str"; |
|
|
|
} |
|
|
|
std::stringstream joined; |
|
|
|
joined << segments[0]; |
|
|
|
for (auto i = 1; i < segments.size(); i++) { |
|
|
|
joined << "." << segments[i]; |
|
|
|
} |
|
|
|
return joined.str(); |
|
|
|
} catch (ParseError) { |
|
|
|
return cppTypeName; |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
std::string name; |
|
|
|
std::vector<TypeInfo> params; |
|
|
|
|
|
|
|
private: |
|
|
|
static const std::unordered_map<std::string, std::string> translation; |
|
|
|
}; |
|
|
|
|
|
|
|
const std::unordered_map<std::string, std::string> TypeInfo::translation = { |
|
|
|
{"bool", "bool"}, {"double", "float"}, {"float", "float"}, |
|
|
|
{"int32_t", "int"}, {"int8_t", "int"}, {"size_t", "int"}, |
|
|
|
{"std::string", "str"}, {"std::tuple", "tuple"}, {"std::vector", "list"}, |
|
|
|
{"uint32_t", "int"}, {"uint64_t", "int"}, |
|
|
|
}; |
|
|
|
|
|
|
|
// a parser takes: |
|
|
|
// 1. a string to parse |
|
|
|
// 2. location to parse from (index of character) |
|
|
|
// returns: |
|
|
|
// 1. parsing result (type T) |
|
|
|
// 2. end location of substring which is consumed by parsing |
|
|
|
// throws exception when failed to parse |
|
|
|
template <typename T> |
|
|
|
using Parser = std::function<std::pair<T, int>(const std::string&, const int)>; |
|
|
|
|
|
|
|
std::pair<Unit, int> parse_blank(const std::string& text, const int begin) { |
|
|
|
auto now = begin; |
|
|
|
while (now < text.length() && isblank(text[now])) |
|
|
|
now += 1; |
|
|
|
return {unit, now}; |
|
|
|
} |
|
|
|
|
|
|
|
Parser<Unit> parse_non_blank_char(char ch) { |
|
|
|
return [=](const std::string& text, const int begin) -> std::pair<Unit, int> { |
|
|
|
auto blankEnd = parse_blank(text, begin).second; |
|
|
|
if (blankEnd >= text.length() || text[blankEnd] != ch) |
|
|
|
throw ParseError{}; |
|
|
|
return {unit, blankEnd + 1}; |
|
|
|
}; |
|
|
|
} |
|
|
|
|
|
|
|
Parser<std::string> parse_allowed_chars(std::function<bool(char)> allow) { |
|
|
|
return [=](const std::string& text, |
|
|
|
const int begin) -> std::pair<std::string, int> { |
|
|
|
auto now = begin; |
|
|
|
while (now < text.length() && allow(text[now])) |
|
|
|
now += 1; |
|
|
|
return {text.substr(begin, now - begin), now}; |
|
|
|
}; |
|
|
|
} |
|
|
|
|
|
|
|
template <typename T> |
|
|
|
Parser<std::tuple<T>> parse_seq(Parser<T> only) { |
|
|
|
return [=](const std::string& text, |
|
|
|
const int begin) -> std::pair<std::tuple<T>, int> { |
|
|
|
auto res = only(text, begin); |
|
|
|
return {{res.first}, res.second}; |
|
|
|
}; |
|
|
|
} |
|
|
|
|
|
|
|
template <typename Head, typename... Tail> |
|
|
|
Parser<std::tuple<Head, Tail...>> parse_seq(Parser<Head> head, Parser<Tail>... tail) { |
|
|
|
return [=](const std::string& text, |
|
|
|
const int begin) -> std::pair<std::tuple<Head, Tail...>, int> { |
|
|
|
std::pair<Head, int> headRes = head(text, begin); |
|
|
|
std::pair<std::tuple<Tail...>, int> tailRes = |
|
|
|
parse_seq(tail...)(text, headRes.second); |
|
|
|
return {std::tuple_cat(std::tuple<Head>(headRes.first), tailRes.first), |
|
|
|
tailRes.second}; |
|
|
|
}; |
|
|
|
} |
|
|
|
|
|
|
|
template <typename T> |
|
|
|
Parser<std::vector<T>> parse_many_at_least0(Parser<T> one) { |
|
|
|
return [=](const std::string& text, |
|
|
|
const int begin) -> std::pair<std::vector<T>, int> { |
|
|
|
std::vector<T> ret; |
|
|
|
auto now = begin; |
|
|
|
try { |
|
|
|
while (true) { |
|
|
|
auto oneRes = one(text, now); |
|
|
|
ret.emplace_back(oneRes.first); |
|
|
|
now = oneRes.second; |
|
|
|
} |
|
|
|
} catch (ParseError) { |
|
|
|
} |
|
|
|
return {ret, now}; |
|
|
|
}; |
|
|
|
} |
|
|
|
|
|
|
|
template <typename C> |
|
|
|
Parser<std::vector<C>> parse_sep_by_at_least1( |
|
|
|
Parser<Unit> separator, Parser<C> component) { |
|
|
|
return [=](const std::string& text, |
|
|
|
const int begin) -> std::pair<std::vector<C>, int> { |
|
|
|
std::vector<C> ret; |
|
|
|
auto headRes = component(text, begin); |
|
|
|
ret.emplace_back(headRes.first); |
|
|
|
auto tailRes = parse_many_at_least0(parse_seq(separator, component))( |
|
|
|
text, headRes.second); |
|
|
|
for (const auto& elem : tailRes.first) { |
|
|
|
ret.emplace_back(std::get<1>(elem)); |
|
|
|
} |
|
|
|
return {ret, tailRes.second}; |
|
|
|
}; |
|
|
|
} |
|
|
|
|
|
|
|
std::pair<std::string, int> parse_identifier(const std::string& text, const int begin) { |
|
|
|
auto blankEnd = parse_blank(text, begin).second; |
|
|
|
auto indentRes = parse_allowed_chars( |
|
|
|
[](char ch) { return std::isalnum(ch) || ch == '_'; })(text, blankEnd); |
|
|
|
if (indentRes.first.empty()) |
|
|
|
throw ParseError{}; |
|
|
|
return indentRes; |
|
|
|
}; |
|
|
|
|
|
|
|
std::pair<std::string, int> parse_qualified(const std::string& text, const int begin) { |
|
|
|
auto blankEnd = parse_blank(text, begin).second; |
|
|
|
auto indentRes = parse_allowed_chars([](char ch) { |
|
|
|
return std::isalnum(ch) || ch == '_' || ch == ':'; |
|
|
|
})(text, blankEnd); |
|
|
|
if (indentRes.first.empty()) |
|
|
|
throw ParseError{}; |
|
|
|
return indentRes; |
|
|
|
}; |
|
|
|
|
|
|
|
std::pair<std::vector<std::string>, int> parse_namespace( |
|
|
|
const std::string& text, const int begin) { |
|
|
|
auto res = parse_many_at_least0(parse_seq( |
|
|
|
parse_non_blank_char(':'), parse_non_blank_char(':'), |
|
|
|
Parser<std::string>(parse_identifier)))(text, begin); |
|
|
|
std::vector<std::string> ret; |
|
|
|
for (const auto& elem : res.first) { |
|
|
|
ret.emplace_back(std::get<2>(elem)); |
|
|
|
} |
|
|
|
return {ret, res.second}; |
|
|
|
} |
|
|
|
|
|
|
|
std::pair<TypeInfo, int> parse_leaf_type(const std::string& text, const int begin) { |
|
|
|
auto ret = parse_qualified(text, begin); |
|
|
|
return {TypeInfo(ret.first), ret.second}; |
|
|
|
}; |
|
|
|
|
|
|
|
std::pair<TypeInfo, int> parse_node_type(const std::string& text, const int begin) { |
|
|
|
auto nameRes = parse_qualified(text, begin); |
|
|
|
auto ret = TypeInfo(nameRes.first); |
|
|
|
auto now = parse_non_blank_char('<')(text, nameRes.second).second; |
|
|
|
auto argsRes = parse_sep_by_at_least1( |
|
|
|
parse_non_blank_char(','), Parser<TypeInfo>(parse_type))(text, now); |
|
|
|
ret.params = argsRes.first; |
|
|
|
now = parse_non_blank_char('>')(text, argsRes.second).second; |
|
|
|
return {ret, now}; |
|
|
|
}; |
|
|
|
|
|
|
|
std::pair<TypeInfo, int> parse_type(const std::string& text, const int begin) { |
|
|
|
try { |
|
|
|
return parse_node_type(text, begin); |
|
|
|
} catch (ParseError) { |
|
|
|
} |
|
|
|
return parse_leaf_type(text, begin); |
|
|
|
}; |
|
|
|
|
|
|
|
std::string cpp_type_to_python_type(const std::string& input) { |
|
|
|
auto res = parse_type(input, 0); |
|
|
|
return res.first.to_python_type_string(); |
|
|
|
} |
|
|
|
|
|
|
|
struct Initproc { |
|
|
|
std::string func; |
|
|
|
Initproc(std::string&& s) : func(std::move(s)) {} |
|
|
|
@@ -25,6 +259,10 @@ private: |
|
|
|
void emit_py_init(); |
|
|
|
void emit_py_getsetters(); |
|
|
|
void emit_py_methods(); |
|
|
|
void emit_py_init_proxy(); |
|
|
|
void emit_py_init_methoddef( |
|
|
|
const std::unordered_map<std::string, std::vector<std::string>>& |
|
|
|
enum_attr_members); |
|
|
|
Initproc emit_initproc(); |
|
|
|
|
|
|
|
MgbOp& op; |
|
|
|
@@ -248,10 +486,18 @@ void $0(PyTypeObject& py_type) { |
|
|
|
} |
|
|
|
|
|
|
|
Initproc OpDefEmitter::emit() { |
|
|
|
std::unordered_map<std::string, std::vector<std::string>> enum_attr_members; |
|
|
|
|
|
|
|
for (auto&& i : op.getMgbAttributes()) { |
|
|
|
if (auto attr = llvm::dyn_cast<MgbEnumAttr>(&i.attr)) { |
|
|
|
subclasses.push_back( |
|
|
|
EnumAttrEmitter(op.getCppClassName(), attr, os, env()).emit()); |
|
|
|
|
|
|
|
auto retType = cpp_type_to_python_type(std::string(attr->getReturnType())); |
|
|
|
enum_attr_members[retType] = std::vector<std::string>(); |
|
|
|
for (const auto& member : attr->getEnumMembers()) { |
|
|
|
enum_attr_members[retType].emplace_back(member); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
@@ -259,6 +505,8 @@ Initproc OpDefEmitter::emit() { |
|
|
|
emit_py_init(); |
|
|
|
emit_py_getsetters(); |
|
|
|
emit_py_methods(); |
|
|
|
emit_py_init_proxy(); |
|
|
|
emit_py_init_methoddef(enum_attr_members); |
|
|
|
return emit_initproc(); |
|
|
|
} |
|
|
|
|
|
|
|
@@ -318,6 +566,8 @@ PyOpDefBegin($_self) // { |
|
|
|
static PyMethodDef tp_methods[]; |
|
|
|
$0 |
|
|
|
static int py_init(PyObject *self, PyObject *args, PyObject *kwds); |
|
|
|
static PyObject* py_init_proxy(PyObject *self, PyObject *args, PyObject *kwds); |
|
|
|
static PyMethodDef py_init_methoddef; |
|
|
|
// }; |
|
|
|
PyOpDefEnd($_self) |
|
|
|
)", |
|
|
|
@@ -438,6 +688,55 @@ void OpDefEmitter::emit_py_methods() { |
|
|
|
&ctx, llvm::join(method_items, "\n ")); |
|
|
|
} |
|
|
|
|
|
|
|
void OpDefEmitter::emit_py_init_proxy() { |
|
|
|
os << tgfmt( |
|
|
|
R"( |
|
|
|
PyObject *PyOp($_self)::py_init_proxy(PyObject *self, PyObject *args, PyObject *kwds) { |
|
|
|
if (PyOp($_self)::py_init(self, args, kwds) < 0) { |
|
|
|
return NULL; |
|
|
|
} |
|
|
|
Py_RETURN_NONE; |
|
|
|
} |
|
|
|
)", |
|
|
|
&ctx); |
|
|
|
} |
|
|
|
|
|
|
|
void OpDefEmitter::emit_py_init_methoddef( |
|
|
|
const std::unordered_map<std::string, std::vector<std::string>>& |
|
|
|
enum_attr_members) { |
|
|
|
std::string docstring = "__init__(self"; |
|
|
|
for (const auto& attr : op.getMgbAttributes()) { |
|
|
|
if (attr.name == "workspace_limit") |
|
|
|
continue; |
|
|
|
auto pyType = cpp_type_to_python_type(std::string(attr.attr.getReturnType())); |
|
|
|
auto findRes = enum_attr_members.find(pyType); |
|
|
|
if (findRes != enum_attr_members.end()) { |
|
|
|
pyType = formatv("Union[str, {0}]", pyType); |
|
|
|
// TODO stubgen cannot handle Literal strings for now |
|
|
|
// auto members = findRes->second; |
|
|
|
// std::string enumTypeString = "Literal["; |
|
|
|
// enumTypeString += formatv("'{0}'", lowercase(members[0])); |
|
|
|
// for (auto i = 1; i < members.size(); i++) { |
|
|
|
// enumTypeString += formatv(", '{0}'", lowercase(members[i])); |
|
|
|
// } |
|
|
|
// enumTypeString += "]"; |
|
|
|
// pyType = enumTypeString; |
|
|
|
} |
|
|
|
docstring += formatv(", {0}: {1} = ...", attr.name, pyType); |
|
|
|
} |
|
|
|
docstring += ") -> None\\n"; |
|
|
|
os << tgfmt( |
|
|
|
R"( |
|
|
|
PyMethodDef PyOp($_self)::py_init_methoddef = { |
|
|
|
"__init__", |
|
|
|
(PyCFunction)PyOp($_self)::py_init_proxy, |
|
|
|
METH_VARARGS | METH_KEYWORDS, |
|
|
|
"$0" |
|
|
|
}; |
|
|
|
)", |
|
|
|
&ctx, docstring); |
|
|
|
} |
|
|
|
|
|
|
|
Initproc OpDefEmitter::emit_initproc() { |
|
|
|
std::string initproc = formatv("_init_py_{0}", op.getCppClassName()); |
|
|
|
std::string subclass_init_call; |
|
|
|
@@ -460,6 +759,10 @@ void $0(py::module m) { |
|
|
|
py_type.tp_init = py_op::py_init; |
|
|
|
py_type.tp_methods = py_op::tp_methods; |
|
|
|
py_type.tp_getset = py_op::py_getsetters; |
|
|
|
|
|
|
|
py_type.tp_dict = PyDict_New(); |
|
|
|
PyObject* descr = PyDescr_NewMethod(&PyOpType($_self), &PyOp($_self)::py_init_methoddef); |
|
|
|
PyDict_SetItemString(py_type.tp_dict, "__init__", descr); |
|
|
|
mgb_assert(PyType_Ready(&py_type) >= 0); |
|
|
|
$1 |
|
|
|
PyType_Modified(&py_type); |
|
|
|
@@ -486,4 +789,4 @@ bool gen_op_def_python_c_extension(raw_ostream& os, llvm::RecordKeeper& keeper) |
|
|
|
os << "\n"; |
|
|
|
return false; |
|
|
|
} |
|
|
|
} // namespace mlir::tblgen |
|
|
|
} // namespace mlir::tblgen |