You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

parser.py 18 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513
  1. # This is the Python adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
  2. #
  3. # Copyright 2020 Huawei Technologies Co., Ltd
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. # ============================================================================
  17. """The module of parser python object, called by c++."""
  18. import ast
  19. import types
  20. import inspect
  21. import hashlib
  22. from textwrap import dedent
  23. from dataclasses import is_dataclass
  24. import asttokens
  25. import mindspore.nn as nn
  26. from mindspore import log as logger
  27. from mindspore import ops
  28. from mindspore.common.dtype import pytype_to_dtype
  29. from mindspore.common.api import _MindSporeFunction
  30. from .namespace import CellNamespace, ClosureNamespace, ClassMemberNamespace
  31. from .resources import parse_object_map, convert_object_map, trope_ns, SYMBOL_UNDEFINE, NO_IMPLEMENT
  32. # define return value
  33. RET_SUCCESS = 0
  34. RET_FAILURE = 0xFF
  35. # define resolve type
  36. RESOLVE_TYPE_NONE = 0 # resolve None
  37. RESOLVE_TYPE_FUNCTION = 1 # resolve function
  38. RESOLVE_TYPE_METHOD = 2 # resolve class method
  39. RESOLVE_TYPE_CLASS_TYPE = 3 # resolve class type
  40. RESOLVE_TYPE_CLASS_INSTANCE = 4 # resolve the class instance of common class
  41. RESOLVE_TYPE_INVALID = 0xFF
  42. # define the class instance detail type
  43. # When the type is RESOLVE_TYPE_CLASS_INSTANCE
  44. CLASS_INSTANCE_TYPE_CELL = 0 # class instance type is Cell
  45. CLASS_INSTANCE_TYPE_PRIMITIVE = 1 # class instance type is Primitive
  46. CLASS_INSTANCE_TYPE_INVALID = 0xFF
  47. # Ast main type
  48. AST_MAIN_TYPE_STMT = 0 # ast.Stmt
  49. AST_MAIN_TYPE_EXPR = 1 # ast.Expr
  50. AST_MAIN_TYPE_SLICE = 2 # ast.Slice
  51. AST_MAIN_TYPE_UNKNOWN = 0xFF # unknown
  52. # Ast sub type
  53. AST_SUB_TYPE_AND = 3 # ast.And
  54. AST_SUB_TYPE_OR = 4 # ast.Or
  55. AST_SUB_TYPE_NAME = 5 # ast.Name
  56. AST_SUB_TYPE_TUPLE = 6 # ast.Tuple
  57. AST_SUB_TYPE_SUBSCRIPT = 7 # ast.Subscript
  58. AST_SUB_TYPE_STARRED = 8 # ast.Starred
  59. AST_SUB_TYPE_UNKNOWN = 0xFF # unknown
  60. # Process expr statement white list
  61. # add as needed, eg: "clear", "extend", "insert", "remove", "reverse"
  62. parse_expr_statement_white_list = (
  63. "append",
  64. )
  65. def create_slice_obj(start, end, step):
  66. """Create slice object"""
  67. return slice(start, end, step)
  68. def parse_cb(func, parse_method=None):
  69. """Implements the function of parse."""
  70. return Parser(func, parse_method)
  71. def get_parse_method_of_class(obj, parse_method=None):
  72. """
  73. Het parse method of class.
  74. Args:
  75. obj(Object): Instance of class.
  76. parse_method(str): Save the method name. Cell object has default method named 'construct'.
  77. Returns:
  78. Function, obj's method.
  79. """
  80. method = None
  81. method_name = None
  82. if parse_method is not None:
  83. method_name = parse_method
  84. else:
  85. if isinstance(obj, nn.Cell):
  86. if obj.enable_hook:
  87. method_name = "_hook_construct"
  88. else:
  89. method_name = "construct"
  90. if method_name is not None:
  91. if hasattr(obj, method_name):
  92. method = getattr(obj, method_name)
  93. return method
  94. def get_bprop_method_of_class(obj, parse_method=None):
  95. """
  96. Get bprop method of class.
  97. Args:
  98. obj (Object): Instance of class.
  99. parse_method(str): Save the method name. Cell object has default method named 'bprop'.
  100. Returns:
  101. Function, obj's method.
  102. """
  103. method = None
  104. if isinstance(obj, nn.Cell):
  105. method_name = "bprop"
  106. if hasattr(obj, method_name):
  107. method = getattr(obj, method_name)
  108. return method
  109. def resolve_symbol(namespace, symbol):
  110. """
  111. Resolve a symbol.
  112. Note:
  113. Can't get function when use closure function. So save the fn on namespace.
  114. Args:
  115. namespace (Object): Symbol's namespace.
  116. symbol (str): Need resolve symbol.
  117. Returns:
  118. Object, resolve result of symbol.
  119. """
  120. # All exceptions need to be caught in this function
  121. try:
  122. resolve_ = namespace[symbol]
  123. # list and dict is not hashable ,it can not be key for the map, just return the result
  124. if isinstance(resolve_, (list, dict)):
  125. return resolve_
  126. # dataclass may not be hashable
  127. if getattr(resolve_, "__hash__") is None:
  128. return resolve_
  129. # If need trope the obj
  130. if resolve_ in convert_object_map:
  131. resolve_ = convert_object_map.get(resolve_)
  132. logger.debug("convert resolve = %r", resolve_)
  133. if resolve_ == NO_IMPLEMENT:
  134. raise NotImplementedError("not implemented for ", str(symbol))
  135. except Exception as e:
  136. if isinstance(e, NotImplementedError):
  137. raise e
  138. resolve_ = None
  139. logger.debug("resolve exception occurred, value = %r", e)
  140. logger.debug("resolve type is invalid, namespace = %s, symbol = %s",
  141. namespace.__str__(), symbol)
  142. if isinstance(resolve_, _MindSporeFunction):
  143. logger.debug("resolve class _MindSporeFunction, resolve fn instead.")
  144. resolve_ = resolve_.fn
  145. return resolve_
  146. def generate_scope(obj):
  147. """Generate the scope for every cell object in the network."""
  148. if isinstance(obj, nn.Cell):
  149. obj.generate_scope()
  150. def get_scope_name(obj):
  151. """Returns the scope of a cell object in one network."""
  152. if isinstance(obj, nn.Cell):
  153. return obj.get_scope()
  154. return None
  155. def get_object_key(obj):
  156. """Return the function key: module + name."""
  157. obj_key = ""
  158. if hasattr(obj, "__name__"):
  159. if hasattr(obj, "cell_init_args"):
  160. obj_key = "%s_ID" % (str(obj.__class__.__name__) + str(obj.__name__) + obj.cell_init_args)
  161. obj_id = "%s_ID%d" % (str(obj.__class__.__name__) + str(obj.__name__), id(obj))
  162. else:
  163. if hasattr(obj, "cell_init_args"):
  164. obj_key = "%s_ID" % (str(obj.__class__.__name__) + obj.cell_init_args)
  165. obj_id = "%s_ID%d" % (str(obj.__class__.__name__), id(obj))
  166. logger.debug("obj_key %s obj_id = %s", obj_key, obj_id)
  167. # method has same id of different instance
  168. if isinstance(obj, types.MethodType):
  169. method_instance = obj.__self__
  170. instance_id = "%s_ID%d" % (str(method_instance.__class__.__name__), id(method_instance))
  171. obj_id = instance_id + obj_id
  172. return obj_id, obj_key
  173. def get_default_input(obj):
  174. if hasattr(obj, '__parameter__'):
  175. return obj.default_input
  176. if isinstance(obj, tuple):
  177. convert = lambda x: x.default_input if hasattr(x, '__parameter__') else x
  178. args = tuple(convert(x) for x in obj)
  179. return args
  180. return obj
  181. def is_class_member(node):
  182. """Check the attr is class member variable."""
  183. type_ = node.__class__.__name__
  184. if type_ == "Attribute":
  185. if not hasattr(node.value, "id"):
  186. return False
  187. id_ = node.value.id
  188. if id_ == "self":
  189. return True
  190. return False
  191. def get_obj_id(obj):
  192. """Get the obj id."""
  193. return str(id(obj))
  194. def get_obj_type(obj):
  195. """Get the obj type."""
  196. obj_type = RESOLVE_TYPE_INVALID
  197. if obj is None:
  198. obj_type = RESOLVE_TYPE_NONE
  199. elif isinstance(obj, types.FunctionType):
  200. obj_type = RESOLVE_TYPE_FUNCTION
  201. elif isinstance(obj, types.MethodType):
  202. obj_type = RESOLVE_TYPE_METHOD
  203. elif isinstance(obj, type):
  204. obj_type = RESOLVE_TYPE_CLASS_TYPE
  205. elif _is_class_instance(obj):
  206. obj_type = RESOLVE_TYPE_CLASS_INSTANCE
  207. else:
  208. # here for ndarray, just print its shape (in case of the array to large and print many data in screen)
  209. is_ndarray = type(obj).__name__ == 'ndarray' and hasattr(obj, 'shape')
  210. raise TypeError(f'Invalid object with type `{type(obj)}` and {"shape" if is_ndarray else "value"} '
  211. f'`{obj.shape if is_ndarray else obj}`.')
  212. return obj_type
  213. def get_class_instance_type(obj):
  214. """Get the class instance detail type."""
  215. # check the obj type
  216. logger.debug("Get the class type(%r)", obj)
  217. class_type = CLASS_INSTANCE_TYPE_INVALID
  218. if _is_class_instance(obj):
  219. if isinstance(obj, nn.Cell):
  220. class_type = CLASS_INSTANCE_TYPE_CELL
  221. elif isinstance(obj, ops.Primitive):
  222. class_type = CLASS_INSTANCE_TYPE_PRIMITIVE
  223. # Add the other type base requirement
  224. return class_type
  225. def _is_class_instance(obj):
  226. """Confirm the obj is class instance."""
  227. return isinstance(obj, (nn.Cell, ops.Primitive)) or _is_dataclass_instance(obj)
  228. def _is_dataclass_instance(obj):
  229. """check whether a class is an instance of a dataclass (and not a dataclass itself)"""
  230. return is_dataclass(obj) and not isinstance(obj, type)
  231. def create_obj_instance(cls_type, args_tuple=None):
  232. """Create python instance."""
  233. obj = None
  234. if isinstance(cls_type, type):
  235. # check the type, now only support nn.Cell and Primitive
  236. if issubclass(cls_type, (nn.Cell, ops.Primitive)):
  237. if args_tuple is not None:
  238. obj = cls_type(*args_tuple)
  239. else:
  240. obj = cls_type()
  241. return obj
  242. def get_module_namespace(obj):
  243. """Get the module's namespace."""
  244. logger.debug("get module namespace, module = %r", obj)
  245. mod_namespace = None
  246. if isinstance(obj, types.ModuleType):
  247. mod_namespace = CellNamespace(obj.__name__)
  248. else:
  249. logger.warning("Module(%r) is invalid, get namespace failure!", obj)
  250. return mod_namespace
  251. def get_class_member_namespace_symbol(obj):
  252. """Get obj class member type."""
  253. logger.debug("get class instance namespace, object = %r", obj)
  254. class_namespace = ClassMemberNamespace(obj)
  255. logger.debug("class namesapce = %r", class_namespace)
  256. return class_namespace
  257. def get_dataclass_attributes(cls):
  258. """Get attributes of dataclass."""
  259. fields = cls.__dataclass_fields__
  260. attributes = {name: pytype_to_dtype(field.type)
  261. for name, field in fields.items()}
  262. return attributes
  263. def get_dataclass_methods(cls):
  264. """Get functions of dataclass."""
  265. methods = {name: getattr(cls, name)
  266. for name in dir(cls)
  267. if isinstance(getattr(cls, name), (types.FunctionType,))}
  268. return methods
  269. class Parser:
  270. """
  271. Parser python code to ast tree.
  272. Args:
  273. fn(FunctionType/MethodType): Need parse object instance.
  274. parse_method(ExtendInfoOfParseObj): Extend information for parse the function.
  275. ast_cache: Dictionary for caching ast tree.
  276. """
  277. ast_cache = {}
  278. def __init__(self, fn: (types.FunctionType, types.MethodType), parse_method=None) -> None:
  279. self.fn = fn
  280. self.parse_method = parse_method
  281. _, self.line_offset = inspect.getsourcelines(self.fn)
  282. self.filename: str = inspect.getfile(self.fn)
  283. # Used to resolve the function's globals Namespace.
  284. self.global_namespace = CellNamespace(fn.__module__)
  285. self.function_module = fn.__module__
  286. # Used to resolve the function's nonlocals.
  287. self.closure_namespace = ClosureNamespace(fn)
  288. self.function_name = fn.__name__
  289. self.col_offset = 0
  290. def parse(self):
  291. """Parse the function or method."""
  292. logger.debug("fn = %r", self.fn)
  293. tree = None
  294. if isinstance(self.fn, (types.FunctionType, types.MethodType)):
  295. original_src = inspect.getsource(self.fn)
  296. hexstr = hashlib.sha256(original_src.encode()).hexdigest()
  297. tree = Parser.ast_cache.get(hexstr)
  298. if not tree:
  299. src = dedent(original_src)
  300. self.col_offset = \
  301. len(original_src.split('\n')[0]) - len(src.split('\n')[0])
  302. logger.debug("get source = %s", src)
  303. tree = asttokens.ASTTokens(src, parse=True).tree
  304. Parser.ast_cache[hexstr] = tree
  305. else:
  306. logger.error("Fn type is invalid")
  307. return tree
  308. def get_args(self, node):
  309. """Get the arg of parse object."""
  310. args = []
  311. # process position args
  312. for arg in node.args.args:
  313. args.append(arg)
  314. # process kwonlyargs: kwonlyargs is append after position args
  315. if node.args.kwonlyargs:
  316. for kwarg in node.args.kwonlyargs:
  317. args.append(kwarg)
  318. # process vararg: vararg is append after kwonlyargs
  319. if node.args.vararg:
  320. args.append(node.args.vararg)
  321. # process kwarg: kwarg is append after vararg
  322. if node.args.kwarg:
  323. args.append(node.args.kwarg)
  324. return args
  325. def get_args_default_values(self, node):
  326. """get the args'default values of parse object."""
  327. nondefaults = [None] * (len(node.args.args) - len(node.args.defaults))
  328. defaults = nondefaults + node.args.defaults + node.args.kw_defaults
  329. if node.args.vararg:
  330. defaults.append(None)
  331. if node.args.kwarg:
  332. defaults.append(None)
  333. return defaults
  334. def get_node_type(self, node):
  335. """Process an ast node."""
  336. method_name = f'{node.__class__.__name__}'
  337. node_type = [method_name]
  338. # judge the ast main type
  339. if isinstance(node, ast.stmt):
  340. node_type.append(AST_MAIN_TYPE_STMT)
  341. elif isinstance(node, (ast.expr, ast.slice)) or node is None:
  342. # ast.slice and ast.expr should be expr
  343. node_type.append(AST_MAIN_TYPE_EXPR)
  344. else:
  345. node_type.append(AST_MAIN_TYPE_UNKNOWN)
  346. return node_type
  347. def get_ast_type(self, node):
  348. """Get the ast type."""
  349. ast_type = AST_SUB_TYPE_UNKNOWN
  350. if isinstance(node, ast.And):
  351. ast_type = AST_SUB_TYPE_AND
  352. elif isinstance(node, ast.Or):
  353. ast_type = AST_SUB_TYPE_OR
  354. elif isinstance(node, ast.Name):
  355. ast_type = AST_SUB_TYPE_NAME
  356. elif isinstance(node, ast.Tuple):
  357. ast_type = AST_SUB_TYPE_TUPLE
  358. elif isinstance(node, ast.Subscript):
  359. ast_type = AST_SUB_TYPE_SUBSCRIPT
  360. elif isinstance(node, ast.Starred):
  361. ast_type = AST_SUB_TYPE_STARRED
  362. else:
  363. ast_type = AST_SUB_TYPE_UNKNOWN
  364. return ast_type
  365. def get_namespace_symbol(self, var: str):
  366. """Get symbol type and namespace and symbol."""
  367. if var in self.closure_namespace:
  368. ops_info = (self.closure_namespace, var)
  369. logger.debug("in closure_namespace")
  370. elif var in self.global_namespace:
  371. ops_info = (self.global_namespace, var)
  372. logger.debug("in global_namespace")
  373. else:
  374. ops_info = parse_object_map.get(SYMBOL_UNDEFINE)
  375. ops_info = [ops_info[0], var]
  376. return ops_info
  377. def get_operation_namespace_symbol(self, var: str):
  378. """Get operation namespace and symbol."""
  379. ops_info = (trope_ns, var)
  380. logger.debug("get operation ops info = %r", ops_info)
  381. return ops_info
  382. def get_ast_namespace_symbol(self, obj):
  383. """Get obj type and namespace and symbol."""
  384. # step 1:get symbol from object map
  385. ops_info = parse_object_map.get(type(obj), SYMBOL_UNDEFINE)
  386. logger.debug("ops info = %r", ops_info)
  387. return ops_info
  388. def get_location(self, node):
  389. """
  390. Get location of node start and end line no.
  391. Args:
  392. node: AST op node or tuple or List. This is a node in the ANF diagram,
  393. here is the code location to get this node.
  394. Returns:
  395. List, [fileName, linestart, colstart, lineend, colend].
  396. """
  397. ret = [self.filename]
  398. err_exit = 0
  399. if isinstance(node, (list, tuple)):
  400. node_size = len(node)
  401. if node_size == 0:
  402. err_exit = 1
  403. else:
  404. start_node = node[0]
  405. end_node = node[-1]
  406. else:
  407. start_node = node
  408. end_node = node
  409. if err_exit == 0:
  410. if hasattr(start_node, "lineno") and \
  411. hasattr(end_node, "col_offset"):
  412. start_lineno, start_colno = start_node.first_token.start
  413. end_lineno, end_colno = end_node.last_token.end
  414. start_lineno += self.line_offset - 1
  415. start_colno += self.col_offset
  416. end_lineno += self.line_offset - 1
  417. end_colno += self.col_offset
  418. ret = ret + [start_lineno, start_colno, end_lineno, end_colno]
  419. else:
  420. ret = ret + [0, 0, 0, 0]
  421. return ret
  422. def expand_expr_statement(self, node):
  423. """
  424. Process the expr statement and expand it.
  425. Returns:
  426. tuple, (True, expr.value, x)/(False, None, None).
  427. """
  428. if isinstance(node, ast.Expr) and hasattr(node, "value"):
  429. expr_value = node.value
  430. if isinstance(expr_value, ast.Call):
  431. func = expr_value.func
  432. if isinstance(func, ast.Attribute) and \
  433. hasattr(func, "attr") and \
  434. hasattr(func, "value"):
  435. method = func.attr
  436. target = func.value
  437. if method in parse_expr_statement_white_list:
  438. logger.debug("Expand expr, target:%s, method:%s", target, method)
  439. return True, expr_value, target
  440. return True, expr_value
  441. return False, None, None