Browse Source

!187 more clear report on mindconverter

Merge pull request !187 from quyongxiu1/mindconvert_report_fix_0.3
pull/187/MERGE
mindspore-ci-bot Gitee 5 years ago
parent
commit
16fcb03446
1 changed files with 103 additions and 37 deletions
  1. +103
    -37
      mindinsight/mindconverter/converter.py

+ 103
- 37
mindinsight/mindconverter/converter.py View File

@@ -28,6 +28,8 @@ from mindinsight.mindconverter.config import ALL_UNSUPPORTED
from mindinsight.mindconverter.common.log import logger
from mindinsight.mindconverter.forward_call import ForwardCall

LINE_NO_INDEX_DIFF = 1


class Converter:
"""Convert class"""
@@ -197,6 +199,7 @@ class Converter:
raise ValueError('"(" not found, {} should work with "("'.format(call_name))
right = self.find_right_parentheses(code, left)
end = right

expr = code[start:end + 1]
args_str = code[left:right + 1]

@@ -336,6 +339,96 @@ class Converter:
mapping.update(convert_fun(*args))
return mapping

@staticmethod
def get_code_start_line_num(source_lines):
"""
Get the start code line number exclude comments.

Args:
source_lines (list[str]): Split results of original code.

Returns:
int, the start line number.
"""
stack = []
index = 0
for i, line in enumerate(source_lines):
if line.strip().startswith('#'):
continue
if line.strip().startswith('"""'):
if not line.endswith('"""\n'):
stack.append('"""')
continue
if line.strip().startswith("'''"):
if not line.endswith("'''\n"):
stack.append("'''")
continue
if line.endswith('"""\n') or line.endswith("'''\n"):
stack.pop()
continue
if line.strip() != '' and not stack:
index = i
break
return index

def update_code_and_convert_info(self, code, mapping):
"""
Replace code according to mapping, and update convert info.

Args:
code (str): The code to replace.
mapping (dict): Mapping for original code and the replaced code.

Returns:
str, the replaced code.
"""

for key, value in mapping.items():
code = code.replace(key, value)

source_lines = code.splitlines(keepends=True)
start_line_number = self.get_code_start_line_num(source_lines)
add_import_infos = ['import mindspore\n',
'import mindspore.nn as nn\n',
'import mindspore.ops.operations as P\n']
for i, add_import_info in enumerate(add_import_infos):
source_lines.insert(start_line_number + i, add_import_info)
self.convert_info += '[Add Import] {}.\n'.format(add_import_info.strip())

insert_count = len(add_import_infos)
line_diff = insert_count - LINE_NO_INDEX_DIFF

for i in range(start_line_number + insert_count, len(source_lines)):
line = source_lines[i]

if (line.startswith('from torch') and 'import' in line) or line.startswith('import torch'):
new_line = '# ' + line
source_lines[i] = new_line
self.convert_info += '[Annotate][Line{:3d}] {} is annotated.\n'.format(i - line_diff, line.strip())
if line.strip().startswith('class') and '(nn.Module)' in line:
new_line = line.replace('nn.Module', 'nn.Cell')
source_lines[i] = new_line
self.convert_info += '[Convert][Line{:3d}] nn.Module is converted.\n'.format(i - line_diff)
if line.strip().startswith('def forward('):
new_line = line.replace('forward', 'construct')
source_lines[i] = new_line
self.convert_info += '[Convert][Line{:3d}] forward is converted.\n'.format(i - line_diff)
if 'nn.Linear' in line:
new_line = line.replace('nn.Linear', 'nn.Dense')
source_lines[i] = new_line
self.convert_info += '[Convert][Line{:3d}] nn.Linear is converted.\n'.format(i - line_diff)
if '(nn.Sequential)' in line:
new_line = line.replace('nn.Sequential', 'nn.SequentialCell')
source_lines[i] = new_line
self.convert_info += '[Convert][Line{:3d}] nn.Sequential is converted.\n'.format(i - line_diff)
if 'nn.init.' in line:
new_line = line.replace('nn.init', 'pass # nn.init')
source_lines[i] = new_line
self.convert_info += '[Annotate][Line{:3d}] {} is annotated.\n'.format(i - line_diff, 'nn.init')

code = ''.join(source_lines)
return code

def convert(self, import_name, output_dir, report_dir):
"""
Convert a module's code, code converted will be save in output_dir, and a report will be save in report_dir.
@@ -346,10 +439,10 @@ class Converter:
report_dir (str): The path to save report file.
"""
logger.info("Start converting %s", import_name)
self.convert_info += '[Start Convert]\nThe module is {}\n'.format(import_name)
start_info = '[Start Convert]\n'
module_info = 'The module is {}.\n'.format(import_name)

import_mod = importlib.import_module(import_name)

srcfile = inspect.getsourcefile(import_mod)
logger.info("Script file is %s", srcfile)

@@ -358,40 +451,14 @@ class Converter:

# replace python function under nn.Module
mapping = self.get_mapping(import_mod, forward_list)

code = inspect.getsource(import_mod)
for key, value in mapping.items():
code = code.replace(key, value)

code = 'import mindspore.ops.operations as P\n' + code
code = 'import mindspore.nn as nn\n' + code
code = 'import mindspore\n' + code

self.convert_info += '||[Import Add] Add follow import sentences:\n'
self.convert_info += 'import mindspore.ops.operations as P\n'
self.convert_info += 'import mindspore.nn as nn\n'
self.convert_info += 'import mindspore\n\n'

code = code.replace('import torch', '# import torch')
code = code.replace('from torch', '# from torch')
code = code.replace('(nn.Module):', '(nn.Cell):')
code = code.replace('forward(', 'construct(')
code = code.replace('nn.Linear', 'nn.Dense')
code = code.replace('(nn.Sequential)', '(nn.SequentialCell)')
code = code.replace('nn.init.', 'pass # nn.init.')

self.convert_info += '||[Import Annotated] Annotated follow import sentences:\n'
self.convert_info += 'import sentence on torch as follows are annotated:\n'
self.convert_info += 'import torch\n'
self.convert_info += 'from torch ...\n'

self.convert_info += '||[Explicit Convert] Module or function are explicitly converted as follows:\n'
self.convert_info += '[nn.Module] is converted to [nn.Cell]\n'
self.convert_info += '[forward] is converted to [construct]\n'
self.convert_info += '[nn.Linear] is converted to [nn.Dense]\n'
self.convert_info += '[nn.Sequential] is converted to [nn.SequentialCell]\n'
self.convert_info += '[nn.init] is not converted and annotated\n'
self.convert_info += '[Convert over]'
code = self.update_code_and_convert_info(code, mapping)
convert_info_split = self.convert_info.splitlines(keepends=True)
convert_info_split = sorted(convert_info_split)
convert_info_split.insert(0, start_info)
convert_info_split.insert(1, module_info)
convert_info_split.append('[Convert Over]')
self.convert_info = ''.join(convert_info_split)

dest_file = os.path.join(output_dir, os.path.basename(srcfile))
with os.fdopen(os.open(dest_file, self.flags, self.modes), 'w') as file:
@@ -428,7 +495,6 @@ def _path_split(file):

Returns:
list[str], list of file tail

"""
file_dir, name = os.path.split(file)
if file_dir:
@@ -456,6 +522,6 @@ def main(files_config):
module_name = '.'.join(in_file_split)
convert_ins.convert(module_name, files_config['outfile_dir'], files_config['report_dir'])

in_module = files_config['in_module']
in_module = files_config.get('in_module')
if in_module:
convert_ins.convert(in_module, files_config['outfile_dir'], files_config['report_dir'])

Loading…
Cancel
Save