@@ -28,13 +28,13 @@ from mindinsight.mindconverter.config import ALL_MAPPING
from mindinsight.mindconverter.config import NN_LIST
from mindinsight.mindconverter.config import ALL_TORCH_APIS
from mindinsight.mindconverter.config import ALL_2P_LIST
from mindinsight.mindconverter.config import UNSUPPORTED_WARN_INFOS
from mindinsight.mindconverter.config import ALL_UNSUPPORTED
from mindinsight.mindconverter.config import get_prompt_info
from mindinsight.mindconverter.common.log import logger
from mindinsight.mindconverter.common.exceptions import NodeTypeNotSupport
from mindinsight.mindconverter.forward_call import ForwardCall
LOG_FMT_CONVERT = "[Convert] '%s' is converted to '%s'."
LOG_FMT_CONVERT_WITH_TIPS = "[Convert] '%s' is converted to '%s'. %s"
LOG_FMT_NOT_CONVERT = "[UnConvert] '%s' didn't convert. %s"
LOG_FMT_PROMPT_INFO = "[INFO] %s"
LOG_SUGGESTION_MANUAL_CONVERT = "Please manual convert the code, along with the code associated with it."
@@ -95,6 +95,7 @@ class _LineColEditVisitor(ast.NodeVisitor):
class _NodeInfo:
"""NodeInfo class definition."""
def __init__(self, node):
self.node = node
self.call_list = [] # Used to save all ast.Call node in self._node
@@ -444,19 +445,25 @@ class AstEditVisitor(ast.NodeVisitor):
is_include_call = False
return is_include_call
def match_api(self, call_func_node, is_forward):
def match_api(self, call_func_node, is_forward, check_context=True ):
"""
Check api name to convert, check api name ok with a is_forward condition.
Args:
call_func_node (ast.Attribute): The call.func node.
is_forward (bool): whether api belong to forward.
check_context (boolean): If True, the code context will be checked. Default is True.
Returns:
str, the standard api name used to match.
ApiMappingEnum, the match result.
"""
api_name, match_case = self._infer_api_name(call_func_node)
match_case = ApiMatchingEnum.NOT_API
api_call_name = pasta.dump(call_func_node)
if api_call_name.startswith('self.'):
return api_call_name, match_case
api_name, match_case = self._infer_api_name(call_func_node, check_context)
api_call_name = pasta.dump(call_func_node)
is_tensor_obj_call = False
if api_name != api_call_name:
@@ -466,15 +473,17 @@ class AstEditVisitor(ast.NodeVisitor):
# rewritten external module name
# e.g., mm.ReLU will be written to nn.ReLU if 'import torch.nn as mm' in script.
if not is_tensor_obj_call and not self._code_analyzer.is_standard_external_ref :
standard_api_call_name = self._mapping_standard_api_name(api_name )
if not is_tensor_obj_call:
standard_api_call_name = self._get_api_whole_name(call_func_node, check_context )
if standard_api_call_name in ALL_TORCH_APIS:
match_case = ApiMatchingEnum.API_FOUND
if (not is_forward and standard_api_call_name in NN_LIST) or \
(is_forward and standard_api_call_name in ALL_2P_LIST):
match_case = ApiMatchingEnum.API_MATCHED
else:
if standard_api_call_name and standard_api_call_name.startswith('torch.nn.init'):
match_case = ApiMatchingEnum.API_MATCHED
return standard_api_call_name, match_case
@staticmethod
@@ -502,6 +511,25 @@ class AstEditVisitor(ast.NodeVisitor):
parameters_str = call_str[left_parenthesis_pos + 1:right_parenthesis_pos]
return parameters_str
def _get_api_whole_name(self, call_func_node, check_context=True):
"""
Get the whole name for the call node.
Args:
call_func_node (AST): The func attribute of ast.Call.
check_context (boolean): If True, the code context will be checked. Default is True.
Returns:
str, the whole name.
"""
api_name, match_case = self._infer_api_name(call_func_node, check_context)
if match_case == ApiMatchingEnum.API_STANDARD:
api_name_splits = api_name.split('.')
api_name_splits[0] = self._get_external_ref_whole_name(api_name_splits[0])
if api_name_splits[0]:
api_name = '.'.join(api_name_splits)
return api_name
def mapping_api(self, call_node, check_context=True):
"""
Convert api_name in code to MindSpore api, if api_name is a python api, code will not convert.
@@ -522,6 +550,26 @@ class AstEditVisitor(ast.NodeVisitor):
if api_call_name.startswith('self.'):
return code
new_code = self._mapping_api(call_node, check_context)
return new_code
def _mapping_api(self, call_node, check_context=True):
"""
Convert api_name in code to MindSpore api, if api_name is a python api, code will not convert.
If do not check context of the script, the code represented by the node must be written in the standard way.
Args:
call_node (ast.Call): The ast node to convert.
check_context (boolean): If True, the code context will be checked. Default is True.
Returns:
str, the converted code.
"""
code = pasta.dump(call_node)
api_call_name = pasta.dump(call_node.func)
# find full api expected to be converted. eg:expr="nn.Conv2d(1,2,3)" args_str="(1,2,3)"
args_str = '(' + self._get_call_parameters_str(call_node) + ')'
@@ -551,26 +599,37 @@ class AstEditVisitor(ast.NodeVisitor):
code = pasta.dump(node)
api_name = pasta.dump(node.func)
# parent node first call is equal to this node, skip when parent node is replaced.
for parent_node in self._stack[:-1]:
# The parent node first call is equal to this node, skip when parent node is replaced.
# This scenario occurs, for example, when out.view(out.size(0), -1) is first converted to
# P.Reshape()(out, (out.size(0). -1)), will skip P.Reshape() in following visiting.
# Access from the penultimate element in reverse order.
for parent_node in self._stack[-2::-1]:
if parent_node in self._new_call_nodes and pasta.dump(parent_node).startswith(api_name):
return
parent = self._stack[-2]
new_node = None
new_code = code
matched_api_name, match_case = self.match_api(node.func, self._is_forward_function)
if match_case in [ApiMatchingEnum.API_INFER, ApiMatchingEnum.API_MATCHED]:
warning_info = get_prompt_info(matched_api_name)
if warning_info is None:
warning_info = ''
if matched_api_name in ALL_MAPPING:
logger.info("Line %3d start converting API: %s", node.lineno, api_name)
new_code = self.mapping_api(node)
if new_code != code:
new_node = pasta.parse(new_code).body[0].value
# find the first call name
new_api_name = new_code[:new_code.find('(')]
self._process_log.info(node.lineno, node.col_offset, LOG_FMT_CONVERT % (api_name, new_api_name))
if matched_api_name in ALL_UNSUPPORTED:
warn_info = UNSUPPORTED_WARN_INFOS.get(api_name, '')
logger.warning("Line %3d: found unsupported API: %s%s", node.lineno, api_name, warn_info)
self._process_log.warning(node.lineno, node.col_offset, LOG_FMT_NOT_CONVERT % (api_name, warn_info))
try:
new_node = pasta.parse(new_code).body[0].value
# find the first call name
new_api_name = new_code[:new_code.find('(')]
except AttributeError:
new_node = pasta.parse(new_code).body[0]
new_api_name = new_code
self._process_log.info(node.lineno, node.col_offset,
LOG_FMT_CONVERT_WITH_TIPS % (api_name, new_api_name, warning_info))
else:
logger.warning("Line %3d: found unsupported API: %s%s", node.lineno, api_name, warning_info)
self._process_log.warning(node.lineno, node.col_offset, LOG_FMT_NOT_CONVERT % (api_name, warning_info))
elif match_case in [ApiMatchingEnum.API_STANDARD, ApiMatchingEnum.API_FOUND]:
self._process_log.warning(node.lineno, node.col_offset, LOG_FMT_NOT_CONVERT % (api_name, ''))
@@ -602,3 +661,57 @@ class AstEditVisitor(ast.NodeVisitor):
elif ref_name != 'F' and external_ref_info.name == 'torch.nn.functional':
renames[ref_name] = 'F'
return renames
def _get_external_ref_whole_name(self, ref_name):
"""
Find out external reference whole name.
For example:
In the parsed source code, there is import statement
import torch.nn as new_name
_get_external_ref_whole_name('new_name') will return 'torch.nn' string.
"""
external_refs = self._code_analyzer.external_references
for external_ref_name, ref_info in external_refs.items():
external_ref_info = ref_info['external_ref_info']
if external_ref_name == ref_name:
return external_ref_info.name
return None
def _check_isinstance_parameter(self, node):
"""Check whether the second parameter of isinstance function contains the torch type."""
is_isinstance_arg = False
# Check whether node is the second parameter of the isinstance function call.
# Access from the penultimate element in reverse order.
for parent_node in self._stack[-2::-1]:
if isinstance(parent_node, ast.Call) and pasta.dump(parent_node.func) == 'isinstance':
isinstance_node = parent_node
seconde_arg_type_nodes = []
if isinstance(isinstance_node.args[1], ast.Tuple):
seconde_arg_type_nodes.extend(isinstance_node.args[1].elts)
else:
seconde_arg_type_nodes.append(isinstance_node.args[1])
if node in seconde_arg_type_nodes:
is_isinstance_arg = True
break
if not is_isinstance_arg:
return False
isinstance_type_arg = pasta.dump(node)
check_torch_type = False
if isinstance_type_arg:
type_splits = isinstance_type_arg.split('.')
whole_name = self._get_external_ref_whole_name(type_splits[0])
if whole_name and whole_name.startswith('torch'):
check_torch_type = True
if check_torch_type:
_, match_case = self.match_api(node, False)
if match_case != ApiMatchingEnum.NOT_API:
warn_info = 'Manually determine the conversion type.'
self._process_log.warning(node.lineno, node.col_offset,
LOG_FMT_NOT_CONVERT % (isinstance_type_arg, warn_info))
return check_torch_type
def visit_Attribute(self, node):
"""Callback function when visit AST tree"""
self._check_isinstance_parameter(node)