You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

generate_stubs.py 18 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517
  1. import argparse
  2. import ast
  3. import importlib
  4. import inspect
  5. import logging
  6. import re
  7. import subprocess
  8. from functools import reduce
  9. from typing import Any, Dict, List, Mapping, Optional, Set, Tuple, Union
  10. def path_to_type(*elements: str) -> ast.AST:
  11. base: ast.AST = ast.Name(id=elements[0], ctx=ast.Load())
  12. for e in elements[1:]:
  13. base = ast.Attribute(value=base, attr=e, ctx=ast.Load())
  14. return base
  15. OBJECT_MEMBERS = dict(inspect.getmembers(object))
  16. BUILTINS: Dict[str, Union[None, Tuple[List[ast.AST], ast.AST]]] = {
  17. "__annotations__": None,
  18. "__bool__": ([], path_to_type("bool")),
  19. "__bytes__": ([], path_to_type("bytes")),
  20. "__class__": None,
  21. "__contains__": ([path_to_type("typing", "Any")], path_to_type("bool")),
  22. "__del__": None,
  23. "__delattr__": ([path_to_type("str")], path_to_type("None")),
  24. "__delitem__": ([path_to_type("typing", "Any")], path_to_type("typing", "Any")),
  25. "__dict__": None,
  26. "__dir__": None,
  27. "__doc__": None,
  28. "__eq__": ([path_to_type("typing", "Any")], path_to_type("bool")),
  29. "__format__": ([path_to_type("str")], path_to_type("str")),
  30. "__ge__": ([path_to_type("typing", "Any")], path_to_type("bool")),
  31. "__getattribute__": ([path_to_type("str")], path_to_type("typing", "Any")),
  32. "__getitem__": ([path_to_type("typing", "Any")], path_to_type("typing", "Any")),
  33. "__gt__": ([path_to_type("typing", "Any")], path_to_type("bool")),
  34. "__hash__": ([], path_to_type("int")),
  35. "__init__": ([], path_to_type("None")),
  36. "__init_subclass__": None,
  37. "__iter__": ([], path_to_type("typing", "Any")),
  38. "__le__": ([path_to_type("typing", "Any")], path_to_type("bool")),
  39. "__len__": ([], path_to_type("int")),
  40. "__lt__": ([path_to_type("typing", "Any")], path_to_type("bool")),
  41. "__module__": None,
  42. "__ne__": ([path_to_type("typing", "Any")], path_to_type("bool")),
  43. "__new__": None,
  44. "__next__": ([], path_to_type("typing", "Any")),
  45. "__int__": ([], path_to_type("None")),
  46. "__reduce__": None,
  47. "__reduce_ex__": None,
  48. "__repr__": ([], path_to_type("str")),
  49. "__setattr__": (
  50. [path_to_type("str"), path_to_type("typing", "Any")],
  51. path_to_type("None"),
  52. ),
  53. "__setitem__": (
  54. [path_to_type("typing", "Any"), path_to_type("typing", "Any")],
  55. path_to_type("typing", "Any"),
  56. ),
  57. "__sizeof__": None,
  58. "__str__": ([], path_to_type("str")),
  59. "__subclasshook__": None,
  60. }
  61. def module_stubs(module: Any) -> ast.Module:
  62. types_to_import = {"typing"}
  63. classes = []
  64. functions = []
  65. for member_name, member_value in inspect.getmembers(module):
  66. element_path = [module.__name__, member_name]
  67. if member_name.startswith("__"):
  68. pass
  69. elif member_name.startswith("DoraStatus"):
  70. pass
  71. elif inspect.isclass(member_value):
  72. classes.append(
  73. class_stubs(member_name, member_value, element_path, types_to_import)
  74. )
  75. elif inspect.isbuiltin(member_value):
  76. functions.append(
  77. function_stub(
  78. member_name,
  79. member_value,
  80. element_path,
  81. types_to_import,
  82. in_class=False,
  83. )
  84. )
  85. else:
  86. logging.warning(f"Unsupported root construction {member_name}")
  87. return ast.Module(
  88. body=[ast.Import(names=[ast.alias(name=t)]) for t in sorted(types_to_import)]
  89. + classes
  90. + functions,
  91. type_ignores=[],
  92. )
  93. def class_stubs(
  94. cls_name: str, cls_def: Any, element_path: List[str], types_to_import: Set[str]
  95. ) -> ast.ClassDef:
  96. attributes: List[ast.AST] = []
  97. methods: List[ast.AST] = []
  98. magic_methods: List[ast.AST] = []
  99. constants: List[ast.AST] = []
  100. for member_name, member_value in inspect.getmembers(cls_def):
  101. current_element_path = [*element_path, member_name]
  102. if member_name == "__init__":
  103. try:
  104. inspect.signature(cls_def) # we check it actually exists
  105. methods = [
  106. function_stub(
  107. member_name,
  108. cls_def,
  109. current_element_path,
  110. types_to_import,
  111. in_class=True,
  112. ),
  113. *methods,
  114. ]
  115. except ValueError as e:
  116. if "no signature found" not in str(e):
  117. raise ValueError(
  118. f"Error while parsing signature of {cls_name}.__init_"
  119. ) from e
  120. elif (
  121. member_value == OBJECT_MEMBERS.get(member_name)
  122. or BUILTINS.get(member_name, ()) is None
  123. ):
  124. pass
  125. elif inspect.isdatadescriptor(member_value):
  126. attributes.extend(
  127. data_descriptor_stub(
  128. member_name, member_value, current_element_path, types_to_import
  129. )
  130. )
  131. elif inspect.isroutine(member_value):
  132. (magic_methods if member_name.startswith("__") else methods).append(
  133. function_stub(
  134. member_name,
  135. member_value,
  136. current_element_path,
  137. types_to_import,
  138. in_class=True,
  139. )
  140. )
  141. elif member_name == "__match_args__":
  142. constants.append(
  143. ast.AnnAssign(
  144. target=ast.Name(id=member_name, ctx=ast.Store()),
  145. annotation=ast.Subscript(
  146. value=path_to_type("tuple"),
  147. slice=ast.Tuple(
  148. elts=[path_to_type("str"), ast.Ellipsis()], ctx=ast.Load()
  149. ),
  150. ctx=ast.Load(),
  151. ),
  152. value=ast.Constant(member_value),
  153. simple=1,
  154. )
  155. )
  156. elif member_value is not None:
  157. constants.append(
  158. ast.AnnAssign(
  159. target=ast.Name(id=member_name, ctx=ast.Store()),
  160. annotation=concatenated_path_to_type(
  161. member_value.__class__.__name__, element_path, types_to_import
  162. ),
  163. value=ast.Ellipsis(),
  164. simple=1,
  165. )
  166. )
  167. else:
  168. logging.warning(
  169. f"Unsupported member {member_name} of class {'.'.join(element_path)}"
  170. )
  171. doc = inspect.getdoc(cls_def)
  172. doc_comment = build_doc_comment(doc) if doc else None
  173. return ast.ClassDef(
  174. cls_name,
  175. bases=[],
  176. keywords=[],
  177. body=(
  178. ([doc_comment] if doc_comment else [])
  179. + attributes
  180. + methods
  181. + magic_methods
  182. + constants
  183. )
  184. or [ast.Ellipsis()],
  185. decorator_list=[path_to_type("typing", "final")],
  186. )
  187. def data_descriptor_stub(
  188. data_desc_name: str,
  189. data_desc_def: Any,
  190. element_path: List[str],
  191. types_to_import: Set[str],
  192. ) -> Union[Tuple[ast.AnnAssign, ast.Expr], Tuple[ast.AnnAssign]]:
  193. annotation = None
  194. doc_comment = None
  195. doc = inspect.getdoc(data_desc_def)
  196. if doc is not None:
  197. annotation = returns_stub(data_desc_name, doc, element_path, types_to_import)
  198. m = re.findall(r"^ *:return: *(.*) *$", doc, re.MULTILINE)
  199. if len(m) == 1:
  200. doc_comment = m[0]
  201. elif len(m) > 1:
  202. raise ValueError(
  203. f"Multiple return annotations found with :return: in {'.'.join(element_path)} documentation"
  204. )
  205. assign = ast.AnnAssign(
  206. target=ast.Name(id=data_desc_name, ctx=ast.Store()),
  207. annotation=annotation or path_to_type("typing", "Any"),
  208. simple=1,
  209. )
  210. doc_comment = build_doc_comment(doc_comment) if doc_comment else None
  211. return (assign, doc_comment) if doc_comment else (assign,)
  212. def function_stub(
  213. fn_name: str,
  214. fn_def: Any,
  215. element_path: List[str],
  216. types_to_import: Set[str],
  217. *,
  218. in_class: bool,
  219. ) -> ast.FunctionDef:
  220. body: List[ast.AST] = []
  221. doc = inspect.getdoc(fn_def)
  222. if doc is not None:
  223. doc_comment = build_doc_comment(doc)
  224. if doc_comment is not None:
  225. body.append(doc_comment)
  226. decorator_list = []
  227. if in_class and hasattr(fn_def, "__self__"):
  228. decorator_list.append(ast.Name("staticmethod"))
  229. return ast.FunctionDef(
  230. fn_name,
  231. arguments_stub(fn_name, fn_def, doc or "", element_path, types_to_import),
  232. body or [ast.Ellipsis()],
  233. decorator_list=decorator_list,
  234. returns=(
  235. returns_stub(fn_name, doc, element_path, types_to_import) if doc else None
  236. ),
  237. lineno=0,
  238. )
  239. def arguments_stub(
  240. callable_name: str,
  241. callable_def: Any,
  242. doc: str,
  243. element_path: List[str],
  244. types_to_import: Set[str],
  245. ) -> ast.arguments:
  246. real_parameters: Mapping[str, inspect.Parameter] = inspect.signature(
  247. callable_def
  248. ).parameters
  249. if callable_name == "__init__":
  250. real_parameters = {
  251. "self": inspect.Parameter("self", inspect.Parameter.POSITIONAL_ONLY),
  252. **real_parameters,
  253. }
  254. parsed_param_types = {}
  255. optional_params = set()
  256. # Types for magic functions types
  257. builtin = BUILTINS.get(callable_name)
  258. if isinstance(builtin, tuple):
  259. param_names = list(real_parameters.keys())
  260. if param_names and param_names[0] == "self":
  261. del param_names[0]
  262. for name, t in zip(param_names, builtin[0]):
  263. parsed_param_types[name] = t
  264. # Types from comment
  265. for match in re.findall(
  266. r"^ *:type *([a-zA-Z0-9_]+): ([^\n]*) *$", doc, re.MULTILINE
  267. ):
  268. if match[0] not in real_parameters:
  269. raise ValueError(
  270. f"The parameter {match[0]} of {'.'.join(element_path)} "
  271. "is defined in the documentation but not in the function signature"
  272. )
  273. type = match[1]
  274. if type.endswith(", optional"):
  275. optional_params.add(match[0])
  276. type = type[:-10]
  277. parsed_param_types[match[0]] = convert_type_from_doc(
  278. type, element_path, types_to_import
  279. )
  280. # we parse the parameters
  281. posonlyargs = []
  282. args = []
  283. vararg = None
  284. kwonlyargs = []
  285. kw_defaults = []
  286. kwarg = None
  287. defaults = []
  288. for param in real_parameters.values():
  289. if param.name != "self" and param.name not in parsed_param_types:
  290. raise ValueError(
  291. f"The parameter {param.name} of {'.'.join(element_path)} "
  292. "has no type definition in the function documentation"
  293. )
  294. param_ast = ast.arg(
  295. arg=param.name, annotation=parsed_param_types.get(param.name)
  296. )
  297. default_ast = None
  298. if param.default != param.empty:
  299. default_ast = ast.Constant(param.default)
  300. if param.name not in optional_params:
  301. raise ValueError(
  302. f"Parameter {param.name} of {'.'.join(element_path)} "
  303. "is optional according to the type but not flagged as such in the doc"
  304. )
  305. elif param.name in optional_params:
  306. raise ValueError(
  307. f"Parameter {param.name} of {'.'.join(element_path)} "
  308. "is optional according to the documentation but has no default value"
  309. )
  310. if param.kind == param.POSITIONAL_ONLY:
  311. args.append(param_ast)
  312. # posonlyargs.append(param_ast)
  313. # defaults.append(default_ast)
  314. elif param.kind == param.POSITIONAL_OR_KEYWORD:
  315. args.append(param_ast)
  316. defaults.append(default_ast)
  317. elif param.kind == param.VAR_POSITIONAL:
  318. vararg = param_ast
  319. elif param.kind == param.KEYWORD_ONLY:
  320. kwonlyargs.append(param_ast)
  321. kw_defaults.append(default_ast)
  322. elif param.kind == param.VAR_KEYWORD:
  323. kwarg = param_ast
  324. return ast.arguments(
  325. posonlyargs=posonlyargs,
  326. args=args,
  327. vararg=vararg,
  328. kwonlyargs=kwonlyargs,
  329. kw_defaults=kw_defaults,
  330. defaults=defaults,
  331. kwarg=kwarg,
  332. )
  333. def returns_stub(
  334. callable_name: str, doc: str, element_path: List[str], types_to_import: Set[str]
  335. ) -> Optional[ast.AST]:
  336. m = re.findall(r"^ *:rtype: *([^\n]*) *$", doc, re.MULTILINE)
  337. if len(m) == 0:
  338. builtin = BUILTINS.get(callable_name)
  339. if isinstance(builtin, tuple) and builtin[1] is not None:
  340. return builtin[1]
  341. raise ValueError(
  342. f"The return type of {'.'.join(element_path)} "
  343. "has no type definition using :rtype: in the function documentation"
  344. )
  345. if len(m) > 1:
  346. raise ValueError(
  347. f"Multiple return type annotations found with :rtype: for {'.'.join(element_path)}"
  348. )
  349. return convert_type_from_doc(m[0], element_path, types_to_import)
  350. def convert_type_from_doc(
  351. type_str: str, element_path: List[str], types_to_import: Set[str]
  352. ) -> ast.AST:
  353. type_str = type_str.strip()
  354. return parse_type_to_ast(type_str, element_path, types_to_import)
  355. def parse_type_to_ast(
  356. type_str: str, element_path: List[str], types_to_import: Set[str]
  357. ) -> ast.AST:
  358. # let's tokenize
  359. tokens = []
  360. current_token = ""
  361. for c in type_str:
  362. if "a" <= c <= "z" or "A" <= c <= "Z" or c == ".":
  363. current_token += c
  364. else:
  365. if current_token:
  366. tokens.append(current_token)
  367. current_token = ""
  368. if c != " ":
  369. tokens.append(c)
  370. if current_token:
  371. tokens.append(current_token)
  372. # let's first parse nested parenthesis
  373. stack: List[List[Any]] = [[]]
  374. for token in tokens:
  375. if token == "[":
  376. children: List[str] = []
  377. stack[-1].append(children)
  378. stack.append(children)
  379. elif token == "]":
  380. stack.pop()
  381. else:
  382. stack[-1].append(token)
  383. # then it's easy
  384. def parse_sequence(sequence: List[Any]) -> ast.AST:
  385. # we split based on "or"
  386. or_groups: List[List[str]] = [[]]
  387. print(sequence)
  388. # TODO: Fix sequence
  389. if "Ros" in sequence and "2" in sequence:
  390. sequence = ["".join(sequence)]
  391. elif "dora.Ros" in sequence and "2" in sequence:
  392. sequence = ["".join(sequence)]
  393. for e in sequence:
  394. if e == "or":
  395. or_groups.append([])
  396. else:
  397. or_groups[-1].append(e)
  398. if any(not g for g in or_groups):
  399. raise ValueError(
  400. f"Not able to parse type '{type_str}' used by {'.'.join(element_path)}"
  401. )
  402. new_elements: List[ast.AST] = []
  403. for group in or_groups:
  404. if len(group) == 1 and isinstance(group[0], str):
  405. new_elements.append(
  406. concatenated_path_to_type(group[0], element_path, types_to_import)
  407. )
  408. elif (
  409. len(group) == 2
  410. and isinstance(group[0], str)
  411. and isinstance(group[1], list)
  412. ):
  413. new_elements.append(
  414. ast.Subscript(
  415. value=concatenated_path_to_type(
  416. group[0], element_path, types_to_import
  417. ),
  418. slice=parse_sequence(group[1]),
  419. ctx=ast.Load(),
  420. )
  421. )
  422. else:
  423. raise ValueError(
  424. f"Not able to parse type '{type_str}' used by {'.'.join(element_path)}"
  425. )
  426. return reduce(
  427. lambda left, right: ast.BinOp(left=left, op=ast.BitOr(), right=right),
  428. new_elements,
  429. )
  430. return parse_sequence(stack[0])
  431. def concatenated_path_to_type(
  432. path: str, element_path: List[str], types_to_import: Set[str]
  433. ) -> ast.AST:
  434. parts = path.split(".")
  435. if any(not p for p in parts):
  436. raise ValueError(
  437. f"Not able to parse type '{path}' used by {'.'.join(element_path)}"
  438. )
  439. if len(parts) > 1:
  440. types_to_import.add(".".join(parts[:-1]))
  441. return path_to_type(*parts)
  442. def build_doc_comment(doc: str) -> Optional[ast.Expr]:
  443. lines = [line.strip() for line in doc.split("\n")]
  444. clean_lines = []
  445. for line in lines:
  446. if line.startswith((":type", ":rtype")):
  447. continue
  448. clean_lines.append(line)
  449. text = "\n".join(clean_lines).strip()
  450. return ast.Expr(value=ast.Constant(text)) if text else None
  451. def format_with_ruff(file: str) -> None:
  452. subprocess.check_call(["python", "-m", "ruff", "format", file])
  453. if __name__ == "__main__":
  454. parser = argparse.ArgumentParser(
  455. description="Extract Python type stub from a python module."
  456. )
  457. parser.add_argument(
  458. "module_name", help="Name of the Python module for which generate stubs"
  459. )
  460. parser.add_argument(
  461. "out",
  462. help="Name of the Python stub file to write to",
  463. type=argparse.FileType("wt"),
  464. )
  465. parser.add_argument(
  466. "--ruff", help="Formats the generated stubs using Ruff", action="store_true"
  467. )
  468. args = parser.parse_args()
  469. stub_content = ast.unparse(module_stubs(importlib.import_module(args.module_name)))
  470. args.out.write(stub_content)
  471. if args.ruff:
  472. format_with_ruff(args.out.name)