|
|
|
@@ -64,7 +64,9 @@ class Converter: |
|
|
|
self._report.append('[Convert Over]') |
|
|
|
dest_file = os.path.join(output_dir, os.path.basename(infile)) |
|
|
|
with os.fdopen(os.open(dest_file, self.flags, self.modes), 'w') as file: |
|
|
|
file.write(pasta.dump(self._tree)) |
|
|
|
script = pasta.dump(self._tree) |
|
|
|
script = adjust_mindspore_import_position(script) |
|
|
|
file.write(script) |
|
|
|
logger.info("Convert success. Result is wrote to %s.", dest_file) |
|
|
|
except ScriptNotSupport as error: |
|
|
|
self._report.append('[ScriptNotSupport] ' + error.message) |
|
|
|
@@ -97,6 +99,61 @@ class Converter: |
|
|
|
return replaced_code |
|
|
|
|
|
|
|
|
|
|
|
def get_code_start_line_num(source_lines): |
|
|
|
""" |
|
|
|
Get the start code line number exclude comments. |
|
|
|
|
|
|
|
Args: |
|
|
|
source_lines (list[str]): Split results of code. |
|
|
|
|
|
|
|
Returns: |
|
|
|
int, the start line number. |
|
|
|
""" |
|
|
|
stack = [] |
|
|
|
index = 0 |
|
|
|
for i, line in enumerate(source_lines): |
|
|
|
line_strip = line.strip() |
|
|
|
if line_strip.startswith('#'): |
|
|
|
continue |
|
|
|
if line_strip.startswith('"""'): |
|
|
|
if not line_strip.endswith('"""'): |
|
|
|
stack.append('"""') |
|
|
|
continue |
|
|
|
if line_strip.startswith("'''"): |
|
|
|
if not line_strip.endswith("'''"): |
|
|
|
stack.append("'''") |
|
|
|
continue |
|
|
|
if line_strip.endswith('"""') or line_strip.endswith("'''"): |
|
|
|
stack.pop() |
|
|
|
continue |
|
|
|
if line_strip != '' and not stack: |
|
|
|
index = i |
|
|
|
break |
|
|
|
return index |
|
|
|
|
|
|
|
|
|
|
|
def adjust_mindspore_import_position(script): |
|
|
|
""" |
|
|
|
Adjust code sentence `import mindspore` in script to a proper position if the sentence is set before a comment. |
|
|
|
|
|
|
|
Args: |
|
|
|
script (str): code script before adjust. |
|
|
|
|
|
|
|
Returns: |
|
|
|
str, code script adjusted. |
|
|
|
""" |
|
|
|
script_list = script.split('\n') |
|
|
|
import_ms_sentence = 'import mindspore' |
|
|
|
if import_ms_sentence in script_list: |
|
|
|
import_index = script_list.index(import_ms_sentence) |
|
|
|
if script_list[import_index + 1].startswith('"""') or script_list[import_index + 1].startswith("'''"): |
|
|
|
script_list.pop(import_index) |
|
|
|
new_index = get_code_start_line_num(script_list) |
|
|
|
script_list.insert(new_index, import_ms_sentence) |
|
|
|
script = '\n'.join(script_list) |
|
|
|
return script |
|
|
|
|
|
|
|
|
|
|
|
def _get_name_ext(file): |
|
|
|
""" |
|
|
|
Split a file name in name and extension. |
|
|
|
|