You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

parser.py 30 kB

5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
4 years ago
4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780
  1. # This is the Python adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
  2. #
  3. # Copyright 2020-2021 Huawei Technologies Co., Ltd
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. # ============================================================================
  17. """The module of parser python object, called by c++."""
  18. import os
  19. import ast
  20. import hashlib
  21. import inspect
  22. import types
  23. from dataclasses import is_dataclass
  24. from textwrap import dedent
  25. import asttokens
  26. from mindspore import Tensor
  27. from mindspore import log as logger
  28. from mindspore import nn
  29. from mindspore import ops
  30. from mindspore.common.api import _MindsporeFunctionExecutor
  31. from mindspore.common.dtype import pytype_to_dtype
  32. from .namespace import CellNamespace, ClosureNamespace, ClassMemberNamespace
  33. from .resources import parse_object_map, convert_object_map, trope_ns, SYMBOL_UNDEFINE, NO_IMPLEMENT
  34. # define return value
  35. RET_SUCCESS = 0
  36. RET_FAILURE = 0xFF
  37. # define resolve type
  38. RESOLVE_TYPE_NONE = 0 # resolve None
  39. RESOLVE_TYPE_FUNCTION = 1 # resolve function
  40. RESOLVE_TYPE_METHOD = 2 # resolve class method
  41. RESOLVE_TYPE_CLASS_TYPE = 3 # resolve class type
  42. RESOLVE_TYPE_CLASS_INSTANCE = 4 # resolve the class instance of common class
  43. RESOLVE_TYPE_INVALID = 0xFF
  44. # define the class instance detail type
  45. # When the type is RESOLVE_TYPE_CLASS_INSTANCE
  46. CLASS_INSTANCE_TYPE_CELL = 0 # class instance type is Cell
  47. CLASS_INSTANCE_TYPE_PRIMITIVE = 1 # class instance type is Primitive
  48. CLASS_INSTANCE_TYPE_INVALID = 0xFF
  49. # Ast main type
  50. AST_MAIN_TYPE_STMT = 0 # ast.Stmt
  51. AST_MAIN_TYPE_EXPR = 1 # ast.Expr
  52. AST_MAIN_TYPE_SLICE = 2 # ast.Slice
  53. AST_MAIN_TYPE_UNKNOWN = 0xFF # unknown
  54. # Ast sub type
  55. AST_SUB_TYPE_AND = 3 # ast.And
  56. AST_SUB_TYPE_OR = 4 # ast.Or
  57. AST_SUB_TYPE_NAME = 5 # ast.Name
  58. AST_SUB_TYPE_TUPLE = 6 # ast.Tuple
  59. AST_SUB_TYPE_SUBSCRIPT = 7 # ast.Subscript
  60. AST_SUB_TYPE_STARRED = 8 # ast.Starred
  61. AST_SUB_TYPE_ATTRIBUTE = 9 # ast.Attribute
  62. AST_SUB_TYPE_UNKNOWN = 0xFF # unknown
  63. # Process expr statement white list
  64. # add as needed, eg: "clear", "extend", "insert", "remove", "reverse"
  65. parse_expr_statement_white_list = (
  66. "append",
  67. )
  68. _builtin_function_or_method_type = type(abs)
  69. def create_slice_obj(start, end, step):
  70. """Create slice object"""
  71. return slice(start, end, step)
  72. def parse_cb(func, parse_method=None):
  73. """Implements the function of parse."""
  74. return Parser(func, parse_method)
  75. def get_parse_method_of_class(obj, parse_method=None):
  76. """
  77. Het parse method of class.
  78. Args:
  79. obj(Object): Instance of class.
  80. parse_method(str): Save the method name. Cell object has default method named 'construct'.
  81. Returns:
  82. Function, obj's method.
  83. """
  84. method = None
  85. method_name = None
  86. if parse_method is not None:
  87. method_name = parse_method
  88. elif isinstance(obj, nn.Cell):
  89. if obj.enable_hook:
  90. method_name = "_hook_construct"
  91. else:
  92. method_name = "construct"
  93. if method_name is not None:
  94. if hasattr(obj, method_name):
  95. method = getattr(obj, method_name)
  96. return method
  97. def get_bprop_method_of_class(obj, parse_method=None):
  98. """
  99. Get bprop method of class.
  100. Args:
  101. obj (Object): Instance of class.
  102. parse_method(str): Save the method name. Cell object has default method named 'bprop'.
  103. Returns:
  104. Function, obj's method.
  105. """
  106. method = None
  107. if isinstance(obj, nn.Cell):
  108. method_name = "bprop"
  109. if hasattr(obj, method_name):
  110. method = getattr(obj, method_name)
  111. return method
  112. # The fallback feature is enabled in default.
  113. # Not support change the flag during the process is alive.
  114. support_fallback_ = os.getenv('DEV_ENV_ENABLE_FALLBACK')
  115. def resolve_symbol(namespace, symbol):
  116. """
  117. Resolve a symbol.
  118. Note:
  119. Can't get function when use closure function. So save the fn on namespace.
  120. Args:
  121. namespace (Object): Symbol's namespace.
  122. symbol (str): Need resolve symbol.
  123. Returns:
  124. Object, resolve result of symbol.
  125. """
  126. # All exceptions need to be caught in this function
  127. try:
  128. resolve_ = namespace[symbol]
  129. # list and dict is not hashable ,it can not be key for the map, just return the result
  130. if isinstance(resolve_, (tuple, list, dict)):
  131. return resolve_
  132. # dataclass may not be hashable
  133. if getattr(resolve_, "__hash__") is None:
  134. return resolve_
  135. # Raise a proper error if not using Fallback feature.
  136. if support_fallback_ == '0':
  137. # Raise NotImplementedError when parsing the numpy methods, but not the numpy constant.
  138. if namespace.name == "numpy" and \
  139. isinstance(resolve_, (types.FunctionType, types.MethodType, types.ModuleType)):
  140. raise NotImplementedError("Mindspore does not support to use the numpy methods " \
  141. "within the construct() or @ms_function decorated function in graph mode.")
  142. # If need trope the obj
  143. if resolve_ in convert_object_map:
  144. resolve_ = convert_object_map.get(resolve_)
  145. logger.debug("Convert resolve = %r", resolve_)
  146. if resolve_ == NO_IMPLEMENT:
  147. raise NotImplementedError(f"Not support for '{symbol}'.")
  148. except Exception as e:
  149. if isinstance(e, NotImplementedError):
  150. raise e
  151. resolve_ = None
  152. logger.debug("Resolve exception occurred, value = %r", e)
  153. logger.debug("Resolve type is invalid, namespace = %s, symbol = %s",
  154. namespace.__str__(), symbol)
  155. if isinstance(resolve_, _MindsporeFunctionExecutor):
  156. logger.debug("Resolve class _MindsporeFunctionExecutor, resolve fn instead.")
  157. resolve_ = resolve_.fn
  158. logger.debug(f"Found '{symbol}' in {namespace.__str__()}, resolved: {resolve_} / {type(resolve_)}")
  159. return resolve_
  160. def generate_scope(obj):
  161. """Generate the scope for every cell object in the network."""
  162. if isinstance(obj, nn.Cell):
  163. obj.generate_scope()
  164. def get_scope_name(obj):
  165. """Returns the scope of a cell object in one network."""
  166. if isinstance(obj, nn.Cell):
  167. return obj.get_scope()
  168. return None
  169. def get_object_key(obj):
  170. """Return the function key: module + name."""
  171. obj_key = ""
  172. if hasattr(obj, "__name__"):
  173. if hasattr(obj, "cell_init_args"):
  174. obj_key = "%s_ID" % (str(obj.__class__.__name__) + str(obj.__name__) + obj.cell_init_args)
  175. obj_id = "%s_ID%d" % (str(obj.__class__.__name__) + str(obj.__name__), id(obj))
  176. else:
  177. # `<class 'xxxxxxx'>`
  178. # -> `xxxxxxx`
  179. tag = str(obj.__class__)[8:-2]
  180. if hasattr(obj, "cell_init_args"):
  181. obj_key = "%s_ID" % (tag + obj.cell_init_args)
  182. obj_id = "%s_ID%d" % (tag, id(obj))
  183. logger.debug("obj_key %s obj_id = %s", obj_key, obj_id)
  184. # method has same id of different instance
  185. if isinstance(obj, types.MethodType):
  186. method_instance = obj.__self__
  187. instance_id = "%s_ID%d" % (str(method_instance.__class__.__name__), id(method_instance))
  188. obj_id = instance_id + obj_id + str(obj.__hash__())
  189. return obj_id, obj_key
  190. def is_class_member(node):
  191. """Check the attr is class member variable."""
  192. type_ = node.__class__.__name__
  193. if type_ == "Attribute":
  194. if not hasattr(node.value, "id"):
  195. return False
  196. id_ = node.value.id
  197. if id_ == "self":
  198. return True
  199. return False
  200. def get_obj_id(obj):
  201. """Get the obj id."""
  202. return str(id(obj))
  203. def get_obj_type(obj):
  204. """Get the obj type."""
  205. logger.debug("Get object type: %r", obj)
  206. obj_type = RESOLVE_TYPE_INVALID
  207. if obj is None:
  208. obj_type = RESOLVE_TYPE_NONE
  209. elif isinstance(obj, types.FunctionType):
  210. obj_type = RESOLVE_TYPE_FUNCTION
  211. elif isinstance(obj, types.MethodType):
  212. obj_type = RESOLVE_TYPE_METHOD
  213. elif isinstance(obj, type):
  214. obj_type = RESOLVE_TYPE_CLASS_TYPE
  215. elif _is_class_instance(obj):
  216. obj_type = RESOLVE_TYPE_CLASS_INSTANCE
  217. else:
  218. # Raise a proper error if not using Fallback feature.
  219. if support_fallback_ != '0':
  220. obj_type = RESOLVE_TYPE_INVALID
  221. else:
  222. # here for ndarray, just print its shape (in case of the array to large and print many data in screen)
  223. is_ndarray = type(obj).__name__ == 'ndarray' and hasattr(obj, 'shape')
  224. raise TypeError(f"Not support for this object with type '{type(obj)}' and "
  225. f"{'shape' if is_ndarray else 'value'} '{obj.shape if is_ndarray else obj}'.")
  226. return obj_type
  227. def get_class_instance_type(obj):
  228. """Get the class instance detail type."""
  229. # check the obj type
  230. logger.debug("Get the class type(%r)", obj)
  231. class_type = CLASS_INSTANCE_TYPE_INVALID
  232. if _is_class_instance(obj):
  233. if isinstance(obj, nn.Cell):
  234. class_type = CLASS_INSTANCE_TYPE_CELL
  235. elif isinstance(obj, ops.Primitive):
  236. class_type = CLASS_INSTANCE_TYPE_PRIMITIVE
  237. # Add the other type base requirement
  238. return class_type
  239. def _is_class_instance(obj):
  240. """Confirm the obj is class instance."""
  241. return isinstance(obj, (nn.Cell, ops.Primitive)) or _is_dataclass_instance(obj)
  242. def _is_dataclass_instance(obj):
  243. """check whether a class is an instance of a dataclass (and not a dataclass itself)"""
  244. return is_dataclass(obj) and not isinstance(obj, type)
  245. def _convert_tuple_to_args_kwargs(params):
  246. args = tuple()
  247. kwargs = dict()
  248. for param in params:
  249. if isinstance(param, dict):
  250. kwargs.update(param)
  251. else:
  252. args += (param,)
  253. return (args, kwargs)
  254. def is_supported_create_instance_type(cls_type):
  255. return issubclass(cls_type, (nn.Cell, ops.Primitive))
  256. def create_instance(cls_type, params=None):
  257. """Create python instance."""
  258. if not isinstance(cls_type, type):
  259. logger.warning(f"create_instance(), cls_type is not a type, cls_type: {cls_type}")
  260. return None
  261. # Check the type, now only support nn.Cell and Primitive.
  262. obj = None
  263. if is_supported_create_instance_type(cls_type):
  264. # Check arguments, only support *args or **kwargs.
  265. if params is None:
  266. obj = cls_type()
  267. elif isinstance(params, tuple):
  268. args, kwargs = _convert_tuple_to_args_kwargs(params)
  269. logger.debug(f"create_instance(), args: {args}, kwargs: {kwargs}")
  270. if args and kwargs:
  271. obj = cls_type(*args, **kwargs)
  272. elif args:
  273. obj = cls_type(*args)
  274. elif kwargs:
  275. obj = cls_type(**kwargs)
  276. # If invalid parameters.
  277. if obj is None:
  278. raise ValueError(f"When call 'create_instance', the parameter should be *args or **kwargs, "
  279. f"but got {params.__class__.__name__}, params: {params}")
  280. return obj
  281. def get_module_namespace(obj):
  282. """Get the module's namespace."""
  283. logger.debug("get module namespace, module = %r", obj)
  284. mod_namespace = None
  285. if isinstance(obj, types.ModuleType):
  286. mod_namespace = CellNamespace(obj.__name__)
  287. else:
  288. logger.warning("Module(%r) is invalid, get namespace failure!", obj)
  289. return mod_namespace
  290. def get_class_member_namespace_symbol(obj):
  291. """Get obj class member type."""
  292. logger.debug("get class instance namespace, object = %r", obj)
  293. class_namespace = ClassMemberNamespace(obj)
  294. logger.debug("class namesapce = %r", class_namespace)
  295. return class_namespace
  296. def get_dataclass_attributes(cls):
  297. """Get attributes of dataclass."""
  298. fields = cls.__dataclass_fields__
  299. attributes = {name: pytype_to_dtype(field.type)
  300. for name, field in fields.items()}
  301. return attributes
  302. def get_dataclass_methods(cls):
  303. """Get functions of dataclass."""
  304. methods = {name: getattr(cls, name)
  305. for name in dir(cls)
  306. if isinstance(getattr(cls, name), (types.FunctionType,))}
  307. return methods
  308. def convert_to_ms_tensor(data):
  309. """Convert C++ tensor to mindspore tensor."""
  310. return Tensor(data)
  311. def get_object_description(obj, fname, fline):
  312. """return method or funcition description for error report, include location, class name, etc."""
  313. if isinstance(obj, types.MethodType):
  314. obj_cls = obj.__self__.__class__
  315. class_name = f"{obj_cls.__module__}.{obj_cls.__qualname__}"
  316. cls_fname = inspect.getfile(obj_cls)
  317. _, cls_fline = inspect.getsourcelines(obj_cls)
  318. class_loc = f"{cls_fname}:{cls_fline}"
  319. return f"bound method '{obj.__name__}' at {fname}:{fline} of <{class_name} at {class_loc} object>"
  320. if isinstance(obj, types.FunctionType):
  321. return f"function '{obj.__name__}' at {fname}:{fline}"
  322. if isinstance(obj, ast.FunctionDef):
  323. return f"function '{obj.name}' at {fname}:{fline}"
  324. if isinstance(obj, ast.Attribute):
  325. return f"attribute "
  326. return str(obj)
  327. def expand_expr_statement(node):
  328. """
  329. Process the expr statement and expand it.
  330. Returns:
  331. tuple, (True, expr.value, x)/(False, None, None).
  332. """
  333. if isinstance(node, ast.Expr):
  334. expr_value = node.value
  335. if isinstance(expr_value, ast.Call):
  336. func = expr_value.func
  337. if isinstance(func, ast.Attribute) and \
  338. hasattr(func, "attr") and \
  339. hasattr(func, "value"):
  340. method = func.attr
  341. target = func.value
  342. if method in parse_expr_statement_white_list:
  343. logger.debug("Expand expr, target:%s, method:%s", target, method)
  344. return True, expr_value, target
  345. if not isinstance(expr_value, ast.Str):
  346. return True, expr_value
  347. return (False,)
  348. def get_ast_namespace_symbol(obj):
  349. """Get obj type and namespace and symbol."""
  350. # step 1:get symbol from object map
  351. ops_info = parse_object_map.get(type(obj), SYMBOL_UNDEFINE)
  352. logger.debug("ops info = %r", ops_info)
  353. return ops_info
  354. def get_operation_namespace_symbol(var: str):
  355. """Get operation namespace and symbol."""
  356. ops_info = (trope_ns, var)
  357. logger.debug("get operation ops info = %r", ops_info)
  358. return ops_info
  359. def get_ast_type(node):
  360. """Get the ast type."""
  361. ast_type = AST_SUB_TYPE_UNKNOWN
  362. if isinstance(node, ast.And):
  363. ast_type = AST_SUB_TYPE_AND
  364. elif isinstance(node, ast.Or):
  365. ast_type = AST_SUB_TYPE_OR
  366. elif isinstance(node, ast.Name):
  367. ast_type = AST_SUB_TYPE_NAME
  368. elif isinstance(node, ast.Tuple):
  369. ast_type = AST_SUB_TYPE_TUPLE
  370. elif isinstance(node, ast.Subscript):
  371. ast_type = AST_SUB_TYPE_SUBSCRIPT
  372. elif isinstance(node, ast.Starred):
  373. ast_type = AST_SUB_TYPE_STARRED
  374. elif isinstance(node, ast.Attribute):
  375. ast_type = AST_SUB_TYPE_ATTRIBUTE
  376. else:
  377. ast_type = AST_SUB_TYPE_UNKNOWN
  378. return ast_type
  379. def get_node_type(node):
  380. """Process an ast node."""
  381. method_name = f"{node.__class__.__name__}"
  382. node_type = [method_name]
  383. # judge the ast main type
  384. if isinstance(node, ast.stmt):
  385. node_type.append(AST_MAIN_TYPE_STMT)
  386. elif isinstance(node, (ast.expr, ast.slice)) or node is None:
  387. # ast.slice and ast.expr should be expr
  388. node_type.append(AST_MAIN_TYPE_EXPR)
  389. else:
  390. node_type.append(AST_MAIN_TYPE_UNKNOWN)
  391. return node_type
  392. def get_args_default_values(node):
  393. """get the args'default values of parse object."""
  394. nondefaults = [None] * (len(node.args.args) - len(node.args.defaults))
  395. defaults = nondefaults + node.args.defaults + node.args.kw_defaults
  396. if node.args.vararg:
  397. defaults.append(None)
  398. if node.args.kwarg:
  399. defaults.append(None)
  400. return defaults
  401. def get_args(node):
  402. """Get the arg of parse object."""
  403. args = []
  404. # process position args
  405. for arg in node.args.args:
  406. args.append(arg)
  407. # process kwonlyargs: kwonlyargs is append after position args
  408. if node.args.kwonlyargs:
  409. for kwarg in node.args.kwonlyargs:
  410. args.append(kwarg)
  411. # process vararg: vararg is append after kwonlyargs
  412. if node.args.vararg:
  413. args.append(node.args.vararg)
  414. # process kwarg: kwarg is append after vararg
  415. if node.args.kwarg:
  416. args.append(node.args.kwarg)
  417. return args
  418. def eval_script(exp_str, params):
  419. """Evaluate a python expression."""
  420. if not isinstance(params, tuple):
  421. raise ValueError(f"eval_script(), params is not a tuple, params: {params}")
  422. if len(params) != 2:
  423. raise ValueError(f"eval_script(), params tuple length is wrong, params: {params}")
  424. # Eval function parses the expression argument and evaluates it as a python expression.
  425. logger.debug(f"exp_str: '{exp_str}', params: '{params}'")
  426. global_params = params[0]
  427. local_params = params[1]
  428. try:
  429. obj = eval(exp_str, global_params, local_params)
  430. except Exception as e:
  431. error_info = f"When eval '{exp_str}' by using Fallback feature, an error occurred: " + str(e) + \
  432. ". You can try to turn off the Fallback feature by 'export DEV_ENV_ENABLE_FALLBACK=0'."
  433. logger.error(error_info)
  434. raise e
  435. # Check the result of eval.
  436. if obj is None:
  437. raise ValueError(f"When call 'eval', the result is none. exp_str: '{exp_str}'")
  438. # Convert set to tuple.
  439. if isinstance(obj, set):
  440. obj = tuple(obj)
  441. return obj
  442. class Parser:
  443. """
  444. Parser python code to ast tree.
  445. Args:
  446. fn(FunctionType/MethodType): Need parse object instance.
  447. parse_method(ExtendInfoOfParseObj): Extend information for parse the function.
  448. ast_cache: Dictionary for caching ast tree.
  449. """
  450. ast_cache = {}
  451. def __init__(self, fn: (types.FunctionType, types.MethodType), parse_method=None) -> None:
  452. self.fn = fn
  453. self.parse_method = parse_method
  454. self.line_offset = 0
  455. self.filename: str = inspect.getfile(inspect.unwrap(self.fn))
  456. # Used to resolve mindspore builtin ops namespace.
  457. self.ms_common_ns = CellNamespace('mindspore.common')
  458. self.ms_nn_ns = CellNamespace('mindspore.nn')
  459. self.ms_ops_ns = CellNamespace('mindspore.ops')
  460. self.ms_ops_c_ns = CellNamespace('mindspore.ops.composite')
  461. self.ms_ops_c_multitype_ns = CellNamespace('mindspore.ops.composite.multitype_ops')
  462. self.ms_ops_p_ns = CellNamespace('mindspore.ops.operations')
  463. # Used to resolve the function's globals namespace.
  464. self.global_namespace = CellNamespace(fn.__module__)
  465. self.function_module = fn.__module__
  466. # Used to resolve the function's nonlocals.
  467. self.closure_namespace = ClosureNamespace(inspect.unwrap(self.fn))
  468. self.function_name = fn.__name__
  469. self.col_offset = 0
  470. def parse(self):
  471. """Parse the function or method."""
  472. logger.debug("fn = %r", self.fn)
  473. if isinstance(self.fn, (types.FunctionType, types.MethodType)):
  474. try:
  475. lines, self.line_offset = inspect.getsourcelines(self.fn)
  476. except OSError as e:
  477. if e.__str__() == "could not get source code":
  478. raise OSError(f"Mindspore can not compile temporary source code in terminal. "
  479. f"Please write source code to a python file and run the file.")
  480. raise e
  481. original_src = ''.join(lines)
  482. hexstr = hashlib.sha256(original_src.encode()).hexdigest()
  483. ast_tokens_cache = Parser.ast_cache.get(hexstr)
  484. if not ast_tokens_cache:
  485. src = dedent(original_src)
  486. self.col_offset = \
  487. len(original_src.split('\n')[0]) - len(src.split('\n')[0])
  488. logger.debug("Get source = %s", src)
  489. try:
  490. ast_tokens = asttokens.ASTTokens(src, parse=True)
  491. except IndentationError as idt_err:
  492. idt_err.filename = self.filename
  493. idt_err.lineno = self.line_offset
  494. idt_err.msg = f"There are incorrect indentations in definition or comment of function: " \
  495. f"'{self.fn.__qualname__}'."
  496. raise idt_err
  497. ast_tokens_cache = (ast_tokens, self.col_offset)
  498. Parser.ast_cache[hexstr] = ast_tokens_cache
  499. else:
  500. self.col_offset = ast_tokens_cache[1]
  501. return ast_tokens_cache[0], ast_tokens_cache[0].tree
  502. logger.error("Fn type is invalid")
  503. return None, None
  504. def is_unsupported_namespace(self, value):
  505. unsupported = isinstance(value, _builtin_function_or_method_type) and value not in convert_object_map
  506. logger.debug(f"'{value}' unsupported: {unsupported}.")
  507. return unsupported
  508. def get_namespace_symbol(self, var: str):
  509. """Get symbol type and namespace and symbol."""
  510. if var in self.closure_namespace:
  511. logger.debug(f"Found '{var}' in closure_namespace {self.closure_namespace.__str__()}")
  512. return self.closure_namespace, var
  513. if var in self.global_namespace:
  514. logger.debug(f"Found '{var}' in global_namespace {self.global_namespace.__str__()}")
  515. value = self.global_namespace[var]
  516. if self.is_unsupported_namespace(value):
  517. error_info = f"The builtin function '{var}' of python is not supported in graph mode."
  518. return None, error_info
  519. return self.global_namespace, var
  520. error_info = f"The name '{var}' is not defined in function '{self.function_name}'."
  521. return None, error_info
  522. def is_unsupported_builtin_type(self, value_type):
  523. """To check if not supported builtin type"""
  524. unsupported_builtin_type = (list, tuple, set, dict, slice, bool, int, float, str, complex, reversed)
  525. is_unsupported = value_type in unsupported_builtin_type
  526. logger.debug(f"value_type: {value_type}, unsupported builtin type: {is_unsupported}.")
  527. return is_unsupported
  528. def is_supported_namespace_module(self, value):
  529. """To check if the module is allowed to support."""
  530. # Check `mindspore` namespace.
  531. if not hasattr(value, '__name__'):
  532. logger.debug(f"'{str(value)}' has no '__name__' attribute, we suppose it's supported.")
  533. return True
  534. name = value.__name__
  535. if name == 'mindspore':
  536. logger.debug(f"Found 'mindspore' root namespace.")
  537. return True
  538. if name == 'mindspore.ops':
  539. logger.debug(f"Found 'mindspore.ops' namespace.")
  540. return True
  541. if name == 'mindspore.nn':
  542. logger.debug(f"Found 'mindspore.nn' namespace.")
  543. return True
  544. if name == 'mindspore.numpy':
  545. logger.debug(f"Found 'mindspore.numpy' namespace.")
  546. return True
  547. # Check `Tensor` namespace.
  548. if value == Tensor:
  549. logger.debug(f"Not support '{name}'.")
  550. return False
  551. # Check `builtins` namespace.
  552. if hasattr(value, '__module__'): # Not types.ModuleType
  553. mod = value.__module__
  554. if mod == 'builtins':
  555. logger.debug(f"Found '{name}' in 'builtins' namespace.")
  556. return True
  557. # We suppose it's supported if not a Module.
  558. if not isinstance(value, types.ModuleType):
  559. logger.debug(f"Found '{name}', not a module.")
  560. return True
  561. # Check supported Module namespace.
  562. rightmost_name = name.split('.')[-1]
  563. if rightmost_name in self.ms_ops_ns:
  564. logger.debug(f"Found '{name}'({rightmost_name}) in ops namespace: {str(self.ms_ops_ns)}.")
  565. return True
  566. if rightmost_name in self.ms_ops_c_ns:
  567. logger.debug(f"Found '{name}'({rightmost_name}) in C namespace: {str(self.ms_ops_c_ns)}.")
  568. return True
  569. if rightmost_name in self.ms_ops_c_multitype_ns:
  570. logger.debug(
  571. f"Found '{name}'({rightmost_name}) in C.multitype namespace: {str(self.ms_ops_c_multitype_ns)}.")
  572. return True
  573. if rightmost_name in self.ms_ops_p_ns:
  574. logger.debug(f"Found '{name}'({rightmost_name}) in P namespace: {str(self.ms_ops_p_ns)}.")
  575. return True
  576. if rightmost_name in self.ms_common_ns:
  577. logger.debug(f"Found '{name}'({rightmost_name}) in common namespace: {str(self.ms_common_ns)}.")
  578. return True
  579. # Support nn.layer. To check if exclude other module.
  580. if rightmost_name in self.ms_nn_ns:
  581. logger.debug(f"Found '{name}'({rightmost_name}) in nn namespace: {str(self.ms_nn_ns)}.")
  582. return True
  583. if rightmost_name in trope_ns:
  584. logger.debug(f"Found '{name}'({rightmost_name}) in trope namespace: {str(trope_ns)}.")
  585. return True
  586. logger.info(f"Not found '{name}' in mindspore supported namespace.")
  587. return False
  588. def get_builtin_namespace_symbol(self, var: str):
  589. """Get mindspore builtin namespace and symbol."""
  590. if var in self.closure_namespace:
  591. logger.debug(f"Found '{var}' in closure_namespace {self.closure_namespace.__str__()}.")
  592. return self.closure_namespace, var
  593. if var in self.global_namespace:
  594. logger.debug(f"Found '{var}' in global_namespace {self.global_namespace.__str__()}.")
  595. value = self.global_namespace[var]
  596. value_str = value.__name__ if hasattr(value, '__name__') else str(value)
  597. logger.debug(f"value: {type(value)}, '{value_str}', hasattr(__name__): {hasattr(value, '__name__')}.")
  598. # To check if allowed to support.
  599. if self.is_unsupported_namespace(value):
  600. return self.global_namespace, var, value
  601. if self.is_unsupported_builtin_type(value):
  602. return self.global_namespace, var, value
  603. if not self.is_supported_namespace_module(value): # Check if support including instance of types.ModuleType
  604. return self.global_namespace, var, value
  605. supported = True
  606. return self.global_namespace, var, value, supported
  607. error_info = f"The name '{var}' is not defined, or not supported in graph mode."
  608. logger.debug(f"error_info: {error_info}")
  609. return None, error_info
  610. def analyze_super(self, class_type_node, subclass_instance):
  611. """Analyze super and return a class instance."""
  612. sub_class = type(subclass_instance)
  613. if class_type_node is None:
  614. return super(sub_class, subclass_instance)
  615. if isinstance(class_type_node, ast.Name):
  616. class_name = getattr(class_type_node, 'id')
  617. elif isinstance(class_type_node, ast.Attribute):
  618. class_name = getattr(class_type_node, 'attr')
  619. else:
  620. raise ValueError(f"The first argument of 'super()' must be a class type, "
  621. f"but got {class_type_node.__class__.__name__}.")
  622. target_father_class = None
  623. for class_element in sub_class.mro():
  624. if class_element.__name__ == class_name:
  625. target_father_class = class_element
  626. break
  627. if target_father_class is None:
  628. raise ValueError(f"The second argument of 'super()' must be 'self', "
  629. f"but got {subclass_instance}.")
  630. return super(target_father_class, subclass_instance)
  631. def get_location(self, node):
  632. """
  633. Get location of node start and end line no.
  634. Args:
  635. node: AST op node or tuple or List. This is a node in the ANF diagram,
  636. here is the code location to get this node.
  637. Returns:
  638. List, [fileName, linestart, colstart, lineend, colend].
  639. """
  640. ret = [self.filename]
  641. err_exit = 0
  642. if isinstance(node, (list, tuple)):
  643. node_size = len(node)
  644. if node_size == 0:
  645. err_exit = 1
  646. else:
  647. start_node = node[0]
  648. end_node = node[-1]
  649. else:
  650. start_node = node
  651. end_node = node
  652. if err_exit == 0:
  653. if hasattr(start_node, "lineno") and \
  654. hasattr(end_node, "col_offset"):
  655. start_lineno, start_colno = start_node.first_token.start
  656. end_lineno, end_colno = end_node.last_token.end
  657. start_lineno += self.line_offset - 1
  658. start_colno += self.col_offset
  659. end_lineno += self.line_offset - 1
  660. end_colno += self.col_offset
  661. ret = ret + [start_lineno, start_colno, end_lineno, end_colno]
  662. else:
  663. ret = ret + [0, 0, 0, 0]
  664. return ret