Use the AST replaces the importlib/inspect modules to analyze and modify network definition script. The importlib/inspect must load python script to analyze, but AST analysis is static code analysis and is very secure.tags/v0.5.0-beta
| @@ -0,0 +1,579 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless REQUIRED by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Convert for Python scripts according API mapping information.""" | |||
| import ast | |||
| import logging | |||
| import re | |||
| from enum import Enum | |||
| import pasta | |||
| from pasta.augment import import_utils | |||
| from mindinsight.mindconverter.code_analysis import CodeAnalyzer | |||
| from mindinsight.mindconverter.code_analysis import APIAnalysisSpec | |||
| from mindinsight.mindconverter.config import ALL_MAPPING | |||
| from mindinsight.mindconverter.config import NN_LIST | |||
| from mindinsight.mindconverter.config import ALL_TORCH_APIS | |||
| from mindinsight.mindconverter.config import ALL_2P_LIST | |||
| from mindinsight.mindconverter.config import UNSUPPORTED_WARN_INFOS | |||
| from mindinsight.mindconverter.config import ALL_UNSUPPORTED | |||
| from mindinsight.mindconverter.common.log import logger | |||
| from mindinsight.mindconverter.common.exceptions import NodeTypeNotSupport | |||
| from mindinsight.mindconverter.forward_call import ForwardCall | |||
| LOG_FMT_CONVERT = "[Convert] '%s' is converted to '%s'." | |||
| LOG_FMT_NOT_CONVERT = "[UnConvert] '%s' didn't convert. %s" | |||
| LOG_FMT_PROMPT_INFO = "[INFO] %s" | |||
| LOG_SUGGESTION_MANUAL_CONVERT = "Please manual convert the code, along with the code associated with it." | |||
| class ApiMatchingEnum(Enum): | |||
| """Node edge type enum.""" | |||
| NOT_API = 'not an api name' | |||
| API_INFER = 'infer api name to map' | |||
| API_STANDARD = 'api name in the correct format' | |||
| API_FOUND = 'found an api name in api list' | |||
| API_MATCHED = 'api is matched to map' | |||
| class _ConvertReport: | |||
| """Report log of converting source code.""" | |||
| def __init__(self, is_stub=False): | |||
| self._is_stub = is_stub | |||
| self._max_line = 0 | |||
| self._log = [] # report log, type is (severity, line, col, msg) | |||
| def _add_log(self, severity, line, col, msg): | |||
| """Add log.""" | |||
| if self._is_stub: | |||
| return | |||
| if isinstance(line, int) and isinstance(col, int): | |||
| self._log.append((severity, line, col, msg)) | |||
| if self._max_line < line: | |||
| self._max_line = line | |||
| def info(self, line, col, msg): | |||
| """Interface to add infer log""" | |||
| self._add_log(logging.INFO, line, col, msg) | |||
| def warning(self, line, col, msg): | |||
| """Interface to add warning log""" | |||
| self._add_log(logging.WARNING, line, col, msg) | |||
| def get_logs(self): | |||
| """Get convert logs""" | |||
| logs = [] | |||
| # sort rule: line * self._max_line + col | |||
| self._log.sort(key=lambda log: log[1] * self._max_line + log[2]) | |||
| for log_info in self._log: | |||
| log_info = "line %d:%d: %s" % (log_info[1], log_info[2], log_info[3]) | |||
| logs.append(log_info) | |||
| return logs | |||
| class _LineColEditVisitor(ast.NodeVisitor): | |||
| """ | |||
| Update line number and col offset of ast node. | |||
| Use the line and column number of the original code to update | |||
| the line and column number of the new code replaced with the original code. | |||
| """ | |||
| class _NodeInfo: | |||
| """NodeInfo class definition.""" | |||
| def __init__(self, node): | |||
| self.node = node | |||
| self.call_list = [] # Used to save all ast.Call node in self._node | |||
| def __init__(self): | |||
| self._dst_node_info = None | |||
| self._src_node_info = None | |||
| self._visiting = self._src_node_info # Used to point to the visiting node | |||
| def update(self, replace_with_node, src_node): | |||
| """Update the line and column number of the new code replaced with the original code.""" | |||
| replace_with_node.lineno = src_node.lineno | |||
| replace_with_node.col_offset = src_node.col_offset | |||
| self._dst_node_info = self._NodeInfo(replace_with_node) | |||
| self._src_node_info = self._NodeInfo(src_node) | |||
| self._visiting = self._src_node_info | |||
| self.visit(self._visiting.node) | |||
| self._visiting = self._dst_node_info | |||
| self.visit(self._visiting.node) | |||
| self._update_line_col() | |||
| def visit_Call(self, node): | |||
| """Callback function when visit AST tree""" | |||
| self._visiting.call_list.append(node) | |||
| self.generic_visit(node) | |||
| def _update_line_col(self): | |||
| """Update the line and column number information for all ast.Call node.""" | |||
| dst_call_list = list(self._dst_node_info.call_list) | |||
| src_call_list = list(self._src_node_info.call_list) | |||
| len_diff = len(dst_call_list) - len(src_call_list) | |||
| # After MindSpore api replaces Torch api, more calls are generated. | |||
| # For example, out.view() is replaced with P.Reshape()(out). | |||
| # out.view() has only one call, but P.Reshape()(out) has two calls. | |||
| # To match the replaced calls, the calls of out.view is padded to the same quantity. | |||
| if len_diff > 0: | |||
| src_call_list = [src_call_list[0]] * len_diff + src_call_list | |||
| for dst_call, src_call in zip(dst_call_list, src_call_list): | |||
| dst_call.lineno = src_call.lineno | |||
| dst_call.col_offset = src_call.col_offset | |||
| if not dst_call.args: | |||
| continue | |||
| # When out.size().view(1, ...) transforms to P.Reshape()(out.size(), 1, ...), | |||
| # in this case, the column of parameter out.size() will be bigger than the following parameters. | |||
| # To ensure the sequence of parameters, adjust the column of the second parameter. | |||
| args = [] | |||
| for arg in dst_call.args: | |||
| if self._check_arg2update(arg): | |||
| args.append(arg) | |||
| for arg in args: | |||
| arg.lineno = dst_call.lineno | |||
| arg.col_offset += dst_call.col_offset | |||
| @staticmethod | |||
| def _check_arg2update(arg): | |||
| # Only the col_offset of the first line code is re-counted, needs to be corrected. | |||
| # When the arg is a function call, its col_offset is handled separately. | |||
| if not isinstance(arg, ast.Call) and arg.lineno == 1: | |||
| return True | |||
| return False | |||
| class AstEditVisitor(ast.NodeVisitor): | |||
| """AST Visitor that process function calls. | |||
| Converts function calls from torch api to MindSpore api using api mapping information. | |||
| """ | |||
| def __init__(self): | |||
| self._process_log = _ConvertReport() | |||
| self._tree = None | |||
| self._code_analyzer = None | |||
| self._stack = [] # Used to easily access the parent node | |||
| self._forward_list = {} | |||
| self._is_forward_function = False # Used to allow access the visiting function forward attribute | |||
| self._new_call_nodes = [] # Used to save new ast.call nodes | |||
| def process(self, ast_tree): | |||
| """ | |||
| Convert source code to MindSpore code. | |||
| Args: | |||
| ast_tree (AST): The root node of the source code. | |||
| """ | |||
| self.__init__() | |||
| self._tree = ast_tree | |||
| self._code_analyzer = CodeAnalyzer() | |||
| self._code_analyzer.process(self._tree) | |||
| self._forward_list = ForwardCall(self._tree).calls | |||
| # replace python function under nn.Module | |||
| self._convert_api() | |||
| # replace external reference statements | |||
| self._convert_external_reference() | |||
| def get_logs(self): | |||
| """Get conversion report.""" | |||
| return self._process_log.get_logs() | |||
| def _convert_cell(self, cell_scope): | |||
| """ | |||
| Convert a PyTorch Module class into MindSpore Cell class. | |||
| Args: | |||
| cell_scope (pasta.base.Scope): The network class definition node inherits from torch.nn.Module. | |||
| """ | |||
| cell_ast_node = cell_scope.node | |||
| line_no = cell_ast_node.lineno | |||
| logger.info("Line %3d: start converting nn.Module %s", line_no, self._code_analyzer.get_name(cell_ast_node)) | |||
| class_elements = self._code_analyzer.network_definitions()['cell'] | |||
| # step1. update function definition | |||
| for func_scope in class_elements.get(cell_scope, []): | |||
| self._update_function_def(func_scope) | |||
| # step2. update base name of class | |||
| self._update_base_name(cell_scope) | |||
| def _update_base_name(self, class_def_scope): | |||
| """ | |||
| Update base name of class. | |||
| Args: | |||
| class_def_scope (ast.ClassDef): Class definition node. | |||
| """ | |||
| base_name_mapping = APIAnalysisSpec.base_name_mapping | |||
| class_def_node = class_def_scope.node | |||
| base_class_nodes = class_def_scope.node.bases | |||
| # update base class name | |||
| for base_class_node in base_class_nodes: | |||
| base_name = base_class_node.attr | |||
| if base_name in APIAnalysisSpec.get_network_base_class_names(): | |||
| old_code = pasta.dump(base_class_node) | |||
| if base_name in base_name_mapping: | |||
| new_code = 'nn.' + base_name_mapping[base_class_node.attr] | |||
| new_node = pasta.parse(new_code) | |||
| pasta.ast_utils.replace_child(class_def_node, base_class_node, new_node) | |||
| self._process_log.info(base_class_node.lineno, base_class_node.col_offset, LOG_FMT_CONVERT % | |||
| (old_code, new_code)) | |||
| else: | |||
| self._process_log.info(base_class_node.lineno, base_class_node.col_offset, LOG_FMT_NOT_CONVERT % | |||
| (old_code, '')) | |||
| def _update_function_def(self, func_scope): | |||
| """ | |||
| Convert a PyTorch function into MindSpore function. | |||
| Args: | |||
| func_scope (pasta.base.scope.Scope): The node scope of function definition. | |||
| """ | |||
| is_forward = self._judge_forward(func_scope) | |||
| # step1. convert the content of the function. | |||
| self._convert_function(func_scope, is_forward) | |||
| # step2. replace function name if name is forward | |||
| func_ast_node = func_scope.node | |||
| old_func_name = 'forward' | |||
| new_func_name = 'construct' | |||
| if func_ast_node.name == old_func_name: | |||
| func_ast_node.name = new_func_name | |||
| self._process_log.info(func_ast_node.lineno, func_ast_node.col_offset, | |||
| LOG_FMT_CONVERT % (old_func_name, new_func_name)) | |||
| def _convert_api(self): | |||
| """Convert PyTorch api call to MindSpore api call in a function.""" | |||
| tasks = [] | |||
| convert_elements = self._code_analyzer.network_definitions() | |||
| for func_node_scope in convert_elements.get("functions", []): | |||
| is_forward = self._judge_forward(func_node_scope) | |||
| tasks.append((self._convert_function, (func_node_scope, is_forward))) | |||
| for class_scope in convert_elements.get("cell", []).keys(): | |||
| tasks.append((self._convert_cell, (class_scope,))) | |||
| for convert_fun, args in tasks: | |||
| convert_fun(*args) | |||
| def _convert_external_reference(self): | |||
| """Convert import statements.""" | |||
| name_replace = APIAnalysisSpec.import_name_mapping | |||
| replace_imports = list(name_replace.values()) | |||
| for ref_info in self._code_analyzer.external_references.values(): | |||
| external_ref_info = ref_info['external_ref_info'] | |||
| parent_node = ref_info['parent_node'] | |||
| if parent_node is None: | |||
| continue | |||
| code = pasta.dump(parent_node) | |||
| if external_ref_info.name in APIAnalysisSpec.get_convertible_external_names(): | |||
| external_ref_info = ref_info['external_ref_info'] | |||
| if external_ref_info.name in name_replace.keys(): | |||
| import_utils.remove_import_alias_node(self._code_analyzer.root_scope, external_ref_info.node) | |||
| replace_info = name_replace[external_ref_info.name] | |||
| new_ref_name = replace_info[1] | |||
| new_external_name = replace_info[0] | |||
| if new_ref_name: | |||
| new_code = f'import {new_external_name} as {new_ref_name}' | |||
| else: | |||
| new_code = f'import {new_external_name}' | |||
| self._process_log.info(parent_node.lineno, parent_node.col_offset, LOG_FMT_CONVERT % | |||
| (code.strip(), new_code.strip())) | |||
| elif external_ref_info.name.startswith('torch.'): | |||
| self._process_log.warning(parent_node.lineno, parent_node.col_offset, LOG_FMT_NOT_CONVERT % | |||
| (code.strip(), LOG_SUGGESTION_MANUAL_CONVERT)) | |||
| else: | |||
| pass | |||
| # Insert import in reverse order, display in forward order. | |||
| for idx in range(len(replace_imports) - 1, -1, -1): | |||
| replace_import = replace_imports[idx] | |||
| if replace_import[1]: | |||
| self._add_import(name_to_import=replace_import[0], as_name=replace_import[1]) | |||
| else: | |||
| self._add_import(name_to_import=replace_import[0]) | |||
| def _add_import(self, name_to_import, as_name=None): | |||
| """ | |||
| Adds an import to the ast tree. | |||
| Args: | |||
| name_to_import: (string) The absolute name to import. | |||
| as_name: (string) The alias for the import ("import name_to_import as asname") | |||
| """ | |||
| new_alias = ast.alias(name=name_to_import, asname=as_name) | |||
| import_node = ast.Import(names=[new_alias]) | |||
| # Insert the node at the top of the module | |||
| self._tree.body.insert(1 if pasta.base.ast_utils.has_docstring(self._tree) else 0, import_node) | |||
| def _convert_function(self, func_scope, is_forward): | |||
| """ | |||
| Convert a PyTorch function into MindSpore function. | |||
| Args: | |||
| func_scope (pasta.base.scope.Scope): The node scope of function definition. | |||
| is_forward (boolean): If the function is defined in forward function in nn.Module in torch. | |||
| """ | |||
| func_ast_node = func_scope.node | |||
| line_no = func_ast_node.lineno | |||
| logger.info("Line %3d: start converting function %s()", line_no, func_ast_node.name) | |||
| parent = func_scope.parent_scope.node | |||
| self._stack.clear() | |||
| self._new_call_nodes.clear() | |||
| if parent: | |||
| self._stack.append(parent) | |||
| self._is_forward_function = is_forward | |||
| self.visit(func_scope.node) | |||
| def _judge_forward(self, func_scope): | |||
| """ | |||
| Check if function is a forward function. | |||
| Args: | |||
| func_scope (pasta.base.scope.Scope): The node scope of function definition. | |||
| Returns: | |||
| boolean, True or False | |||
| """ | |||
| is_forward = func_scope.node in self._forward_list.values() | |||
| if is_forward: | |||
| logger.debug("%s is a forward function", self._code_analyzer.get_name(func_scope)) | |||
| return is_forward | |||
| # Overridden to maintain stack information to access parent node | |||
| def visit(self, node): | |||
| """Visit a ast tree.""" | |||
| self._stack.append(node) | |||
| super(AstEditVisitor, self).visit(node) | |||
| self._stack.pop() | |||
| def _mapping_standard_api_name(self, api_name): | |||
| """Get mapping from external reference name to standard external reference name""" | |||
| standard_name = api_name | |||
| if not self._code_analyzer.is_standard_external_ref: | |||
| # key is real ref name, value is standard ref name. | |||
| mapping_names = self._mapping_standard_external_ref() | |||
| api_name_parts = api_name.split('.') | |||
| api_name_parts[0] = mapping_names.get(api_name_parts[0], api_name_parts[0]) | |||
| standard_name = '.'.join(api_name_parts) | |||
| return standard_name | |||
| def _infer_api_name(self, call_func_node, check_context=True): | |||
| """Infer the call name. | |||
| Examples: | |||
| 1. nn.Sequential inferred to nn.Sequential | |||
| 2. mmm.size inferred to .size if import torch.nn as nn | |||
| 3. mmm.size inferred to mmm.size if import torch.nn as mmm | |||
| """ | |||
| match_case = ApiMatchingEnum.NOT_API | |||
| api_name = None | |||
| call_name = pasta.dump(call_func_node) | |||
| is_include_sub_call = self._is_include_sub_call(call_func_node) | |||
| if is_include_sub_call: | |||
| name_attributes = call_name.rsplit('.', 1) | |||
| else: | |||
| name_attributes = call_name.split('.') | |||
| # rewritten external module name | |||
| # e.g., mm.ReLU will be written to nn.ReLU if 'import torch.nn as mm' in script. | |||
| if check_context and not self._code_analyzer.is_standard_external_ref: | |||
| standard_name = self._mapping_standard_api_name(name_attributes[0]) | |||
| else: | |||
| standard_name = name_attributes[0] | |||
| if standard_name in ["nn", "F", "torch"]: | |||
| match_case = ApiMatchingEnum.API_STANDARD | |||
| api_name = call_name | |||
| else: | |||
| # only infer function for tensor object. | |||
| # e.g., api_call_name is out.view, .view is an api name for out which is maybe a tensor object. | |||
| # e.g., 'xxxx'.size can be not inferred to .size, because string is not a tensor object. | |||
| first_name = standard_name.split('.')[0] | |||
| if not re.search(r'\W', first_name) and len(name_attributes) > 1: | |||
| api_name = '.' + name_attributes[-1] | |||
| match_case = ApiMatchingEnum.API_INFER | |||
| return api_name, match_case | |||
| @staticmethod | |||
| def _is_include_sub_call(call_func_node): | |||
| """"Inspect a sub call in call expression. | |||
| Examples: | |||
| 1. nn.functional.relu() return False | |||
| 2. nn.functional.relu(out).size() return True. nn.functional.relu(out) is sub call. | |||
| 3. nn.functional.relu(out=out.size()).size() return False. out.size() is not sub call of argument. | |||
| """ | |||
| is_include_call = False | |||
| try: | |||
| sub_node = call_func_node | |||
| while sub_node and not isinstance(sub_node, ast.Call): | |||
| sub_node = sub_node.value | |||
| if isinstance(sub_node, ast.Call): | |||
| is_include_call = True | |||
| except AttributeError: | |||
| is_include_call = False | |||
| return is_include_call | |||
| def match_api(self, call_func_node, is_forward): | |||
| """ | |||
| Check api name to convert, check api name ok with a is_forward condition. | |||
| Args: | |||
| call_func_node (ast.Attribute): The call.func node. | |||
| is_forward (bool): whether api belong to forward. | |||
| Returns: | |||
| str, the standard api name used to match. | |||
| ApiMappingEnum, the match result. | |||
| """ | |||
| api_name, match_case = self._infer_api_name(call_func_node) | |||
| api_call_name = pasta.dump(call_func_node) | |||
| is_tensor_obj_call = False | |||
| if api_name != api_call_name: | |||
| is_tensor_obj_call = True | |||
| standard_api_call_name = api_name | |||
| # rewritten external module name | |||
| # e.g., mm.ReLU will be written to nn.ReLU if 'import torch.nn as mm' in script. | |||
| if not is_tensor_obj_call and not self._code_analyzer.is_standard_external_ref: | |||
| standard_api_call_name = self._mapping_standard_api_name(api_name) | |||
| if standard_api_call_name in ALL_TORCH_APIS: | |||
| match_case = ApiMatchingEnum.API_FOUND | |||
| if (not is_forward and standard_api_call_name in NN_LIST) or \ | |||
| (is_forward and standard_api_call_name in ALL_2P_LIST): | |||
| match_case = ApiMatchingEnum.API_MATCHED | |||
| return standard_api_call_name, match_case | |||
| def mapping_api(self, call_node, check_context=True): | |||
| """ | |||
| Convert api_name in code to MindSpore api, if api_name is a python api, code will not convert. | |||
| If do not check context of the script, the code represented by the node must be written in the standard way. | |||
| Args: | |||
| call_node (ast.Call): The ast node to convert. | |||
| check_context (boolean): If True, the code context will be checked. Default is True. | |||
| Returns: | |||
| str, the converted code. | |||
| """ | |||
| if not isinstance(call_node, ast.Call): | |||
| raise NodeTypeNotSupport("It is not ast.Call node.") | |||
| code = pasta.dump(call_node) | |||
| api_call_name = pasta.dump(call_node.func) | |||
| if api_call_name.startswith('self.'): | |||
| return code | |||
| # find full api expected to be converted. eg:expr="nn.Conv2d(1,2,3)" args_str="(1,2,3)" | |||
| args_str = code[len(api_call_name):].strip() | |||
| try: | |||
| api_name, _ = self._infer_api_name(call_node.func, check_context) | |||
| standard_api_call_name = api_call_name | |||
| if api_name != api_call_name: | |||
| # api name .view inferred from out.view, split tensor object name is out | |||
| tensor_obj_name = api_call_name[:-len(api_name)] | |||
| map_helper = ALL_MAPPING[api_name] | |||
| new_code = map_helper.convert(tensor_obj_name, args_str) | |||
| else: | |||
| # change to external ref name | |||
| # e.g., mm.ReLU will be changed to nn.ReLU if 'import torch.nn as mm' in script. | |||
| if check_context and not self._code_analyzer.is_standard_external_ref: | |||
| standard_api_call_name = self._mapping_standard_api_name(api_name) | |||
| map_helper = ALL_MAPPING[standard_api_call_name] | |||
| new_code = map_helper.convert(standard_api_call_name, args_str) | |||
| except KeyError: | |||
| return code | |||
| return new_code | |||
| def visit_Call(self, node): | |||
| """Callback function when visit AST tree""" | |||
| code = pasta.dump(node) | |||
| api_name = pasta.dump(node.func) | |||
| # parent node first call is equal to this node, skip when parent node is replaced. | |||
| for parent_node in self._stack[:-1]: | |||
| if parent_node in self._new_call_nodes and pasta.dump(parent_node).startswith(api_name): | |||
| return | |||
| parent = self._stack[-2] | |||
| new_node = None | |||
| matched_api_name, match_case = self.match_api(node.func, self._is_forward_function) | |||
| if match_case in [ApiMatchingEnum.API_INFER, ApiMatchingEnum.API_MATCHED]: | |||
| if matched_api_name in ALL_MAPPING: | |||
| logger.info("Line %3d start converting API: %s", node.lineno, api_name) | |||
| new_code = self.mapping_api(node) | |||
| if new_code != code: | |||
| new_node = pasta.parse(new_code).body[0].value | |||
| # find the first call name | |||
| new_api_name = new_code[:new_code.find('(')] | |||
| self._process_log.info(node.lineno, node.col_offset, LOG_FMT_CONVERT % (api_name, new_api_name)) | |||
| if matched_api_name in ALL_UNSUPPORTED: | |||
| warn_info = UNSUPPORTED_WARN_INFOS.get(api_name, '') | |||
| logger.warning("Line %3d: found unsupported API: %s%s", node.lineno, api_name, warn_info) | |||
| self._process_log.warning(node.lineno, node.col_offset, LOG_FMT_NOT_CONVERT % (api_name, warn_info)) | |||
| elif match_case in [ApiMatchingEnum.API_STANDARD, ApiMatchingEnum.API_FOUND]: | |||
| self._process_log.warning(node.lineno, node.col_offset, LOG_FMT_NOT_CONVERT % (api_name, '')) | |||
| else: | |||
| pass | |||
| if parent and new_node: | |||
| update_line_col = _LineColEditVisitor() | |||
| update_line_col.update(new_node, node) | |||
| pasta.ast_utils.replace_child(parent, node, new_node) | |||
| self._new_call_nodes.append(new_node) | |||
| node = new_node | |||
| self._stack[-1] = node | |||
| try: | |||
| self.generic_visit(node) | |||
| except Exception: | |||
| logger.error('original code:%s, new code:%s', code, new_code, exc_info=True) | |||
| raise | |||
| def _mapping_standard_external_ref(self): | |||
| """Obtain the mapping dict of mapping the external references to standard external references.""" | |||
| renames = {} | |||
| external_refs = self._code_analyzer.external_references | |||
| for ref_name, ref_info in external_refs.items(): | |||
| external_ref_info = ref_info['external_ref_info'] | |||
| if ref_name != 'nn' and external_ref_info.name == 'torch.nn': | |||
| renames[ref_name] = 'nn' | |||
| elif ref_name != 'F' and external_ref_info.name == 'torch.nn.functional': | |||
| renames[ref_name] = 'F' | |||
| return renames | |||
| @@ -186,25 +186,23 @@ def cli_entry(): | |||
| mode = permissions << 6 | |||
| os.makedirs(args.output, mode=mode, exist_ok=True) | |||
| os.makedirs(args.report, mode=mode, exist_ok=True) | |||
| _run(args.in_file, args.output, '', args.report) | |||
| _run(args.in_file, args.output, args.report) | |||
| def _run(in_files, out_dir, in_module, report): | |||
| def _run(in_files, out_dir, report): | |||
| """ | |||
| Run converter command. | |||
| Args: | |||
| in_files (str): The file path or directory to convert. | |||
| out_dir (str): The output directory to save converted file. | |||
| in_module (str): The module name to convert. | |||
| report (str): The report file path. | |||
| """ | |||
| files_config = { | |||
| 'root_path': in_files if in_files else '', | |||
| 'in_files': [], | |||
| 'outfile_dir': out_dir, | |||
| 'report_dir': report, | |||
| 'in_module': in_module | |||
| 'report_dir': report | |||
| } | |||
| if os.path.isfile(in_files): | |||
| files_config['root_path'] = os.path.dirname(in_files) | |||
| @@ -0,0 +1,399 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless REQUIRED by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """code analysis module""" | |||
| import ast | |||
| import pasta | |||
| from pasta.base import scope | |||
| from mindinsight.mindconverter.common.exceptions import ScriptNotSupport | |||
| class APIAnalysisSpec: | |||
| """API analysis specifications""" | |||
| import_name_mapping = {'torch': ['mindspore', None], | |||
| 'torch.nn': ['mindspore.nn', 'nn'], | |||
| 'torch.nn.functional': ['mindspore.ops.operations', 'P']} | |||
| base_name_mapping = {'Module': 'Cell', | |||
| 'Sequential': 'SequentialCell' | |||
| } | |||
| @classmethod | |||
| def get_convertible_external_names(cls): | |||
| """ | |||
| Obtain the convertible external names. | |||
| The external name is the full dotted name being referenced. | |||
| """ | |||
| return cls.import_name_mapping.keys() | |||
| @staticmethod | |||
| def get_network_base_class_names(): | |||
| """Obtain the base names which network class base from""" | |||
| return ['Module', | |||
| 'Sequential', | |||
| 'ModuleList', | |||
| 'ModuleDict', | |||
| 'ParameterList', | |||
| 'ParameterDict'] | |||
| @staticmethod | |||
| def check_external_alias_ref(ref_name, external_name): | |||
| """ | |||
| Check 'import as' is standard. | |||
| Standard references are follow: | |||
| import torch.nn as nn | |||
| import torch.nn.functional as F | |||
| Args: | |||
| ref_name (str): The name that refers to the external_name. | |||
| external_name (str): The full dotted name being referenced. For examples: | |||
| 1. 'import torch.nn as nn', torch.nn is external_name, nn is ref_name. | |||
| 2. 'from torch import nn as mm, torch.nn is external_name, mm is ref_name which is not a standard name. | |||
| Returns: | |||
| boolean, True if ref_name is standard else False. | |||
| """ | |||
| if ref_name != 'nn' and external_name == 'torch.nn': | |||
| is_standard = False | |||
| elif ref_name != 'F' and external_name == 'torch.nn.functional': | |||
| is_standard = False | |||
| else: | |||
| is_standard = True | |||
| return is_standard | |||
| class CodeAnalyzer(ast.NodeVisitor): | |||
| """Code analyzer that analyzes PyTorch python script by AST Visitor. | |||
| CodeAnalyzer find the codes that need to be converted to MindSpore, | |||
| and provides the attributes related to the codes. | |||
| """ | |||
| def __init__(self): | |||
| self._stack = [] # Used to easily access the parent node | |||
| self._external_references = {} | |||
| self._is_standard_external_ref = True | |||
| self._root_scope = None | |||
| # Used to save functions that need to be converted, value type is pasta.base.scope.Scope | |||
| self._network_functions = [] | |||
| # Used to easily trace the function node | |||
| self._functions_stack = [] | |||
| # key type is pasta.base.scope.Scope, value type is list | |||
| self._network_classes = {} | |||
| @property | |||
| def root_scope(self): | |||
| """The root scope of the python script code.""" | |||
| return self._root_scope | |||
| @property | |||
| def is_standard_external_ref(self): | |||
| """Obtain whether the result is a standard external reference.""" | |||
| return self._is_standard_external_ref | |||
| @property | |||
| def external_references(self): | |||
| """Obtain all external references in the analyzed code.""" | |||
| return self._external_references | |||
| def network_definitions(self): | |||
| """Obtain the network definitions which need to be converted.""" | |||
| return {"functions": self._network_functions, | |||
| "cell": self._network_classes} | |||
| def process(self, ast_tree): | |||
| """ | |||
| Start to analyze the code. | |||
| Args: | |||
| ast_tree (AST): The root node of the source code. | |||
| """ | |||
| self.__init__() | |||
| self._root_scope = scope.analyze(ast_tree) | |||
| self._pre_process() | |||
| self.visit(ast_tree) | |||
| if not self._network_classes: | |||
| msg = "model definition not be found." | |||
| raise ScriptNotSupport(msg) | |||
| @staticmethod | |||
| def _check_external_standard(external_refs): | |||
| """Check whether all external references are standard.""" | |||
| is_standard = True | |||
| for external_name, external_ref_info in external_refs.items(): | |||
| is_standard = APIAnalysisSpec.check_external_alias_ref(external_name, external_ref_info.name) | |||
| if not is_standard: | |||
| break | |||
| return is_standard | |||
| def _is_base_from_cell(self, node): | |||
| """ | |||
| Check whether the node bases from cell classes which are defined in APIAnalysisSpec. | |||
| Args: | |||
| node (ast.ClassDef): The node which is a class definition. | |||
| Returns: | |||
| boolean, True if the check result is Passed else False. | |||
| """ | |||
| if self._is_ref_convertible_imports(node): | |||
| whole_name = self._get_whole_name(node) | |||
| if whole_name.split('.')[-1] in APIAnalysisSpec.get_network_base_class_names(): | |||
| return True | |||
| return False | |||
| def _pre_process(self): | |||
| """Preprocessor checks the code before analyzing.""" | |||
| is_torch = False | |||
| # check whether the code imports torch. | |||
| for ref_name in self._root_scope.external_references.keys(): | |||
| if ref_name.split('.')[0] in APIAnalysisSpec.get_convertible_external_names(): | |||
| is_torch = True | |||
| break | |||
| if not is_torch: | |||
| msg = "The source code does not import torch, model definition can not be found." | |||
| raise ScriptNotSupport(msg) | |||
| # Find out external reference in the code and save it. | |||
| external_refs = self._analyze_import_references(self._root_scope) | |||
| self._is_standard_external_ref = self._check_external_standard(external_refs) | |||
| self._check_external_standard(external_refs) | |||
| for external_name, external_ref_info in external_refs.items(): | |||
| self._external_references.update({ | |||
| external_name: { | |||
| 'external_ref_info': external_ref_info, | |||
| 'parent_node': None | |||
| } | |||
| }) | |||
| @staticmethod | |||
| def _analyze_import_references(root_scope): | |||
| """Find out all references from the import statements.""" | |||
| external_name_ref = {} | |||
| for node_references in root_scope.external_references.values(): | |||
| for node_ref in node_references: | |||
| if node_ref.name_ref: | |||
| # (from)import alias, node_ref.name_ref.id is alias name | |||
| if node_ref.name_ref.definition.asname == node_ref.name_ref.id: | |||
| external_name_ref[node_ref.name_ref.id] = node_ref | |||
| # import without alias, node_ref.name_ref.definition.asname is None. | |||
| # e.g., import a.b.c, reference maybe is a, a.b or a.b.c in the root_scope.external_references. | |||
| # The reference a.b.c is really wanted. | |||
| elif node_ref.name_ref.definition.name == node_ref.name_ref.id: | |||
| external_name_ref[node_ref.name_ref.id] = node_ref | |||
| else: | |||
| pass | |||
| return external_name_ref | |||
| def visit(self, node): | |||
| """Overridden visit of the base class to maintain stack information to access parent node.""" | |||
| self._stack.append(node) | |||
| super(CodeAnalyzer, self).visit(node) | |||
| self._stack.pop() | |||
| @staticmethod | |||
| def _get_full_name(node): | |||
| """Get the full name of the node.""" | |||
| if not isinstance(node, (ast.Attribute, ast.Name)): | |||
| return None | |||
| return pasta.dump(node) | |||
| def _get_whole_name(self, node): | |||
| """ | |||
| Get the whole name of the node. | |||
| For example, nn.Module is spliced two nodes, nn node and Module node. | |||
| When visit ast nodes, | |||
| Module node is first visited, the full name is the same as the whole name, that is nn.Module. | |||
| And then nn node is visited, the full name is nn, the whole name is nn.Module. | |||
| """ | |||
| full_name = self._get_full_name(node) | |||
| if not full_name: | |||
| return None | |||
| # node is in stack top pos | |||
| if node is self._stack[-1]: | |||
| parent_index = -1 | |||
| while isinstance(self._stack[parent_index], ast.Attribute): | |||
| parent_index -= 1 | |||
| whole_name = self._get_full_name(self._stack[parent_index]) | |||
| else: | |||
| whole_name = full_name | |||
| return whole_name | |||
| def _is_ref_convertible_imports(self, node): | |||
| """Check whether the node references convertible imports.""" | |||
| check_result = False | |||
| whole_name = self._get_whole_name(node) | |||
| if whole_name: | |||
| module_name = whole_name.split('.')[0] | |||
| for ref_name, ref_info in self._external_references.items(): | |||
| external_ref = ref_info['external_ref_info'] | |||
| # external reference is convertible module | |||
| if external_ref.name in APIAnalysisSpec.get_convertible_external_names(): | |||
| # import from the same external module | |||
| if module_name == ref_name.split('.')[0]: | |||
| check_result = True | |||
| break | |||
| return check_result | |||
| @staticmethod | |||
| def _get_external_node(external_references): | |||
| """Get all external reference nodes.""" | |||
| external_nodes = {} | |||
| for ref_name, ref_info in external_references.items(): | |||
| external_nodes.update({ref_info['external_ref_info'].node: ref_name}) | |||
| return external_nodes | |||
| @staticmethod | |||
| def _get_convertible_external_node(external_name_ref): | |||
| """Get all convertible external reference nodes.""" | |||
| convertible_external_nodes = {} | |||
| for ref_name, ref_info in external_name_ref.items(): | |||
| if ref_info['external_ref_info'].name in APIAnalysisSpec.get_convertible_external_names(): | |||
| convertible_external_nodes.update({ref_info['external_ref_info'].node: ref_name}) | |||
| return convertible_external_nodes | |||
| def _update_external_ref_parent(self, node): | |||
| """Set external reference parent node info.""" | |||
| external_nodes = self._get_external_node(self._external_references) | |||
| convertible_external_nodes = self._get_convertible_external_node(self._external_references) | |||
| for name_node in node.names: | |||
| if name_node in convertible_external_nodes.keys(): | |||
| if len(node.names) > 1: | |||
| msg = """\ | |||
| Not support multiple imports of torch on one line in your script. line:%s: %s | |||
| """ % (node.lineno, pasta.dump(node)) | |||
| raise ScriptNotSupport(msg) | |||
| if name_node in external_nodes.keys(): | |||
| ref_name = external_nodes[name_node] | |||
| self._external_references[ref_name]['parent_node'] = node | |||
| @staticmethod | |||
| def _get_class_scope(node_scope): | |||
| """Find the class scope of the node_scope.""" | |||
| parent_scope = node_scope.parent_scope | |||
| class_scope = None | |||
| while parent_scope: | |||
| if isinstance(parent_scope.node, ast.ClassDef): | |||
| class_scope = parent_scope | |||
| break | |||
| parent_scope = parent_scope.parent_scope | |||
| return class_scope | |||
| def _update_convertible_functions(self, node): | |||
| """Update convertible functions.""" | |||
| node_scope = self._root_scope.lookup_scope(node) | |||
| class_scope = self._get_class_scope(node_scope) | |||
| if class_scope: | |||
| network_classes = self._network_classes.get(class_scope, []) | |||
| if node_scope not in network_classes: | |||
| network_classes.append(node_scope) | |||
| else: | |||
| if node_scope not in self._network_functions: | |||
| self._network_functions.append(node_scope) | |||
| def visit_ClassDef(self, node): | |||
| """Callback function when visit AST tree""" | |||
| if not self._stack[-1] is node: | |||
| return | |||
| for base in node.bases: | |||
| if self._is_ref_convertible_imports(base): | |||
| self._network_classes[self._root_scope.lookup_scope(node)] = [] | |||
| self.generic_visit(node) | |||
| def visit_Import(self, node): | |||
| """Callback function when visit AST tree""" | |||
| self._update_external_ref_parent(node) | |||
| self.generic_visit(node) | |||
| def visit_ImportFrom(self, node): | |||
| """Callback function when visit AST tree""" | |||
| self._update_external_ref_parent(node) | |||
| self.generic_visit(node) | |||
| def visit_Call(self, node): | |||
| """Callback function when visit AST tree""" | |||
| if not self._stack[-1] is node: | |||
| return | |||
| is_in_network_function = False | |||
| # If torch call is happened in the function, save the function for network definition. | |||
| if self._functions_stack and self._is_ref_convertible_imports(node.func): | |||
| self._update_convertible_functions(self._functions_stack[-1]) | |||
| is_in_network_function = True | |||
| if not is_in_network_function: | |||
| self.generic_visit(node) | |||
| def visit_FunctionDef(self, node): | |||
| """Callback function when visit AST tree""" | |||
| if not self._stack[-1] is node: | |||
| return | |||
| if node.name == "forward": | |||
| self._update_convertible_functions(node) | |||
| self._functions_stack.append(node) | |||
| self.generic_visit(node) | |||
| self._functions_stack.pop() | |||
| def get_name(self, node): | |||
| """ | |||
| Get the node name. | |||
| Args: | |||
| node (AST): The ast node of the source code. | |||
| Returns: | |||
| str, the name of the node | |||
| """ | |||
| if isinstance(node, pasta.base.scope.Scope): | |||
| items = [self.get_name(node.node)] | |||
| parent_scope = node.parent_scope | |||
| while parent_scope: | |||
| if not isinstance(parent_scope.node, ast.Module): | |||
| items.append(self.get_name(parent_scope.node)) | |||
| parent_scope = parent_scope.parent_scope | |||
| return '.'.join(reversed(items)) | |||
| if isinstance(node, (ast.ClassDef, ast.FunctionDef)): | |||
| return node.name | |||
| if isinstance(node, (ast.Name, ast.Attribute)): | |||
| return self._get_full_name(node) | |||
| return str(node) | |||
| def lookup_scope(self, node): | |||
| """ | |||
| Search the scope of the node. | |||
| Args: | |||
| node (AST): The ast node of the source code. | |||
| Returns: | |||
| scope, the scope of the node | |||
| """ | |||
| if isinstance(node, pasta.base.scope.Scope): | |||
| return node | |||
| return self._root_scope.lookup_scope(node) | |||
| @@ -0,0 +1,44 @@ | |||
| # Copyright 2019 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Define custom exception.""" | |||
| from enum import unique | |||
| from mindinsight.utils.constant import ScriptConverterErrors | |||
| from mindinsight.utils.exceptions import MindInsightException | |||
| @unique | |||
| class ConverterErrors(ScriptConverterErrors): | |||
| """Converter error codes.""" | |||
| SCRIPT_NOT_SUPPORT = 1 | |||
| NODE_TYPE_NOT_SUPPORT = 2 | |||
| class ScriptNotSupport(MindInsightException): | |||
| """The script can not support to process.""" | |||
| def __init__(self, msg): | |||
| super(ScriptNotSupport, self).__init__(ConverterErrors.SCRIPT_NOT_SUPPORT, | |||
| msg, | |||
| http_code=400) | |||
| class NodeTypeNotSupport(MindInsightException): | |||
| """The astNode can not support to process.""" | |||
| def __init__(self, msg): | |||
| super(NodeTypeNotSupport, self).__init__(ConverterErrors.NODE_TYPE_NOT_SUPPORT, | |||
| msg, | |||
| http_code=400) | |||
| @@ -13,463 +13,88 @@ | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """converter module""" | |||
| import copy | |||
| import importlib | |||
| import inspect | |||
| import os | |||
| import stat | |||
| from mindinsight.mindconverter.config import ALL_MAPPING | |||
| from mindinsight.mindconverter.config import NN_LIST | |||
| from mindinsight.mindconverter.config import ALL_TORCH_APIS | |||
| from mindinsight.mindconverter.config import ALL_2P_LIST | |||
| from mindinsight.mindconverter.config import UNSUPPORTED_WARN_INFOS | |||
| from mindinsight.mindconverter.config import ALL_UNSUPPORTED | |||
| from mindinsight.mindconverter.common.log import logger | |||
| from mindinsight.mindconverter.forward_call import ForwardCall | |||
| import pasta | |||
| LINE_NO_INDEX_DIFF = 1 | |||
| from mindinsight.mindconverter.common.exceptions import ScriptNotSupport | |||
| from mindinsight.mindconverter.common.log import logger | |||
| from mindinsight.mindconverter.ast_edits import AstEditVisitor | |||
| class Converter: | |||
| """Convert class""" | |||
| convert_info = '' | |||
| flags = os.O_WRONLY | os.O_CREAT | os.O_EXCL | |||
| modes = stat.S_IWUSR | stat.S_IRUSR | |||
| @staticmethod | |||
| def is_local_defined(obj, member): | |||
| """ | |||
| Check if obj and member are both defined in the same source file. | |||
| Args: | |||
| obj (Union[object, module]): A module or a class. | |||
| member (func): A function of obj. | |||
| Returns: | |||
| bool, True or False. | |||
| """ | |||
| srcfile = inspect.getsourcefile(obj) | |||
| return inspect.getsourcefile(member) == srcfile | |||
| def __init__(self): | |||
| self._tree = None | |||
| self._infile = None | |||
| self._code_analyzer = None | |||
| self._ast_editor = None | |||
| self._report = [] | |||
| @classmethod | |||
| def is_valid_module(cls, obj, member): | |||
| """ | |||
| Check if obj and member defined in same source file and member is inherited from torch.nn.Module. | |||
| Args: | |||
| obj (Union[object, module]): A module or a class. | |||
| member (func): A function. | |||
| Returns: | |||
| bool, True or False. | |||
| """ | |||
| if inspect.isclass(member): | |||
| is_subclass = member.__base__.__name__ in ['Module', | |||
| 'Sequential', | |||
| 'ModuleList', | |||
| 'ModuleDict', | |||
| 'ParameterList', | |||
| 'ParameterDict'] | |||
| return is_subclass and cls.is_local_defined(obj, member) | |||
| return False | |||
| @classmethod | |||
| def is_valid_function(cls, obj, member): | |||
| """ | |||
| Check if member is function and defined in the file same as obj. | |||
| Args: | |||
| obj (Union[object, module]: The obj. | |||
| member (func): The func. | |||
| Returns: | |||
| bool, True or False. | |||
| """ | |||
| return inspect.isfunction(member) and cls.is_local_defined(obj, member) | |||
| @staticmethod | |||
| def find_left_parentheses(string, right): | |||
| """ | |||
| Find index of the first left parenthesis. | |||
| Args: | |||
| string (str): A line of code. | |||
| right (int): The right index for string to find from. | |||
| Returns: | |||
| int, index of the first parenthesis. | |||
| Raises: | |||
| ValueError: If line of code doesn't contain any pair of `()` or `(` and `)` are not paired. | |||
| """ | |||
| if string[right] != ')': | |||
| raise ValueError('code [{}] at index {} not ")".'.format(string, right)) | |||
| stack = [] | |||
| for i in range(right, -1, -1): | |||
| if string[i] == ')': | |||
| stack.append(')') | |||
| elif string[i] == '(': | |||
| stack.pop() | |||
| if not stack: | |||
| return i | |||
| raise ValueError("{} should contain ()".format(string)) | |||
| @staticmethod | |||
| def find_right_parentheses(string, left): | |||
| def convert(self, infile, output_dir, report_dir): | |||
| """ | |||
| Find first index of right parenthesis which make all left parenthesis make sense. | |||
| Convert a module's code, code converted will be save in output_dir, and a report will be save in report_dir. | |||
| Args: | |||
| string (str): A line of code. | |||
| left (int): Start index of string to find from. | |||
| Returns: | |||
| int, index of the found right parenthesis. | |||
| Raises: | |||
| ValueError: If line of code doesn't contain any pair of `()` or `(` and `)` are not paired. | |||
| infile (str): The script to convert. | |||
| output_dir (str): The path to save converted file. | |||
| report_dir (str): The path to save report file. | |||
| """ | |||
| stack = [] | |||
| for i in range(left, len(string)): | |||
| if string[i] == '(': | |||
| stack.append('(') | |||
| elif string[i] == ')': | |||
| stack.pop() | |||
| if not stack: | |||
| return i | |||
| raise ValueError("{} should contain ()".format(string)) | |||
| in_file_split = _path_split(infile) | |||
| in_file_split[-1], _ = _get_name_ext(in_file_split[-1]) | |||
| module_name = '.'.join(in_file_split) | |||
| with open(infile, 'r') as file: | |||
| content = ''.join(file.readlines()) | |||
| self._infile = infile | |||
| self._tree = pasta.parse(content) | |||
| self._report.clear() | |||
| try: | |||
| logger.info("Script file is %s", infile) | |||
| logger.info("Start converting %s", module_name) | |||
| self._report.append('[Start Convert]') | |||
| self._ast_editor = AstEditVisitor() | |||
| self._ast_editor.process(self._tree) | |||
| self._report.extend(self._ast_editor.get_logs()) | |||
| self._report.append('[Convert Over]') | |||
| dest_file = os.path.join(output_dir, os.path.basename(infile)) | |||
| with os.fdopen(os.open(dest_file, self.flags, self.modes), 'w') as file: | |||
| file.write(pasta.dump(self._tree)) | |||
| logger.info("Convert success. Result is wrote to %s.", dest_file) | |||
| except ScriptNotSupport as error: | |||
| self._report.append('[ScriptNotSupport] ' + error.message) | |||
| self._report.append('[Convert failed]') | |||
| raise error | |||
| except Exception as error: | |||
| self._report.clear() | |||
| raise error | |||
| finally: | |||
| if self._report: | |||
| dest_report_file = os.path.join(report_dir, | |||
| '_'.join(os.path.basename(infile).split('.')[:-1]) + '_report.txt') | |||
| with os.fdopen(os.open(dest_report_file, self.flags, self.modes), 'a') as file: | |||
| file.write('\n'.join(self._report)) | |||
| logger.info("Convert report is saved in %s", dest_report_file) | |||
| @staticmethod | |||
| def get_call_name(code, end): | |||
| def convert_api(source_code): | |||
| """ | |||
| Traverse code in a reversed function from index end and get the call name and start index of the call name, | |||
| if call name not found, return a null character string and -1 | |||
| Convert api_name in code to MindSpore api, if api_name is a python api, code will not convert. | |||
| Args: | |||
| code (str): The str of code to find from. | |||
| end (int): Start index to find. | |||
| Returns: | |||
| tuple(str, int), one is founded api name if found, else a null character string, the other is start index | |||
| of founded api name, -1 if api name not found | |||
| """ | |||
| stack = [] | |||
| for i in range(end - 1, -1, -1): | |||
| if code[i] in ["(", "[", "{"]: | |||
| if stack: | |||
| stack.pop() | |||
| else: | |||
| return code[i + 1:end], i + 1 | |||
| elif code[i] in [")", "]", "}"]: | |||
| stack.append(code[i]) | |||
| elif stack: | |||
| continue | |||
| elif not (code[i].isalpha() or code[i].isdigit() or code[i] == '_' or code[i] == '.'): | |||
| return code[i + 1:end], i + 1 | |||
| return "", -1 | |||
| def convert_api(self, code, start, api_name=""): | |||
| """ | |||
| Convert api_name in code to MindSpore api with start as a start index, if api_name is a python api, | |||
| code will not convert. | |||
| Args: | |||
| code (str): The str code to convert. | |||
| start (int): The index of code to start convert from. | |||
| api_name (str): The api name to convert. | |||
| source_code (ast.Call): The ast node to convert. | |||
| Returns: | |||
| str, the converted code. | |||
| int, index of converted api_name in code. | |||
| """ | |||
| # handle format like .shape( | |||
| if api_name.startswith('.'): | |||
| call_name, new_start = self.get_call_name(code, start) | |||
| if start == -1 or call_name == "self": | |||
| return code, start + 1 | |||
| else: | |||
| call_name = api_name | |||
| new_start = start | |||
| # find full api expected to be converted. eg:expr="nn.Conv2d(1,2,3)" args_str="(1,2,3)" | |||
| left = code.find("(", start) | |||
| if left == -1: | |||
| raise ValueError('"(" not found, {} should work with "("'.format(call_name)) | |||
| right = self.find_right_parentheses(code, left) | |||
| end = right | |||
| expr = code[start:end + 1] | |||
| args_str = code[left:right + 1] | |||
| map_helper = ALL_MAPPING[api_name] | |||
| new_expr = map_helper.convert(call_name, args_str) | |||
| next_newline = code.find("\n", end + 1) | |||
| fill_num = (expr.count("\n") - new_expr.count("\n")) | |||
| if next_newline != -1: | |||
| code = code[:new_start] + new_expr + code[end + 1:next_newline] + ("\n" * fill_num) + code[next_newline:] | |||
| else: | |||
| code = code[:new_start] + new_expr + ")" + ("\n" * fill_num) + code[end + 2:] | |||
| return code, start + len(map_helper.ms_api.name) | |||
| @staticmethod | |||
| def find_api(code, i, is_forward): | |||
| """ | |||
| Find api name from code with a start index i, check api name ok with a is_forward condition. | |||
| Args: | |||
| code (str): The code from which to find api name. | |||
| i (int): The start index to find. | |||
| is_forward (bool): Check if the found api name ok. | |||
| Returns: | |||
| str, api name if find api name and check ok with is_forward condition, else a null character string. | |||
| """ | |||
| if code[i:].startswith("nn.") \ | |||
| or code[i:].startswith("F.") \ | |||
| or code[i:].startswith("torch.") \ | |||
| or code[i:].startswith('.'): | |||
| j = code.find('(', i) | |||
| if j != -1 and code[i:j] in ALL_TORCH_APIS: | |||
| api_name = code[i:j] | |||
| if (not is_forward and api_name in NN_LIST) or (is_forward and api_name in ALL_2P_LIST): | |||
| return api_name | |||
| return "" | |||
| def convert_function(self, fun_name, fun, is_forward): | |||
| """ | |||
| Convert a PyTorch function into MindSpore function. | |||
| Args: | |||
| fun_name (str): The str of function name. | |||
| fun (func): The function to convert. | |||
| is_forward (bool): If the function is defined in forward function in nn.Module in torch. | |||
| Returns: | |||
| dict, old code and converted code map if convert happens, else {}. | |||
| """ | |||
| _, line_no = inspect.getsourcelines(fun) | |||
| logger.info("Line %3d: start converting function %s()", line_no, fun_name) | |||
| code = inspect.getsource(fun) | |||
| code_saved = copy.copy(code) | |||
| i = 0 | |||
| while i < len(code): | |||
| api_name = self.find_api(code, i, is_forward) | |||
| if api_name: | |||
| line_no1 = line_no + code[:i].count('\n') | |||
| if api_name in ALL_MAPPING: | |||
| logger.info("Line %3d start converting API: %s", line_no1, api_name) | |||
| code, i = self.convert_api(code, i, api_name) | |||
| self.convert_info += "[Convert][Line{:3d}] {} is converted.\n".format(line_no1, api_name) | |||
| continue | |||
| if api_name in ALL_UNSUPPORTED: | |||
| warn_info = ". " + UNSUPPORTED_WARN_INFOS[api_name] if api_name in UNSUPPORTED_WARN_INFOS else "" | |||
| logger.warning("Line %3d: found unsupported API: %s%s", line_no1, api_name, warn_info) | |||
| self.convert_info += "[Unconvert][Line{:3d}] {} didn't convert{}\n".format(line_no1, | |||
| api_name, warn_info) | |||
| i += 1 | |||
| return {code_saved: code} if code_saved != code else {} | |||
| @staticmethod | |||
| def judge_forward(name, forward_list): | |||
| """ | |||
| Check if function is a forward function. | |||
| Args: | |||
| name (str): The function name. | |||
| forward_list (set): A set of forward function. | |||
| Returns: | |||
| bool, True or False | |||
| """ | |||
| is_forward = name in forward_list or name.split(".")[-1] == "forward" | |||
| if is_forward: | |||
| logger.debug("%s is a forward function", name) | |||
| return is_forward | |||
| def convert_module(self, module_name, module, forward_list): | |||
| """ | |||
| Convert a PyTorch module code into MindSpore module code. | |||
| Args: | |||
| module_name (str): The module's name. | |||
| module (module): The module to convert. | |||
| forward_list (set): A set of forward function. | |||
| Returns: | |||
| dict, map of old code and converted code. | |||
| """ | |||
| _, line_no = inspect.getsourcelines(module) | |||
| logger.info("Line {:3d}: start converting nn.Module {}".format(line_no, module_name)) | |||
| mapped = {} | |||
| for name, member in inspect.getmembers(module): | |||
| if self.is_valid_function(module, member): | |||
| is_forward = self.judge_forward("{}.{}".format(module_name, name), forward_list) | |||
| mapped.update(self.convert_function(name, member, is_forward)) | |||
| return mapped | |||
| def get_mapping(self, import_mod, forward_list): | |||
| """ | |||
| Convert code of a module and get mapping of old code and convert code. | |||
| Args: | |||
| import_mod (module): The module to convert. | |||
| forward_list (set): A set of forward function. | |||
| Returns: | |||
| dict, mapping for old code and converted code of the module | |||
| """ | |||
| mapping = {} | |||
| tasks = [] | |||
| for name, member in inspect.getmembers(import_mod): | |||
| if self.is_valid_module(import_mod, member): | |||
| _, line_no = inspect.getsourcelines(member) | |||
| tasks.append((line_no, self.convert_module, (name, member, forward_list))) | |||
| elif self.is_valid_function(import_mod, member): | |||
| _, line_no = inspect.getsourcelines(member) | |||
| is_forward = self.judge_forward("{}.{}".format(import_mod, name), forward_list) | |||
| tasks.append((line_no, self.convert_function, (name, member, is_forward))) | |||
| tasks.sort() | |||
| for _, convert_fun, args in tasks: | |||
| mapping.update(convert_fun(*args)) | |||
| return mapping | |||
| @staticmethod | |||
| def get_code_start_line_num(source_lines): | |||
| """ | |||
| Get the start code line number exclude comments. | |||
| Args: | |||
| source_lines (list[str]): Split results of original code. | |||
| Returns: | |||
| int, the start line number. | |||
| """ | |||
| stack = [] | |||
| index = 0 | |||
| for i, line in enumerate(source_lines): | |||
| if line.strip().startswith('#'): | |||
| continue | |||
| if line.strip().startswith('"""'): | |||
| if not line.endswith('"""\n'): | |||
| stack.append('"""') | |||
| continue | |||
| if line.strip().startswith("'''"): | |||
| if not line.endswith("'''\n"): | |||
| stack.append("'''") | |||
| continue | |||
| if line.endswith('"""\n') or line.endswith("'''\n"): | |||
| stack.pop() | |||
| continue | |||
| if line.strip() != '' and not stack: | |||
| index = i | |||
| break | |||
| return index | |||
| def update_code_and_convert_info(self, code, mapping): | |||
| """ | |||
| Replace code according to mapping, and update convert info. | |||
| Args: | |||
| code (str): The code to replace. | |||
| mapping (dict): Mapping for original code and the replaced code. | |||
| Returns: | |||
| str, the replaced code. | |||
| """ | |||
| for key, value in mapping.items(): | |||
| code = code.replace(key, value) | |||
| source_lines = code.splitlines(keepends=True) | |||
| start_line_number = self.get_code_start_line_num(source_lines) | |||
| add_import_infos = ['import mindspore\n', | |||
| 'import mindspore.nn as nn\n', | |||
| 'import mindspore.ops.operations as P\n'] | |||
| for i, add_import_info in enumerate(add_import_infos): | |||
| source_lines.insert(start_line_number + i, add_import_info) | |||
| self.convert_info += '[Add Import] {}.\n'.format(add_import_info.strip()) | |||
| insert_count = len(add_import_infos) | |||
| line_diff = insert_count - LINE_NO_INDEX_DIFF | |||
| for i in range(start_line_number + insert_count, len(source_lines)): | |||
| line = source_lines[i] | |||
| if (line.startswith('from torch') and 'import' in line) or line.startswith('import torch'): | |||
| new_line = '# ' + line | |||
| source_lines[i] = new_line | |||
| self.convert_info += '[Annotate][Line{:3d}] {} is annotated.\n'.format(i - line_diff, line.strip()) | |||
| if line.strip().startswith('class') and '(nn.Module)' in line: | |||
| new_line = line.replace('nn.Module', 'nn.Cell') | |||
| source_lines[i] = new_line | |||
| self.convert_info += '[Convert][Line{:3d}] nn.Module is converted.\n'.format(i - line_diff) | |||
| if line.strip().startswith('def forward('): | |||
| new_line = line.replace('forward', 'construct') | |||
| source_lines[i] = new_line | |||
| self.convert_info += '[Convert][Line{:3d}] forward is converted.\n'.format(i - line_diff) | |||
| if 'nn.Linear' in line: | |||
| new_line = line.replace('nn.Linear', 'nn.Dense') | |||
| source_lines[i] = new_line | |||
| self.convert_info += '[Convert][Line{:3d}] nn.Linear is converted.\n'.format(i - line_diff) | |||
| if '(nn.Sequential)' in line: | |||
| new_line = line.replace('nn.Sequential', 'nn.SequentialCell') | |||
| source_lines[i] = new_line | |||
| self.convert_info += '[Convert][Line{:3d}] nn.Sequential is converted.\n'.format(i - line_diff) | |||
| if 'nn.init.' in line: | |||
| new_line = line.replace('nn.init', 'pass # nn.init') | |||
| source_lines[i] = new_line | |||
| self.convert_info += '[Annotate][Line{:3d}] {} is annotated.\n'.format(i - line_diff, 'nn.init') | |||
| code = ''.join(source_lines) | |||
| return code | |||
| def convert(self, import_name, output_dir, report_dir): | |||
| """ | |||
| Convert a module's code, code converted will be save in output_dir, and a report will be save in report_dir. | |||
| Args: | |||
| import_name (str): The module from which to import the module to convert. | |||
| output_dir (str): The path to save converted file. | |||
| report_dir (str): The path to save report file. | |||
| """ | |||
| logger.info("Start converting %s", import_name) | |||
| start_info = '[Start Convert]\n' | |||
| module_info = 'The module is {}.\n'.format(import_name) | |||
| import_mod = importlib.import_module(import_name) | |||
| srcfile = inspect.getsourcefile(import_mod) | |||
| logger.info("Script file is %s", srcfile) | |||
| forward_list = set(ForwardCall(srcfile).calls) | |||
| logger.debug("Forward_list: %s", forward_list) | |||
| # replace python function under nn.Module | |||
| mapping = self.get_mapping(import_mod, forward_list) | |||
| code = inspect.getsource(import_mod) | |||
| code = self.update_code_and_convert_info(code, mapping) | |||
| convert_info_split = self.convert_info.splitlines(keepends=True) | |||
| convert_info_split = sorted(convert_info_split) | |||
| convert_info_split.insert(0, start_info) | |||
| convert_info_split.insert(1, module_info) | |||
| convert_info_split.append('[Convert Over]') | |||
| self.convert_info = ''.join(convert_info_split) | |||
| dest_file = os.path.join(output_dir, os.path.basename(srcfile)) | |||
| with os.fdopen(os.open(dest_file, self.flags, self.modes), 'w') as file: | |||
| file.write(code) | |||
| logger.info("Convert success. Result is wrote to %s.", dest_file) | |||
| dest_report_file = os.path.join(report_dir, | |||
| '_'.join(os.path.basename(srcfile).split('.')[:-1]) + '_report.txt') | |||
| with os.fdopen(os.open(dest_report_file, self.flags, self.modes), 'a') as file: | |||
| file.write(self.convert_info) | |||
| logger.info("Convert report is saved in %s", dest_report_file) | |||
| ast_node = pasta.parse(source_code).body[0].value | |||
| check_context = False | |||
| replaced_code = AstEditVisitor().mapping_api(ast_node, check_context) | |||
| return replaced_code | |||
| def _get_name_ext(file): | |||
| @@ -514,14 +139,6 @@ def main(files_config): | |||
| files_config (dict): The config of files which to convert. | |||
| """ | |||
| convert_ins = Converter() | |||
| root_path = files_config['root_path'] | |||
| in_files = files_config['in_files'] | |||
| for in_file in in_files: | |||
| in_file_split = _path_split(in_file[len(root_path):]) | |||
| in_file_split[-1], _ = _get_name_ext(in_file_split[-1]) | |||
| module_name = '.'.join(in_file_split) | |||
| convert_ins.convert(module_name, files_config['outfile_dir'], files_config['report_dir']) | |||
| in_module = files_config.get('in_module') | |||
| if in_module: | |||
| convert_ins.convert(in_module, files_config['outfile_dir'], files_config['report_dir']) | |||
| convert_ins.convert(in_file, files_config['outfile_dir'], files_config['report_dir']) | |||
| @@ -14,7 +14,8 @@ | |||
| # ============================================================================ | |||
| """Find out forward functions of script file""" | |||
| import ast | |||
| import os | |||
| import pasta | |||
| class ForwardCall(ast.NodeVisitor): | |||
| @@ -24,73 +25,80 @@ class ForwardCall(ast.NodeVisitor): | |||
| Find the sub functions called by the forward function in the script file. | |||
| """ | |||
| def __init__(self, filename): | |||
| self.filename = filename | |||
| self.module_name = os.path.basename(filename).replace('.py', '') | |||
| self.name_stack = [] | |||
| self.forward_stack = [] | |||
| self.calls = set() | |||
| def __init__(self, ast_tree): | |||
| self._tree = ast_tree | |||
| self._name_stack = [] | |||
| self._forward_stack = [] | |||
| self.calls = {} # key is function name, value is forward function ast node. | |||
| self._function_list = {} # key is function name, value is function ast node. | |||
| self.process() | |||
| def process(self): | |||
| """Parse the python source file to find the forward functions.""" | |||
| with open(self.filename, 'rt', encoding='utf-8') as file: | |||
| content = file.read() | |||
| self.visit(ast.parse(content, self.filename)) | |||
| """visit ast tree to find the forward functions.""" | |||
| self.visit(self._tree) | |||
| # first visit to find out all functions, so restores all variables except _function_list | |||
| self._name_stack.clear() | |||
| self._forward_stack.clear() | |||
| self.calls.clear() | |||
| self.visit(self._tree) | |||
| def get_current_namespace(self): | |||
| """Get the namespace when visit the AST node""" | |||
| namespace = '.'.join(self.name_stack) | |||
| namespace = '.'.join(self._name_stack) | |||
| return namespace | |||
| @classmethod | |||
| def get_ast_node_name(cls, node): | |||
| """Get AST node name.""" | |||
| if isinstance(node, ast.Attribute): | |||
| return f'{cls.get_ast_node_name(node.value)}.{node.attr}' | |||
| if isinstance(node, ast.Name): | |||
| return node.id | |||
| def get_call_name(cls, node): | |||
| """Get functional call name.""" | |||
| if not isinstance(node, ast.Call): | |||
| return None | |||
| return node | |||
| return pasta.dump(node.func) | |||
| def visit_ClassDef(self, node): | |||
| """Callback function when visit AST tree""" | |||
| self.name_stack.append(node.name) | |||
| self._name_stack.append(node.name) | |||
| self.generic_visit(node) | |||
| self.name_stack.pop() | |||
| self._name_stack.pop() | |||
| def visit_FunctionDef(self, node): | |||
| """Callback function when visit AST tree""" | |||
| namespace = self.get_current_namespace() | |||
| if namespace: | |||
| func_name = f'{namespace}.{node.name}' | |||
| else: | |||
| func_name = node.name | |||
| func_name = f'{self.get_current_namespace()}.{node.name}' | |||
| is_in_chain = func_name in self.calls or node.name == 'forward' | |||
| if is_in_chain: | |||
| self.forward_stack.append(func_name) | |||
| self._forward_stack.append(func_name) | |||
| if node.name == 'forward': | |||
| self.calls.add(func_name) | |||
| self.calls.update({func_name: node}) | |||
| self._function_list.update({func_name: node}) | |||
| self.generic_visit(node) | |||
| if is_in_chain: | |||
| self.forward_stack.pop() | |||
| self._forward_stack.pop() | |||
| def visit_Call(self, node): | |||
| """Callback function when visit AST tree""" | |||
| for arg in node.args: | |||
| self.visit(arg) | |||
| for kw in node.keywords: | |||
| self.visit(kw.value) | |||
| func_name = self.get_ast_node_name(node.func) | |||
| for keyword in node.keywords: | |||
| self.visit(keyword.value) | |||
| func_name = self.get_call_name(node) | |||
| if isinstance(node.func, ast.Name): | |||
| if func_name not in ['super', 'str', 'repr']: | |||
| if self.forward_stack: | |||
| self.calls.add(func_name) | |||
| if self._forward_stack: | |||
| self.calls.update({func_name: self._function_list.get(func_name)}) | |||
| self.visit(node.func) | |||
| else: | |||
| if self.forward_stack: | |||
| if 'self' in func_name: | |||
| self.calls.add(f'{self.get_current_namespace()}.{func_name.split(".")[-1]}') | |||
| if self._forward_stack: | |||
| if func_name.startswith('self.'): | |||
| whole_name = f'{self.get_current_namespace()}.{func_name.split(".")[-1]}' | |||
| self.calls.update({whole_name: self._function_list.get(whole_name)}) | |||
| else: | |||
| self.calls.add(func_name) | |||
| self.calls.update({func_name: self._function_list.get(func_name)}) | |||
| self.visit(node.func) | |||
| @@ -30,6 +30,7 @@ class MindInsightModules(Enum): | |||
| LINEAGEMGR = 2 | |||
| DATAVISUAL = 5 | |||
| PROFILERMGR = 6 | |||
| SCRIPTCONVERTER = 7 | |||
| class GeneralErrors(Enum): | |||
| @@ -69,3 +70,7 @@ class DataVisualErrors(Enum): | |||
| SCALAR_NOT_EXIST = 14 | |||
| HISTOGRAM_NOT_EXIST = 15 | |||
| TRAIN_JOB_DETAIL_NOT_IN_CACHE = 16 | |||
| class ScriptConverterErrors(Enum): | |||
| """Enum definition for mindconverter errors.""" | |||
| @@ -22,380 +22,201 @@ class TestConverter: | |||
| converter_ins = Converter() | |||
| def test_judge_forward(self): | |||
| """test judge_forward""" | |||
| name1 = 'conv1' | |||
| forward_list = {'conv1', 'relu'} | |||
| result1 = self.converter_ins.judge_forward(name1, forward_list) | |||
| assert result1 is True | |||
| name2 = 'self.forward' | |||
| result2 = self.converter_ins.judge_forward(name2, forward_list) | |||
| assert result2 is True | |||
| def test_find_left_parentheses(self): | |||
| """test find_left_parentheses""" | |||
| code = '''nn.Sequential(nn.Conv2d(in_dim, 6, 5, stride=1, padding=0, ), | |||
| nn.ReLU(), | |||
| nn.ReLU(True), | |||
| nn.MaxPool2d(2, 2), | |||
| nn.Conv2d(6, 16, 5, stride=1, padding=0), | |||
| nn.ReLU(inplace=False), | |||
| nn.MaxPool2d(2, 2))''' | |||
| right_index = len(code) - 1 | |||
| left_index = code.index('nn.Conv2d') | |||
| result = self.converter_ins.find_left_parentheses(code, right_index) | |||
| assert result == left_index - 1 | |||
| def test_find_api(self): | |||
| """test find_api""" | |||
| code = '''nn.Sequential(nn.Conv2d(in_dim, 6, 5, stride=1, padding=0, ), | |||
| nn.ReLU(), | |||
| nn.ReLU(True), | |||
| nn.MaxPool2d(2, 2), # TODO padding | |||
| nn.Conv2d(6, 16, 5, stride=1, padding=0), | |||
| nn.ReLU(inplace=False), | |||
| nn.MaxPool2d(2, 2))''' | |||
| index = 0 | |||
| is_forward = False | |||
| result = self.converter_ins.find_api(code, index, is_forward) | |||
| assert result == 'nn.Sequential' | |||
| def test_get_call_name(self): | |||
| """test get_call_name""" | |||
| code = '''nn.Sequential(nn.Conv2d(in_dim, 6, 5, stride=1, padding=0))''' | |||
| end = len(code) | |||
| call_name, index = self.converter_ins.get_call_name(code, end) | |||
| assert call_name == '' | |||
| assert index == -1 | |||
| def test_find_right_parentheses(self): | |||
| """test find_right_parentheses""" | |||
| code = '''nn.Sequential(nn.Conv2d(in_dim, 6, 5, stride=1, padding=0, ), | |||
| nn.ReLU(), | |||
| nn.ReLU(True), | |||
| nn.MaxPool2d(2, 2), # TODO padding | |||
| nn.Conv2d(6, 16, 5, stride=1, padding=0), | |||
| nn.ReLU(inplace=False), | |||
| nn.MaxPool2d(2, 2))''' | |||
| left_index = 0 | |||
| result = self.converter_ins.find_right_parentheses(code, left_index) | |||
| assert_index = len(code) - 1 | |||
| assert result == assert_index | |||
| # test convert_api with nn ops | |||
| def test_convert_api_nn_layernorm(self): | |||
| """Test convert_api function work ok when convert api nn.LayerNorm""" | |||
| code = """ | |||
| def __init__(self, num_classes=1000): | |||
| self.features = nn.SequentialCell([ | |||
| nn.LayerNorm((5, 10, 10), elementwise_affine=False), | |||
| nn.ReLU(inplace=False) | |||
| ]) | |||
| """ | |||
| code = "nn.LayerNorm((5, 10, 10), elementwise_affine=False)" | |||
| api_name = 'nn.LayerNorm' | |||
| start = code.find(api_name) | |||
| layer_norm_info = NN_MAPPING.get(api_name) | |||
| expected_ms_api_name = 'nn.LayerNorm' | |||
| epsilon = layer_norm_info.pt_api.params.get('eps') | |||
| replaced_code, new_start = self.converter_ins.convert_api(code, start, api_name) | |||
| replaced_code = self.converter_ins.convert_api(code) | |||
| assert replaced_code == code.replace('nn.LayerNorm((5, 10, 10), elementwise_affine=False)', | |||
| '{}(normalized_shape=(5, 10, 10), epsilon={})'.format( | |||
| expected_ms_api_name, epsilon)) | |||
| assert new_start == start + len(expected_ms_api_name) | |||
| def test_convert_api_nn_leaky_relu(self): | |||
| """Test convert_api function work ok when convert api nn.LeakyReLU""" | |||
| code = """ | |||
| def __init__(self, num_classes=1000): | |||
| self.features = nn.SequentialCell([ | |||
| nn.LayerNorm((5, 10, 10), elementwise_affine=False), | |||
| nn.LeakyReLU(0.3)]) | |||
| """ | |||
| api_name = 'nn.LeakyReLU' | |||
| start = code.find(api_name) | |||
| code = "nn.LeakyReLU(0.3)" | |||
| expected_ms_api_name = 'nn.LeakyReLU' | |||
| replaced_code, new_start = self.converter_ins.convert_api(code, start, api_name) | |||
| replaced_code = self.converter_ins.convert_api(code) | |||
| assert replaced_code == code.replace('nn.LeakyReLU(0.3)', | |||
| '{}(alpha=0.3)'.format(expected_ms_api_name)) | |||
| assert new_start == start + len(expected_ms_api_name) | |||
| def test_convert_api_nn_prelu(self): | |||
| """Test convert_api function work ok when convert api nn.PReLU""" | |||
| code = """ | |||
| input = torch.randn(2, 3, 5) | |||
| nn.PReLU()(input) | |||
| """ | |||
| api_name = 'nn.PReLU' | |||
| start = code.find(api_name) | |||
| code = "nn.PReLU()(input)" | |||
| expected_ms_api_name = 'nn.PReLU' | |||
| replaced_code, new_start = self.converter_ins.convert_api(code, start, api_name) | |||
| replaced_code = self.converter_ins.convert_api(code) | |||
| assert replaced_code == code.replace('nn.PReLU()(input)', | |||
| '{}()(input)'.format(expected_ms_api_name)) | |||
| assert new_start == start + len(expected_ms_api_name) | |||
| def test_convert_api_nn_softmax(self): | |||
| """Test convert_api function work ok when convert api nn.Softmax""" | |||
| code = """ | |||
| nn.Softmax(dim=1)(input) | |||
| """ | |||
| api_name = 'nn.Softmax' | |||
| code = "nn.Softmax(dim=1)" | |||
| expected_ms_api_name = 'nn.Softmax' | |||
| start = code.find(api_name) | |||
| replaced_code, new_start = self.converter_ins.convert_api(code, start, api_name) | |||
| assert replaced_code == code.replace('nn.Softmax(dim=1)(input)', | |||
| '{}(axis=1)(input)'.format(expected_ms_api_name)) | |||
| assert new_start == start + len(expected_ms_api_name) | |||
| replaced_code = self.converter_ins.convert_api(code) | |||
| assert replaced_code == code.replace('nn.Softmax(dim=1)', | |||
| '{}(axis=1)'.format(expected_ms_api_name)) | |||
| # test convert_api with torch dot ops | |||
| def test_convert_api_torch_dot_abs(self): | |||
| """Test convert_api function work ok when convert api torch.abs""" | |||
| code = """ | |||
| torch.abs(input) | |||
| """ | |||
| api_name = 'torch.abs' | |||
| start = code.find(api_name) | |||
| code = "torch.abs(input)" | |||
| expected_ms_api_name = 'P.Abs' | |||
| replaced_code, new_start = self.converter_ins.convert_api(code, start, api_name) | |||
| replaced_code = self.converter_ins.convert_api(code) | |||
| assert replaced_code == code.replace('torch.abs(input)', | |||
| '{}()(input)'.format(expected_ms_api_name)) | |||
| assert new_start == start + len(expected_ms_api_name) | |||
| def test_convert_api_torch_dot_acos(self): | |||
| """Test convert_api function work ok when convert api torch.acos""" | |||
| code = """ | |||
| torch.acos(input) | |||
| """ | |||
| api_name = 'torch.acos' | |||
| start = code.find(api_name) | |||
| code = "torch.acos(input)" | |||
| expected_ms_api_name = 'P.ACos' | |||
| replaced_code, new_start = self.converter_ins.convert_api(code, start, api_name) | |||
| replaced_code = self.converter_ins.convert_api(code) | |||
| assert replaced_code == code.replace('torch.acos(input)', | |||
| '{}()(input)'.format(expected_ms_api_name)) | |||
| assert new_start == start + len(expected_ms_api_name) | |||
| def test_convert_api_torch_dot_cos(self): | |||
| """Test convert_api function work ok when convert api torch.cos""" | |||
| code = """ | |||
| torch.cos(input) | |||
| """ | |||
| api_name = 'torch.cos' | |||
| code = "torch.cos(input)" | |||
| expected_ms_api_name = 'P.Cos' | |||
| start = code.find(api_name) | |||
| replaced_code, new_start = self.converter_ins.convert_api(code, start, api_name) | |||
| replaced_code = self.converter_ins.convert_api(code) | |||
| assert replaced_code == code.replace('torch.cos(input)', | |||
| '{}()(input)'.format(expected_ms_api_name)) | |||
| assert new_start == start + len(expected_ms_api_name) | |||
| def test_convert_api_torch_dot_exp(self): | |||
| """Test convert_api function work ok when convert api torch.exp""" | |||
| code = """ | |||
| torch.exp(input) | |||
| """ | |||
| api_name = 'torch.exp' | |||
| code = "torch.exp(input)" | |||
| expected_ms_api_name = 'P.Exp' | |||
| start = code.find(api_name) | |||
| replaced_code, new_start = self.converter_ins.convert_api(code, start, api_name) | |||
| replaced_code = self.converter_ins.convert_api(code) | |||
| assert replaced_code == code.replace('torch.exp(input)', | |||
| '{}()(input)'.format(expected_ms_api_name)) | |||
| assert new_start == start + len(expected_ms_api_name) | |||
| def test_convert_api_torch_dot_log(self): | |||
| """Test convert_api function work ok when convert api torch.log""" | |||
| code = """ | |||
| torch.log(input) | |||
| """ | |||
| api_name = 'torch.log' | |||
| code = "torch.log(input)" | |||
| expected_ms_api_name = 'P.Log' | |||
| start = code.find(api_name) | |||
| replaced_code, new_start = self.converter_ins.convert_api(code, start, api_name) | |||
| replaced_code = self.converter_ins.convert_api(code) | |||
| assert replaced_code == code.replace('torch.log(input)', | |||
| '{}()(input)'.format(expected_ms_api_name)) | |||
| assert new_start == start + len(expected_ms_api_name) | |||
| def test_convert_api_torch_dot_pow(self): | |||
| """Test convert_api function work ok when convert api torch.pow""" | |||
| code = """ | |||
| torch.pow(a, exp) | |||
| """ | |||
| api_name = 'torch.pow' | |||
| code = "torch.pow(a, exp)" | |||
| expected_ms_api_name = 'P.Pow' | |||
| start = code.find(api_name) | |||
| replaced_code, new_start = self.converter_ins.convert_api(code, start, api_name) | |||
| replaced_code = self.converter_ins.convert_api(code) | |||
| assert replaced_code == code.replace('torch.pow(a, exp)', | |||
| '{}()(a, exp)'.format(expected_ms_api_name)) | |||
| assert new_start == start + len(expected_ms_api_name) | |||
| def test_convert_api_torch_dot_div(self): | |||
| """Test convert_api function work ok when convert api torch.div""" | |||
| code = """ | |||
| input = torch.randn(5) | |||
| other = torch.randn(5) | |||
| torch.div(input, other) | |||
| """ | |||
| api_name = 'torch.div' | |||
| code = "torch.div(input, other)" | |||
| expected_ms_api_name = 'P.Div' | |||
| start = code.find(api_name) | |||
| replaced_code, new_start = self.converter_ins.convert_api(code, start, api_name) | |||
| replaced_code = self.converter_ins.convert_api(code) | |||
| assert replaced_code == code.replace('torch.div(input, other)', | |||
| '{}()(input, other)'.format(expected_ms_api_name)) | |||
| assert new_start == start + len(expected_ms_api_name) | |||
| def test_convert_api_torch_dot_sin(self): | |||
| """Test convert_api function work ok when convert api torch.sin""" | |||
| code = """ | |||
| torch.sin(input) | |||
| """ | |||
| api_name = 'torch.sin' | |||
| code = "torch.sin(input)" | |||
| expected_ms_api_name = 'P.Sin' | |||
| start = code.find(api_name) | |||
| replaced_code, new_start = self.converter_ins.convert_api(code, start, api_name) | |||
| replaced_code = self.converter_ins.convert_api(code) | |||
| assert replaced_code == code.replace('torch.sin(input)', | |||
| '{}()(input)'.format(expected_ms_api_name)) | |||
| assert new_start == start + len(expected_ms_api_name) | |||
| def test_convert_api_torch_dot_sqrt(self): | |||
| """Test convert_api function work ok when convert api torch.sqrt""" | |||
| code = """ | |||
| torch.sqrt(input) | |||
| """ | |||
| api_name = 'torch.sqrt' | |||
| code = "torch.sqrt(input)" | |||
| expected_ms_api_name = 'P.Sqrt' | |||
| start = code.find(api_name) | |||
| replaced_code, new_start = self.converter_ins.convert_api(code, start, api_name) | |||
| replaced_code = self.converter_ins.convert_api(code) | |||
| assert replaced_code == code.replace('torch.sqrt(input)', | |||
| '{}()(input)'.format(expected_ms_api_name)) | |||
| assert new_start == start + len(expected_ms_api_name) | |||
| def test_convert_api_torch_dot_eye_with_n(self): | |||
| """Test convert_api function work ok when convert api torch.eye""" | |||
| code = """ | |||
| torch.eye(3) | |||
| """ | |||
| api_name = 'torch.eye' | |||
| code = "torch.eye(3)" | |||
| expected_ms_api_name = 'P.Eye' | |||
| start = code.find(api_name) | |||
| replaced_code, new_start = self.converter_ins.convert_api(code, start, api_name) | |||
| replaced_code = self.converter_ins.convert_api(code) | |||
| assert replaced_code == code.replace('torch.eye(3)', | |||
| '{}()(3, 3, mindspore.int32)'.format(expected_ms_api_name)) | |||
| assert new_start == start + len(expected_ms_api_name) | |||
| def test_convert_api_torch_dot_eye_with_m(self): | |||
| """Test convert_api function work ok when convert api torch.eye""" | |||
| code = """ | |||
| torch.eye(3, 4) | |||
| """ | |||
| api_name = 'torch.eye' | |||
| code = "torch.eye(3, 4)" | |||
| expected_ms_api_name = 'P.Eye' | |||
| start = code.find(api_name) | |||
| replaced_code, new_start = self.converter_ins.convert_api(code, start, api_name) | |||
| replaced_code = self.converter_ins.convert_api(code) | |||
| assert replaced_code == code.replace('torch.eye(3, 4)', | |||
| '{}()(3, 4, mindspore.int32)'.format(expected_ms_api_name)) | |||
| assert new_start == start + len(expected_ms_api_name) | |||
| def test_convert_api_torch_dot_add_with_alpha_default(self): | |||
| """Test convert_api function work ok when convert api torch.add""" | |||
| code = """ | |||
| torch.add(input, value) | |||
| """ | |||
| api_name = 'torch.add' | |||
| code = "torch.add(input, value)" | |||
| expected_ms_api_name = 'P.TensorAdd' | |||
| start = code.find(api_name) | |||
| replaced_code, new_start = self.converter_ins.convert_api(code, start, api_name) | |||
| replaced_code = self.converter_ins.convert_api(code) | |||
| assert replaced_code == code.replace('torch.add(input, value)', | |||
| '{}()(input, value)'.format(expected_ms_api_name)) | |||
| assert new_start == start + len(expected_ms_api_name) | |||
| def test_convert_api_torch_dot_add_with_alpha_not_default(self): | |||
| """Test convert_api function work ok when convert api torch.add""" | |||
| code = """ | |||
| torch.add(input, value, 3) | |||
| """ | |||
| api_name = 'torch.add' | |||
| code = "torch.add(input, value, 3)" | |||
| expected_ms_api_name = 'P.TensorAdd' | |||
| start = code.find(api_name) | |||
| replaced_code, new_start = self.converter_ins.convert_api(code, start, api_name) | |||
| replaced_code = self.converter_ins.convert_api(code) | |||
| assert replaced_code == code.replace('torch.add(input, value, 3)', | |||
| '{}()(input, value*3)'.format(expected_ms_api_name)) | |||
| assert new_start == start + len(expected_ms_api_name) | |||
| # test convert_api with F ops | |||
| def test_convert_api_f_normalize(self): | |||
| """Test convert_api function work ok when convert api F.normalize""" | |||
| code = """ | |||
| input = torch.randn(2, 3, 5) | |||
| F.normalize(input) | |||
| """ | |||
| api_name = 'F.normalize' | |||
| start = code.find(api_name) | |||
| code = "F.normalize(input)" | |||
| expected_ms_api_name = 'P.L2Normalize' | |||
| replaced_code, new_start = self.converter_ins.convert_api(code, start, api_name) | |||
| replaced_code = self.converter_ins.convert_api(code) | |||
| assert replaced_code == code.replace('F.normalize(input)', | |||
| '{}(1, 1e-12)(input)'.format(expected_ms_api_name)) | |||
| assert new_start == start + len(expected_ms_api_name) | |||
| def test_convert_api_f_sigmoid(self): | |||
| """Test convert_api function work ok when convert api F.sigmoid""" | |||
| code = """ | |||
| input = torch.randn(2, 3, 5) | |||
| F.sigmoid(input) | |||
| """ | |||
| api_name = 'F.sigmoid' | |||
| start = code.find(api_name) | |||
| code = "F.sigmoid(input)" | |||
| expected_ms_api_name = 'P.Sigmoid' | |||
| replaced_code, new_start = self.converter_ins.convert_api(code, start, api_name) | |||
| replaced_code = self.converter_ins.convert_api(code) | |||
| assert replaced_code == code.replace('F.sigmoid(input)', | |||
| '{}()(input)'.format(expected_ms_api_name)) | |||
| assert new_start == start + len(expected_ms_api_name) | |||
| # test convert_api with tensor dot ops | |||
| def test_convert_api_tensor_dot_repeat(self): | |||
| """Test convert_api function work ok when convert api .repeat""" | |||
| code = """ | |||
| x.repeat(4, 2) | |||
| """ | |||
| api_name = '.repeat' | |||
| start = code.find(api_name) | |||
| code = "x.repeat(4, 2)" | |||
| expected_ms_api_name = 'P.Tile' | |||
| replaced_code, new_start = self.converter_ins.convert_api(code, start, api_name) | |||
| replaced_code = self.converter_ins.convert_api(code) | |||
| assert replaced_code == code.replace('x.repeat(4, 2)', | |||
| '{}()(x, {})'.format(expected_ms_api_name, '(4, 2,)')) | |||
| assert new_start == start + len(expected_ms_api_name) | |||
| def test_convert_api_tensor_dot_permute(self): | |||
| """Test convert_api function work ok when convert api .permute""" | |||
| code = """ | |||
| x.permute(2, 0, 1) | |||
| """ | |||
| api_name = '.permute' | |||
| start = code.find(api_name) | |||
| code = "x.permute(2, 0, 1)" | |||
| expected_ms_api_name = 'P.Transpose' | |||
| replaced_code, new_start = self.converter_ins.convert_api(code, start, api_name) | |||
| replaced_code = self.converter_ins.convert_api(code) | |||
| assert replaced_code == code.replace('x.permute(2, 0, 1)', | |||
| '{}()(x, (2, 0, 1,))'.format(expected_ms_api_name)) | |||
| assert new_start == start + len(expected_ms_api_name) | |||
| @@ -15,7 +15,6 @@ | |||
| """Test forward_call module.""" | |||
| import ast | |||
| import textwrap | |||
| from unittest.mock import patch | |||
| from mindinsight.mindconverter.forward_call import ForwardCall | |||
| @@ -50,12 +49,10 @@ class TestForwardCall: | |||
| return out | |||
| """) | |||
| @patch.object(ForwardCall, 'process') | |||
| def test_process(self, mock_process): | |||
| def test_process(self): | |||
| """Test the function of visit ast tree to find out forward functions.""" | |||
| mock_process.return_value = None | |||
| forward_call = ForwardCall("mock") | |||
| forward_call.visit(ast.parse(self.source)) | |||
| ast_tree = ast.parse(self.source) | |||
| forward_call = ForwardCall(ast_tree) | |||
| expect_calls = ['TestNet.forward', | |||
| 'TestNet.forward1', | |||
| @@ -70,6 +67,6 @@ class TestForwardCall: | |||
| 'TestNet.fc3', | |||
| ] | |||
| expect_calls.sort() | |||
| real_calls = list(forward_call.calls) | |||
| real_calls = list(forward_call.calls.keys()) | |||
| real_calls.sort() | |||
| assert real_calls == expect_calls | |||