From dde197ceab43db73b8592b050e7b36a7055c457f Mon Sep 17 00:00:00 2001 From: ggpolar Date: Fri, 19 Jun 2020 14:18:15 +0800 Subject: [PATCH] Modify the prompt message and parse more statements. 1. More detailed reports are added to the conversion report. 2. The conversion prompt is provided for the 'isinstance' statement in the conversion script. --- mindinsight/mindconverter/ast_edits.py | 147 ++++++++++++++++-- .../mindconverter/common/exceptions.py | 10 ++ mindinsight/mindconverter/config.py | 19 +-- 3 files changed, 146 insertions(+), 30 deletions(-) diff --git a/mindinsight/mindconverter/ast_edits.py b/mindinsight/mindconverter/ast_edits.py index dc5a3913..74d08466 100644 --- a/mindinsight/mindconverter/ast_edits.py +++ b/mindinsight/mindconverter/ast_edits.py @@ -28,13 +28,13 @@ from mindinsight.mindconverter.config import ALL_MAPPING from mindinsight.mindconverter.config import NN_LIST from mindinsight.mindconverter.config import ALL_TORCH_APIS from mindinsight.mindconverter.config import ALL_2P_LIST -from mindinsight.mindconverter.config import UNSUPPORTED_WARN_INFOS -from mindinsight.mindconverter.config import ALL_UNSUPPORTED +from mindinsight.mindconverter.config import get_prompt_info from mindinsight.mindconverter.common.log import logger from mindinsight.mindconverter.common.exceptions import NodeTypeNotSupport from mindinsight.mindconverter.forward_call import ForwardCall LOG_FMT_CONVERT = "[Convert] '%s' is converted to '%s'." +LOG_FMT_CONVERT_WITH_TIPS = "[Convert] '%s' is converted to '%s'. %s" LOG_FMT_NOT_CONVERT = "[UnConvert] '%s' didn't convert. %s" LOG_FMT_PROMPT_INFO = "[INFO] %s" LOG_SUGGESTION_MANUAL_CONVERT = "Please manual convert the code, along with the code associated with it." @@ -95,6 +95,7 @@ class _LineColEditVisitor(ast.NodeVisitor): class _NodeInfo: """NodeInfo class definition.""" + def __init__(self, node): self.node = node self.call_list = [] # Used to save all ast.Call node in self._node @@ -444,19 +445,25 @@ class AstEditVisitor(ast.NodeVisitor): is_include_call = False return is_include_call - def match_api(self, call_func_node, is_forward): + def match_api(self, call_func_node, is_forward, check_context=True): """ Check api name to convert, check api name ok with a is_forward condition. Args: call_func_node (ast.Attribute): The call.func node. is_forward (bool): whether api belong to forward. + check_context (boolean): If True, the code context will be checked. Default is True. Returns: str, the standard api name used to match. ApiMappingEnum, the match result. """ - api_name, match_case = self._infer_api_name(call_func_node) + match_case = ApiMatchingEnum.NOT_API + api_call_name = pasta.dump(call_func_node) + if api_call_name.startswith('self.'): + return api_call_name, match_case + + api_name, match_case = self._infer_api_name(call_func_node, check_context) api_call_name = pasta.dump(call_func_node) is_tensor_obj_call = False if api_name != api_call_name: @@ -466,15 +473,17 @@ class AstEditVisitor(ast.NodeVisitor): # rewritten external module name # e.g., mm.ReLU will be written to nn.ReLU if 'import torch.nn as mm' in script. - if not is_tensor_obj_call and not self._code_analyzer.is_standard_external_ref: - standard_api_call_name = self._mapping_standard_api_name(api_name) + if not is_tensor_obj_call: + standard_api_call_name = self._get_api_whole_name(call_func_node, check_context) if standard_api_call_name in ALL_TORCH_APIS: match_case = ApiMatchingEnum.API_FOUND if (not is_forward and standard_api_call_name in NN_LIST) or \ (is_forward and standard_api_call_name in ALL_2P_LIST): match_case = ApiMatchingEnum.API_MATCHED - + else: + if standard_api_call_name and standard_api_call_name.startswith('torch.nn.init'): + match_case = ApiMatchingEnum.API_MATCHED return standard_api_call_name, match_case @staticmethod @@ -502,6 +511,25 @@ class AstEditVisitor(ast.NodeVisitor): parameters_str = call_str[left_parenthesis_pos + 1:right_parenthesis_pos] return parameters_str + def _get_api_whole_name(self, call_func_node, check_context=True): + """ + Get the whole name for the call node. + + Args: + call_func_node (AST): The func attribute of ast.Call. + check_context (boolean): If True, the code context will be checked. Default is True. + + Returns: + str, the whole name. + """ + api_name, match_case = self._infer_api_name(call_func_node, check_context) + if match_case == ApiMatchingEnum.API_STANDARD: + api_name_splits = api_name.split('.') + api_name_splits[0] = self._get_external_ref_whole_name(api_name_splits[0]) + if api_name_splits[0]: + api_name = '.'.join(api_name_splits) + return api_name + def mapping_api(self, call_node, check_context=True): """ Convert api_name in code to MindSpore api, if api_name is a python api, code will not convert. @@ -522,6 +550,26 @@ class AstEditVisitor(ast.NodeVisitor): if api_call_name.startswith('self.'): return code + new_code = self._mapping_api(call_node, check_context) + + return new_code + + def _mapping_api(self, call_node, check_context=True): + """ + Convert api_name in code to MindSpore api, if api_name is a python api, code will not convert. + + If do not check context of the script, the code represented by the node must be written in the standard way. + + Args: + call_node (ast.Call): The ast node to convert. + check_context (boolean): If True, the code context will be checked. Default is True. + + Returns: + str, the converted code. + """ + code = pasta.dump(call_node) + api_call_name = pasta.dump(call_node.func) + # find full api expected to be converted. eg:expr="nn.Conv2d(1,2,3)" args_str="(1,2,3)" args_str = '(' + self._get_call_parameters_str(call_node) + ')' @@ -551,26 +599,37 @@ class AstEditVisitor(ast.NodeVisitor): code = pasta.dump(node) api_name = pasta.dump(node.func) - # parent node first call is equal to this node, skip when parent node is replaced. - for parent_node in self._stack[:-1]: + # The parent node first call is equal to this node, skip when parent node is replaced. + # This scenario occurs, for example, when out.view(out.size(0), -1) is first converted to + # P.Reshape()(out, (out.size(0). -1)), will skip P.Reshape() in following visiting. + # Access from the penultimate element in reverse order. + for parent_node in self._stack[-2::-1]: if parent_node in self._new_call_nodes and pasta.dump(parent_node).startswith(api_name): return parent = self._stack[-2] new_node = None + new_code = code matched_api_name, match_case = self.match_api(node.func, self._is_forward_function) if match_case in [ApiMatchingEnum.API_INFER, ApiMatchingEnum.API_MATCHED]: + warning_info = get_prompt_info(matched_api_name) + if warning_info is None: + warning_info = '' if matched_api_name in ALL_MAPPING: logger.info("Line %3d start converting API: %s", node.lineno, api_name) new_code = self.mapping_api(node) if new_code != code: - new_node = pasta.parse(new_code).body[0].value - # find the first call name - new_api_name = new_code[:new_code.find('(')] - self._process_log.info(node.lineno, node.col_offset, LOG_FMT_CONVERT % (api_name, new_api_name)) - if matched_api_name in ALL_UNSUPPORTED: - warn_info = UNSUPPORTED_WARN_INFOS.get(api_name, '') - logger.warning("Line %3d: found unsupported API: %s%s", node.lineno, api_name, warn_info) - self._process_log.warning(node.lineno, node.col_offset, LOG_FMT_NOT_CONVERT % (api_name, warn_info)) + try: + new_node = pasta.parse(new_code).body[0].value + # find the first call name + new_api_name = new_code[:new_code.find('(')] + except AttributeError: + new_node = pasta.parse(new_code).body[0] + new_api_name = new_code + self._process_log.info(node.lineno, node.col_offset, + LOG_FMT_CONVERT_WITH_TIPS % (api_name, new_api_name, warning_info)) + else: + logger.warning("Line %3d: found unsupported API: %s%s", node.lineno, api_name, warning_info) + self._process_log.warning(node.lineno, node.col_offset, LOG_FMT_NOT_CONVERT % (api_name, warning_info)) elif match_case in [ApiMatchingEnum.API_STANDARD, ApiMatchingEnum.API_FOUND]: self._process_log.warning(node.lineno, node.col_offset, LOG_FMT_NOT_CONVERT % (api_name, '')) @@ -602,3 +661,57 @@ class AstEditVisitor(ast.NodeVisitor): elif ref_name != 'F' and external_ref_info.name == 'torch.nn.functional': renames[ref_name] = 'F' return renames + + def _get_external_ref_whole_name(self, ref_name): + """ + Find out external reference whole name. + + For example: + In the parsed source code, there is import statement + import torch.nn as new_name + _get_external_ref_whole_name('new_name') will return 'torch.nn' string. + """ + external_refs = self._code_analyzer.external_references + for external_ref_name, ref_info in external_refs.items(): + external_ref_info = ref_info['external_ref_info'] + if external_ref_name == ref_name: + return external_ref_info.name + return None + + def _check_isinstance_parameter(self, node): + """Check whether the second parameter of isinstance function contains the torch type.""" + is_isinstance_arg = False + # Check whether node is the second parameter of the isinstance function call. + # Access from the penultimate element in reverse order. + for parent_node in self._stack[-2::-1]: + if isinstance(parent_node, ast.Call) and pasta.dump(parent_node.func) == 'isinstance': + isinstance_node = parent_node + seconde_arg_type_nodes = [] + if isinstance(isinstance_node.args[1], ast.Tuple): + seconde_arg_type_nodes.extend(isinstance_node.args[1].elts) + else: + seconde_arg_type_nodes.append(isinstance_node.args[1]) + if node in seconde_arg_type_nodes: + is_isinstance_arg = True + break + if not is_isinstance_arg: + return False + + isinstance_type_arg = pasta.dump(node) + check_torch_type = False + if isinstance_type_arg: + type_splits = isinstance_type_arg.split('.') + whole_name = self._get_external_ref_whole_name(type_splits[0]) + if whole_name and whole_name.startswith('torch'): + check_torch_type = True + if check_torch_type: + _, match_case = self.match_api(node, False) + if match_case != ApiMatchingEnum.NOT_API: + warn_info = 'Manually determine the conversion type.' + self._process_log.warning(node.lineno, node.col_offset, + LOG_FMT_NOT_CONVERT % (isinstance_type_arg, warn_info)) + return check_torch_type + + def visit_Attribute(self, node): + """Callback function when visit AST tree""" + self._check_isinstance_parameter(node) diff --git a/mindinsight/mindconverter/common/exceptions.py b/mindinsight/mindconverter/common/exceptions.py index 494b2dce..69407961 100644 --- a/mindinsight/mindconverter/common/exceptions.py +++ b/mindinsight/mindconverter/common/exceptions.py @@ -24,6 +24,7 @@ class ConverterErrors(ScriptConverterErrors): """Converter error codes.""" SCRIPT_NOT_SUPPORT = 1 NODE_TYPE_NOT_SUPPORT = 2 + CODE_SYNTAX_ERROR = 3 class ScriptNotSupport(MindInsightException): @@ -42,3 +43,12 @@ class NodeTypeNotSupport(MindInsightException): super(NodeTypeNotSupport, self).__init__(ConverterErrors.NODE_TYPE_NOT_SUPPORT, msg, http_code=400) + + +class CodeSyntaxError(MindInsightException): + """The CodeSyntaxError class definition.""" + + def __init__(self, msg): + super(CodeSyntaxError, self).__init__(ConverterErrors.CODE_SYNTAX_ERROR, + msg, + http_code=400) diff --git a/mindinsight/mindconverter/config.py b/mindinsight/mindconverter/config.py index 98c0a102..807561b4 100644 --- a/mindinsight/mindconverter/config.py +++ b/mindinsight/mindconverter/config.py @@ -22,7 +22,7 @@ import os import pasta from mindinsight.mindconverter.common.log import logger - +from mindinsight.mindconverter.common.exceptions import CodeSyntaxError REQUIRED = 'REQUIRED' UNREQUIRED = 'UNREQUIRED' @@ -31,6 +31,7 @@ FUNC_MODULE = 'mindinsight.mindconverter.funcs' class APIPt: """Base API for args parse, and API for one frame.""" + def __init__(self, name: str, params: OrderedDict): self.name = name self.params = OrderedDict() @@ -77,10 +78,8 @@ class APIPt: try: ast_node = ast.parse("whatever_call_name" + args_str) call_node = ast_node.body[0].value - if not isinstance(call_node, ast.Call): - raise ValueError('call name with args str [{}] not instance of ast.Call'.format(args_str)) - except: - raise ValueError("can't parse code:\n{}".format(args_str)) + except SyntaxError as parse_error: + raise CodeSyntaxError("can't parse code:\n{}".format(args_str)) from parse_error # regard all actual parameter as one parameter if len(self.params) == 1: @@ -118,6 +117,7 @@ class APIPt: class APIMs(APIPt): """API for MindSpore""" + def __init__(self, name: str, params: OrderedDict, p_attrs=None): self.is_primitive = name.startswith('P.') if self.is_primitive: @@ -167,6 +167,7 @@ class APIMs(APIPt): class MappingHelper: """Mapping from one frame to another frame""" + def __init__(self, ms_api: APIMs, pt_api: APIPt, **kwargs): ms2pt_mapping = kwargs.get('ms2pt_mapping') gen_explicit_map = kwargs.get('gen_explicit_map') @@ -392,7 +393,6 @@ TENSOR_DOT_MAPPING = get_mapping_from_file(TENSOR_DOT_MAPPING_PATH) ALL_MAPPING = {**NN_MAPPING, **F_MAPPING, **TORCH_DOT_MAPPING, **TENSOR_DOT_MAPPING} - # ---------------------------- api list support or not support ---------------------------- NN_LIST_PATH = os.path.realpath(os.path.join(os.path.dirname(__file__), 'ops', 'nn_list.json')) NN_LIST = load_json_file(NN_LIST_PATH) @@ -400,7 +400,6 @@ NN_LIST += ["torch." + name for name in NN_LIST] NN_SUPPORTED = [x for x in NN_LIST if x in ALL_MAPPING] NN_UNSUPPORTED = [x for x in NN_LIST if x not in ALL_MAPPING] - F_LIST_PATH = os.path.realpath(os.path.join(os.path.dirname(__file__), 'ops', 'f_list.json')) F_LIST = load_json_file(F_LIST_PATH) F_LIST += ["F." + name[len("torch.nn.functional."):] for name in F_LIST] + \ @@ -408,29 +407,23 @@ F_LIST += ["F." + name[len("torch.nn.functional."):] for name in F_LIST] + \ F_SUPPORTED = [x for x in F_LIST if x in ALL_MAPPING] F_UNSUPPORTED = [x for x in F_LIST if x not in ALL_MAPPING] - TORCH_DOT_LIST_PATH = os.path.realpath(os.path.join(os.path.dirname(__file__), 'ops', 'torch_dot_list.json')) TORCH_DOT_LIST = load_json_file(TORCH_DOT_LIST_PATH) - TORCH_DOT_SUPPORTED = [x for x in TORCH_DOT_LIST if x in ALL_MAPPING] TORCH_DOT_UNSUPPORTED = [x for x in TORCH_DOT_LIST if x not in ALL_MAPPING] - TENSOR_DOT_LIST_PATH = os.path.realpath(os.path.join(os.path.dirname(__file__), 'ops', 'tensor_dot_list.json')) TENSOR_DOT_LIST = load_json_file(TENSOR_DOT_LIST_PATH) - TENSOR_DOT_SUPPORTED = [x for x in TENSOR_DOT_LIST if x in ALL_MAPPING] TENSOR_DOT_UNSUPPORTED = [x for x in TENSOR_DOT_LIST if x not in ALL_MAPPING] - ALL_2P_LIST = F_LIST + TORCH_DOT_LIST + TENSOR_DOT_LIST ALL_TORCH_APIS = NN_LIST + F_LIST + TORCH_DOT_LIST + TENSOR_DOT_LIST ALL_SUPPORTED = NN_SUPPORTED + F_SUPPORTED + TORCH_DOT_SUPPORTED + TENSOR_DOT_SUPPORTED ALL_UNSUPPORTED = NN_UNSUPPORTED + F_UNSUPPORTED + TORCH_DOT_UNSUPPORTED + TENSOR_DOT_UNSUPPORTED - UNSUPPORTED_WARN_INFOS = { "nn.AdaptiveAvgPool2d": "Maybe could convert to mindspore.ops.operations.ReduceMean.", "nn.AvgPool1d": "Maybe could convert to mindspore.nn.AvgPool1d.",