|
|
|
@@ -28,6 +28,8 @@ from mindinsight.mindconverter.config import ALL_UNSUPPORTED |
|
|
|
from mindinsight.mindconverter.common.log import logger |
|
|
|
from mindinsight.mindconverter.forward_call import ForwardCall |
|
|
|
|
|
|
|
LINE_NO_INDEX_DIFF = 1 |
|
|
|
|
|
|
|
|
|
|
|
class Converter: |
|
|
|
"""Convert class""" |
|
|
|
@@ -197,6 +199,7 @@ class Converter: |
|
|
|
raise ValueError('"(" not found, {} should work with "("'.format(call_name)) |
|
|
|
right = self.find_right_parentheses(code, left) |
|
|
|
end = right |
|
|
|
|
|
|
|
expr = code[start:end + 1] |
|
|
|
args_str = code[left:right + 1] |
|
|
|
|
|
|
|
@@ -336,6 +339,96 @@ class Converter: |
|
|
|
mapping.update(convert_fun(*args)) |
|
|
|
return mapping |
|
|
|
|
|
|
|
@staticmethod |
|
|
|
def get_code_start_line_num(source_lines): |
|
|
|
""" |
|
|
|
Get the start code line number exclude comments. |
|
|
|
|
|
|
|
Args: |
|
|
|
source_lines (list[str]): Split results of original code. |
|
|
|
|
|
|
|
Returns: |
|
|
|
int, the start line number. |
|
|
|
""" |
|
|
|
stack = [] |
|
|
|
index = 0 |
|
|
|
for i, line in enumerate(source_lines): |
|
|
|
if line.strip().startswith('#'): |
|
|
|
continue |
|
|
|
if line.strip().startswith('"""'): |
|
|
|
if not line.endswith('"""\n'): |
|
|
|
stack.append('"""') |
|
|
|
continue |
|
|
|
if line.strip().startswith("'''"): |
|
|
|
if not line.endswith("'''\n"): |
|
|
|
stack.append("'''") |
|
|
|
continue |
|
|
|
if line.endswith('"""\n') or line.endswith("'''\n"): |
|
|
|
stack.pop() |
|
|
|
continue |
|
|
|
if line.strip() != '' and not stack: |
|
|
|
index = i |
|
|
|
break |
|
|
|
return index |
|
|
|
|
|
|
|
def update_code_and_convert_info(self, code, mapping): |
|
|
|
""" |
|
|
|
Replace code according to mapping, and update convert info. |
|
|
|
|
|
|
|
Args: |
|
|
|
code (str): The code to replace. |
|
|
|
mapping (dict): Mapping for original code and the replaced code. |
|
|
|
|
|
|
|
Returns: |
|
|
|
str, the replaced code. |
|
|
|
""" |
|
|
|
|
|
|
|
for key, value in mapping.items(): |
|
|
|
code = code.replace(key, value) |
|
|
|
|
|
|
|
source_lines = code.splitlines(keepends=True) |
|
|
|
start_line_number = self.get_code_start_line_num(source_lines) |
|
|
|
add_import_infos = ['import mindspore\n', |
|
|
|
'import mindspore.nn as nn\n', |
|
|
|
'import mindspore.ops.operations as P\n'] |
|
|
|
for i, add_import_info in enumerate(add_import_infos): |
|
|
|
source_lines.insert(start_line_number + i, add_import_info) |
|
|
|
self.convert_info += '[Add Import] {}.\n'.format(add_import_info.strip()) |
|
|
|
|
|
|
|
insert_count = len(add_import_infos) |
|
|
|
line_diff = insert_count - LINE_NO_INDEX_DIFF |
|
|
|
|
|
|
|
for i in range(start_line_number + insert_count, len(source_lines)): |
|
|
|
line = source_lines[i] |
|
|
|
|
|
|
|
if (line.startswith('from torch') and 'import' in line) or line.startswith('import torch'): |
|
|
|
new_line = '# ' + line |
|
|
|
source_lines[i] = new_line |
|
|
|
self.convert_info += '[Annotate][Line{:3d}] {} is annotated.\n'.format(i - line_diff, line.strip()) |
|
|
|
if line.strip().startswith('class') and '(nn.Module)' in line: |
|
|
|
new_line = line.replace('nn.Module', 'nn.Cell') |
|
|
|
source_lines[i] = new_line |
|
|
|
self.convert_info += '[Convert][Line{:3d}] nn.Module is converted.\n'.format(i - line_diff) |
|
|
|
if line.strip().startswith('def forward('): |
|
|
|
new_line = line.replace('forward', 'construct') |
|
|
|
source_lines[i] = new_line |
|
|
|
self.convert_info += '[Convert][Line{:3d}] forward is converted.\n'.format(i - line_diff) |
|
|
|
if 'nn.Linear' in line: |
|
|
|
new_line = line.replace('nn.Linear', 'nn.Dense') |
|
|
|
source_lines[i] = new_line |
|
|
|
self.convert_info += '[Convert][Line{:3d}] nn.Linear is converted.\n'.format(i - line_diff) |
|
|
|
if '(nn.Sequential)' in line: |
|
|
|
new_line = line.replace('nn.Sequential', 'nn.SequentialCell') |
|
|
|
source_lines[i] = new_line |
|
|
|
self.convert_info += '[Convert][Line{:3d}] nn.Sequential is converted.\n'.format(i - line_diff) |
|
|
|
if 'nn.init.' in line: |
|
|
|
new_line = line.replace('nn.init', 'pass # nn.init') |
|
|
|
source_lines[i] = new_line |
|
|
|
self.convert_info += '[Annotate][Line{:3d}] {} is annotated.\n'.format(i - line_diff, 'nn.init') |
|
|
|
|
|
|
|
code = ''.join(source_lines) |
|
|
|
return code |
|
|
|
|
|
|
|
def convert(self, import_name, output_dir, report_dir): |
|
|
|
""" |
|
|
|
Convert a module's code, code converted will be save in output_dir, and a report will be save in report_dir. |
|
|
|
@@ -346,10 +439,10 @@ class Converter: |
|
|
|
report_dir (str): The path to save report file. |
|
|
|
""" |
|
|
|
logger.info("Start converting %s", import_name) |
|
|
|
self.convert_info += '[Start Convert]\nThe module is {}\n'.format(import_name) |
|
|
|
start_info = '[Start Convert]\n' |
|
|
|
module_info = 'The module is {}.\n'.format(import_name) |
|
|
|
|
|
|
|
import_mod = importlib.import_module(import_name) |
|
|
|
|
|
|
|
srcfile = inspect.getsourcefile(import_mod) |
|
|
|
logger.info("Script file is %s", srcfile) |
|
|
|
|
|
|
|
@@ -358,40 +451,14 @@ class Converter: |
|
|
|
|
|
|
|
# replace python function under nn.Module |
|
|
|
mapping = self.get_mapping(import_mod, forward_list) |
|
|
|
|
|
|
|
code = inspect.getsource(import_mod) |
|
|
|
for key, value in mapping.items(): |
|
|
|
code = code.replace(key, value) |
|
|
|
|
|
|
|
code = 'import mindspore.ops.operations as P\n' + code |
|
|
|
code = 'import mindspore.nn as nn\n' + code |
|
|
|
code = 'import mindspore\n' + code |
|
|
|
|
|
|
|
self.convert_info += '||[Import Add] Add follow import sentences:\n' |
|
|
|
self.convert_info += 'import mindspore.ops.operations as P\n' |
|
|
|
self.convert_info += 'import mindspore.nn as nn\n' |
|
|
|
self.convert_info += 'import mindspore\n\n' |
|
|
|
|
|
|
|
code = code.replace('import torch', '# import torch') |
|
|
|
code = code.replace('from torch', '# from torch') |
|
|
|
code = code.replace('(nn.Module):', '(nn.Cell):') |
|
|
|
code = code.replace('forward(', 'construct(') |
|
|
|
code = code.replace('nn.Linear', 'nn.Dense') |
|
|
|
code = code.replace('(nn.Sequential)', '(nn.SequentialCell)') |
|
|
|
code = code.replace('nn.init.', 'pass # nn.init.') |
|
|
|
|
|
|
|
self.convert_info += '||[Import Annotated] Annotated follow import sentences:\n' |
|
|
|
self.convert_info += 'import sentence on torch as follows are annotated:\n' |
|
|
|
self.convert_info += 'import torch\n' |
|
|
|
self.convert_info += 'from torch ...\n' |
|
|
|
|
|
|
|
self.convert_info += '||[Explicit Convert] Module or function are explicitly converted as follows:\n' |
|
|
|
self.convert_info += '[nn.Module] is converted to [nn.Cell]\n' |
|
|
|
self.convert_info += '[forward] is converted to [construct]\n' |
|
|
|
self.convert_info += '[nn.Linear] is converted to [nn.Dense]\n' |
|
|
|
self.convert_info += '[nn.Sequential] is converted to [nn.SequentialCell]\n' |
|
|
|
self.convert_info += '[nn.init] is not converted and annotated\n' |
|
|
|
self.convert_info += '[Convert over]' |
|
|
|
code = self.update_code_and_convert_info(code, mapping) |
|
|
|
convert_info_split = self.convert_info.splitlines(keepends=True) |
|
|
|
convert_info_split = sorted(convert_info_split) |
|
|
|
convert_info_split.insert(0, start_info) |
|
|
|
convert_info_split.insert(1, module_info) |
|
|
|
convert_info_split.append('[Convert Over]') |
|
|
|
self.convert_info = ''.join(convert_info_split) |
|
|
|
|
|
|
|
dest_file = os.path.join(output_dir, os.path.basename(srcfile)) |
|
|
|
with os.fdopen(os.open(dest_file, self.flags, self.modes), 'w') as file: |
|
|
|
@@ -428,7 +495,6 @@ def _path_split(file): |
|
|
|
|
|
|
|
Returns: |
|
|
|
list[str], list of file tail |
|
|
|
|
|
|
|
""" |
|
|
|
file_dir, name = os.path.split(file) |
|
|
|
if file_dir: |
|
|
|
@@ -456,6 +522,6 @@ def main(files_config): |
|
|
|
module_name = '.'.join(in_file_split) |
|
|
|
convert_ins.convert(module_name, files_config['outfile_dir'], files_config['report_dir']) |
|
|
|
|
|
|
|
in_module = files_config['in_module'] |
|
|
|
in_module = files_config.get('in_module') |
|
|
|
if in_module: |
|
|
|
convert_ins.convert(in_module, files_config['outfile_dir'], files_config['report_dir']) |