| @@ -170,6 +170,9 @@ class InFileAction(argparse.Action): | |||||
| if not os.path.isfile(outfile_dir): | if not os.path.isfile(outfile_dir): | ||||
| parser_in.error(f'{option_string} {outfile_dir} is not a file') | parser_in.error(f'{option_string} {outfile_dir} is not a file') | ||||
| if not os.path.basename(outfile_dir).endswith("py"): | |||||
| parser_in.error(f'{option_string} {outfile_dir} is not a valid python file') | |||||
| setattr(namespace, self.dest, outfile_dir) | setattr(namespace, self.dest, outfile_dir) | ||||
| @@ -282,32 +285,32 @@ class NodeAction(argparse.Action): | |||||
| parser = argparse.ArgumentParser( | parser = argparse.ArgumentParser( | ||||
| prog='mindconverter', | |||||
| description='MindConverter CLI entry point (version: {})'.format(mindinsight.__version__), | |||||
| allow_abbrev=False) | |||||
| prog='mindconverter', | |||||
| description='MindConverter CLI entry point (version: {})'.format(mindinsight.__version__), | |||||
| allow_abbrev=False) | |||||
| parser.add_argument( | parser.add_argument( | ||||
| '--version', | |||||
| action='version', | |||||
| version='%(prog)s ({})'.format(mindinsight.__version__)) | |||||
| '--version', | |||||
| action='version', | |||||
| version='%(prog)s ({})'.format(mindinsight.__version__)) | |||||
| parser.add_argument( | parser.add_argument( | ||||
| '--in_file', | |||||
| type=str, | |||||
| action=InFileAction, | |||||
| required=False, | |||||
| default=None, | |||||
| help=""" | |||||
| '--in_file', | |||||
| type=str, | |||||
| action=InFileAction, | |||||
| required=False, | |||||
| default=None, | |||||
| help=""" | |||||
| Specify path for script file to use AST schema to | Specify path for script file to use AST schema to | ||||
| do script conversation. | do script conversation. | ||||
| """) | """) | ||||
| parser.add_argument( | parser.add_argument( | ||||
| '--model_file', | |||||
| type=str, | |||||
| action=ModelFileAction, | |||||
| required=False, | |||||
| help=""" | |||||
| '--model_file', | |||||
| type=str, | |||||
| action=ModelFileAction, | |||||
| required=False, | |||||
| help=""" | |||||
| PyTorch .pth or Tensorflow .pb model file path to use graph | PyTorch .pth or Tensorflow .pb model file path to use graph | ||||
| based schema to do script generation. When | based schema to do script generation. When | ||||
| `--in_file` and `--model_file` are both provided, | `--in_file` and `--model_file` are both provided, | ||||
| @@ -315,12 +318,12 @@ parser.add_argument( | |||||
| """) | """) | ||||
| parser.add_argument( | parser.add_argument( | ||||
| '--shape', | |||||
| type=str, | |||||
| action=ShapeAction, | |||||
| default=None, | |||||
| required=False, | |||||
| help=""" | |||||
| '--shape', | |||||
| type=str, | |||||
| action=ShapeAction, | |||||
| default=None, | |||||
| required=False, | |||||
| help=""" | |||||
| Optional, expected input tensor shape of | Optional, expected input tensor shape of | ||||
| `--model_file`. It's required when use graph based | `--model_file`. It's required when use graph based | ||||
| schema. | schema. | ||||
| @@ -328,55 +331,55 @@ parser.add_argument( | |||||
| """) | """) | ||||
| parser.add_argument( | parser.add_argument( | ||||
| '--input_nodes', | |||||
| type=str, | |||||
| action=NodeAction, | |||||
| default=None, | |||||
| required=False, | |||||
| help=""" | |||||
| '--input_nodes', | |||||
| type=str, | |||||
| action=NodeAction, | |||||
| default=None, | |||||
| required=False, | |||||
| help=""" | |||||
| Optional, input node(s) name of `--model_file`. It's required when use Tensorflow model. | Optional, input node(s) name of `--model_file`. It's required when use Tensorflow model. | ||||
| Usage: --input_nodes input_1:0,input_2:0 | Usage: --input_nodes input_1:0,input_2:0 | ||||
| """) | """) | ||||
| parser.add_argument( | parser.add_argument( | ||||
| '--output_nodes', | |||||
| type=str, | |||||
| action=NodeAction, | |||||
| default=None, | |||||
| required=False, | |||||
| help=""" | |||||
| '--output_nodes', | |||||
| type=str, | |||||
| action=NodeAction, | |||||
| default=None, | |||||
| required=False, | |||||
| help=""" | |||||
| Optional, output node(s) name of `--model_file`. It's required when use Tensorflow model. | Optional, output node(s) name of `--model_file`. It's required when use Tensorflow model. | ||||
| Usage: --output_nodes output_1:0,output_2:0 | Usage: --output_nodes output_1:0,output_2:0 | ||||
| """) | """) | ||||
| parser.add_argument( | parser.add_argument( | ||||
| '--output', | |||||
| type=str, | |||||
| action=OutputDirAction, | |||||
| default=os.path.join(os.getcwd(), 'output'), | |||||
| help=""" | |||||
| '--output', | |||||
| type=str, | |||||
| action=OutputDirAction, | |||||
| default=os.path.join(os.getcwd(), 'output'), | |||||
| help=""" | |||||
| Optional, specify path for converted script file | Optional, specify path for converted script file | ||||
| directory. Default output directory is `output` folder | directory. Default output directory is `output` folder | ||||
| in the current working directory. | in the current working directory. | ||||
| """) | """) | ||||
| parser.add_argument( | parser.add_argument( | ||||
| '--report', | |||||
| type=str, | |||||
| action=LogFileAction, | |||||
| default=None, | |||||
| help=""" | |||||
| '--report', | |||||
| type=str, | |||||
| action=LogFileAction, | |||||
| default=None, | |||||
| help=""" | |||||
| Optional, specify report directory. Default is | Optional, specify report directory. Default is | ||||
| converted script directory. | converted script directory. | ||||
| """) | """) | ||||
| parser.add_argument( | parser.add_argument( | ||||
| '--project_path', | |||||
| type=str, | |||||
| action=ProjectPathAction, | |||||
| required=False, | |||||
| default=None, | |||||
| help=""" | |||||
| '--project_path', | |||||
| type=str, | |||||
| action=ProjectPathAction, | |||||
| required=False, | |||||
| default=None, | |||||
| help=""" | |||||
| Optional, PyTorch scripts project path. If PyTorch | Optional, PyTorch scripts project path. If PyTorch | ||||
| project is not in PYTHONPATH, please assign | project is not in PYTHONPATH, please assign | ||||
| `--project_path` when use graph based schema. | `--project_path` when use graph based schema. | ||||
| @@ -13,6 +13,7 @@ | |||||
| # limitations under the License. | # limitations under the License. | ||||
| # ============================================================================ | # ============================================================================ | ||||
| """Define custom exception.""" | """Define custom exception.""" | ||||
| import abc | |||||
| import sys | import sys | ||||
| from enum import unique | from enum import unique | ||||
| from importlib import import_module | from importlib import import_module | ||||
| @@ -23,7 +24,7 @@ from treelib.exceptions import DuplicatedNodeIdError, MultipleRootError, NodeIDA | |||||
| from mindinsight.mindconverter.common.log import logger as log, logger_console as log_console | from mindinsight.mindconverter.common.log import logger as log, logger_console as log_console | ||||
| from mindinsight.utils.constant import ScriptConverterErrors | from mindinsight.utils.constant import ScriptConverterErrors | ||||
| from mindinsight.utils.exceptions import MindInsightException, ParamMissError | |||||
| from mindinsight.utils.exceptions import MindInsightException | |||||
| @unique | @unique | ||||
| @@ -40,12 +41,16 @@ class ConverterErrors(ScriptConverterErrors): | |||||
| SCRIPT_GENERATE_FAIL = 9 | SCRIPT_GENERATE_FAIL = 9 | ||||
| REPORT_GENERATE_FAIL = 10 | REPORT_GENERATE_FAIL = 10 | ||||
| NODE_CONVERSION_ERROR = 11 | NODE_CONVERSION_ERROR = 11 | ||||
| INPUT_SHAPE_ERROR = 12 | |||||
| TF_RUNTIME_ERROR = 13 | |||||
| BASE_CONVERTER_FAIL = 000 | BASE_CONVERTER_FAIL = 000 | ||||
| GRAPH_INIT_FAIL = 100 | GRAPH_INIT_FAIL = 100 | ||||
| TREE_CREATE_FAIL = 200 | TREE_CREATE_FAIL = 200 | ||||
| SOURCE_FILES_SAVE_FAIL = 300 | SOURCE_FILES_SAVE_FAIL = 300 | ||||
| GENERATOR_FAIL = 400 | GENERATOR_FAIL = 400 | ||||
| SUB_GRAPH_SEARCHING_FAIL = 500 | |||||
| MODEL_LOADING_FAIL = 600 | |||||
| class ScriptNotSupport(MindInsightException): | class ScriptNotSupport(MindInsightException): | ||||
| @@ -80,7 +85,6 @@ class MindConverterException(Exception): | |||||
| def __init__(self, **kwargs): | def __init__(self, **kwargs): | ||||
| """Initialization of MindInsightException.""" | """Initialization of MindInsightException.""" | ||||
| error = kwargs.get('error', None) | error = kwargs.get('error', None) | ||||
| user_msg = kwargs.get('user_msg', '') | user_msg = kwargs.get('user_msg', '') | ||||
| debug_msg = kwargs.get('debug_msg', '') | debug_msg = kwargs.get('debug_msg', '') | ||||
| @@ -97,6 +101,9 @@ class MindConverterException(Exception): | |||||
| def __str__(self): | def __str__(self): | ||||
| return '[{}] code: {}, msg: {}'.format(self.__class__.__name__, self.error_code(), self.user_msg) | return '[{}] code: {}, msg: {}'.format(self.__class__.__name__, self.error_code(), self.user_msg) | ||||
| def __repr__(self): | |||||
| return self.__str__() | |||||
| def error_code(self): | def error_code(self): | ||||
| """" | """" | ||||
| Calculate error code. | Calculate error code. | ||||
| @@ -109,54 +116,59 @@ class MindConverterException(Exception): | |||||
| Returns: | Returns: | ||||
| str, Hex string representing the composed MindConverter error code. | str, Hex string representing the composed MindConverter error code. | ||||
| """ | """ | ||||
| num = 0xFFFF & self.error.value | num = 0xFFFF & self.error.value | ||||
| error_code = ''.join((f'{self.cls_code}'.zfill(3), hex(num)[2:].zfill(4).upper())) | error_code = ''.join((f'{self.cls_code}'.zfill(3), hex(num)[2:].zfill(4).upper())) | ||||
| return error_code | return error_code | ||||
| @staticmethod | |||||
| def raise_from(): | |||||
| @classmethod | |||||
| @abc.abstractmethod | |||||
| def raise_from(cls): | |||||
| """Raise from below exceptions.""" | """Raise from below exceptions.""" | ||||
| return None | |||||
| @classmethod | @classmethod | ||||
| def check_except_with_print_pytorch(cls, msg): | |||||
| """Check except in pytorch.""" | |||||
| def uniform_catcher(cls, msg): | |||||
| """Uniform exception catcher.""" | |||||
| def decorator(func): | def decorator(func): | ||||
| def _f(graph_path, sample_shape, output_folder, report_folder): | |||||
| def _f(*args, **kwargs): | |||||
| try: | try: | ||||
| func(graph_path=graph_path, sample_shape=sample_shape, | |||||
| output_folder=output_folder, report_folder=report_folder) | |||||
| res = func(*args, **kwargs) | |||||
| except cls.raise_from() as e: | except cls.raise_from() as e: | ||||
| error = cls(msg=msg) | error = cls(msg=msg) | ||||
| detail_info = f"Error detail: {str(e)}" | detail_info = f"Error detail: {str(e)}" | ||||
| log_console.error(str(error)) | log_console.error(str(error)) | ||||
| log_console.error(detail_info) | log_console.error(detail_info) | ||||
| log.exception(e) | log.exception(e) | ||||
| sys.exit(-1) | |||||
| sys.exit(0) | |||||
| except ModuleNotFoundError as e: | |||||
| detail_info = f"Error detail: Required package not found, please check the runtime environment." | |||||
| log_console.error(str(e)) | |||||
| log_console.error(detail_info) | |||||
| log.exception(e) | |||||
| sys.exit(0) | |||||
| return res | |||||
| return _f | return _f | ||||
| return decorator | return decorator | ||||
| @classmethod | @classmethod | ||||
| def check_except_with_print_tf(cls, msg): | |||||
| """Check except in tf.""" | |||||
| def check_except(cls, msg): | |||||
| """Check except.""" | |||||
| def decorator(func): | def decorator(func): | ||||
| def _f(graph_path, sample_shape, | |||||
| input_nodes, output_nodes, | |||||
| output_folder, report_folder): | |||||
| def _f(*args, **kwargs): | |||||
| try: | try: | ||||
| func(graph_path=graph_path, sample_shape=sample_shape, | |||||
| input_nodes=input_nodes, output_nodes=output_nodes, | |||||
| output_folder=output_folder, report_folder=report_folder) | |||||
| output = func(*args, **kwargs) | |||||
| except cls.raise_from() as e: | except cls.raise_from() as e: | ||||
| error = cls(msg=msg) | |||||
| detail_info = f"Error detail: {str(e)}" | |||||
| log_console.error(str(error)) | |||||
| log_console.error(detail_info) | |||||
| log.error(msg) | |||||
| log.exception(e) | |||||
| raise cls(msg=msg) | |||||
| except Exception as e: | |||||
| log.error(msg) | |||||
| log.exception(e) | log.exception(e) | ||||
| sys.exit(-1) | |||||
| raise e | |||||
| return output | |||||
| return _f | return _f | ||||
| @@ -170,31 +182,12 @@ class BaseConverterFail(MindConverterException): | |||||
| super(BaseConverterFail, self).__init__(error=ConverterErrors.BASE_CONVERTER_FAIL, | super(BaseConverterFail, self).__init__(error=ConverterErrors.BASE_CONVERTER_FAIL, | ||||
| user_msg=msg) | user_msg=msg) | ||||
| @staticmethod | |||||
| def raise_from(): | |||||
| @classmethod | |||||
| def raise_from(cls): | |||||
| """Raise from exceptions below.""" | """Raise from exceptions below.""" | ||||
| except_source = (UnknownModel, | |||||
| ParamMissError) | |||||
| except_source = Exception, cls | |||||
| return except_source | return except_source | ||||
| @classmethod | |||||
| def check_except(cls, msg): | |||||
| """Check except.""" | |||||
| def decorator(func): | |||||
| def _f(file_config): | |||||
| try: | |||||
| func(file_config=file_config) | |||||
| except cls.raise_from() as e: | |||||
| error = cls(msg=msg) | |||||
| detail_info = f"Error detail: {str(e)}" | |||||
| log_console.error(str(error)) | |||||
| log_console.error(detail_info) | |||||
| log.exception(e) | |||||
| sys.exit(-1) | |||||
| return _f | |||||
| return decorator | |||||
| class UnknownModel(MindConverterException): | class UnknownModel(MindConverterException): | ||||
| """The unknown model error.""" | """The unknown model error.""" | ||||
| @@ -203,6 +196,10 @@ class UnknownModel(MindConverterException): | |||||
| super(UnknownModel, self).__init__(error=ConverterErrors.UNKNOWN_MODEL, | super(UnknownModel, self).__init__(error=ConverterErrors.UNKNOWN_MODEL, | ||||
| user_msg=msg) | user_msg=msg) | ||||
| @classmethod | |||||
| def raise_from(cls): | |||||
| return cls | |||||
| class GraphInitFail(MindConverterException): | class GraphInitFail(MindConverterException): | ||||
| """The graph init fail error.""" | """The graph init fail error.""" | ||||
| @@ -211,27 +208,19 @@ class GraphInitFail(MindConverterException): | |||||
| super(GraphInitFail, self).__init__(error=ConverterErrors.GRAPH_INIT_FAIL, | super(GraphInitFail, self).__init__(error=ConverterErrors.GRAPH_INIT_FAIL, | ||||
| user_msg=kwargs.get('msg', '')) | user_msg=kwargs.get('msg', '')) | ||||
| @staticmethod | |||||
| def raise_from(): | |||||
| @classmethod | |||||
| def raise_from(cls): | |||||
| """Raise from exceptions below.""" | """Raise from exceptions below.""" | ||||
| except_source = (FileNotFoundError, | except_source = (FileNotFoundError, | ||||
| ModuleNotFoundError, | ModuleNotFoundError, | ||||
| ModelNotSupport, | ModelNotSupport, | ||||
| SubGraphSearchingFail, | |||||
| TypeError, | TypeError, | ||||
| ZeroDivisionError, | ZeroDivisionError, | ||||
| RuntimeError) | |||||
| RuntimeError, | |||||
| cls) | |||||
| return except_source | return except_source | ||||
| @classmethod | |||||
| def check_except_pytorch(cls, msg): | |||||
| """Check except for pytorch.""" | |||||
| return super().check_except_with_print_pytorch(msg) | |||||
| @classmethod | |||||
| def check_except_tf(cls, msg): | |||||
| """Check except for tf.""" | |||||
| return super().check_except_with_print_tf(msg) | |||||
| class TreeCreateFail(MindConverterException): | class TreeCreateFail(MindConverterException): | ||||
| """The tree create fail.""" | """The tree create fail.""" | ||||
| @@ -240,23 +229,13 @@ class TreeCreateFail(MindConverterException): | |||||
| super(TreeCreateFail, self).__init__(error=ConverterErrors.TREE_CREATE_FAIL, | super(TreeCreateFail, self).__init__(error=ConverterErrors.TREE_CREATE_FAIL, | ||||
| user_msg=msg) | user_msg=msg) | ||||
| @staticmethod | |||||
| def raise_from(): | |||||
| @classmethod | |||||
| def raise_from(cls): | |||||
| """Raise from exceptions below.""" | """Raise from exceptions below.""" | ||||
| except_source = (NodeInputMissing, | except_source = (NodeInputMissing, | ||||
| TreeNodeInsertFail) | |||||
| TreeNodeInsertFail, cls) | |||||
| return except_source | return except_source | ||||
| @classmethod | |||||
| def check_except_pytorch(cls, msg): | |||||
| """Check except.""" | |||||
| return super().check_except_with_print_pytorch(msg) | |||||
| @classmethod | |||||
| def check_except_tf(cls, msg): | |||||
| """Check except for tf.""" | |||||
| return super().check_except_with_print_tf(msg) | |||||
| class SourceFilesSaveFail(MindConverterException): | class SourceFilesSaveFail(MindConverterException): | ||||
| """The source files save fail error.""" | """The source files save fail error.""" | ||||
| @@ -265,25 +244,15 @@ class SourceFilesSaveFail(MindConverterException): | |||||
| super(SourceFilesSaveFail, self).__init__(error=ConverterErrors.SOURCE_FILES_SAVE_FAIL, | super(SourceFilesSaveFail, self).__init__(error=ConverterErrors.SOURCE_FILES_SAVE_FAIL, | ||||
| user_msg=msg) | user_msg=msg) | ||||
| @staticmethod | |||||
| def raise_from(): | |||||
| @classmethod | |||||
| def raise_from(cls): | |||||
| """Raise from exceptions below.""" | """Raise from exceptions below.""" | ||||
| except_source = (NodeInputTypeNotSupport, | except_source = (NodeInputTypeNotSupport, | ||||
| ScriptGenerateFail, | ScriptGenerateFail, | ||||
| ReportGenerateFail, | ReportGenerateFail, | ||||
| IOError) | |||||
| IOError, cls) | |||||
| return except_source | return except_source | ||||
| @classmethod | |||||
| def check_except_pytorch(cls, msg): | |||||
| """Check except.""" | |||||
| return super().check_except_with_print_pytorch(msg) | |||||
| @classmethod | |||||
| def check_except_tf(cls, msg): | |||||
| """Check except for tf.""" | |||||
| return super().check_except_with_print_tf(msg) | |||||
| class ModelNotSupport(MindConverterException): | class ModelNotSupport(MindConverterException): | ||||
| """The model not support error.""" | """The model not support error.""" | ||||
| @@ -293,55 +262,32 @@ class ModelNotSupport(MindConverterException): | |||||
| user_msg=msg, | user_msg=msg, | ||||
| cls_code=ConverterErrors.GRAPH_INIT_FAIL.value) | cls_code=ConverterErrors.GRAPH_INIT_FAIL.value) | ||||
| @staticmethod | |||||
| def raise_from(): | |||||
| @classmethod | |||||
| def raise_from(cls): | |||||
| """Raise from exceptions below.""" | """Raise from exceptions below.""" | ||||
| except_source = (RuntimeError, | except_source = (RuntimeError, | ||||
| ModuleNotFoundError, | |||||
| ValueError, | ValueError, | ||||
| AssertionError, | |||||
| TypeError, | TypeError, | ||||
| OSError, | OSError, | ||||
| ZeroDivisionError) | |||||
| ZeroDivisionError, cls) | |||||
| return except_source | return except_source | ||||
| @classmethod | |||||
| def check_except_pytorch(cls, msg): | |||||
| """Check except.""" | |||||
| def decorator(func): | |||||
| def _f(arch, model_path, **kwargs): | |||||
| try: | |||||
| output = func(arch, model_path=model_path, **kwargs) | |||||
| except cls.raise_from() as e: | |||||
| error = cls(msg=msg) | |||||
| log.error(msg) | |||||
| log.exception(e) | |||||
| raise error from e | |||||
| return output | |||||
| return _f | |||||
| return decorator | |||||
| class TfRuntimeError(MindConverterException): | |||||
| """Catch tf runtime error.""" | |||||
| def __init__(self, msg): | |||||
| super(TfRuntimeError, self).__init__(error=ConverterErrors.TF_RUNTIME_ERROR, | |||||
| user_msg=msg, | |||||
| cls_code=ConverterErrors.GRAPH_INIT_FAIL.value) | |||||
| @classmethod | @classmethod | ||||
| def check_except_tf(cls, msg): | |||||
| """Check except.""" | |||||
| def raise_from(cls): | |||||
| tf_error_module = import_module('tensorflow.python.framework.errors_impl') | tf_error_module = import_module('tensorflow.python.framework.errors_impl') | ||||
| tf_error = getattr(tf_error_module, 'OpError') | tf_error = getattr(tf_error_module, 'OpError') | ||||
| cls._error = cls.raise_from() + (tf_error,) | |||||
| def decorator(func): | |||||
| def _f(arch, model_path, **kwargs): | |||||
| try: | |||||
| output = func(arch, model_path=model_path, **kwargs) | |||||
| except cls._error as e: | |||||
| error = cls(msg=msg) | |||||
| log.error(msg) | |||||
| log.exception(e) | |||||
| raise error from e | |||||
| return output | |||||
| return _f | |||||
| return decorator | |||||
| return tf_error, ValueError, RuntimeError, cls | |||||
| class NodeInputMissing(MindConverterException): | class NodeInputMissing(MindConverterException): | ||||
| @@ -352,6 +298,10 @@ class NodeInputMissing(MindConverterException): | |||||
| user_msg=msg, | user_msg=msg, | ||||
| cls_code=ConverterErrors.TREE_CREATE_FAIL.value) | cls_code=ConverterErrors.TREE_CREATE_FAIL.value) | ||||
| @classmethod | |||||
| def raise_from(cls): | |||||
| return ValueError, IndexError, KeyError, AttributeError, cls | |||||
| class TreeNodeInsertFail(MindConverterException): | class TreeNodeInsertFail(MindConverterException): | ||||
| """The tree node create fail error.""" | """The tree node create fail error.""" | ||||
| @@ -361,32 +311,15 @@ class TreeNodeInsertFail(MindConverterException): | |||||
| user_msg=msg, | user_msg=msg, | ||||
| cls_code=ConverterErrors.TREE_CREATE_FAIL.value) | cls_code=ConverterErrors.TREE_CREATE_FAIL.value) | ||||
| @staticmethod | |||||
| def raise_from(): | |||||
| @classmethod | |||||
| def raise_from(cls): | |||||
| """Raise from exceptions below.""" | """Raise from exceptions below.""" | ||||
| except_source = (OSError, | except_source = (OSError, | ||||
| DuplicatedNodeIdError, | DuplicatedNodeIdError, | ||||
| MultipleRootError, | MultipleRootError, | ||||
| NodeIDAbsentError) | |||||
| NodeIDAbsentError, cls) | |||||
| return except_source | return except_source | ||||
| @classmethod | |||||
| def check_except(cls, msg): | |||||
| """Check except.""" | |||||
| def decorator(func): | |||||
| def _f(arch, graph): | |||||
| try: | |||||
| output = func(arch, graph=graph) | |||||
| except cls.raise_from() as e: | |||||
| error = cls(msg=msg) | |||||
| log.error(msg) | |||||
| log.exception(e) | |||||
| raise error from e | |||||
| return output | |||||
| return _f | |||||
| return decorator | |||||
| class NodeInputTypeNotSupport(MindConverterException): | class NodeInputTypeNotSupport(MindConverterException): | ||||
| """The node input type NOT support error.""" | """The node input type NOT support error.""" | ||||
| @@ -396,6 +329,10 @@ class NodeInputTypeNotSupport(MindConverterException): | |||||
| user_msg=msg, | user_msg=msg, | ||||
| cls_code=ConverterErrors.SOURCE_FILES_SAVE_FAIL.value) | cls_code=ConverterErrors.SOURCE_FILES_SAVE_FAIL.value) | ||||
| @classmethod | |||||
| def raise_from(cls): | |||||
| return ValueError, TypeError, IndexError, cls | |||||
| class ScriptGenerateFail(MindConverterException): | class ScriptGenerateFail(MindConverterException): | ||||
| """The script generate fail error.""" | """The script generate fail error.""" | ||||
| @@ -405,31 +342,14 @@ class ScriptGenerateFail(MindConverterException): | |||||
| user_msg=msg, | user_msg=msg, | ||||
| cls_code=ConverterErrors.SOURCE_FILES_SAVE_FAIL.value) | cls_code=ConverterErrors.SOURCE_FILES_SAVE_FAIL.value) | ||||
| @staticmethod | |||||
| def raise_from(): | |||||
| @classmethod | |||||
| def raise_from(cls): | |||||
| """Raise from exceptions below.""" | """Raise from exceptions below.""" | ||||
| except_source = (RuntimeError, | except_source = (RuntimeError, | ||||
| parse.ParseError, | parse.ParseError, | ||||
| AttributeError) | |||||
| AttributeError, cls) | |||||
| return except_source | return except_source | ||||
| @classmethod | |||||
| def check_except(cls, msg): | |||||
| """Check except.""" | |||||
| def decorator(func): | |||||
| def _f(arch, mapper): | |||||
| try: | |||||
| output = func(arch, mapper=mapper) | |||||
| except cls.raise_from() as e: | |||||
| error = cls(msg=msg) | |||||
| log.error(msg) | |||||
| log.exception(e) | |||||
| raise error from e | |||||
| return output | |||||
| return _f | |||||
| return decorator | |||||
| class ReportGenerateFail(MindConverterException): | class ReportGenerateFail(MindConverterException): | ||||
| """The report generate fail error.""" | """The report generate fail error.""" | ||||
| @@ -439,28 +359,24 @@ class ReportGenerateFail(MindConverterException): | |||||
| user_msg=msg, | user_msg=msg, | ||||
| cls_code=ConverterErrors.SOURCE_FILES_SAVE_FAIL.value) | cls_code=ConverterErrors.SOURCE_FILES_SAVE_FAIL.value) | ||||
| @staticmethod | |||||
| def raise_from(): | |||||
| @classmethod | |||||
| def raise_from(cls): | |||||
| """Raise from exceptions below.""" | """Raise from exceptions below.""" | ||||
| except_source = ZeroDivisionError | |||||
| return except_source | |||||
| return ZeroDivisionError, cls | |||||
| @classmethod | |||||
| def check_except(cls, msg): | |||||
| """Check except.""" | |||||
| def decorator(func): | |||||
| def _f(arch, mapper): | |||||
| try: | |||||
| output = func(arch, mapper=mapper) | |||||
| except cls.raise_from() as e: | |||||
| error = cls(msg=msg) | |||||
| log.error(msg) | |||||
| log.exception(e) | |||||
| raise error from e | |||||
| return output | |||||
| return _f | |||||
| return decorator | |||||
| class SubGraphSearchingFail(MindConverterException): | |||||
| """Sub-graph searching exception.""" | |||||
| def __init__(self, msg): | |||||
| super(SubGraphSearchingFail, self).__init__(error=ConverterErrors.MODEL_NOT_SUPPORT, | |||||
| cls_code=ConverterErrors.SUB_GRAPH_SEARCHING_FAIL.value, | |||||
| user_msg=msg) | |||||
| @classmethod | |||||
| def raise_from(cls): | |||||
| """Define exception in sub-graph searching module.""" | |||||
| return IndexError, KeyError, ValueError, AttributeError, ZeroDivisionError, cls | |||||
| class GeneratorFail(MindConverterException): | class GeneratorFail(MindConverterException): | ||||
| @@ -470,10 +386,23 @@ class GeneratorFail(MindConverterException): | |||||
| super(GeneratorFail, self).__init__(error=ConverterErrors.NODE_CONVERSION_ERROR, | super(GeneratorFail, self).__init__(error=ConverterErrors.NODE_CONVERSION_ERROR, | ||||
| user_msg=msg, | user_msg=msg, | ||||
| cls_code=ConverterErrors.GENERATOR_FAIL.value) | cls_code=ConverterErrors.GENERATOR_FAIL.value) | ||||
| @classmethod | @classmethod | ||||
| def raise_from(cls): | def raise_from(cls): | ||||
| """Raise from exceptions below.""" | """Raise from exceptions below.""" | ||||
| except_source = (ValueError, | |||||
| TypeError, | |||||
| cls) | |||||
| except_source = (ValueError, TypeError, cls) | |||||
| return except_source | return except_source | ||||
| class ModelLoadingFail(MindConverterException): | |||||
| """Model loading fail.""" | |||||
| def __init__(self, msg): | |||||
| super(ModelLoadingFail, self).__init__(error=ConverterErrors.INPUT_SHAPE_ERROR, | |||||
| cls_code=ConverterErrors.MODEL_LOADING_FAIL.value, | |||||
| user_msg=msg) | |||||
| @classmethod | |||||
| def raise_from(cls): | |||||
| """Define exception when model loading fail.""" | |||||
| return ValueError, cls | |||||
| @@ -135,9 +135,11 @@ def lib_version_satisfied(current_ver: str, mini_ver_limited: str, | |||||
| if current_ver < mini_ver_limited or (newest_ver_limited and current_ver > newest_ver_limited): | if current_ver < mini_ver_limited or (newest_ver_limited and current_ver > newest_ver_limited): | ||||
| return False | return False | ||||
| return True | return True | ||||
| def get_dict_key_by_value(val, dic): | def get_dict_key_by_value(val, dic): | ||||
| """ | """ | ||||
| Return the first appeared key of a dictionay by given value. | |||||
| Return the first appeared key of a dictionary by given value. | |||||
| Args: | Args: | ||||
| val (Any): Value of the key. | val (Any): Value of the key. | ||||
| @@ -16,6 +16,7 @@ | |||||
| import os | import os | ||||
| import re | import re | ||||
| import argparse | import argparse | ||||
| import sys | |||||
| from importlib import import_module | from importlib import import_module | ||||
| from importlib.util import find_spec | from importlib.util import find_spec | ||||
| @@ -25,12 +26,11 @@ from mindinsight.mindconverter.graph_based_converter.common.utils import lib_ver | |||||
| from mindinsight.mindconverter.graph_based_converter.constant import BINARY_HEADER_PYTORCH_FILE, FrameworkType, \ | from mindinsight.mindconverter.graph_based_converter.constant import BINARY_HEADER_PYTORCH_FILE, FrameworkType, \ | ||||
| BINARY_HEADER_PYTORCH_BITS, ONNX_MIN_VER, TF2ONNX_MIN_VER, ONNXRUNTIME_MIN_VER | BINARY_HEADER_PYTORCH_BITS, ONNX_MIN_VER, TF2ONNX_MIN_VER, ONNXRUNTIME_MIN_VER | ||||
| from mindinsight.mindconverter.graph_based_converter.mapper import ONNXToMindSporeMapper | from mindinsight.mindconverter.graph_based_converter.mapper import ONNXToMindSporeMapper | ||||
| from mindinsight.mindconverter.common.log import logger as log | |||||
| from mindinsight.mindconverter.common.log import logger as log, logger_console as log_console | |||||
| from mindinsight.mindconverter.common.exceptions import GraphInitFail, TreeCreateFail, SourceFilesSaveFail, \ | from mindinsight.mindconverter.common.exceptions import GraphInitFail, TreeCreateFail, SourceFilesSaveFail, \ | ||||
| BaseConverterFail, UnknownModel | |||||
| BaseConverterFail, UnknownModel, GeneratorFail, TfRuntimeError | |||||
| from mindinsight.utils.exceptions import ParamMissError | from mindinsight.utils.exceptions import ParamMissError | ||||
| permissions = os.R_OK | os.W_OK | os.X_OK | permissions = os.R_OK | os.W_OK | os.X_OK | ||||
| os.umask(permissions << 3 | permissions) | os.umask(permissions << 3 | permissions) | ||||
| @@ -71,8 +71,10 @@ def torch_installation_validation(func): | |||||
| "scripts converter, and PyTorch vision must " | "scripts converter, and PyTorch vision must " | ||||
| "be consisted with model generation runtime.") | "be consisted with model generation runtime.") | ||||
| log.error(str(error)) | log.error(str(error)) | ||||
| log.exception(error) | |||||
| raise error | |||||
| detail_info = f"Error detail: {str(error)}" | |||||
| log_console.error(str(error)) | |||||
| log_console.error(detail_info) | |||||
| sys.exit(0) | |||||
| func(graph_path=graph_path, sample_shape=sample_shape, | func(graph_path=graph_path, sample_shape=sample_shape, | ||||
| output_folder=output_folder, report_folder=report_folder) | output_folder=output_folder, report_folder=report_folder) | ||||
| @@ -103,7 +105,10 @@ def tf_installation_validation(func): | |||||
| f"based scripts converter for TensorFlow conversion." | f"based scripts converter for TensorFlow conversion." | ||||
| ) | ) | ||||
| log.error(str(error)) | log.error(str(error)) | ||||
| raise error | |||||
| detail_info = f"Error detail: {str(error)}" | |||||
| log_console.error(str(error)) | |||||
| log_console.error(detail_info) | |||||
| sys.exit(0) | |||||
| onnx, tf2onnx = import_module("onnx"), import_module("tf2onnx") | onnx, tf2onnx = import_module("onnx"), import_module("tf2onnx") | ||||
| ort = import_module("onnxruntime") | ort = import_module("onnxruntime") | ||||
| @@ -117,7 +122,10 @@ def tf_installation_validation(func): | |||||
| f"based scripts converter for TensorFlow conversion." | f"based scripts converter for TensorFlow conversion." | ||||
| ) | ) | ||||
| log.error(str(error)) | log.error(str(error)) | ||||
| raise error | |||||
| detail_info = f"Error detail: {str(error)}" | |||||
| log_console.error(str(error)) | |||||
| log_console.error(detail_info) | |||||
| sys.exit(0) | |||||
| func(graph_path=graph_path, sample_shape=sample_shape, | func(graph_path=graph_path, sample_shape=sample_shape, | ||||
| output_folder=output_folder, report_folder=report_folder, | output_folder=output_folder, report_folder=report_folder, | ||||
| @@ -142,9 +150,10 @@ def _extract_model_name(model_path): | |||||
| @torch_installation_validation | @torch_installation_validation | ||||
| @GraphInitFail.check_except_pytorch("Error occurred when init graph object.") | |||||
| @TreeCreateFail.check_except_pytorch("Error occurred when create hierarchical tree.") | |||||
| @SourceFilesSaveFail.check_except_pytorch("Error occurred when save source files.") | |||||
| @GraphInitFail.uniform_catcher("Error occurred when init graph object.") | |||||
| @TreeCreateFail.uniform_catcher("Error occurred when create hierarchical tree.") | |||||
| @SourceFilesSaveFail.uniform_catcher("Error occurred when save source files.") | |||||
| @GeneratorFail.uniform_catcher("Error occurred when generate code.") | |||||
| def graph_based_converter_pytorch_to_ms(graph_path: str, sample_shape: tuple, | def graph_based_converter_pytorch_to_ms(graph_path: str, sample_shape: tuple, | ||||
| output_folder: str, report_folder: str = None): | output_folder: str, report_folder: str = None): | ||||
| """ | """ | ||||
| @@ -176,9 +185,11 @@ def graph_based_converter_pytorch_to_ms(graph_path: str, sample_shape: tuple, | |||||
| @tf_installation_validation | @tf_installation_validation | ||||
| @GraphInitFail.check_except_tf("Error occurred when init graph object.") | |||||
| @TreeCreateFail.check_except_tf("Error occurred when create hierarchical tree.") | |||||
| @SourceFilesSaveFail.check_except_tf("Error occurred when save source files.") | |||||
| @GraphInitFail.uniform_catcher("Error occurred when init graph object.") | |||||
| @TfRuntimeError.uniform_catcher("Error occurred when init graph, TensorFlow runtime error.") | |||||
| @TreeCreateFail.uniform_catcher("Error occurred when create hierarchical tree.") | |||||
| @SourceFilesSaveFail.uniform_catcher("Error occurred when save source files.") | |||||
| @GeneratorFail.uniform_catcher("Error occurred when generate code.") | |||||
| def graph_based_converter_tf_to_ms(graph_path: str, sample_shape: tuple, | def graph_based_converter_tf_to_ms(graph_path: str, sample_shape: tuple, | ||||
| input_nodes: str, output_nodes: str, | input_nodes: str, output_nodes: str, | ||||
| output_folder: str, report_folder: str = None): | output_folder: str, report_folder: str = None): | ||||
| @@ -210,7 +221,7 @@ def graph_based_converter_tf_to_ms(graph_path: str, sample_shape: tuple, | |||||
| save_code_file_and_report(model_name, code_fragments, output_folder, report_folder) | save_code_file_and_report(model_name, code_fragments, output_folder, report_folder) | ||||
| @BaseConverterFail.check_except("Failed to start base converter.") | |||||
| @BaseConverterFail.uniform_catcher("Failed to start base converter.") | |||||
| def main_graph_base_converter(file_config): | def main_graph_base_converter(file_config): | ||||
| """ | """ | ||||
| The entrance for converter, script files will be converted. | The entrance for converter, script files will be converted. | ||||
| @@ -201,6 +201,7 @@ class ArgsTranslation: | |||||
| class ArgsTranslationHelper: | class ArgsTranslationHelper: | ||||
| """Define operations related to ArgsTranslation instances.""" | """Define operations related to ArgsTranslation instances.""" | ||||
| @staticmethod | @staticmethod | ||||
| def find_formal_args_in_modules(args_translators): | def find_formal_args_in_modules(args_translators): | ||||
| """ | """ | ||||
| @@ -541,7 +541,7 @@ class ModuleStruct: | |||||
| for output in output_list: | for output in output_list: | ||||
| (provider_succ, provider_closet_opt_var) = output | (provider_succ, provider_closet_opt_var) = output | ||||
| if provider_closet_opt_var in struct.matched_inputs: | if provider_closet_opt_var in struct.matched_inputs: | ||||
| continue # skip repeat | |||||
| continue # skip repeat | |||||
| if provider_succ == struct.onnx_name: | if provider_succ == struct.onnx_name: | ||||
| struct.matched_inputs.append(provider_closet_opt_var) | struct.matched_inputs.append(provider_closet_opt_var) | ||||
| @@ -695,20 +695,5 @@ class ModuleStruct: | |||||
| """Register submodule outputs to this module's return.""" | """Register submodule outputs to this module's return.""" | ||||
| submodule_returns = md_struct.get_returned_opt_var_name() | submodule_returns = md_struct.get_returned_opt_var_name() | ||||
| submodule_opt_var_name = md_struct.ms_opt_var_name | submodule_opt_var_name = md_struct.ms_opt_var_name | ||||
| for (submodule_ext_succ, opt_var_name_in_this_module, ith_output) in submodule_returns: | |||||
| for (submodule_ext_succ, _, ith_output) in submodule_returns: | |||||
| self.external_successor_local_returns_map[submodule_ext_succ] = (submodule_opt_var_name, ith_output) | self.external_successor_local_returns_map[submodule_ext_succ] = (submodule_opt_var_name, ith_output) | ||||
| # edit external succ 's inputs in parent module | |||||
| ext_node = self.GLOBAL_CONTEXT_MGR.onnx_node_name_to_node_struct_map.get(submodule_ext_succ) | |||||
| ext_node_parent = ext_node.parent_module_struct | |||||
| while ext_node_parent != self.parent_module_struct: | |||||
| ext_node_parent.inputs_in_parent_module[ext_node.onnx_name] = md_struct.ms_opt_var_name | |||||
| ext_node_parent = ext_node_parent.parent_module_struct | |||||
| # need find the prec_name? | |||||
| for ext_node_prec, opt_var_name in ext_node.inputs_in_parent_module.copy().items(): | |||||
| if isinstance(opt_var_name, str): | |||||
| if opt_var_name == opt_var_name_in_this_module: | |||||
| ext_node.inputs_in_parent_module[ext_node_prec] = (self.ms_opt_var_name, ith_output) | |||||
| if isinstance(opt_var_name, tuple): | |||||
| if opt_var_name[0] == opt_var_name_in_this_module: | |||||
| ext_node.inputs_in_parent_module[ext_node_prec] = (self.ms_opt_var_name, ith_output) | |||||
| @@ -34,6 +34,7 @@ class Pattern: | |||||
| # If pattern in BUILD_IN_MODULE_NAME or BUILD_IN_PATTERN, | # If pattern in BUILD_IN_MODULE_NAME or BUILD_IN_PATTERN, | ||||
| # the pattern will get additional score. | # the pattern will get additional score. | ||||
| self.additional_score = 0 | self.additional_score = 0 | ||||
| self.know_module_name = None | |||||
| def insert(self, idx, seq_len): | def insert(self, idx, seq_len): | ||||
| """ | """ | ||||
| @@ -18,7 +18,7 @@ import uuid | |||||
| from typing import Dict, List, Callable, Union | from typing import Dict, List, Callable, Union | ||||
| from collections import OrderedDict | from collections import OrderedDict | ||||
| from .common import context, gen_hash_key, DagGraph, MAX_OUT_DEGREE, cal_matching_score | from .common import context, gen_hash_key, DagGraph, MAX_OUT_DEGREE, cal_matching_score | ||||
| from .known_module_name import BUILT_IN_MODULE_NAME, is_built_in_module_name | |||||
| from .known_module_name import BUILT_IN_MODULE_NAME | |||||
| from .pattern import Pattern, scope_name_mapping | from .pattern import Pattern, scope_name_mapping | ||||
| from .built_in_pattern import BUILT_IN_PATTERN, is_built_in_pattern | from .built_in_pattern import BUILT_IN_PATTERN, is_built_in_pattern | ||||
| from .pattern_fuzzy_matching import pattern_fuzzy_matching | from .pattern_fuzzy_matching import pattern_fuzzy_matching | ||||
| @@ -85,13 +85,15 @@ def _is_valid_pattern(pattern, dag): | |||||
| return True | return True | ||||
| def generate_module_name(pattern): | |||||
| def match_known_module_name(pattern): | |||||
| """ | """ | ||||
| Generate module name. | |||||
| Matching with know module name. | |||||
| Args: | Args: | ||||
| pattern (Pattern): To be replaced pattern. | pattern (Pattern): To be replaced pattern. | ||||
| Returns: | |||||
| str, matched module name, return None if not matched. | |||||
| """ | """ | ||||
| matched_result = [] | matched_result = [] | ||||
| for ptn, module_name in BUILT_IN_MODULE_NAME.items(): | for ptn, module_name in BUILT_IN_MODULE_NAME.items(): | ||||
| @@ -109,7 +111,11 @@ def generate_module_name(pattern): | |||||
| module_name = f"{module_name}{used_module_name[pattern.pattern]}" | module_name = f"{module_name}{used_module_name[pattern.pattern]}" | ||||
| used_module_name[pattern.pattern] += 1 | used_module_name[pattern.pattern] += 1 | ||||
| return module_name | return module_name | ||||
| return None | |||||
| def generate_module_name(): | |||||
| """Generate module name.""" | |||||
| global global_idx | global global_idx | ||||
| name = f"Module{global_idx}" | name = f"Module{global_idx}" | ||||
| global_idx += 1 | global_idx += 1 | ||||
| @@ -439,13 +445,16 @@ class SearchPath: | |||||
| to recover the sequence. | to recover the sequence. | ||||
| """ | """ | ||||
| if self.pattern.pattern not in scope_name_mapping: | if self.pattern.pattern not in scope_name_mapping: | ||||
| module_name = generate_module_name(self.pattern) | |||||
| scope_name_mapping[self.pattern.pattern] = module_name | |||||
| module_name = generate_module_name() | |||||
| known_module_name = match_known_module_name(self.pattern) | |||||
| scope_name_mapping[self.pattern] = module_name | |||||
| module_name_to_src[module_name] = self.pattern.pattern | module_name_to_src[module_name] = self.pattern.pattern | ||||
| else: | else: | ||||
| module_name = scope_name_mapping[self.pattern.pattern] | module_name = scope_name_mapping[self.pattern.pattern] | ||||
| known_module_name = module_name_to_src[module_name].known_module_name | |||||
| self.pattern.module_name = module_name | self.pattern.module_name = module_name | ||||
| if is_built_in_module_name(module_name): | |||||
| self.pattern.known_module_name = known_module_name | |||||
| if known_module_name: | |||||
| self.pattern.additional_score += cal_matching_score(self.pattern.ptn_length) | self.pattern.additional_score += cal_matching_score(self.pattern.ptn_length) | ||||
| topo_order, inverted_index = self.replace_sub_graph_completely(self.pattern, self.topo_order_bef_repl) | topo_order, inverted_index = self.replace_sub_graph_completely(self.pattern, self.topo_order_bef_repl) | ||||
| return topo_order, inverted_index | return topo_order, inverted_index | ||||
| @@ -18,8 +18,10 @@ from typing import Dict, List | |||||
| from .common import context, DagGraph, gen_hash_key, ACCEPTABLE_RESULT_COUNT | from .common import context, DagGraph, gen_hash_key, ACCEPTABLE_RESULT_COUNT | ||||
| from .common import MINI_FREQUENCY, MAX_ITERATION_DEPTH, SATISFIED_SCORE | from .common import MINI_FREQUENCY, MAX_ITERATION_DEPTH, SATISFIED_SCORE | ||||
| from ..common.global_context import GlobalContext | |||||
| from ..third_party_graph.onnx_utils import BaseNode | from ..third_party_graph.onnx_utils import BaseNode | ||||
| from .search_path import SearchPath, Pattern, generate_pattern, find_built_in_pattern | from .search_path import SearchPath, Pattern, generate_pattern, find_built_in_pattern | ||||
| from ...common.exceptions import SubGraphSearchingFail | |||||
| def _is_satisfied(path): | def _is_satisfied(path): | ||||
| @@ -249,6 +251,23 @@ def validate_topo_order_succession(): | |||||
| return True | return True | ||||
| def _add_known_module_name(search_path): | |||||
| """ | |||||
| Add known module name to GlobalContext. | |||||
| Args: | |||||
| search_path (SearchPath): Search path. | |||||
| """ | |||||
| ctx = GlobalContext() | |||||
| if search_path.pattern.known_module_name: | |||||
| ctx.known_module_name[search_path.pattern.module_name] = search_path.pattern.known_module_name | |||||
| for it in search_path.recursion_path: | |||||
| if it.pattern.known_module_name: | |||||
| ctx.known_module_name[it.pattern.module_name] = it.pattern.known_module_name | |||||
| @SubGraphSearchingFail.check_except("Sub-Graph searching fail.") | |||||
| def generate_scope_name(data_loader): | def generate_scope_name(data_loader): | ||||
| """ | """ | ||||
| Generate scope name according to computation graph. | Generate scope name according to computation graph. | ||||
| @@ -270,6 +289,9 @@ def generate_scope_name(data_loader): | |||||
| if len(topo_order_with_scope_name_list) != len(data_loader.nodes_dict): | if len(topo_order_with_scope_name_list) != len(data_loader.nodes_dict): | ||||
| topo_order_with_scope_name_list = flatten_graph(init_dag) | topo_order_with_scope_name_list = flatten_graph(init_dag) | ||||
| if result: | |||||
| _add_known_module_name(result) | |||||
| except (ValueError, IndexError, AttributeError, KeyError) as _: | except (ValueError, IndexError, AttributeError, KeyError) as _: | ||||
| topo_order_with_scope_name_list = flatten_graph(init_dag) | topo_order_with_scope_name_list = flatten_graph(init_dag) | ||||
| return topo_order_with_scope_name_list | return topo_order_with_scope_name_list | ||||
| @@ -27,7 +27,7 @@ from ..common.global_context import GlobalContext | |||||
| from ..constant import ONNX_TYPE_INT, ONNX_TYPE_INTS, ONNX_TYPE_STRING, \ | from ..constant import ONNX_TYPE_INT, ONNX_TYPE_INTS, ONNX_TYPE_STRING, \ | ||||
| ONNX_TYPE_FLOATS, ONNX_TYPE_FLOAT, SCALAR_WITHOUT_SHAPE, DYNAMIC_SHAPE, UNKNOWN_DIM_VAL | ONNX_TYPE_FLOATS, ONNX_TYPE_FLOAT, SCALAR_WITHOUT_SHAPE, DYNAMIC_SHAPE, UNKNOWN_DIM_VAL | ||||
| from ...common.exceptions import GraphInitFail, ModelNotSupport | |||||
| from ...common.exceptions import GraphInitFail, ModelNotSupport, ModelLoadingFail | |||||
| def convert_tf_graph_to_onnx(model_path, model_inputs, model_outputs, opset=12): | def convert_tf_graph_to_onnx(model_path, model_inputs, model_outputs, opset=12): | ||||
| @@ -308,7 +308,7 @@ class OnnxDataLoader: | |||||
| w = int(match.group('w')) | w = int(match.group('w')) | ||||
| c = int(match.group('c')) | c = int(match.group('c')) | ||||
| if [h, w, c] != list(self.graph_input_shape)[1:4]: | if [h, w, c] != list(self.graph_input_shape)[1:4]: | ||||
| raise ValueError(f"Shape given should be (N, {h}, {w}, {c}) but got {self.graph_input_shape}") | |||||
| raise ModelLoadingFail(f"Shape given should be (N, {h}, {w}, {c}) but got {self.graph_input_shape}") | |||||
| return True | return True | ||||
| return False | return False | ||||
| @@ -25,7 +25,9 @@ class PyTorchGraphParser(GraphParser): | |||||
| """Define pytorch graph parser.""" | """Define pytorch graph parser.""" | ||||
| @classmethod | @classmethod | ||||
| @ModelNotSupport.check_except_pytorch("Error occurs in loading model, make sure model.pth correct.") | |||||
| @ModelNotSupport.check_except( | |||||
| "Error occurs in loading model, please check your model or runtime environment integrity." | |||||
| ) | |||||
| def parse(cls, model_path: str, **kwargs): | def parse(cls, model_path: str, **kwargs): | ||||
| """ | """ | ||||
| Parser pytorch graph. | Parser pytorch graph. | ||||
| @@ -50,11 +52,9 @@ class PyTorchGraphParser(GraphParser): | |||||
| else: | else: | ||||
| model = torch.load(f=model_path, map_location="cpu") | model = torch.load(f=model_path, map_location="cpu") | ||||
| except ModuleNotFoundError: | except ModuleNotFoundError: | ||||
| error_msg = \ | |||||
| "Cannot find model scripts in system path, " \ | |||||
| "set `--project_path` to the path of model scripts folder correctly." | |||||
| error_msg = "Cannot find model scripts in system path, " \ | |||||
| "set `--project_path` to the path of model scripts folder correctly." | |||||
| error = ModuleNotFoundError(error_msg) | error = ModuleNotFoundError(error_msg) | ||||
| log.error(str(error)) | |||||
| raise error from None | |||||
| raise error | |||||
| return model | return model | ||||
| @@ -25,7 +25,9 @@ class TFGraphParser(GraphParser): | |||||
| """Define TF graph parser.""" | """Define TF graph parser.""" | ||||
| @classmethod | @classmethod | ||||
| @ModelNotSupport.check_except_tf("Error occurs in loading model, make sure model.pb correct.") | |||||
| @ModelNotSupport.check_except( | |||||
| "Error occurs in loading model, please check your model or runtime environment integrity." | |||||
| ) | |||||
| def parse(cls, model_path: str, **kwargs): | def parse(cls, model_path: str, **kwargs): | ||||
| """ | """ | ||||
| Parse TF Computational Graph File (.pb) | Parse TF Computational Graph File (.pb) | ||||
| @@ -36,7 +38,6 @@ class TFGraphParser(GraphParser): | |||||
| Returns: | Returns: | ||||
| object, ONNX model. | object, ONNX model. | ||||
| """ | """ | ||||
| onnx_utils = import_module( | onnx_utils = import_module( | ||||
| "mindinsight.mindconverter.graph_based_converter.third_party_graph.onnx_utils") | "mindinsight.mindconverter.graph_based_converter.third_party_graph.onnx_utils") | ||||
| convert_tf_graph_to_onnx = getattr(onnx_utils, "convert_tf_graph_to_onnx") | convert_tf_graph_to_onnx = getattr(onnx_utils, "convert_tf_graph_to_onnx") | ||||
| @@ -50,6 +51,5 @@ class TFGraphParser(GraphParser): | |||||
| model = convert_tf_graph_to_onnx(model_path, | model = convert_tf_graph_to_onnx(model_path, | ||||
| model_inputs=tf_input_nodes, | model_inputs=tf_input_nodes, | ||||
| model_outputs=tf_output_nodes, | |||||
| ) | |||||
| model_outputs=tf_output_nodes) | |||||
| return model | return model | ||||
| @@ -21,13 +21,10 @@ Usage: | |||||
| """ | """ | ||||
| import difflib | import difflib | ||||
| import os | import os | ||||
| import re | |||||
| import sys | import sys | ||||
| import pytest | import pytest | ||||
| from mindinsight.mindconverter.converter import main | from mindinsight.mindconverter.converter import main | ||||
| from mindinsight.mindconverter.graph_based_converter.framework import main_graph_base_converter | |||||
| @pytest.mark.usefixtures('create_output_dir') | @pytest.mark.usefixtures('create_output_dir') | ||||
| @@ -82,35 +79,3 @@ class TestConverter: | |||||
| converted_ratio = 100 - (diff_lines * 100) / (len(expect_source)) | converted_ratio = 100 - (diff_lines * 100) / (len(expect_source)) | ||||
| assert converted_ratio >= 80 | assert converted_ratio >= 80 | ||||
| @pytest.mark.level0 | |||||
| @pytest.mark.platform_arm_ascend_training | |||||
| @pytest.mark.platform_x86_gpu_training | |||||
| @pytest.mark.platform_x86_ascend_training | |||||
| @pytest.mark.platform_x86_cpu | |||||
| @pytest.mark.env_single | |||||
| def test_main_graph_based_converter(self, output): | |||||
| """Test main graph based converter.""" | |||||
| pytorch_filename = "resnet50.pth" | |||||
| expected_model_filename = "resnet50.py" | |||||
| expected_report_filename = "report_of_resnet50.txt" | |||||
| file_config = { | |||||
| 'model_file': os.path.join(self.pytorch_dir, pytorch_filename), | |||||
| 'shape': (1, 3, 224, 224), | |||||
| 'outfile_dir': output, | |||||
| 'report_dir': output | |||||
| } | |||||
| with pytest.raises(ValueError) as e: | |||||
| main_graph_base_converter(file_config=file_config) | |||||
| assert os.path.isfile(os.path.join(output, expected_model_filename)) | |||||
| assert os.path.isfile(os.path.join(output, expected_report_filename)) | |||||
| with open(os.path.join(output, expected_report_filename)) as converted_r: | |||||
| converted_report = converted_r.readlines() | |||||
| converted_rate = re.findall(r".*(?:Converted Rate: )(.*)[.]", converted_report[-1]) | |||||
| assert converted_rate[0] == '100.00%' | |||||
| exec_msg = e.value.args[0] | |||||
| assert exec_msg == "torch.__spec__ is None" | |||||