You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

generate_stubs.py 18 kB

10 months ago
10 months ago
10 months ago
10 months ago
10 months ago
10 months ago
10 months ago
10 months ago
10 months ago
10 months ago
10 months ago
10 months ago
10 months ago
10 months ago
10 months ago
10 months ago
10 months ago
10 months ago
10 months ago
10 months ago
10 months ago
10 months ago
10 months ago
10 months ago
10 months ago
10 months ago
10 months ago
10 months ago
10 months ago
10 months ago
10 months ago
10 months ago
10 months ago
10 months ago
10 months ago
10 months ago
10 months ago
10 months ago
10 months ago
10 months ago
10 months ago
10 months ago
10 months ago
10 months ago
10 months ago
10 months ago
10 months ago
10 months ago
10 months ago
10 months ago
10 months ago
10 months ago
10 months ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528
  1. """TODO: Add docstring."""
  2. import argparse
  3. import ast
  4. import importlib
  5. import inspect
  6. import logging
  7. import re
  8. import subprocess
  9. from collections.abc import Mapping
  10. from functools import reduce
  11. from typing import Any, Dict, List, Optional, Set, Tuple, Union
  12. def path_to_type(*elements: str) -> ast.AST:
  13. """TODO: Add docstring."""
  14. base: ast.AST = ast.Name(id=elements[0], ctx=ast.Load())
  15. for e in elements[1:]:
  16. base = ast.Attribute(value=base, attr=e, ctx=ast.Load())
  17. return base
  18. OBJECT_MEMBERS = dict(inspect.getmembers(object))
  19. BUILTINS: Dict[str, Union[None, Tuple[List[ast.AST], ast.AST]]] = {
  20. "__annotations__": None,
  21. "__bool__": ([], path_to_type("bool")),
  22. "__bytes__": ([], path_to_type("bytes")),
  23. "__class__": None,
  24. "__contains__": ([path_to_type("typing", "Any")], path_to_type("bool")),
  25. "__del__": None,
  26. "__delattr__": ([path_to_type("str")], path_to_type("None")),
  27. "__delitem__": ([path_to_type("typing", "Any")], path_to_type("typing", "Any")),
  28. "__dict__": None,
  29. "__dir__": None,
  30. "__doc__": None,
  31. "__eq__": ([path_to_type("typing", "Any")], path_to_type("bool")),
  32. "__format__": ([path_to_type("str")], path_to_type("str")),
  33. "__ge__": ([path_to_type("typing", "Any")], path_to_type("bool")),
  34. "__getattribute__": ([path_to_type("str")], path_to_type("typing", "Any")),
  35. "__getitem__": ([path_to_type("typing", "Any")], path_to_type("typing", "Any")),
  36. "__gt__": ([path_to_type("typing", "Any")], path_to_type("bool")),
  37. "__hash__": ([], path_to_type("int")),
  38. "__init__": ([], path_to_type("None")),
  39. "__init_subclass__": None,
  40. "__iter__": ([], path_to_type("typing", "Any")),
  41. "__le__": ([path_to_type("typing", "Any")], path_to_type("bool")),
  42. "__len__": ([], path_to_type("int")),
  43. "__lt__": ([path_to_type("typing", "Any")], path_to_type("bool")),
  44. "__module__": None,
  45. "__ne__": ([path_to_type("typing", "Any")], path_to_type("bool")),
  46. "__new__": None,
  47. "__next__": ([], path_to_type("typing", "Any")),
  48. "__int__": ([], path_to_type("None")),
  49. "__reduce__": None,
  50. "__reduce_ex__": None,
  51. "__repr__": ([], path_to_type("str")),
  52. "__setattr__": (
  53. [path_to_type("str"), path_to_type("typing", "Any")],
  54. path_to_type("None"),
  55. ),
  56. "__setitem__": (
  57. [path_to_type("typing", "Any"), path_to_type("typing", "Any")],
  58. path_to_type("typing", "Any"),
  59. ),
  60. "__sizeof__": None,
  61. "__str__": ([], path_to_type("str")),
  62. "__subclasshook__": None,
  63. }
  64. def module_stubs(module: Any) -> ast.Module:
  65. """TODO: Add docstring."""
  66. types_to_import = {"typing"}
  67. classes = []
  68. functions = []
  69. for member_name, member_value in inspect.getmembers(module):
  70. element_path = [module.__name__, member_name]
  71. if member_name.startswith("__") or member_name.startswith("DoraStatus"):
  72. pass
  73. elif inspect.isclass(member_value):
  74. classes.append(
  75. class_stubs(member_name, member_value, element_path, types_to_import),
  76. )
  77. elif inspect.isbuiltin(member_value):
  78. functions.append(
  79. function_stub(
  80. member_name,
  81. member_value,
  82. element_path,
  83. types_to_import,
  84. in_class=False,
  85. ),
  86. )
  87. else:
  88. logging.warning(f"Unsupported root construction {member_name}")
  89. return ast.Module(
  90. body=[ast.Import(names=[ast.alias(name=t)]) for t in sorted(types_to_import)]
  91. + classes
  92. + functions,
  93. type_ignores=[],
  94. )
  95. def class_stubs(
  96. cls_name: str, cls_def: Any, element_path: List[str], types_to_import: Set[str],
  97. ) -> ast.ClassDef:
  98. """TODO: Add docstring."""
  99. attributes: List[ast.AST] = []
  100. methods: List[ast.AST] = []
  101. magic_methods: List[ast.AST] = []
  102. constants: List[ast.AST] = []
  103. for member_name, member_value in inspect.getmembers(cls_def):
  104. current_element_path = [*element_path, member_name]
  105. if member_name == "__init__":
  106. try:
  107. inspect.signature(cls_def) # we check it actually exists
  108. methods = [
  109. function_stub(
  110. member_name,
  111. cls_def,
  112. current_element_path,
  113. types_to_import,
  114. in_class=True,
  115. ),
  116. *methods,
  117. ]
  118. except ValueError as e:
  119. if "no signature found" not in str(e):
  120. raise ValueError(
  121. f"Error while parsing signature of {cls_name}.__init_",
  122. ) from e
  123. elif (
  124. member_value == OBJECT_MEMBERS.get(member_name)
  125. or BUILTINS.get(member_name, ()) is None
  126. ):
  127. pass
  128. elif inspect.isdatadescriptor(member_value):
  129. attributes.extend(
  130. data_descriptor_stub(
  131. member_name, member_value, current_element_path, types_to_import,
  132. ),
  133. )
  134. elif inspect.isroutine(member_value):
  135. (magic_methods if member_name.startswith("__") else methods).append(
  136. function_stub(
  137. member_name,
  138. member_value,
  139. current_element_path,
  140. types_to_import,
  141. in_class=True,
  142. ),
  143. )
  144. elif member_name == "__match_args__":
  145. constants.append(
  146. ast.AnnAssign(
  147. target=ast.Name(id=member_name, ctx=ast.Store()),
  148. annotation=ast.Subscript(
  149. value=path_to_type("tuple"),
  150. slice=ast.Tuple(
  151. elts=[path_to_type("str"), ast.Ellipsis()], ctx=ast.Load(),
  152. ),
  153. ctx=ast.Load(),
  154. ),
  155. value=ast.Constant(member_value),
  156. simple=1,
  157. ),
  158. )
  159. elif member_value is not None:
  160. constants.append(
  161. ast.AnnAssign(
  162. target=ast.Name(id=member_name, ctx=ast.Store()),
  163. annotation=concatenated_path_to_type(
  164. member_value.__class__.__name__, element_path, types_to_import,
  165. ),
  166. value=ast.Ellipsis(),
  167. simple=1,
  168. ),
  169. )
  170. else:
  171. logging.warning(
  172. f"Unsupported member {member_name} of class {'.'.join(element_path)}",
  173. )
  174. doc = inspect.getdoc(cls_def)
  175. doc_comment = build_doc_comment(doc) if doc else None
  176. return ast.ClassDef(
  177. cls_name,
  178. bases=[],
  179. keywords=[],
  180. body=(
  181. ([doc_comment] if doc_comment else [])
  182. + attributes
  183. + methods
  184. + magic_methods
  185. + constants
  186. )
  187. or [ast.Ellipsis()],
  188. decorator_list=[path_to_type("typing", "final")],
  189. )
  190. def data_descriptor_stub(
  191. data_desc_name: str,
  192. data_desc_def: Any,
  193. element_path: List[str],
  194. types_to_import: Set[str],
  195. ) -> Union[Tuple[ast.AnnAssign, ast.Expr], Tuple[ast.AnnAssign]]:
  196. """TODO: Add docstring."""
  197. annotation = None
  198. doc_comment = None
  199. doc = inspect.getdoc(data_desc_def)
  200. if doc is not None:
  201. annotation = returns_stub(data_desc_name, doc, element_path, types_to_import)
  202. m = re.findall(r"^ *:return: *(.*) *$", doc, re.MULTILINE)
  203. if len(m) == 1:
  204. doc_comment = m[0]
  205. elif len(m) > 1:
  206. raise ValueError(
  207. f"Multiple return annotations found with :return: in {'.'.join(element_path)} documentation",
  208. )
  209. assign = ast.AnnAssign(
  210. target=ast.Name(id=data_desc_name, ctx=ast.Store()),
  211. annotation=annotation or path_to_type("typing", "Any"),
  212. simple=1,
  213. )
  214. doc_comment = build_doc_comment(doc_comment) if doc_comment else None
  215. return (assign, doc_comment) if doc_comment else (assign,)
  216. def function_stub(
  217. fn_name: str,
  218. fn_def: Any,
  219. element_path: List[str],
  220. types_to_import: Set[str],
  221. *,
  222. in_class: bool,
  223. ) -> ast.FunctionDef:
  224. """TODO: Add docstring."""
  225. body: List[ast.AST] = []
  226. doc = inspect.getdoc(fn_def)
  227. if doc is not None:
  228. doc_comment = build_doc_comment(doc)
  229. if doc_comment is not None:
  230. body.append(doc_comment)
  231. decorator_list = []
  232. if in_class and hasattr(fn_def, "__self__"):
  233. decorator_list.append(ast.Name("staticmethod"))
  234. return ast.FunctionDef(
  235. fn_name,
  236. arguments_stub(fn_name, fn_def, doc or "", element_path, types_to_import),
  237. body or [ast.Ellipsis()],
  238. decorator_list=decorator_list,
  239. returns=(
  240. returns_stub(fn_name, doc, element_path, types_to_import) if doc else None
  241. ),
  242. lineno=0,
  243. )
  244. def arguments_stub(
  245. callable_name: str,
  246. callable_def: Any,
  247. doc: str,
  248. element_path: List[str],
  249. types_to_import: Set[str],
  250. ) -> ast.arguments:
  251. """TODO: Add docstring."""
  252. real_parameters: Mapping[str, inspect.Parameter] = inspect.signature(
  253. callable_def,
  254. ).parameters
  255. if callable_name == "__init__":
  256. real_parameters = {
  257. "self": inspect.Parameter("self", inspect.Parameter.POSITIONAL_ONLY),
  258. **real_parameters,
  259. }
  260. parsed_param_types = {}
  261. optional_params = set()
  262. # Types for magic functions types
  263. builtin = BUILTINS.get(callable_name)
  264. if isinstance(builtin, tuple):
  265. param_names = list(real_parameters.keys())
  266. if param_names and param_names[0] == "self":
  267. del param_names[0]
  268. parsed_param_types = {name: t for name, t in zip(param_names, builtin[0])}
  269. # Types from comment
  270. for match in re.findall(
  271. r"^ *:type *([a-zA-Z0-9_]+): ([^\n]*) *$", doc, re.MULTILINE,
  272. ):
  273. if match[0] not in real_parameters:
  274. raise ValueError(
  275. f"The parameter {match[0]} of {'.'.join(element_path)} "
  276. "is defined in the documentation but not in the function signature",
  277. )
  278. type = match[1]
  279. if type.endswith(", optional"):
  280. optional_params.add(match[0])
  281. type = type[:-10]
  282. parsed_param_types[match[0]] = convert_type_from_doc(
  283. type, element_path, types_to_import,
  284. )
  285. # we parse the parameters
  286. posonlyargs = []
  287. args = []
  288. vararg = None
  289. kwonlyargs = []
  290. kw_defaults = []
  291. kwarg = None
  292. defaults = []
  293. for param in real_parameters.values():
  294. if param.name != "self" and param.name not in parsed_param_types:
  295. raise ValueError(
  296. f"The parameter {param.name} of {'.'.join(element_path)} "
  297. "has no type definition in the function documentation",
  298. )
  299. param_ast = ast.arg(
  300. arg=param.name, annotation=parsed_param_types.get(param.name),
  301. )
  302. default_ast = None
  303. if param.default != param.empty:
  304. default_ast = ast.Constant(param.default)
  305. if param.name not in optional_params:
  306. raise ValueError(
  307. f"Parameter {param.name} of {'.'.join(element_path)} "
  308. "is optional according to the type but not flagged as such in the doc",
  309. )
  310. elif param.name in optional_params:
  311. raise ValueError(
  312. f"Parameter {param.name} of {'.'.join(element_path)} "
  313. "is optional according to the documentation but has no default value",
  314. )
  315. if param.kind == param.POSITIONAL_ONLY:
  316. args.append(param_ast)
  317. # posonlyargs.append(param_ast)
  318. # defaults.append(default_ast)
  319. elif param.kind == param.POSITIONAL_OR_KEYWORD:
  320. args.append(param_ast)
  321. defaults.append(default_ast)
  322. elif param.kind == param.VAR_POSITIONAL:
  323. vararg = param_ast
  324. elif param.kind == param.KEYWORD_ONLY:
  325. kwonlyargs.append(param_ast)
  326. kw_defaults.append(default_ast)
  327. elif param.kind == param.VAR_KEYWORD:
  328. kwarg = param_ast
  329. return ast.arguments(
  330. posonlyargs=posonlyargs,
  331. args=args,
  332. vararg=vararg,
  333. kwonlyargs=kwonlyargs,
  334. kw_defaults=kw_defaults,
  335. defaults=defaults,
  336. kwarg=kwarg,
  337. )
  338. def returns_stub(
  339. callable_name: str, doc: str, element_path: List[str], types_to_import: Set[str],
  340. ) -> Optional[ast.AST]:
  341. """TODO: Add docstring."""
  342. m = re.findall(r"^ *:rtype: *([^\n]*) *$", doc, re.MULTILINE)
  343. if len(m) == 0:
  344. builtin = BUILTINS.get(callable_name)
  345. if isinstance(builtin, tuple) and builtin[1] is not None:
  346. return builtin[1]
  347. raise ValueError(
  348. f"The return type of {'.'.join(element_path)} "
  349. "has no type definition using :rtype: in the function documentation",
  350. )
  351. if len(m) > 1:
  352. raise ValueError(
  353. f"Multiple return type annotations found with :rtype: for {'.'.join(element_path)}",
  354. )
  355. return convert_type_from_doc(m[0], element_path, types_to_import)
  356. def convert_type_from_doc(
  357. type_str: str, element_path: List[str], types_to_import: Set[str],
  358. ) -> ast.AST:
  359. """TODO: Add docstring."""
  360. type_str = type_str.strip()
  361. return parse_type_to_ast(type_str, element_path, types_to_import)
  362. def parse_type_to_ast(
  363. type_str: str, element_path: List[str], types_to_import: Set[str],
  364. ) -> ast.AST:
  365. # let's tokenize
  366. """TODO: Add docstring."""
  367. tokens = []
  368. current_token = ""
  369. for c in type_str:
  370. if "a" <= c <= "z" or "A" <= c <= "Z" or c == ".":
  371. current_token += c
  372. else:
  373. if current_token:
  374. tokens.append(current_token)
  375. current_token = ""
  376. if c != " ":
  377. tokens.append(c)
  378. if current_token:
  379. tokens.append(current_token)
  380. # let's first parse nested parenthesis
  381. stack: List[List[Any]] = [[]]
  382. for token in tokens:
  383. if token == "[":
  384. children: List[str] = []
  385. stack[-1].append(children)
  386. stack.append(children)
  387. elif token == "]":
  388. stack.pop()
  389. else:
  390. stack[-1].append(token)
  391. # then it's easy
  392. def parse_sequence(sequence: List[Any]) -> ast.AST:
  393. # we split based on "or"
  394. """TODO: Add docstring."""
  395. or_groups: List[List[str]] = [[]]
  396. print(sequence)
  397. # TODO: Fix sequence
  398. if ("Ros" in sequence and "2" in sequence) or ("dora.Ros" in sequence and "2" in sequence):
  399. sequence = ["".join(sequence)]
  400. for e in sequence:
  401. if e == "or":
  402. or_groups.append([])
  403. else:
  404. or_groups[-1].append(e)
  405. if any(not g for g in or_groups):
  406. raise ValueError(
  407. f"Not able to parse type '{type_str}' used by {'.'.join(element_path)}",
  408. )
  409. new_elements: List[ast.AST] = []
  410. for group in or_groups:
  411. if len(group) == 1 and isinstance(group[0], str):
  412. new_elements.append(
  413. concatenated_path_to_type(group[0], element_path, types_to_import),
  414. )
  415. elif (
  416. len(group) == 2
  417. and isinstance(group[0], str)
  418. and isinstance(group[1], list)
  419. ):
  420. new_elements.append(
  421. ast.Subscript(
  422. value=concatenated_path_to_type(
  423. group[0], element_path, types_to_import,
  424. ),
  425. slice=parse_sequence(group[1]),
  426. ctx=ast.Load(),
  427. ),
  428. )
  429. else:
  430. raise ValueError(
  431. f"Not able to parse type '{type_str}' used by {'.'.join(element_path)}",
  432. )
  433. return reduce(
  434. lambda left, right: ast.BinOp(left=left, op=ast.BitOr(), right=right),
  435. new_elements,
  436. )
  437. return parse_sequence(stack[0])
  438. def concatenated_path_to_type(
  439. path: str, element_path: List[str], types_to_import: Set[str],
  440. ) -> ast.AST:
  441. """TODO: Add docstring."""
  442. parts = path.split(".")
  443. if any(not p for p in parts):
  444. raise ValueError(
  445. f"Not able to parse type '{path}' used by {'.'.join(element_path)}",
  446. )
  447. if len(parts) > 1:
  448. types_to_import.add(".".join(parts[:-1]))
  449. return path_to_type(*parts)
  450. def build_doc_comment(doc: str) -> Optional[ast.Expr]:
  451. """TODO: Add docstring."""
  452. lines = [line.strip() for line in doc.split("\n")]
  453. clean_lines = []
  454. for line in lines:
  455. if line.startswith((":type", ":rtype")):
  456. continue
  457. clean_lines.append(line)
  458. text = "\n".join(clean_lines).strip()
  459. return ast.Expr(value=ast.Constant(text)) if text else None
  460. def format_with_ruff(file: str) -> None:
  461. """TODO: Add docstring."""
  462. subprocess.check_call(["python", "-m", "ruff", "format", file])
  463. if __name__ == "__main__":
  464. parser = argparse.ArgumentParser(
  465. description="Extract Python type stub from a python module.",
  466. )
  467. parser.add_argument(
  468. "module_name", help="Name of the Python module for which generate stubs",
  469. )
  470. parser.add_argument(
  471. "out",
  472. help="Name of the Python stub file to write to",
  473. type=argparse.FileType("wt"),
  474. )
  475. parser.add_argument(
  476. "--ruff", help="Formats the generated stubs using Ruff", action="store_true",
  477. )
  478. args = parser.parse_args()
  479. stub_content = ast.unparse(module_stubs(importlib.import_module(args.module_name)))
  480. args.out.write(stub_content)
  481. if args.ruff:
  482. format_with_ruff(args.out.name)