|
- # This is the Python adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
- #
- # Copyright 2020 Huawei Technologies Co., Ltd
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # ============================================================================
- """The module of parser python object, called by c++."""
-
- import ast
- import types
- import inspect
- import hashlib
- from textwrap import dedent
- from dataclasses import is_dataclass
- import asttokens
- import mindspore.nn as nn
- from mindspore import log as logger
- from mindspore import Tensor as MsTensor
- from mindspore import ops
- from mindspore.common.dtype import pytype_to_dtype
- from mindspore.common.api import _MindSporeFunction
- from .namespace import CellNamespace, ClosureNamespace, ClassMemberNamespace
- from .resources import parse_object_map, convert_object_map, trope_ns, SYMBOL_UNDEFINE, NO_IMPLEMENT
-
- # define return value
- RET_SUCCESS = 0
- RET_FAILURE = 0xFF
-
- # define resolve type
- RESOLVE_TYPE_NONE = 0 # resolve None
- RESOLVE_TYPE_FUNCTION = 1 # resolve function
- RESOLVE_TYPE_METHOD = 2 # resolve class method
- RESOLVE_TYPE_CLASS_TYPE = 3 # resolve class type
- RESOLVE_TYPE_CLASS_INSTANCE = 4 # resolve the class instance of common class
- RESOLVE_TYPE_INVALID = 0xFF
-
- # define the class instance detail type
- # When the type is RESOLVE_TYPE_CLASS_INSTANCE
- CLASS_INSTANCE_TYPE_CELL = 0 # class instance type is Cell
- CLASS_INSTANCE_TYPE_PRIMITIVE = 1 # class instance type is Primitive
- CLASS_INSTANCE_TYPE_INVALID = 0xFF
-
- # Ast main type
- AST_MAIN_TYPE_STMT = 0 # ast.Stmt
- AST_MAIN_TYPE_EXPR = 1 # ast.Expr
- AST_MAIN_TYPE_SLICE = 2 # ast.Slice
- AST_MAIN_TYPE_UNKNOWN = 0xFF # unknown
-
- # Ast sub type
- AST_SUB_TYPE_AND = 3 # ast.And
- AST_SUB_TYPE_OR = 4 # ast.Or
- AST_SUB_TYPE_NAME = 5 # ast.Name
- AST_SUB_TYPE_TUPLE = 6 # ast.Tuple
- AST_SUB_TYPE_SUBSCRIPT = 7 # ast.Subscript
- AST_SUB_TYPE_STARRED = 8 # ast.Starred
- AST_SUB_TYPE_UNKNOWN = 0xFF # unknown
-
- # Process expr statement white list
- # add as needed, eg: "clear", "extend", "insert", "remove", "reverse"
- parse_expr_statement_white_list = (
- "append",
- )
-
-
- def create_slice_obj(start, end, step):
- """Create slice object"""
- return slice(start, end, step)
-
-
- def parse_cb(func, parse_method=None):
- """Implements the function of parse."""
- return Parser(func, parse_method)
-
-
- def get_parse_method_of_class(obj, parse_method=None):
- """
- Het parse method of class.
-
- Args:
- obj(Object): Instance of class.
- parse_method(str): Save the method name. Cell object has default method named 'construct'.
-
- Returns:
- Function, obj's method.
- """
- method = None
- method_name = None
- if parse_method is not None:
- method_name = parse_method
- else:
- if isinstance(obj, nn.Cell):
- if obj.enable_hook:
- method_name = "_hook_construct"
- else:
- method_name = "construct"
- if method_name is not None:
- if hasattr(obj, method_name):
- method = getattr(obj, method_name)
- return method
-
-
- def get_bprop_method_of_class(obj, parse_method=None):
- """
- Get bprop method of class.
-
- Args:
- obj (Object): Instance of class.
- parse_method(str): Save the method name. Cell object has default method named 'bprop'.
-
- Returns:
- Function, obj's method.
- """
- method = None
- if isinstance(obj, nn.Cell):
- method_name = "bprop"
- if hasattr(obj, method_name):
- method = getattr(obj, method_name)
- return method
-
-
- def resolve_symbol(namespace, symbol):
- """
- Resolve a symbol.
-
- Note:
- Can't get function when use closure function. So save the fn on namespace.
-
- Args:
- namespace (Object): Symbol's namespace.
- symbol (str): Need resolve symbol.
-
- Returns:
- Object, resolve result of symbol.
- """
- # All exceptions need to be caught in this function
- try:
- resolve_ = namespace[symbol]
-
- # list and dict is not hashable ,it can not be key for the map, just return the result
- if isinstance(resolve_, (list, dict)):
- return resolve_
-
- # dataclass may not be hashable
- if getattr(resolve_, "__hash__") is None:
- return resolve_
-
- # If need trope the obj
- if resolve_ in convert_object_map:
- resolve_ = convert_object_map.get(resolve_)
- logger.debug("convert resolve = %r", resolve_)
- if resolve_ == NO_IMPLEMENT:
- raise NotImplementedError("not implemented for ", str(symbol))
- except Exception as e:
- if isinstance(e, NotImplementedError):
- raise e
- resolve_ = None
- logger.debug("resolve exception occurred, value = %r", e)
- logger.debug("resolve type is invalid, namespace = %s, symbol = %s",
- namespace.__str__(), symbol)
- if isinstance(resolve_, _MindSporeFunction):
- logger.debug("resolve class _MindSporeFunction, resolve fn instead.")
- resolve_ = resolve_.fn
- return resolve_
-
-
- def generate_scope(obj):
- """Generate the scope for every cell object in the network."""
- if isinstance(obj, nn.Cell):
- obj.generate_scope()
-
-
- def get_scope_name(obj):
- """Returns the scope of a cell object in one network."""
- if isinstance(obj, nn.Cell):
- return obj.get_scope()
- return None
-
-
- def get_object_key(obj):
- """Return the function key: module + name."""
- obj_key = ""
- if hasattr(obj, "__name__"):
- if hasattr(obj, "cell_init_args"):
- obj_key = "%s_ID" % (str(obj.__class__.__name__) + str(obj.__name__) + obj.cell_init_args)
- obj_id = "%s_ID%d" % (str(obj.__class__.__name__) + str(obj.__name__), id(obj))
- else:
- if hasattr(obj, "cell_init_args"):
- obj_key = "%s_ID" % (str(obj.__class__.__name__) + obj.cell_init_args)
- obj_id = "%s_ID%d" % (str(obj.__class__.__name__), id(obj))
- logger.debug("obj_key %s obj_id = %s", obj_key, obj_id)
-
- # method has same id of different instance
- if isinstance(obj, types.MethodType):
- method_instance = obj.__self__
- instance_id = "%s_ID%d" % (str(method_instance.__class__.__name__), id(method_instance))
- obj_id = instance_id + obj_id + str(obj.__hash__())
- return obj_id, obj_key
-
-
- def is_class_member(node):
- """Check the attr is class member variable."""
- type_ = node.__class__.__name__
- if type_ == "Attribute":
- if not hasattr(node.value, "id"):
- return False
- id_ = node.value.id
- if id_ == "self":
- return True
- return False
-
-
- def get_obj_id(obj):
- """Get the obj id."""
- return str(id(obj))
-
-
- def get_obj_type(obj):
- """Get the obj type."""
- obj_type = RESOLVE_TYPE_INVALID
- if obj is None:
- obj_type = RESOLVE_TYPE_NONE
- elif isinstance(obj, types.FunctionType):
- obj_type = RESOLVE_TYPE_FUNCTION
- elif isinstance(obj, types.MethodType):
- obj_type = RESOLVE_TYPE_METHOD
- elif isinstance(obj, type):
- obj_type = RESOLVE_TYPE_CLASS_TYPE
- elif _is_class_instance(obj):
- obj_type = RESOLVE_TYPE_CLASS_INSTANCE
- else:
- # here for ndarray, just print its shape (in case of the array to large and print many data in screen)
- is_ndarray = type(obj).__name__ == 'ndarray' and hasattr(obj, 'shape')
- raise TypeError(f'Invalid object with type `{type(obj)}` and {"shape" if is_ndarray else "value"} '
- f'`{obj.shape if is_ndarray else obj}`.')
- return obj_type
-
-
- def get_class_instance_type(obj):
- """Get the class instance detail type."""
- # check the obj type
- logger.debug("Get the class type(%r)", obj)
- class_type = CLASS_INSTANCE_TYPE_INVALID
- if _is_class_instance(obj):
- if isinstance(obj, nn.Cell):
- class_type = CLASS_INSTANCE_TYPE_CELL
- elif isinstance(obj, ops.Primitive):
- class_type = CLASS_INSTANCE_TYPE_PRIMITIVE
- # Add the other type base requirement
- return class_type
-
-
- def _is_class_instance(obj):
- """Confirm the obj is class instance."""
- return isinstance(obj, (nn.Cell, ops.Primitive)) or _is_dataclass_instance(obj)
-
-
- def _is_dataclass_instance(obj):
- """check whether a class is an instance of a dataclass (and not a dataclass itself)"""
- return is_dataclass(obj) and not isinstance(obj, type)
-
-
- def create_obj_instance(cls_type, args_tuple=None):
- """Create python instance."""
- obj = None
- if isinstance(cls_type, type):
- # check the type, now only support nn.Cell and Primitive
- if issubclass(cls_type, (nn.Cell, ops.Primitive)):
- if args_tuple is not None:
- obj = cls_type(*args_tuple)
- else:
- obj = cls_type()
- return obj
-
-
- def get_module_namespace(obj):
- """Get the module's namespace."""
- logger.debug("get module namespace, module = %r", obj)
- mod_namespace = None
- if isinstance(obj, types.ModuleType):
- mod_namespace = CellNamespace(obj.__name__)
- else:
- logger.warning("Module(%r) is invalid, get namespace failure!", obj)
- return mod_namespace
-
-
- def get_class_member_namespace_symbol(obj):
- """Get obj class member type."""
- logger.debug("get class instance namespace, object = %r", obj)
- class_namespace = ClassMemberNamespace(obj)
- logger.debug("class namesapce = %r", class_namespace)
- return class_namespace
-
-
- def get_dataclass_attributes(cls):
- """Get attributes of dataclass."""
- fields = cls.__dataclass_fields__
- attributes = {name: pytype_to_dtype(field.type)
- for name, field in fields.items()}
- return attributes
-
-
- def get_dataclass_methods(cls):
- """Get functions of dataclass."""
- methods = {name: getattr(cls, name)
- for name in dir(cls)
- if isinstance(getattr(cls, name), (types.FunctionType,))}
- return methods
-
-
- def convert_to_ms_tensor(data):
- """Convert C++ tensor to mindspore tensor."""
- return MsTensor(data)
-
-
- def get_object_description(obj, fname, fline):
- """return method or funcition description for error report, include location, class name, etc."""
- if isinstance(obj, types.MethodType):
- obj_cls = obj.__self__.__class__
- class_name = f'{obj_cls.__module__}.{obj_cls.__qualname__}'
- cls_fname = inspect.getfile(obj_cls)
- _, cls_fline = inspect.getsourcelines(obj_cls)
- class_loc = f'{cls_fname}:{cls_fline}'
- return f"bound method '{obj.__name__}' at {fname}:{fline} of <{class_name} at {class_loc} object>"
- if isinstance(obj, (types.FunctionType, ast.FunctionDef)):
- return f"function '{obj.name}' at {fname}:{fline}"
- return str(obj)
-
-
- class Parser:
- """
- Parser python code to ast tree.
-
- Args:
- fn(FunctionType/MethodType): Need parse object instance.
- parse_method(ExtendInfoOfParseObj): Extend information for parse the function.
- ast_cache: Dictionary for caching ast tree.
- """
- ast_cache = {}
-
- def __init__(self, fn: (types.FunctionType, types.MethodType), parse_method=None) -> None:
- self.fn = fn
- self.parse_method = parse_method
- self.line_offset = 0
- self.filename: str = inspect.getfile(self.fn)
-
- # Used to resolve the function's globals Namespace.
- self.global_namespace = CellNamespace(fn.__module__)
- self.function_module = fn.__module__
- # Used to resolve the function's nonlocals.
- self.closure_namespace = ClosureNamespace(fn)
- self.function_name = fn.__name__
- self.col_offset = 0
-
- def parse(self):
- """Parse the function or method."""
- logger.debug("fn = %r", self.fn)
- tree = None
- if isinstance(self.fn, (types.FunctionType, types.MethodType)):
- lines, self.line_offset = inspect.getsourcelines(self.fn)
- original_src = ''.join(lines)
- hexstr = hashlib.sha256(original_src.encode()).hexdigest()
- tree = Parser.ast_cache.get(hexstr)
- if not tree:
- src = dedent(original_src)
- self.col_offset = \
- len(original_src.split('\n')[0]) - len(src.split('\n')[0])
- logger.debug("get source = %s", src)
- tree = asttokens.ASTTokens(src, parse=True).tree
- Parser.ast_cache[hexstr] = tree
- else:
- logger.error("Fn type is invalid")
- return tree
-
- def get_args(self, node):
- """Get the arg of parse object."""
- args = []
- # process position args
- for arg in node.args.args:
- args.append(arg)
-
- # process kwonlyargs: kwonlyargs is append after position args
- if node.args.kwonlyargs:
- for kwarg in node.args.kwonlyargs:
- args.append(kwarg)
- # process vararg: vararg is append after kwonlyargs
- if node.args.vararg:
- args.append(node.args.vararg)
- # process kwarg: kwarg is append after vararg
- if node.args.kwarg:
- args.append(node.args.kwarg)
- return args
-
- def get_args_default_values(self, node):
- """get the args'default values of parse object."""
- nondefaults = [None] * (len(node.args.args) - len(node.args.defaults))
- defaults = nondefaults + node.args.defaults + node.args.kw_defaults
- if node.args.vararg:
- defaults.append(None)
- if node.args.kwarg:
- defaults.append(None)
- return defaults
-
- def get_node_type(self, node):
- """Process an ast node."""
- method_name = f'{node.__class__.__name__}'
- node_type = [method_name]
- # judge the ast main type
- if isinstance(node, ast.stmt):
- node_type.append(AST_MAIN_TYPE_STMT)
- elif isinstance(node, (ast.expr, ast.slice)) or node is None:
- # ast.slice and ast.expr should be expr
- node_type.append(AST_MAIN_TYPE_EXPR)
- else:
- node_type.append(AST_MAIN_TYPE_UNKNOWN)
- return node_type
-
- def get_ast_type(self, node):
- """Get the ast type."""
- ast_type = AST_SUB_TYPE_UNKNOWN
- if isinstance(node, ast.And):
- ast_type = AST_SUB_TYPE_AND
- elif isinstance(node, ast.Or):
- ast_type = AST_SUB_TYPE_OR
- elif isinstance(node, ast.Name):
- ast_type = AST_SUB_TYPE_NAME
- elif isinstance(node, ast.Tuple):
- ast_type = AST_SUB_TYPE_TUPLE
- elif isinstance(node, ast.Subscript):
- ast_type = AST_SUB_TYPE_SUBSCRIPT
- elif isinstance(node, ast.Starred):
- ast_type = AST_SUB_TYPE_STARRED
- else:
- ast_type = AST_SUB_TYPE_UNKNOWN
- return ast_type
-
- def get_namespace_symbol(self, var: str):
- """Get symbol type and namespace and symbol."""
- if var in self.closure_namespace:
- ops_info = (self.closure_namespace, var)
- logger.debug("in closure_namespace")
- elif var in self.global_namespace:
- ops_info = (self.global_namespace, var)
- logger.debug("in global_namespace")
- else:
- ops_info = parse_object_map.get(SYMBOL_UNDEFINE)
- ops_info = [ops_info[0], var]
- return ops_info
-
- def get_operation_namespace_symbol(self, var: str):
- """Get operation namespace and symbol."""
- ops_info = (trope_ns, var)
- logger.debug("get operation ops info = %r", ops_info)
- return ops_info
-
- def get_ast_namespace_symbol(self, obj):
- """Get obj type and namespace and symbol."""
- # step 1:get symbol from object map
- ops_info = parse_object_map.get(type(obj), SYMBOL_UNDEFINE)
- logger.debug("ops info = %r", ops_info)
- return ops_info
-
- def analyze_super(self, class_type_node, subclass_instance):
- """Analyze super and return a class instance."""
- sub_class = type(subclass_instance)
- if class_type_node is None:
- return super(sub_class, subclass_instance)
- if isinstance(class_type_node, ast.Name):
- class_name = getattr(class_type_node, 'id')
- elif isinstance(class_type_node, ast.Attribute):
- class_name = getattr(class_type_node, 'attr')
- else:
- raise ValueError(f"When call 'super', the first arg should be a class type, "
- f"but got {class_type_node.__class__.__name__}.")
-
- target_father_class = None
- for class_element in sub_class.mro():
- if class_element.__name__ == class_name:
- target_father_class = class_element
- break
- if target_father_class is None:
- raise ValueError("When call 'super', the second arg should be an instance of first arg.")
- return super(target_father_class, subclass_instance)
-
- def get_location(self, node):
- """
- Get location of node start and end line no.
-
- Args:
- node: AST op node or tuple or List. This is a node in the ANF diagram,
- here is the code location to get this node.
-
- Returns:
- List, [fileName, linestart, colstart, lineend, colend].
- """
- ret = [self.filename]
- err_exit = 0
- if isinstance(node, (list, tuple)):
- node_size = len(node)
- if node_size == 0:
- err_exit = 1
- else:
- start_node = node[0]
- end_node = node[-1]
- else:
- start_node = node
- end_node = node
-
- if err_exit == 0:
- if hasattr(start_node, "lineno") and \
- hasattr(end_node, "col_offset"):
- start_lineno, start_colno = start_node.first_token.start
- end_lineno, end_colno = end_node.last_token.end
- start_lineno += self.line_offset - 1
- start_colno += self.col_offset
- end_lineno += self.line_offset - 1
- end_colno += self.col_offset
- ret = ret + [start_lineno, start_colno, end_lineno, end_colno]
- else:
- ret = ret + [0, 0, 0, 0]
- return ret
-
- def expand_expr_statement(self, node):
- """
- Process the expr statement and expand it.
-
- Returns:
- tuple, (True, expr.value, x)/(False, None, None).
- """
- if isinstance(node, ast.Expr) and hasattr(node, "value"):
- expr_value = node.value
- if isinstance(expr_value, ast.Call):
- func = expr_value.func
- if isinstance(func, ast.Attribute) and \
- hasattr(func, "attr") and \
- hasattr(func, "value"):
- method = func.attr
- target = func.value
- if method in parse_expr_statement_white_list:
- logger.debug("Expand expr, target:%s, method:%s", target, method)
- return True, expr_value, target
- return True, expr_value
- return False, None, None
|