diff --git a/mindinsight/mindconverter/ast_edits.py b/mindinsight/mindconverter/ast_edits.py index dc5a3913..74d08466 100644 --- a/mindinsight/mindconverter/ast_edits.py +++ b/mindinsight/mindconverter/ast_edits.py @@ -28,13 +28,13 @@ from mindinsight.mindconverter.config import ALL_MAPPING from mindinsight.mindconverter.config import NN_LIST from mindinsight.mindconverter.config import ALL_TORCH_APIS from mindinsight.mindconverter.config import ALL_2P_LIST -from mindinsight.mindconverter.config import UNSUPPORTED_WARN_INFOS -from mindinsight.mindconverter.config import ALL_UNSUPPORTED +from mindinsight.mindconverter.config import get_prompt_info from mindinsight.mindconverter.common.log import logger from mindinsight.mindconverter.common.exceptions import NodeTypeNotSupport from mindinsight.mindconverter.forward_call import ForwardCall LOG_FMT_CONVERT = "[Convert] '%s' is converted to '%s'." +LOG_FMT_CONVERT_WITH_TIPS = "[Convert] '%s' is converted to '%s'. %s" LOG_FMT_NOT_CONVERT = "[UnConvert] '%s' didn't convert. %s" LOG_FMT_PROMPT_INFO = "[INFO] %s" LOG_SUGGESTION_MANUAL_CONVERT = "Please manual convert the code, along with the code associated with it." @@ -95,6 +95,7 @@ class _LineColEditVisitor(ast.NodeVisitor): class _NodeInfo: """NodeInfo class definition.""" + def __init__(self, node): self.node = node self.call_list = [] # Used to save all ast.Call node in self._node @@ -444,19 +445,25 @@ class AstEditVisitor(ast.NodeVisitor): is_include_call = False return is_include_call - def match_api(self, call_func_node, is_forward): + def match_api(self, call_func_node, is_forward, check_context=True): """ Check api name to convert, check api name ok with a is_forward condition. Args: call_func_node (ast.Attribute): The call.func node. is_forward (bool): whether api belong to forward. + check_context (boolean): If True, the code context will be checked. Default is True. Returns: str, the standard api name used to match. ApiMappingEnum, the match result. """ - api_name, match_case = self._infer_api_name(call_func_node) + match_case = ApiMatchingEnum.NOT_API + api_call_name = pasta.dump(call_func_node) + if api_call_name.startswith('self.'): + return api_call_name, match_case + + api_name, match_case = self._infer_api_name(call_func_node, check_context) api_call_name = pasta.dump(call_func_node) is_tensor_obj_call = False if api_name != api_call_name: @@ -466,15 +473,17 @@ class AstEditVisitor(ast.NodeVisitor): # rewritten external module name # e.g., mm.ReLU will be written to nn.ReLU if 'import torch.nn as mm' in script. - if not is_tensor_obj_call and not self._code_analyzer.is_standard_external_ref: - standard_api_call_name = self._mapping_standard_api_name(api_name) + if not is_tensor_obj_call: + standard_api_call_name = self._get_api_whole_name(call_func_node, check_context) if standard_api_call_name in ALL_TORCH_APIS: match_case = ApiMatchingEnum.API_FOUND if (not is_forward and standard_api_call_name in NN_LIST) or \ (is_forward and standard_api_call_name in ALL_2P_LIST): match_case = ApiMatchingEnum.API_MATCHED - + else: + if standard_api_call_name and standard_api_call_name.startswith('torch.nn.init'): + match_case = ApiMatchingEnum.API_MATCHED return standard_api_call_name, match_case @staticmethod @@ -502,6 +511,25 @@ class AstEditVisitor(ast.NodeVisitor): parameters_str = call_str[left_parenthesis_pos + 1:right_parenthesis_pos] return parameters_str + def _get_api_whole_name(self, call_func_node, check_context=True): + """ + Get the whole name for the call node. + + Args: + call_func_node (AST): The func attribute of ast.Call. + check_context (boolean): If True, the code context will be checked. Default is True. + + Returns: + str, the whole name. + """ + api_name, match_case = self._infer_api_name(call_func_node, check_context) + if match_case == ApiMatchingEnum.API_STANDARD: + api_name_splits = api_name.split('.') + api_name_splits[0] = self._get_external_ref_whole_name(api_name_splits[0]) + if api_name_splits[0]: + api_name = '.'.join(api_name_splits) + return api_name + def mapping_api(self, call_node, check_context=True): """ Convert api_name in code to MindSpore api, if api_name is a python api, code will not convert. @@ -522,6 +550,26 @@ class AstEditVisitor(ast.NodeVisitor): if api_call_name.startswith('self.'): return code + new_code = self._mapping_api(call_node, check_context) + + return new_code + + def _mapping_api(self, call_node, check_context=True): + """ + Convert api_name in code to MindSpore api, if api_name is a python api, code will not convert. + + If do not check context of the script, the code represented by the node must be written in the standard way. + + Args: + call_node (ast.Call): The ast node to convert. + check_context (boolean): If True, the code context will be checked. Default is True. + + Returns: + str, the converted code. + """ + code = pasta.dump(call_node) + api_call_name = pasta.dump(call_node.func) + # find full api expected to be converted. eg:expr="nn.Conv2d(1,2,3)" args_str="(1,2,3)" args_str = '(' + self._get_call_parameters_str(call_node) + ')' @@ -551,26 +599,37 @@ class AstEditVisitor(ast.NodeVisitor): code = pasta.dump(node) api_name = pasta.dump(node.func) - # parent node first call is equal to this node, skip when parent node is replaced. - for parent_node in self._stack[:-1]: + # The parent node first call is equal to this node, skip when parent node is replaced. + # This scenario occurs, for example, when out.view(out.size(0), -1) is first converted to + # P.Reshape()(out, (out.size(0). -1)), will skip P.Reshape() in following visiting. + # Access from the penultimate element in reverse order. + for parent_node in self._stack[-2::-1]: if parent_node in self._new_call_nodes and pasta.dump(parent_node).startswith(api_name): return parent = self._stack[-2] new_node = None + new_code = code matched_api_name, match_case = self.match_api(node.func, self._is_forward_function) if match_case in [ApiMatchingEnum.API_INFER, ApiMatchingEnum.API_MATCHED]: + warning_info = get_prompt_info(matched_api_name) + if warning_info is None: + warning_info = '' if matched_api_name in ALL_MAPPING: logger.info("Line %3d start converting API: %s", node.lineno, api_name) new_code = self.mapping_api(node) if new_code != code: - new_node = pasta.parse(new_code).body[0].value - # find the first call name - new_api_name = new_code[:new_code.find('(')] - self._process_log.info(node.lineno, node.col_offset, LOG_FMT_CONVERT % (api_name, new_api_name)) - if matched_api_name in ALL_UNSUPPORTED: - warn_info = UNSUPPORTED_WARN_INFOS.get(api_name, '') - logger.warning("Line %3d: found unsupported API: %s%s", node.lineno, api_name, warn_info) - self._process_log.warning(node.lineno, node.col_offset, LOG_FMT_NOT_CONVERT % (api_name, warn_info)) + try: + new_node = pasta.parse(new_code).body[0].value + # find the first call name + new_api_name = new_code[:new_code.find('(')] + except AttributeError: + new_node = pasta.parse(new_code).body[0] + new_api_name = new_code + self._process_log.info(node.lineno, node.col_offset, + LOG_FMT_CONVERT_WITH_TIPS % (api_name, new_api_name, warning_info)) + else: + logger.warning("Line %3d: found unsupported API: %s%s", node.lineno, api_name, warning_info) + self._process_log.warning(node.lineno, node.col_offset, LOG_FMT_NOT_CONVERT % (api_name, warning_info)) elif match_case in [ApiMatchingEnum.API_STANDARD, ApiMatchingEnum.API_FOUND]: self._process_log.warning(node.lineno, node.col_offset, LOG_FMT_NOT_CONVERT % (api_name, '')) @@ -602,3 +661,57 @@ class AstEditVisitor(ast.NodeVisitor): elif ref_name != 'F' and external_ref_info.name == 'torch.nn.functional': renames[ref_name] = 'F' return renames + + def _get_external_ref_whole_name(self, ref_name): + """ + Find out external reference whole name. + + For example: + In the parsed source code, there is import statement + import torch.nn as new_name + _get_external_ref_whole_name('new_name') will return 'torch.nn' string. + """ + external_refs = self._code_analyzer.external_references + for external_ref_name, ref_info in external_refs.items(): + external_ref_info = ref_info['external_ref_info'] + if external_ref_name == ref_name: + return external_ref_info.name + return None + + def _check_isinstance_parameter(self, node): + """Check whether the second parameter of isinstance function contains the torch type.""" + is_isinstance_arg = False + # Check whether node is the second parameter of the isinstance function call. + # Access from the penultimate element in reverse order. + for parent_node in self._stack[-2::-1]: + if isinstance(parent_node, ast.Call) and pasta.dump(parent_node.func) == 'isinstance': + isinstance_node = parent_node + seconde_arg_type_nodes = [] + if isinstance(isinstance_node.args[1], ast.Tuple): + seconde_arg_type_nodes.extend(isinstance_node.args[1].elts) + else: + seconde_arg_type_nodes.append(isinstance_node.args[1]) + if node in seconde_arg_type_nodes: + is_isinstance_arg = True + break + if not is_isinstance_arg: + return False + + isinstance_type_arg = pasta.dump(node) + check_torch_type = False + if isinstance_type_arg: + type_splits = isinstance_type_arg.split('.') + whole_name = self._get_external_ref_whole_name(type_splits[0]) + if whole_name and whole_name.startswith('torch'): + check_torch_type = True + if check_torch_type: + _, match_case = self.match_api(node, False) + if match_case != ApiMatchingEnum.NOT_API: + warn_info = 'Manually determine the conversion type.' + self._process_log.warning(node.lineno, node.col_offset, + LOG_FMT_NOT_CONVERT % (isinstance_type_arg, warn_info)) + return check_torch_type + + def visit_Attribute(self, node): + """Callback function when visit AST tree""" + self._check_isinstance_parameter(node) diff --git a/mindinsight/mindconverter/common/exceptions.py b/mindinsight/mindconverter/common/exceptions.py index 494b2dce..69407961 100644 --- a/mindinsight/mindconverter/common/exceptions.py +++ b/mindinsight/mindconverter/common/exceptions.py @@ -24,6 +24,7 @@ class ConverterErrors(ScriptConverterErrors): """Converter error codes.""" SCRIPT_NOT_SUPPORT = 1 NODE_TYPE_NOT_SUPPORT = 2 + CODE_SYNTAX_ERROR = 3 class ScriptNotSupport(MindInsightException): @@ -42,3 +43,12 @@ class NodeTypeNotSupport(MindInsightException): super(NodeTypeNotSupport, self).__init__(ConverterErrors.NODE_TYPE_NOT_SUPPORT, msg, http_code=400) + + +class CodeSyntaxError(MindInsightException): + """The CodeSyntaxError class definition.""" + + def __init__(self, msg): + super(CodeSyntaxError, self).__init__(ConverterErrors.CODE_SYNTAX_ERROR, + msg, + http_code=400) diff --git a/mindinsight/mindconverter/config.py b/mindinsight/mindconverter/config.py index 98c0a102..807561b4 100644 --- a/mindinsight/mindconverter/config.py +++ b/mindinsight/mindconverter/config.py @@ -22,7 +22,7 @@ import os import pasta from mindinsight.mindconverter.common.log import logger - +from mindinsight.mindconverter.common.exceptions import CodeSyntaxError REQUIRED = 'REQUIRED' UNREQUIRED = 'UNREQUIRED' @@ -31,6 +31,7 @@ FUNC_MODULE = 'mindinsight.mindconverter.funcs' class APIPt: """Base API for args parse, and API for one frame.""" + def __init__(self, name: str, params: OrderedDict): self.name = name self.params = OrderedDict() @@ -77,10 +78,8 @@ class APIPt: try: ast_node = ast.parse("whatever_call_name" + args_str) call_node = ast_node.body[0].value - if not isinstance(call_node, ast.Call): - raise ValueError('call name with args str [{}] not instance of ast.Call'.format(args_str)) - except: - raise ValueError("can't parse code:\n{}".format(args_str)) + except SyntaxError as parse_error: + raise CodeSyntaxError("can't parse code:\n{}".format(args_str)) from parse_error # regard all actual parameter as one parameter if len(self.params) == 1: @@ -118,6 +117,7 @@ class APIPt: class APIMs(APIPt): """API for MindSpore""" + def __init__(self, name: str, params: OrderedDict, p_attrs=None): self.is_primitive = name.startswith('P.') if self.is_primitive: @@ -167,6 +167,7 @@ class APIMs(APIPt): class MappingHelper: """Mapping from one frame to another frame""" + def __init__(self, ms_api: APIMs, pt_api: APIPt, **kwargs): ms2pt_mapping = kwargs.get('ms2pt_mapping') gen_explicit_map = kwargs.get('gen_explicit_map') @@ -392,7 +393,6 @@ TENSOR_DOT_MAPPING = get_mapping_from_file(TENSOR_DOT_MAPPING_PATH) ALL_MAPPING = {**NN_MAPPING, **F_MAPPING, **TORCH_DOT_MAPPING, **TENSOR_DOT_MAPPING} - # ---------------------------- api list support or not support ---------------------------- NN_LIST_PATH = os.path.realpath(os.path.join(os.path.dirname(__file__), 'ops', 'nn_list.json')) NN_LIST = load_json_file(NN_LIST_PATH) @@ -400,7 +400,6 @@ NN_LIST += ["torch." + name for name in NN_LIST] NN_SUPPORTED = [x for x in NN_LIST if x in ALL_MAPPING] NN_UNSUPPORTED = [x for x in NN_LIST if x not in ALL_MAPPING] - F_LIST_PATH = os.path.realpath(os.path.join(os.path.dirname(__file__), 'ops', 'f_list.json')) F_LIST = load_json_file(F_LIST_PATH) F_LIST += ["F." + name[len("torch.nn.functional."):] for name in F_LIST] + \ @@ -408,29 +407,23 @@ F_LIST += ["F." + name[len("torch.nn.functional."):] for name in F_LIST] + \ F_SUPPORTED = [x for x in F_LIST if x in ALL_MAPPING] F_UNSUPPORTED = [x for x in F_LIST if x not in ALL_MAPPING] - TORCH_DOT_LIST_PATH = os.path.realpath(os.path.join(os.path.dirname(__file__), 'ops', 'torch_dot_list.json')) TORCH_DOT_LIST = load_json_file(TORCH_DOT_LIST_PATH) - TORCH_DOT_SUPPORTED = [x for x in TORCH_DOT_LIST if x in ALL_MAPPING] TORCH_DOT_UNSUPPORTED = [x for x in TORCH_DOT_LIST if x not in ALL_MAPPING] - TENSOR_DOT_LIST_PATH = os.path.realpath(os.path.join(os.path.dirname(__file__), 'ops', 'tensor_dot_list.json')) TENSOR_DOT_LIST = load_json_file(TENSOR_DOT_LIST_PATH) - TENSOR_DOT_SUPPORTED = [x for x in TENSOR_DOT_LIST if x in ALL_MAPPING] TENSOR_DOT_UNSUPPORTED = [x for x in TENSOR_DOT_LIST if x not in ALL_MAPPING] - ALL_2P_LIST = F_LIST + TORCH_DOT_LIST + TENSOR_DOT_LIST ALL_TORCH_APIS = NN_LIST + F_LIST + TORCH_DOT_LIST + TENSOR_DOT_LIST ALL_SUPPORTED = NN_SUPPORTED + F_SUPPORTED + TORCH_DOT_SUPPORTED + TENSOR_DOT_SUPPORTED ALL_UNSUPPORTED = NN_UNSUPPORTED + F_UNSUPPORTED + TORCH_DOT_UNSUPPORTED + TENSOR_DOT_UNSUPPORTED - UNSUPPORTED_WARN_INFOS = { "nn.AdaptiveAvgPool2d": "Maybe could convert to mindspore.ops.operations.ReduceMean.", "nn.AvgPool1d": "Maybe could convert to mindspore.nn.AvgPool1d.",