You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

parser.py 20 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562
  1. # This is the Python adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
  2. #
  3. # Copyright 2020 Huawei Technologies Co., Ltd
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. # ============================================================================
  17. """The module of parser python object, called by c++."""
  18. import ast
  19. import hashlib
  20. import inspect
  21. import types
  22. from dataclasses import is_dataclass
  23. from textwrap import dedent
  24. import asttokens
  25. from mindspore import Tensor as MsTensor
  26. from mindspore import context
  27. from mindspore import log as logger
  28. from mindspore import nn
  29. from mindspore import ops
  30. from mindspore.common.api import _MindSporeFunction
  31. from mindspore.common.dtype import pytype_to_dtype
  32. from .namespace import CellNamespace, ClosureNamespace, ClassMemberNamespace
  33. from .resources import parse_object_map, convert_object_map, trope_ns, SYMBOL_UNDEFINE, NO_IMPLEMENT
  34. # define return value
  35. RET_SUCCESS = 0
  36. RET_FAILURE = 0xFF
  37. # define resolve type
  38. RESOLVE_TYPE_NONE = 0 # resolve None
  39. RESOLVE_TYPE_FUNCTION = 1 # resolve function
  40. RESOLVE_TYPE_METHOD = 2 # resolve class method
  41. RESOLVE_TYPE_CLASS_TYPE = 3 # resolve class type
  42. RESOLVE_TYPE_CLASS_INSTANCE = 4 # resolve the class instance of common class
  43. RESOLVE_TYPE_INVALID = 0xFF
  44. # define the class instance detail type
  45. # When the type is RESOLVE_TYPE_CLASS_INSTANCE
  46. CLASS_INSTANCE_TYPE_CELL = 0 # class instance type is Cell
  47. CLASS_INSTANCE_TYPE_PRIMITIVE = 1 # class instance type is Primitive
  48. CLASS_INSTANCE_TYPE_INVALID = 0xFF
  49. # Ast main type
  50. AST_MAIN_TYPE_STMT = 0 # ast.Stmt
  51. AST_MAIN_TYPE_EXPR = 1 # ast.Expr
  52. AST_MAIN_TYPE_SLICE = 2 # ast.Slice
  53. AST_MAIN_TYPE_UNKNOWN = 0xFF # unknown
  54. # Ast sub type
  55. AST_SUB_TYPE_AND = 3 # ast.And
  56. AST_SUB_TYPE_OR = 4 # ast.Or
  57. AST_SUB_TYPE_NAME = 5 # ast.Name
  58. AST_SUB_TYPE_TUPLE = 6 # ast.Tuple
  59. AST_SUB_TYPE_SUBSCRIPT = 7 # ast.Subscript
  60. AST_SUB_TYPE_STARRED = 8 # ast.Starred
  61. AST_SUB_TYPE_UNKNOWN = 0xFF # unknown
  62. # Process expr statement white list
  63. # add as needed, eg: "clear", "extend", "insert", "remove", "reverse"
  64. parse_expr_statement_white_list = (
  65. "append",
  66. )
  67. def create_slice_obj(start, end, step):
  68. """Create slice object"""
  69. return slice(start, end, step)
  70. def parse_cb(func, parse_method=None):
  71. """Implements the function of parse."""
  72. return Parser(func, parse_method)
  73. def get_parse_method_of_class(obj, parse_method=None):
  74. """
  75. Het parse method of class.
  76. Args:
  77. obj(Object): Instance of class.
  78. parse_method(str): Save the method name. Cell object has default method named 'construct'.
  79. Returns:
  80. Function, obj's method.
  81. """
  82. method = None
  83. method_name = None
  84. if parse_method is not None:
  85. method_name = parse_method
  86. else:
  87. if isinstance(obj, nn.Cell):
  88. if obj.enable_hook:
  89. if context.get_context("mode") == context.GRAPH_MODE:
  90. raise ValueError("The graph mode does not support hook function.")
  91. method_name = "_hook_construct"
  92. else:
  93. method_name = "construct"
  94. if method_name is not None:
  95. if hasattr(obj, method_name):
  96. method = getattr(obj, method_name)
  97. return method
  98. def get_bprop_method_of_class(obj, parse_method=None):
  99. """
  100. Get bprop method of class.
  101. Args:
  102. obj (Object): Instance of class.
  103. parse_method(str): Save the method name. Cell object has default method named 'bprop'.
  104. Returns:
  105. Function, obj's method.
  106. """
  107. method = None
  108. if isinstance(obj, nn.Cell):
  109. method_name = "bprop"
  110. if hasattr(obj, method_name):
  111. method = getattr(obj, method_name)
  112. return method
  113. def resolve_symbol(namespace, symbol):
  114. """
  115. Resolve a symbol.
  116. Note:
  117. Can't get function when use closure function. So save the fn on namespace.
  118. Args:
  119. namespace (Object): Symbol's namespace.
  120. symbol (str): Need resolve symbol.
  121. Returns:
  122. Object, resolve result of symbol.
  123. """
  124. # All exceptions need to be caught in this function
  125. try:
  126. resolve_ = namespace[symbol]
  127. # list and dict is not hashable ,it can not be key for the map, just return the result
  128. if isinstance(resolve_, (tuple, list, dict)):
  129. return resolve_
  130. # dataclass may not be hashable
  131. if getattr(resolve_, "__hash__") is None:
  132. return resolve_
  133. # If need trope the obj
  134. if resolve_ in convert_object_map:
  135. resolve_ = convert_object_map.get(resolve_)
  136. logger.debug("convert resolve = %r", resolve_)
  137. if resolve_ == NO_IMPLEMENT:
  138. raise NotImplementedError("not implemented for ", str(symbol))
  139. except Exception as e:
  140. if isinstance(e, NotImplementedError):
  141. raise e
  142. resolve_ = None
  143. logger.debug("resolve exception occurred, value = %r", e)
  144. logger.debug("resolve type is invalid, namespace = %s, symbol = %s",
  145. namespace.__str__(), symbol)
  146. if isinstance(resolve_, _MindSporeFunction):
  147. logger.debug("resolve class _MindSporeFunction, resolve fn instead.")
  148. resolve_ = resolve_.fn
  149. return resolve_
  150. def generate_scope(obj):
  151. """Generate the scope for every cell object in the network."""
  152. if isinstance(obj, nn.Cell):
  153. obj.generate_scope()
  154. def get_scope_name(obj):
  155. """Returns the scope of a cell object in one network."""
  156. if isinstance(obj, nn.Cell):
  157. return obj.get_scope()
  158. return None
  159. def get_object_key(obj):
  160. """Return the function key: module + name."""
  161. obj_key = ""
  162. if hasattr(obj, "__name__"):
  163. if hasattr(obj, "cell_init_args"):
  164. obj_key = "%s_ID" % (str(obj.__class__.__name__) + str(obj.__name__) + obj.cell_init_args)
  165. obj_id = "%s_ID%d" % (str(obj.__class__.__name__) + str(obj.__name__), id(obj))
  166. else:
  167. # `<class 'xxxxxxx'>`
  168. # -> `xxxxxxx`
  169. tag = str(obj.__class__)[8:-2]
  170. if hasattr(obj, "cell_init_args"):
  171. obj_key = "%s_ID" % (tag + obj.cell_init_args)
  172. obj_id = "%s_ID%d" % (tag, id(obj))
  173. logger.debug("obj_key %s obj_id = %s", obj_key, obj_id)
  174. # method has same id of different instance
  175. if isinstance(obj, types.MethodType):
  176. method_instance = obj.__self__
  177. instance_id = "%s_ID%d" % (str(method_instance.__class__.__name__), id(method_instance))
  178. obj_id = instance_id + obj_id + str(obj.__hash__())
  179. return obj_id, obj_key
  180. def is_class_member(node):
  181. """Check the attr is class member variable."""
  182. type_ = node.__class__.__name__
  183. if type_ == "Attribute":
  184. if not hasattr(node.value, "id"):
  185. return False
  186. id_ = node.value.id
  187. if id_ == "self":
  188. return True
  189. return False
  190. def get_obj_id(obj):
  191. """Get the obj id."""
  192. return str(id(obj))
  193. def get_obj_type(obj):
  194. """Get the obj type."""
  195. obj_type = RESOLVE_TYPE_INVALID
  196. if obj is None:
  197. obj_type = RESOLVE_TYPE_NONE
  198. elif isinstance(obj, types.FunctionType):
  199. obj_type = RESOLVE_TYPE_FUNCTION
  200. elif isinstance(obj, types.MethodType):
  201. obj_type = RESOLVE_TYPE_METHOD
  202. elif isinstance(obj, type):
  203. obj_type = RESOLVE_TYPE_CLASS_TYPE
  204. elif _is_class_instance(obj):
  205. obj_type = RESOLVE_TYPE_CLASS_INSTANCE
  206. else:
  207. # here for ndarray, just print its shape (in case of the array to large and print many data in screen)
  208. is_ndarray = type(obj).__name__ == 'ndarray' and hasattr(obj, 'shape')
  209. raise TypeError(f'Invalid object with type `{type(obj)}` and {"shape" if is_ndarray else "value"} '
  210. f'`{obj.shape if is_ndarray else obj}`.')
  211. return obj_type
  212. def get_class_instance_type(obj):
  213. """Get the class instance detail type."""
  214. # check the obj type
  215. logger.debug("Get the class type(%r)", obj)
  216. class_type = CLASS_INSTANCE_TYPE_INVALID
  217. if _is_class_instance(obj):
  218. if isinstance(obj, nn.Cell):
  219. class_type = CLASS_INSTANCE_TYPE_CELL
  220. elif isinstance(obj, ops.Primitive):
  221. class_type = CLASS_INSTANCE_TYPE_PRIMITIVE
  222. # Add the other type base requirement
  223. return class_type
  224. def _is_class_instance(obj):
  225. """Confirm the obj is class instance."""
  226. return isinstance(obj, (nn.Cell, ops.Primitive)) or _is_dataclass_instance(obj)
  227. def _is_dataclass_instance(obj):
  228. """check whether a class is an instance of a dataclass (and not a dataclass itself)"""
  229. return is_dataclass(obj) and not isinstance(obj, type)
  230. def create_obj_instance(cls_type, args_tuple=None):
  231. """Create python instance."""
  232. obj = None
  233. if isinstance(cls_type, type):
  234. # check the type, now only support nn.Cell and Primitive
  235. if issubclass(cls_type, (nn.Cell, ops.Primitive)):
  236. if args_tuple is not None:
  237. obj = cls_type(*args_tuple)
  238. else:
  239. obj = cls_type()
  240. return obj
  241. def get_module_namespace(obj):
  242. """Get the module's namespace."""
  243. logger.debug("get module namespace, module = %r", obj)
  244. mod_namespace = None
  245. if isinstance(obj, types.ModuleType):
  246. mod_namespace = CellNamespace(obj.__name__)
  247. else:
  248. logger.warning("Module(%r) is invalid, get namespace failure!", obj)
  249. return mod_namespace
  250. def get_class_member_namespace_symbol(obj):
  251. """Get obj class member type."""
  252. logger.debug("get class instance namespace, object = %r", obj)
  253. class_namespace = ClassMemberNamespace(obj)
  254. logger.debug("class namesapce = %r", class_namespace)
  255. return class_namespace
  256. def get_dataclass_attributes(cls):
  257. """Get attributes of dataclass."""
  258. fields = cls.__dataclass_fields__
  259. attributes = {name: pytype_to_dtype(field.type)
  260. for name, field in fields.items()}
  261. return attributes
  262. def get_dataclass_methods(cls):
  263. """Get functions of dataclass."""
  264. methods = {name: getattr(cls, name)
  265. for name in dir(cls)
  266. if isinstance(getattr(cls, name), (types.FunctionType,))}
  267. return methods
  268. def convert_to_ms_tensor(data):
  269. """Convert C++ tensor to mindspore tensor."""
  270. return MsTensor(data)
  271. def get_object_description(obj, fname, fline):
  272. """return method or funcition description for error report, include location, class name, etc."""
  273. if isinstance(obj, types.MethodType):
  274. obj_cls = obj.__self__.__class__
  275. class_name = f'{obj_cls.__module__}.{obj_cls.__qualname__}'
  276. cls_fname = inspect.getfile(obj_cls)
  277. _, cls_fline = inspect.getsourcelines(obj_cls)
  278. class_loc = f'{cls_fname}:{cls_fline}'
  279. return f"bound method '{obj.__name__}' at {fname}:{fline} of <{class_name} at {class_loc} object>"
  280. if isinstance(obj, types.FunctionType):
  281. return f"function '{obj.__name__}' at {fname}:{fline}"
  282. if isinstance(obj, ast.FunctionDef):
  283. return f"function '{obj.name}' at {fname}:{fline}"
  284. return str(obj)
  285. class Parser:
  286. """
  287. Parser python code to ast tree.
  288. Args:
  289. fn(FunctionType/MethodType): Need parse object instance.
  290. parse_method(ExtendInfoOfParseObj): Extend information for parse the function.
  291. ast_cache: Dictionary for caching ast tree.
  292. """
  293. ast_cache = {}
  294. def __init__(self, fn: (types.FunctionType, types.MethodType), parse_method=None) -> None:
  295. self.fn = fn
  296. self.parse_method = parse_method
  297. self.line_offset = 0
  298. self.filename: str = inspect.getfile(self.fn)
  299. # Used to resolve the function's globals Namespace.
  300. self.global_namespace = CellNamespace(fn.__module__)
  301. self.function_module = fn.__module__
  302. # Used to resolve the function's nonlocals.
  303. self.closure_namespace = ClosureNamespace(fn)
  304. self.function_name = fn.__name__
  305. self.col_offset = 0
  306. def parse(self):
  307. """Parse the function or method."""
  308. logger.debug("fn = %r", self.fn)
  309. tree = None
  310. if isinstance(self.fn, (types.FunctionType, types.MethodType)):
  311. lines, self.line_offset = inspect.getsourcelines(self.fn)
  312. original_src = ''.join(lines)
  313. hexstr = hashlib.sha256(original_src.encode()).hexdigest()
  314. tree = Parser.ast_cache.get(hexstr)
  315. if not tree:
  316. src = dedent(original_src)
  317. self.col_offset = \
  318. len(original_src.split('\n')[0]) - len(src.split('\n')[0])
  319. logger.debug("get source = %s", src)
  320. tree = asttokens.ASTTokens(src, parse=True).tree
  321. Parser.ast_cache[hexstr] = tree
  322. else:
  323. logger.error("Fn type is invalid")
  324. return tree
  325. def get_args(self, node):
  326. """Get the arg of parse object."""
  327. args = []
  328. # process position args
  329. for arg in node.args.args:
  330. args.append(arg)
  331. # process kwonlyargs: kwonlyargs is append after position args
  332. if node.args.kwonlyargs:
  333. for kwarg in node.args.kwonlyargs:
  334. args.append(kwarg)
  335. # process vararg: vararg is append after kwonlyargs
  336. if node.args.vararg:
  337. args.append(node.args.vararg)
  338. # process kwarg: kwarg is append after vararg
  339. if node.args.kwarg:
  340. args.append(node.args.kwarg)
  341. return args
  342. def get_args_default_values(self, node):
  343. """get the args'default values of parse object."""
  344. nondefaults = [None] * (len(node.args.args) - len(node.args.defaults))
  345. defaults = nondefaults + node.args.defaults + node.args.kw_defaults
  346. if node.args.vararg:
  347. defaults.append(None)
  348. if node.args.kwarg:
  349. defaults.append(None)
  350. return defaults
  351. def get_node_type(self, node):
  352. """Process an ast node."""
  353. method_name = f'{node.__class__.__name__}'
  354. node_type = [method_name]
  355. # judge the ast main type
  356. if isinstance(node, ast.stmt):
  357. node_type.append(AST_MAIN_TYPE_STMT)
  358. elif isinstance(node, (ast.expr, ast.slice)) or node is None:
  359. # ast.slice and ast.expr should be expr
  360. node_type.append(AST_MAIN_TYPE_EXPR)
  361. else:
  362. node_type.append(AST_MAIN_TYPE_UNKNOWN)
  363. return node_type
  364. def get_ast_type(self, node):
  365. """Get the ast type."""
  366. ast_type = AST_SUB_TYPE_UNKNOWN
  367. if isinstance(node, ast.And):
  368. ast_type = AST_SUB_TYPE_AND
  369. elif isinstance(node, ast.Or):
  370. ast_type = AST_SUB_TYPE_OR
  371. elif isinstance(node, ast.Name):
  372. ast_type = AST_SUB_TYPE_NAME
  373. elif isinstance(node, ast.Tuple):
  374. ast_type = AST_SUB_TYPE_TUPLE
  375. elif isinstance(node, ast.Subscript):
  376. ast_type = AST_SUB_TYPE_SUBSCRIPT
  377. elif isinstance(node, ast.Starred):
  378. ast_type = AST_SUB_TYPE_STARRED
  379. else:
  380. ast_type = AST_SUB_TYPE_UNKNOWN
  381. return ast_type
  382. def get_namespace_symbol(self, var: str):
  383. """Get symbol type and namespace and symbol."""
  384. if var in self.closure_namespace:
  385. ops_info = (self.closure_namespace, var)
  386. logger.debug("in closure_namespace")
  387. elif var in self.global_namespace:
  388. ops_info = (self.global_namespace, var)
  389. logger.debug("in global_namespace")
  390. else:
  391. ops_info = parse_object_map.get(SYMBOL_UNDEFINE)
  392. ops_info = [ops_info[0], var]
  393. return ops_info
  394. def get_operation_namespace_symbol(self, var: str):
  395. """Get operation namespace and symbol."""
  396. ops_info = (trope_ns, var)
  397. logger.debug("get operation ops info = %r", ops_info)
  398. return ops_info
  399. def get_ast_namespace_symbol(self, obj):
  400. """Get obj type and namespace and symbol."""
  401. # step 1:get symbol from object map
  402. ops_info = parse_object_map.get(type(obj), SYMBOL_UNDEFINE)
  403. logger.debug("ops info = %r", ops_info)
  404. return ops_info
  405. def analyze_super(self, class_type_node, subclass_instance):
  406. """Analyze super and return a class instance."""
  407. sub_class = type(subclass_instance)
  408. if class_type_node is None:
  409. return super(sub_class, subclass_instance)
  410. if isinstance(class_type_node, ast.Name):
  411. class_name = getattr(class_type_node, 'id')
  412. elif isinstance(class_type_node, ast.Attribute):
  413. class_name = getattr(class_type_node, 'attr')
  414. else:
  415. raise ValueError(f"When call 'super', the first arg should be a class type, "
  416. f"but got {class_type_node.__class__.__name__}.")
  417. target_father_class = None
  418. for class_element in sub_class.mro():
  419. if class_element.__name__ == class_name:
  420. target_father_class = class_element
  421. break
  422. if target_father_class is None:
  423. raise ValueError("When call 'super', the second arg should be an instance of first arg.")
  424. return super(target_father_class, subclass_instance)
  425. def get_location(self, node):
  426. """
  427. Get location of node start and end line no.
  428. Args:
  429. node: AST op node or tuple or List. This is a node in the ANF diagram,
  430. here is the code location to get this node.
  431. Returns:
  432. List, [fileName, linestart, colstart, lineend, colend].
  433. """
  434. ret = [self.filename]
  435. err_exit = 0
  436. if isinstance(node, (list, tuple)):
  437. node_size = len(node)
  438. if node_size == 0:
  439. err_exit = 1
  440. else:
  441. start_node = node[0]
  442. end_node = node[-1]
  443. else:
  444. start_node = node
  445. end_node = node
  446. if err_exit == 0:
  447. if hasattr(start_node, "lineno") and \
  448. hasattr(end_node, "col_offset"):
  449. start_lineno, start_colno = start_node.first_token.start
  450. end_lineno, end_colno = end_node.last_token.end
  451. start_lineno += self.line_offset - 1
  452. start_colno += self.col_offset
  453. end_lineno += self.line_offset - 1
  454. end_colno += self.col_offset
  455. ret = ret + [start_lineno, start_colno, end_lineno, end_colno]
  456. else:
  457. ret = ret + [0, 0, 0, 0]
  458. return ret
  459. def expand_expr_statement(self, node):
  460. """
  461. Process the expr statement and expand it.
  462. Returns:
  463. tuple, (True, expr.value, x)/(False, None, None).
  464. """
  465. if isinstance(node, ast.Expr) and hasattr(node, "value"):
  466. expr_value = node.value
  467. if isinstance(expr_value, ast.Call):
  468. func = expr_value.func
  469. if isinstance(func, ast.Attribute) and \
  470. hasattr(func, "attr") and \
  471. hasattr(func, "value"):
  472. method = func.attr
  473. target = func.value
  474. if method in parse_expr_statement_white_list:
  475. logger.debug("Expand expr, target:%s, method:%s", target, method)
  476. return True, expr_value, target
  477. return True, expr_value
  478. return False, None, None