| @@ -39,6 +39,7 @@ class ConvertToolLoader: | |||
| self.log = None | |||
| self.compare_none_error = None | |||
| self.compare_exception = None | |||
| self.toolkit_path = self.find_toolkit_path() | |||
| self.load_convert_tool() | |||
| @staticmethod | |||
| @@ -63,10 +64,9 @@ class ConvertToolLoader: | |||
| def load_convert_tool(self): | |||
| """load CANN conversion tool from the toolkit path.""" | |||
| toolkit_path = self.find_toolkit_path() | |||
| # add toolkit path to system searching module path | |||
| if str(toolkit_path) not in sys.path: | |||
| sys.path.append(str(toolkit_path)) | |||
| if str(self.toolkit_path) not in sys.path: | |||
| sys.path.insert(0, str(self.toolkit_path)) | |||
| try: | |||
| self.utils = import_module('utils') | |||
| self.common = import_module('common') | |||
| @@ -75,12 +75,10 @@ class ConvertToolLoader: | |||
| self.format_conversion = import_module( | |||
| 'shape_conversion').FormatConversionMain | |||
| except ModuleNotFoundError: | |||
| # restore system searching module path | |||
| if str(toolkit_path) in sys.path: | |||
| sys.path.remove(str(toolkit_path)) | |||
| self.reset_system_path() | |||
| raise ModuleNotFoundError( | |||
| "Failed to load CANN conversion tools under {}. Please make sure Ascend " \ | |||
| "toolkit has been installed properly.".format(toolkit_path)) | |||
| "toolkit has been installed properly.".format(self.toolkit_path)) | |||
| try: | |||
| self.progress = import_module("progress").Progress | |||
| @@ -100,9 +98,10 @@ class ConvertToolLoader: | |||
| self.compare_none_error = self.utils.VECTOR_COMPARISON_NONE_ERROR | |||
| self.compare_exception = self.utils.CompareError | |||
| def reset_system_path(self): | |||
| # restore system searching module path | |||
| if str(toolkit_path) in sys.path: | |||
| sys.path.remove(str(toolkit_path)) | |||
| if str(self.toolkit_path) in sys.path: | |||
| sys.path.remove(str(self.toolkit_path)) | |||
| def parse_args(file_list, output_path): | |||
| @@ -147,16 +146,23 @@ class AsyncDumpConverter: | |||
| def convert_files(self): | |||
| """Main entry of the converter to convert async dump files into npy format.""" | |||
| self.convert_tool.log.print_info_log('Start to convert async dump files.') | |||
| ret_code = self.convert_tool.compare_none_error | |||
| if self.args.format is not None: | |||
| convert = self.convert_tool.format_conversion(self.args) | |||
| else: | |||
| convert = self.convert_tool.dump_data_parser(self.args) | |||
| ret_code = self.handle_multi_process(convert, self.files_to_convert) | |||
| self._rename_generated_npy_files() | |||
| if ret_code != self.convert_tool.compare_none_error: | |||
| if os.path.exists(self.failed_file_path): | |||
| self.convert_failed_tensors() | |||
| try: | |||
| ret_code = self.convert_tool.compare_none_error | |||
| if self.args.format is not None: | |||
| convert = self.convert_tool.format_conversion(self.args) | |||
| else: | |||
| convert = self.convert_tool.dump_data_parser(self.args) | |||
| # 1. check if arguments are valid | |||
| convert.check_arguments_valid() | |||
| # 2. convert format for dump data | |||
| ret_code = self.handle_multi_process(convert, self.files_to_convert) | |||
| self._rename_generated_npy_files() | |||
| if ret_code != self.convert_tool.compare_none_error: | |||
| if os.path.exists(self.failed_file_path): | |||
| self.convert_failed_tensors() | |||
| finally: | |||
| # clean up sys.path no matter conversion is successful or not to avoid pollution | |||
| self.convert_tool.reset_system_path() | |||
| self.convert_tool.log.print_info_log('Finish to convert async dump files.') | |||
| def convert_failed_tensors(self): | |||