Browse Source

Add module rename function and add exception define.

tags/v1.1.0
liuchongming 5 years ago
parent
commit
d58a5bcbbb
13 changed files with 249 additions and 321 deletions
  1. +54
    -51
      mindinsight/mindconverter/cli.py
  2. +114
    -185
      mindinsight/mindconverter/common/exceptions.py
  3. +3
    -1
      mindinsight/mindconverter/graph_based_converter/common/utils.py
  4. +25
    -14
      mindinsight/mindconverter/graph_based_converter/framework.py
  5. +1
    -0
      mindinsight/mindconverter/graph_based_converter/generator/args_translator.py
  6. +2
    -17
      mindinsight/mindconverter/graph_based_converter/generator/module_struct.py
  7. +1
    -0
      mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/pattern.py
  8. +15
    -6
      mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/search_path.py
  9. +22
    -0
      mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/searcher.py
  10. +2
    -2
      mindinsight/mindconverter/graph_based_converter/third_party_graph/onnx_utils.py
  11. +6
    -6
      mindinsight/mindconverter/graph_based_converter/third_party_graph/pytorch_graph_parser.py
  12. +4
    -4
      mindinsight/mindconverter/graph_based_converter/third_party_graph/tf_graph_parser.py
  13. +0
    -35
      tests/st/func/mindconverter/test_converter.py

+ 54
- 51
mindinsight/mindconverter/cli.py View File

@@ -170,6 +170,9 @@ class InFileAction(argparse.Action):
if not os.path.isfile(outfile_dir): if not os.path.isfile(outfile_dir):
parser_in.error(f'{option_string} {outfile_dir} is not a file') parser_in.error(f'{option_string} {outfile_dir} is not a file')


if not os.path.basename(outfile_dir).endswith("py"):
parser_in.error(f'{option_string} {outfile_dir} is not a valid python file')

setattr(namespace, self.dest, outfile_dir) setattr(namespace, self.dest, outfile_dir)




@@ -282,32 +285,32 @@ class NodeAction(argparse.Action):




parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
prog='mindconverter',
description='MindConverter CLI entry point (version: {})'.format(mindinsight.__version__),
allow_abbrev=False)
prog='mindconverter',
description='MindConverter CLI entry point (version: {})'.format(mindinsight.__version__),
allow_abbrev=False)


parser.add_argument( parser.add_argument(
'--version',
action='version',
version='%(prog)s ({})'.format(mindinsight.__version__))
'--version',
action='version',
version='%(prog)s ({})'.format(mindinsight.__version__))


parser.add_argument( parser.add_argument(
'--in_file',
type=str,
action=InFileAction,
required=False,
default=None,
help="""
'--in_file',
type=str,
action=InFileAction,
required=False,
default=None,
help="""
Specify path for script file to use AST schema to Specify path for script file to use AST schema to
do script conversation. do script conversation.
""") """)


parser.add_argument( parser.add_argument(
'--model_file',
type=str,
action=ModelFileAction,
required=False,
help="""
'--model_file',
type=str,
action=ModelFileAction,
required=False,
help="""
PyTorch .pth or Tensorflow .pb model file path to use graph PyTorch .pth or Tensorflow .pb model file path to use graph
based schema to do script generation. When based schema to do script generation. When
`--in_file` and `--model_file` are both provided, `--in_file` and `--model_file` are both provided,
@@ -315,12 +318,12 @@ parser.add_argument(
""") """)


parser.add_argument( parser.add_argument(
'--shape',
type=str,
action=ShapeAction,
default=None,
required=False,
help="""
'--shape',
type=str,
action=ShapeAction,
default=None,
required=False,
help="""
Optional, expected input tensor shape of Optional, expected input tensor shape of
`--model_file`. It's required when use graph based `--model_file`. It's required when use graph based
schema. schema.
@@ -328,55 +331,55 @@ parser.add_argument(
""") """)


parser.add_argument( parser.add_argument(
'--input_nodes',
type=str,
action=NodeAction,
default=None,
required=False,
help="""
'--input_nodes',
type=str,
action=NodeAction,
default=None,
required=False,
help="""
Optional, input node(s) name of `--model_file`. It's required when use Tensorflow model. Optional, input node(s) name of `--model_file`. It's required when use Tensorflow model.
Usage: --input_nodes input_1:0,input_2:0 Usage: --input_nodes input_1:0,input_2:0
""") """)


parser.add_argument( parser.add_argument(
'--output_nodes',
type=str,
action=NodeAction,
default=None,
required=False,
help="""
'--output_nodes',
type=str,
action=NodeAction,
default=None,
required=False,
help="""
Optional, output node(s) name of `--model_file`. It's required when use Tensorflow model. Optional, output node(s) name of `--model_file`. It's required when use Tensorflow model.
Usage: --output_nodes output_1:0,output_2:0 Usage: --output_nodes output_1:0,output_2:0
""") """)


parser.add_argument( parser.add_argument(
'--output',
type=str,
action=OutputDirAction,
default=os.path.join(os.getcwd(), 'output'),
help="""
'--output',
type=str,
action=OutputDirAction,
default=os.path.join(os.getcwd(), 'output'),
help="""
Optional, specify path for converted script file Optional, specify path for converted script file
directory. Default output directory is `output` folder directory. Default output directory is `output` folder
in the current working directory. in the current working directory.
""") """)


parser.add_argument( parser.add_argument(
'--report',
type=str,
action=LogFileAction,
default=None,
help="""
'--report',
type=str,
action=LogFileAction,
default=None,
help="""
Optional, specify report directory. Default is Optional, specify report directory. Default is
converted script directory. converted script directory.
""") """)


parser.add_argument( parser.add_argument(
'--project_path',
type=str,
action=ProjectPathAction,
required=False,
default=None,
help="""
'--project_path',
type=str,
action=ProjectPathAction,
required=False,
default=None,
help="""
Optional, PyTorch scripts project path. If PyTorch Optional, PyTorch scripts project path. If PyTorch
project is not in PYTHONPATH, please assign project is not in PYTHONPATH, please assign
`--project_path` when use graph based schema. `--project_path` when use graph based schema.


+ 114
- 185
mindinsight/mindconverter/common/exceptions.py View File

@@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
"""Define custom exception.""" """Define custom exception."""
import abc
import sys import sys
from enum import unique from enum import unique
from importlib import import_module from importlib import import_module
@@ -23,7 +24,7 @@ from treelib.exceptions import DuplicatedNodeIdError, MultipleRootError, NodeIDA
from mindinsight.mindconverter.common.log import logger as log, logger_console as log_console from mindinsight.mindconverter.common.log import logger as log, logger_console as log_console


from mindinsight.utils.constant import ScriptConverterErrors from mindinsight.utils.constant import ScriptConverterErrors
from mindinsight.utils.exceptions import MindInsightException, ParamMissError
from mindinsight.utils.exceptions import MindInsightException




@unique @unique
@@ -40,12 +41,16 @@ class ConverterErrors(ScriptConverterErrors):
SCRIPT_GENERATE_FAIL = 9 SCRIPT_GENERATE_FAIL = 9
REPORT_GENERATE_FAIL = 10 REPORT_GENERATE_FAIL = 10
NODE_CONVERSION_ERROR = 11 NODE_CONVERSION_ERROR = 11
INPUT_SHAPE_ERROR = 12
TF_RUNTIME_ERROR = 13


BASE_CONVERTER_FAIL = 000 BASE_CONVERTER_FAIL = 000
GRAPH_INIT_FAIL = 100 GRAPH_INIT_FAIL = 100
TREE_CREATE_FAIL = 200 TREE_CREATE_FAIL = 200
SOURCE_FILES_SAVE_FAIL = 300 SOURCE_FILES_SAVE_FAIL = 300
GENERATOR_FAIL = 400 GENERATOR_FAIL = 400
SUB_GRAPH_SEARCHING_FAIL = 500
MODEL_LOADING_FAIL = 600




class ScriptNotSupport(MindInsightException): class ScriptNotSupport(MindInsightException):
@@ -80,7 +85,6 @@ class MindConverterException(Exception):


def __init__(self, **kwargs): def __init__(self, **kwargs):
"""Initialization of MindInsightException.""" """Initialization of MindInsightException."""

error = kwargs.get('error', None) error = kwargs.get('error', None)
user_msg = kwargs.get('user_msg', '') user_msg = kwargs.get('user_msg', '')
debug_msg = kwargs.get('debug_msg', '') debug_msg = kwargs.get('debug_msg', '')
@@ -97,6 +101,9 @@ class MindConverterException(Exception):
def __str__(self): def __str__(self):
return '[{}] code: {}, msg: {}'.format(self.__class__.__name__, self.error_code(), self.user_msg) return '[{}] code: {}, msg: {}'.format(self.__class__.__name__, self.error_code(), self.user_msg)


def __repr__(self):
return self.__str__()

def error_code(self): def error_code(self):
"""" """"
Calculate error code. Calculate error code.
@@ -109,54 +116,59 @@ class MindConverterException(Exception):
Returns: Returns:
str, Hex string representing the composed MindConverter error code. str, Hex string representing the composed MindConverter error code.
""" """

num = 0xFFFF & self.error.value num = 0xFFFF & self.error.value
error_code = ''.join((f'{self.cls_code}'.zfill(3), hex(num)[2:].zfill(4).upper())) error_code = ''.join((f'{self.cls_code}'.zfill(3), hex(num)[2:].zfill(4).upper()))
return error_code return error_code


@staticmethod
def raise_from():
@classmethod
@abc.abstractmethod
def raise_from(cls):
"""Raise from below exceptions.""" """Raise from below exceptions."""
return None


@classmethod @classmethod
def check_except_with_print_pytorch(cls, msg):
"""Check except in pytorch."""
def uniform_catcher(cls, msg):
"""Uniform exception catcher."""


def decorator(func): def decorator(func):
def _f(graph_path, sample_shape, output_folder, report_folder):
def _f(*args, **kwargs):
try: try:
func(graph_path=graph_path, sample_shape=sample_shape,
output_folder=output_folder, report_folder=report_folder)
res = func(*args, **kwargs)
except cls.raise_from() as e: except cls.raise_from() as e:
error = cls(msg=msg) error = cls(msg=msg)
detail_info = f"Error detail: {str(e)}" detail_info = f"Error detail: {str(e)}"
log_console.error(str(error)) log_console.error(str(error))
log_console.error(detail_info) log_console.error(detail_info)
log.exception(e) log.exception(e)
sys.exit(-1)
sys.exit(0)
except ModuleNotFoundError as e:
detail_info = f"Error detail: Required package not found, please check the runtime environment."
log_console.error(str(e))
log_console.error(detail_info)
log.exception(e)
sys.exit(0)
return res

return _f return _f

return decorator return decorator


@classmethod @classmethod
def check_except_with_print_tf(cls, msg):
"""Check except in tf."""
def check_except(cls, msg):
"""Check except."""


def decorator(func): def decorator(func):
def _f(graph_path, sample_shape,
input_nodes, output_nodes,
output_folder, report_folder):
def _f(*args, **kwargs):
try: try:
func(graph_path=graph_path, sample_shape=sample_shape,
input_nodes=input_nodes, output_nodes=output_nodes,
output_folder=output_folder, report_folder=report_folder)
output = func(*args, **kwargs)
except cls.raise_from() as e: except cls.raise_from() as e:
error = cls(msg=msg)
detail_info = f"Error detail: {str(e)}"
log_console.error(str(error))
log_console.error(detail_info)
log.error(msg)
log.exception(e)
raise cls(msg=msg)
except Exception as e:
log.error(msg)
log.exception(e) log.exception(e)
sys.exit(-1)
raise e
return output


return _f return _f


@@ -170,31 +182,12 @@ class BaseConverterFail(MindConverterException):
super(BaseConverterFail, self).__init__(error=ConverterErrors.BASE_CONVERTER_FAIL, super(BaseConverterFail, self).__init__(error=ConverterErrors.BASE_CONVERTER_FAIL,
user_msg=msg) user_msg=msg)


@staticmethod
def raise_from():
@classmethod
def raise_from(cls):
"""Raise from exceptions below.""" """Raise from exceptions below."""
except_source = (UnknownModel,
ParamMissError)
except_source = Exception, cls
return except_source return except_source


@classmethod
def check_except(cls, msg):
"""Check except."""

def decorator(func):
def _f(file_config):
try:
func(file_config=file_config)
except cls.raise_from() as e:
error = cls(msg=msg)
detail_info = f"Error detail: {str(e)}"
log_console.error(str(error))
log_console.error(detail_info)
log.exception(e)
sys.exit(-1)
return _f
return decorator



class UnknownModel(MindConverterException): class UnknownModel(MindConverterException):
"""The unknown model error.""" """The unknown model error."""
@@ -203,6 +196,10 @@ class UnknownModel(MindConverterException):
super(UnknownModel, self).__init__(error=ConverterErrors.UNKNOWN_MODEL, super(UnknownModel, self).__init__(error=ConverterErrors.UNKNOWN_MODEL,
user_msg=msg) user_msg=msg)


@classmethod
def raise_from(cls):
return cls



class GraphInitFail(MindConverterException): class GraphInitFail(MindConverterException):
"""The graph init fail error.""" """The graph init fail error."""
@@ -211,27 +208,19 @@ class GraphInitFail(MindConverterException):
super(GraphInitFail, self).__init__(error=ConverterErrors.GRAPH_INIT_FAIL, super(GraphInitFail, self).__init__(error=ConverterErrors.GRAPH_INIT_FAIL,
user_msg=kwargs.get('msg', '')) user_msg=kwargs.get('msg', ''))


@staticmethod
def raise_from():
@classmethod
def raise_from(cls):
"""Raise from exceptions below.""" """Raise from exceptions below."""
except_source = (FileNotFoundError, except_source = (FileNotFoundError,
ModuleNotFoundError, ModuleNotFoundError,
ModelNotSupport, ModelNotSupport,
SubGraphSearchingFail,
TypeError, TypeError,
ZeroDivisionError, ZeroDivisionError,
RuntimeError)
RuntimeError,
cls)
return except_source return except_source


@classmethod
def check_except_pytorch(cls, msg):
"""Check except for pytorch."""
return super().check_except_with_print_pytorch(msg)

@classmethod
def check_except_tf(cls, msg):
"""Check except for tf."""
return super().check_except_with_print_tf(msg)



class TreeCreateFail(MindConverterException): class TreeCreateFail(MindConverterException):
"""The tree create fail.""" """The tree create fail."""
@@ -240,23 +229,13 @@ class TreeCreateFail(MindConverterException):
super(TreeCreateFail, self).__init__(error=ConverterErrors.TREE_CREATE_FAIL, super(TreeCreateFail, self).__init__(error=ConverterErrors.TREE_CREATE_FAIL,
user_msg=msg) user_msg=msg)


@staticmethod
def raise_from():
@classmethod
def raise_from(cls):
"""Raise from exceptions below.""" """Raise from exceptions below."""
except_source = (NodeInputMissing, except_source = (NodeInputMissing,
TreeNodeInsertFail)
TreeNodeInsertFail, cls)
return except_source return except_source


@classmethod
def check_except_pytorch(cls, msg):
"""Check except."""
return super().check_except_with_print_pytorch(msg)

@classmethod
def check_except_tf(cls, msg):
"""Check except for tf."""
return super().check_except_with_print_tf(msg)



class SourceFilesSaveFail(MindConverterException): class SourceFilesSaveFail(MindConverterException):
"""The source files save fail error.""" """The source files save fail error."""
@@ -265,25 +244,15 @@ class SourceFilesSaveFail(MindConverterException):
super(SourceFilesSaveFail, self).__init__(error=ConverterErrors.SOURCE_FILES_SAVE_FAIL, super(SourceFilesSaveFail, self).__init__(error=ConverterErrors.SOURCE_FILES_SAVE_FAIL,
user_msg=msg) user_msg=msg)


@staticmethod
def raise_from():
@classmethod
def raise_from(cls):
"""Raise from exceptions below.""" """Raise from exceptions below."""
except_source = (NodeInputTypeNotSupport, except_source = (NodeInputTypeNotSupport,
ScriptGenerateFail, ScriptGenerateFail,
ReportGenerateFail, ReportGenerateFail,
IOError)
IOError, cls)
return except_source return except_source


@classmethod
def check_except_pytorch(cls, msg):
"""Check except."""
return super().check_except_with_print_pytorch(msg)

@classmethod
def check_except_tf(cls, msg):
"""Check except for tf."""
return super().check_except_with_print_tf(msg)



class ModelNotSupport(MindConverterException): class ModelNotSupport(MindConverterException):
"""The model not support error.""" """The model not support error."""
@@ -293,55 +262,32 @@ class ModelNotSupport(MindConverterException):
user_msg=msg, user_msg=msg,
cls_code=ConverterErrors.GRAPH_INIT_FAIL.value) cls_code=ConverterErrors.GRAPH_INIT_FAIL.value)


@staticmethod
def raise_from():
@classmethod
def raise_from(cls):
"""Raise from exceptions below.""" """Raise from exceptions below."""
except_source = (RuntimeError, except_source = (RuntimeError,
ModuleNotFoundError,
ValueError, ValueError,
AssertionError,
TypeError, TypeError,
OSError, OSError,
ZeroDivisionError)
ZeroDivisionError, cls)
return except_source return except_source


@classmethod
def check_except_pytorch(cls, msg):
"""Check except."""


def decorator(func):
def _f(arch, model_path, **kwargs):
try:
output = func(arch, model_path=model_path, **kwargs)
except cls.raise_from() as e:
error = cls(msg=msg)
log.error(msg)
log.exception(e)
raise error from e
return output
return _f
return decorator
class TfRuntimeError(MindConverterException):
"""Catch tf runtime error."""

def __init__(self, msg):
super(TfRuntimeError, self).__init__(error=ConverterErrors.TF_RUNTIME_ERROR,
user_msg=msg,
cls_code=ConverterErrors.GRAPH_INIT_FAIL.value)


@classmethod @classmethod
def check_except_tf(cls, msg):
"""Check except."""
def raise_from(cls):
tf_error_module = import_module('tensorflow.python.framework.errors_impl') tf_error_module = import_module('tensorflow.python.framework.errors_impl')
tf_error = getattr(tf_error_module, 'OpError') tf_error = getattr(tf_error_module, 'OpError')

cls._error = cls.raise_from() + (tf_error,)

def decorator(func):
def _f(arch, model_path, **kwargs):
try:
output = func(arch, model_path=model_path, **kwargs)
except cls._error as e:
error = cls(msg=msg)
log.error(msg)
log.exception(e)
raise error from e
return output

return _f

return decorator
return tf_error, ValueError, RuntimeError, cls




class NodeInputMissing(MindConverterException): class NodeInputMissing(MindConverterException):
@@ -352,6 +298,10 @@ class NodeInputMissing(MindConverterException):
user_msg=msg, user_msg=msg,
cls_code=ConverterErrors.TREE_CREATE_FAIL.value) cls_code=ConverterErrors.TREE_CREATE_FAIL.value)


@classmethod
def raise_from(cls):
return ValueError, IndexError, KeyError, AttributeError, cls



class TreeNodeInsertFail(MindConverterException): class TreeNodeInsertFail(MindConverterException):
"""The tree node create fail error.""" """The tree node create fail error."""
@@ -361,32 +311,15 @@ class TreeNodeInsertFail(MindConverterException):
user_msg=msg, user_msg=msg,
cls_code=ConverterErrors.TREE_CREATE_FAIL.value) cls_code=ConverterErrors.TREE_CREATE_FAIL.value)


@staticmethod
def raise_from():
@classmethod
def raise_from(cls):
"""Raise from exceptions below.""" """Raise from exceptions below."""
except_source = (OSError, except_source = (OSError,
DuplicatedNodeIdError, DuplicatedNodeIdError,
MultipleRootError, MultipleRootError,
NodeIDAbsentError)
NodeIDAbsentError, cls)
return except_source return except_source


@classmethod
def check_except(cls, msg):
"""Check except."""

def decorator(func):
def _f(arch, graph):
try:
output = func(arch, graph=graph)
except cls.raise_from() as e:
error = cls(msg=msg)
log.error(msg)
log.exception(e)
raise error from e
return output
return _f
return decorator



class NodeInputTypeNotSupport(MindConverterException): class NodeInputTypeNotSupport(MindConverterException):
"""The node input type NOT support error.""" """The node input type NOT support error."""
@@ -396,6 +329,10 @@ class NodeInputTypeNotSupport(MindConverterException):
user_msg=msg, user_msg=msg,
cls_code=ConverterErrors.SOURCE_FILES_SAVE_FAIL.value) cls_code=ConverterErrors.SOURCE_FILES_SAVE_FAIL.value)


@classmethod
def raise_from(cls):
return ValueError, TypeError, IndexError, cls



class ScriptGenerateFail(MindConverterException): class ScriptGenerateFail(MindConverterException):
"""The script generate fail error.""" """The script generate fail error."""
@@ -405,31 +342,14 @@ class ScriptGenerateFail(MindConverterException):
user_msg=msg, user_msg=msg,
cls_code=ConverterErrors.SOURCE_FILES_SAVE_FAIL.value) cls_code=ConverterErrors.SOURCE_FILES_SAVE_FAIL.value)


@staticmethod
def raise_from():
@classmethod
def raise_from(cls):
"""Raise from exceptions below.""" """Raise from exceptions below."""
except_source = (RuntimeError, except_source = (RuntimeError,
parse.ParseError, parse.ParseError,
AttributeError)
AttributeError, cls)
return except_source return except_source


@classmethod
def check_except(cls, msg):
"""Check except."""

def decorator(func):
def _f(arch, mapper):
try:
output = func(arch, mapper=mapper)
except cls.raise_from() as e:
error = cls(msg=msg)
log.error(msg)
log.exception(e)
raise error from e
return output
return _f
return decorator



class ReportGenerateFail(MindConverterException): class ReportGenerateFail(MindConverterException):
"""The report generate fail error.""" """The report generate fail error."""
@@ -439,28 +359,24 @@ class ReportGenerateFail(MindConverterException):
user_msg=msg, user_msg=msg,
cls_code=ConverterErrors.SOURCE_FILES_SAVE_FAIL.value) cls_code=ConverterErrors.SOURCE_FILES_SAVE_FAIL.value)


@staticmethod
def raise_from():
@classmethod
def raise_from(cls):
"""Raise from exceptions below.""" """Raise from exceptions below."""
except_source = ZeroDivisionError
return except_source
return ZeroDivisionError, cls


@classmethod
def check_except(cls, msg):
"""Check except."""


def decorator(func):
def _f(arch, mapper):
try:
output = func(arch, mapper=mapper)
except cls.raise_from() as e:
error = cls(msg=msg)
log.error(msg)
log.exception(e)
raise error from e
return output
return _f
return decorator
class SubGraphSearchingFail(MindConverterException):
"""Sub-graph searching exception."""
def __init__(self, msg):
super(SubGraphSearchingFail, self).__init__(error=ConverterErrors.MODEL_NOT_SUPPORT,
cls_code=ConverterErrors.SUB_GRAPH_SEARCHING_FAIL.value,
user_msg=msg)
@classmethod
def raise_from(cls):
"""Define exception in sub-graph searching module."""
return IndexError, KeyError, ValueError, AttributeError, ZeroDivisionError, cls




class GeneratorFail(MindConverterException): class GeneratorFail(MindConverterException):
@@ -470,10 +386,23 @@ class GeneratorFail(MindConverterException):
super(GeneratorFail, self).__init__(error=ConverterErrors.NODE_CONVERSION_ERROR, super(GeneratorFail, self).__init__(error=ConverterErrors.NODE_CONVERSION_ERROR,
user_msg=msg, user_msg=msg,
cls_code=ConverterErrors.GENERATOR_FAIL.value) cls_code=ConverterErrors.GENERATOR_FAIL.value)

@classmethod @classmethod
def raise_from(cls): def raise_from(cls):
"""Raise from exceptions below.""" """Raise from exceptions below."""
except_source = (ValueError,
TypeError,
cls)
except_source = (ValueError, TypeError, cls)
return except_source return except_source


class ModelLoadingFail(MindConverterException):
"""Model loading fail."""

def __init__(self, msg):
super(ModelLoadingFail, self).__init__(error=ConverterErrors.INPUT_SHAPE_ERROR,
cls_code=ConverterErrors.MODEL_LOADING_FAIL.value,
user_msg=msg)

@classmethod
def raise_from(cls):
"""Define exception when model loading fail."""
return ValueError, cls

+ 3
- 1
mindinsight/mindconverter/graph_based_converter/common/utils.py View File

@@ -135,9 +135,11 @@ def lib_version_satisfied(current_ver: str, mini_ver_limited: str,
if current_ver < mini_ver_limited or (newest_ver_limited and current_ver > newest_ver_limited): if current_ver < mini_ver_limited or (newest_ver_limited and current_ver > newest_ver_limited):
return False return False
return True return True


def get_dict_key_by_value(val, dic): def get_dict_key_by_value(val, dic):
""" """
Return the first appeared key of a dictionay by given value.
Return the first appeared key of a dictionary by given value.


Args: Args:
val (Any): Value of the key. val (Any): Value of the key.


+ 25
- 14
mindinsight/mindconverter/graph_based_converter/framework.py View File

@@ -16,6 +16,7 @@
import os import os
import re import re
import argparse import argparse
import sys
from importlib import import_module from importlib import import_module
from importlib.util import find_spec from importlib.util import find_spec


@@ -25,12 +26,11 @@ from mindinsight.mindconverter.graph_based_converter.common.utils import lib_ver
from mindinsight.mindconverter.graph_based_converter.constant import BINARY_HEADER_PYTORCH_FILE, FrameworkType, \ from mindinsight.mindconverter.graph_based_converter.constant import BINARY_HEADER_PYTORCH_FILE, FrameworkType, \
BINARY_HEADER_PYTORCH_BITS, ONNX_MIN_VER, TF2ONNX_MIN_VER, ONNXRUNTIME_MIN_VER BINARY_HEADER_PYTORCH_BITS, ONNX_MIN_VER, TF2ONNX_MIN_VER, ONNXRUNTIME_MIN_VER
from mindinsight.mindconverter.graph_based_converter.mapper import ONNXToMindSporeMapper from mindinsight.mindconverter.graph_based_converter.mapper import ONNXToMindSporeMapper
from mindinsight.mindconverter.common.log import logger as log
from mindinsight.mindconverter.common.log import logger as log, logger_console as log_console
from mindinsight.mindconverter.common.exceptions import GraphInitFail, TreeCreateFail, SourceFilesSaveFail, \ from mindinsight.mindconverter.common.exceptions import GraphInitFail, TreeCreateFail, SourceFilesSaveFail, \
BaseConverterFail, UnknownModel
BaseConverterFail, UnknownModel, GeneratorFail, TfRuntimeError
from mindinsight.utils.exceptions import ParamMissError from mindinsight.utils.exceptions import ParamMissError



permissions = os.R_OK | os.W_OK | os.X_OK permissions = os.R_OK | os.W_OK | os.X_OK
os.umask(permissions << 3 | permissions) os.umask(permissions << 3 | permissions)


@@ -71,8 +71,10 @@ def torch_installation_validation(func):
"scripts converter, and PyTorch vision must " "scripts converter, and PyTorch vision must "
"be consisted with model generation runtime.") "be consisted with model generation runtime.")
log.error(str(error)) log.error(str(error))
log.exception(error)
raise error
detail_info = f"Error detail: {str(error)}"
log_console.error(str(error))
log_console.error(detail_info)
sys.exit(0)


func(graph_path=graph_path, sample_shape=sample_shape, func(graph_path=graph_path, sample_shape=sample_shape,
output_folder=output_folder, report_folder=report_folder) output_folder=output_folder, report_folder=report_folder)
@@ -103,7 +105,10 @@ def tf_installation_validation(func):
f"based scripts converter for TensorFlow conversion." f"based scripts converter for TensorFlow conversion."
) )
log.error(str(error)) log.error(str(error))
raise error
detail_info = f"Error detail: {str(error)}"
log_console.error(str(error))
log_console.error(detail_info)
sys.exit(0)


onnx, tf2onnx = import_module("onnx"), import_module("tf2onnx") onnx, tf2onnx = import_module("onnx"), import_module("tf2onnx")
ort = import_module("onnxruntime") ort = import_module("onnxruntime")
@@ -117,7 +122,10 @@ def tf_installation_validation(func):
f"based scripts converter for TensorFlow conversion." f"based scripts converter for TensorFlow conversion."
) )
log.error(str(error)) log.error(str(error))
raise error
detail_info = f"Error detail: {str(error)}"
log_console.error(str(error))
log_console.error(detail_info)
sys.exit(0)


func(graph_path=graph_path, sample_shape=sample_shape, func(graph_path=graph_path, sample_shape=sample_shape,
output_folder=output_folder, report_folder=report_folder, output_folder=output_folder, report_folder=report_folder,
@@ -142,9 +150,10 @@ def _extract_model_name(model_path):




@torch_installation_validation @torch_installation_validation
@GraphInitFail.check_except_pytorch("Error occurred when init graph object.")
@TreeCreateFail.check_except_pytorch("Error occurred when create hierarchical tree.")
@SourceFilesSaveFail.check_except_pytorch("Error occurred when save source files.")
@GraphInitFail.uniform_catcher("Error occurred when init graph object.")
@TreeCreateFail.uniform_catcher("Error occurred when create hierarchical tree.")
@SourceFilesSaveFail.uniform_catcher("Error occurred when save source files.")
@GeneratorFail.uniform_catcher("Error occurred when generate code.")
def graph_based_converter_pytorch_to_ms(graph_path: str, sample_shape: tuple, def graph_based_converter_pytorch_to_ms(graph_path: str, sample_shape: tuple,
output_folder: str, report_folder: str = None): output_folder: str, report_folder: str = None):
""" """
@@ -176,9 +185,11 @@ def graph_based_converter_pytorch_to_ms(graph_path: str, sample_shape: tuple,




@tf_installation_validation @tf_installation_validation
@GraphInitFail.check_except_tf("Error occurred when init graph object.")
@TreeCreateFail.check_except_tf("Error occurred when create hierarchical tree.")
@SourceFilesSaveFail.check_except_tf("Error occurred when save source files.")
@GraphInitFail.uniform_catcher("Error occurred when init graph object.")
@TfRuntimeError.uniform_catcher("Error occurred when init graph, TensorFlow runtime error.")
@TreeCreateFail.uniform_catcher("Error occurred when create hierarchical tree.")
@SourceFilesSaveFail.uniform_catcher("Error occurred when save source files.")
@GeneratorFail.uniform_catcher("Error occurred when generate code.")
def graph_based_converter_tf_to_ms(graph_path: str, sample_shape: tuple, def graph_based_converter_tf_to_ms(graph_path: str, sample_shape: tuple,
input_nodes: str, output_nodes: str, input_nodes: str, output_nodes: str,
output_folder: str, report_folder: str = None): output_folder: str, report_folder: str = None):
@@ -210,7 +221,7 @@ def graph_based_converter_tf_to_ms(graph_path: str, sample_shape: tuple,
save_code_file_and_report(model_name, code_fragments, output_folder, report_folder) save_code_file_and_report(model_name, code_fragments, output_folder, report_folder)




@BaseConverterFail.check_except("Failed to start base converter.")
@BaseConverterFail.uniform_catcher("Failed to start base converter.")
def main_graph_base_converter(file_config): def main_graph_base_converter(file_config):
""" """
The entrance for converter, script files will be converted. The entrance for converter, script files will be converted.


+ 1
- 0
mindinsight/mindconverter/graph_based_converter/generator/args_translator.py View File

@@ -201,6 +201,7 @@ class ArgsTranslation:


class ArgsTranslationHelper: class ArgsTranslationHelper:
"""Define operations related to ArgsTranslation instances.""" """Define operations related to ArgsTranslation instances."""

@staticmethod @staticmethod
def find_formal_args_in_modules(args_translators): def find_formal_args_in_modules(args_translators):
""" """


+ 2
- 17
mindinsight/mindconverter/graph_based_converter/generator/module_struct.py View File

@@ -541,7 +541,7 @@ class ModuleStruct:
for output in output_list: for output in output_list:
(provider_succ, provider_closet_opt_var) = output (provider_succ, provider_closet_opt_var) = output
if provider_closet_opt_var in struct.matched_inputs: if provider_closet_opt_var in struct.matched_inputs:
continue # skip repeat
continue # skip repeat
if provider_succ == struct.onnx_name: if provider_succ == struct.onnx_name:
struct.matched_inputs.append(provider_closet_opt_var) struct.matched_inputs.append(provider_closet_opt_var)


@@ -695,20 +695,5 @@ class ModuleStruct:
"""Register submodule outputs to this module's return.""" """Register submodule outputs to this module's return."""
submodule_returns = md_struct.get_returned_opt_var_name() submodule_returns = md_struct.get_returned_opt_var_name()
submodule_opt_var_name = md_struct.ms_opt_var_name submodule_opt_var_name = md_struct.ms_opt_var_name
for (submodule_ext_succ, opt_var_name_in_this_module, ith_output) in submodule_returns:
for (submodule_ext_succ, _, ith_output) in submodule_returns:
self.external_successor_local_returns_map[submodule_ext_succ] = (submodule_opt_var_name, ith_output) self.external_successor_local_returns_map[submodule_ext_succ] = (submodule_opt_var_name, ith_output)
# edit external succ 's inputs in parent module
ext_node = self.GLOBAL_CONTEXT_MGR.onnx_node_name_to_node_struct_map.get(submodule_ext_succ)
ext_node_parent = ext_node.parent_module_struct
while ext_node_parent != self.parent_module_struct:
ext_node_parent.inputs_in_parent_module[ext_node.onnx_name] = md_struct.ms_opt_var_name
ext_node_parent = ext_node_parent.parent_module_struct

# need find the prec_name?
for ext_node_prec, opt_var_name in ext_node.inputs_in_parent_module.copy().items():
if isinstance(opt_var_name, str):
if opt_var_name == opt_var_name_in_this_module:
ext_node.inputs_in_parent_module[ext_node_prec] = (self.ms_opt_var_name, ith_output)
if isinstance(opt_var_name, tuple):
if opt_var_name[0] == opt_var_name_in_this_module:
ext_node.inputs_in_parent_module[ext_node_prec] = (self.ms_opt_var_name, ith_output)

+ 1
- 0
mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/pattern.py View File

@@ -34,6 +34,7 @@ class Pattern:
# If pattern in BUILD_IN_MODULE_NAME or BUILD_IN_PATTERN, # If pattern in BUILD_IN_MODULE_NAME or BUILD_IN_PATTERN,
# the pattern will get additional score. # the pattern will get additional score.
self.additional_score = 0 self.additional_score = 0
self.know_module_name = None


def insert(self, idx, seq_len): def insert(self, idx, seq_len):
""" """


+ 15
- 6
mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/search_path.py View File

@@ -18,7 +18,7 @@ import uuid
from typing import Dict, List, Callable, Union from typing import Dict, List, Callable, Union
from collections import OrderedDict from collections import OrderedDict
from .common import context, gen_hash_key, DagGraph, MAX_OUT_DEGREE, cal_matching_score from .common import context, gen_hash_key, DagGraph, MAX_OUT_DEGREE, cal_matching_score
from .known_module_name import BUILT_IN_MODULE_NAME, is_built_in_module_name
from .known_module_name import BUILT_IN_MODULE_NAME
from .pattern import Pattern, scope_name_mapping from .pattern import Pattern, scope_name_mapping
from .built_in_pattern import BUILT_IN_PATTERN, is_built_in_pattern from .built_in_pattern import BUILT_IN_PATTERN, is_built_in_pattern
from .pattern_fuzzy_matching import pattern_fuzzy_matching from .pattern_fuzzy_matching import pattern_fuzzy_matching
@@ -85,13 +85,15 @@ def _is_valid_pattern(pattern, dag):
return True return True




def generate_module_name(pattern):
def match_known_module_name(pattern):
""" """
Generate module name.
Matching with know module name.


Args: Args:
pattern (Pattern): To be replaced pattern. pattern (Pattern): To be replaced pattern.


Returns:
str, matched module name, return None if not matched.
""" """
matched_result = [] matched_result = []
for ptn, module_name in BUILT_IN_MODULE_NAME.items(): for ptn, module_name in BUILT_IN_MODULE_NAME.items():
@@ -109,7 +111,11 @@ def generate_module_name(pattern):
module_name = f"{module_name}{used_module_name[pattern.pattern]}" module_name = f"{module_name}{used_module_name[pattern.pattern]}"
used_module_name[pattern.pattern] += 1 used_module_name[pattern.pattern] += 1
return module_name return module_name
return None



def generate_module_name():
"""Generate module name."""
global global_idx global global_idx
name = f"Module{global_idx}" name = f"Module{global_idx}"
global_idx += 1 global_idx += 1
@@ -439,13 +445,16 @@ class SearchPath:
to recover the sequence. to recover the sequence.
""" """
if self.pattern.pattern not in scope_name_mapping: if self.pattern.pattern not in scope_name_mapping:
module_name = generate_module_name(self.pattern)
scope_name_mapping[self.pattern.pattern] = module_name
module_name = generate_module_name()
known_module_name = match_known_module_name(self.pattern)
scope_name_mapping[self.pattern] = module_name
module_name_to_src[module_name] = self.pattern.pattern module_name_to_src[module_name] = self.pattern.pattern
else: else:
module_name = scope_name_mapping[self.pattern.pattern] module_name = scope_name_mapping[self.pattern.pattern]
known_module_name = module_name_to_src[module_name].known_module_name
self.pattern.module_name = module_name self.pattern.module_name = module_name
if is_built_in_module_name(module_name):
self.pattern.known_module_name = known_module_name
if known_module_name:
self.pattern.additional_score += cal_matching_score(self.pattern.ptn_length) self.pattern.additional_score += cal_matching_score(self.pattern.ptn_length)
topo_order, inverted_index = self.replace_sub_graph_completely(self.pattern, self.topo_order_bef_repl) topo_order, inverted_index = self.replace_sub_graph_completely(self.pattern, self.topo_order_bef_repl)
return topo_order, inverted_index return topo_order, inverted_index


+ 22
- 0
mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/searcher.py View File

@@ -18,8 +18,10 @@ from typing import Dict, List


from .common import context, DagGraph, gen_hash_key, ACCEPTABLE_RESULT_COUNT from .common import context, DagGraph, gen_hash_key, ACCEPTABLE_RESULT_COUNT
from .common import MINI_FREQUENCY, MAX_ITERATION_DEPTH, SATISFIED_SCORE from .common import MINI_FREQUENCY, MAX_ITERATION_DEPTH, SATISFIED_SCORE
from ..common.global_context import GlobalContext
from ..third_party_graph.onnx_utils import BaseNode from ..third_party_graph.onnx_utils import BaseNode
from .search_path import SearchPath, Pattern, generate_pattern, find_built_in_pattern from .search_path import SearchPath, Pattern, generate_pattern, find_built_in_pattern
from ...common.exceptions import SubGraphSearchingFail




def _is_satisfied(path): def _is_satisfied(path):
@@ -249,6 +251,23 @@ def validate_topo_order_succession():
return True return True




def _add_known_module_name(search_path):
"""
Add known module name to GlobalContext.

Args:
search_path (SearchPath): Search path.

"""
ctx = GlobalContext()
if search_path.pattern.known_module_name:
ctx.known_module_name[search_path.pattern.module_name] = search_path.pattern.known_module_name
for it in search_path.recursion_path:
if it.pattern.known_module_name:
ctx.known_module_name[it.pattern.module_name] = it.pattern.known_module_name


@SubGraphSearchingFail.check_except("Sub-Graph searching fail.")
def generate_scope_name(data_loader): def generate_scope_name(data_loader):
""" """
Generate scope name according to computation graph. Generate scope name according to computation graph.
@@ -270,6 +289,9 @@ def generate_scope_name(data_loader):
if len(topo_order_with_scope_name_list) != len(data_loader.nodes_dict): if len(topo_order_with_scope_name_list) != len(data_loader.nodes_dict):
topo_order_with_scope_name_list = flatten_graph(init_dag) topo_order_with_scope_name_list = flatten_graph(init_dag)


if result:
_add_known_module_name(result)

except (ValueError, IndexError, AttributeError, KeyError) as _: except (ValueError, IndexError, AttributeError, KeyError) as _:
topo_order_with_scope_name_list = flatten_graph(init_dag) topo_order_with_scope_name_list = flatten_graph(init_dag)
return topo_order_with_scope_name_list return topo_order_with_scope_name_list

+ 2
- 2
mindinsight/mindconverter/graph_based_converter/third_party_graph/onnx_utils.py View File

@@ -27,7 +27,7 @@ from ..common.global_context import GlobalContext


from ..constant import ONNX_TYPE_INT, ONNX_TYPE_INTS, ONNX_TYPE_STRING, \ from ..constant import ONNX_TYPE_INT, ONNX_TYPE_INTS, ONNX_TYPE_STRING, \
ONNX_TYPE_FLOATS, ONNX_TYPE_FLOAT, SCALAR_WITHOUT_SHAPE, DYNAMIC_SHAPE, UNKNOWN_DIM_VAL ONNX_TYPE_FLOATS, ONNX_TYPE_FLOAT, SCALAR_WITHOUT_SHAPE, DYNAMIC_SHAPE, UNKNOWN_DIM_VAL
from ...common.exceptions import GraphInitFail, ModelNotSupport
from ...common.exceptions import GraphInitFail, ModelNotSupport, ModelLoadingFail




def convert_tf_graph_to_onnx(model_path, model_inputs, model_outputs, opset=12): def convert_tf_graph_to_onnx(model_path, model_inputs, model_outputs, opset=12):
@@ -308,7 +308,7 @@ class OnnxDataLoader:
w = int(match.group('w')) w = int(match.group('w'))
c = int(match.group('c')) c = int(match.group('c'))
if [h, w, c] != list(self.graph_input_shape)[1:4]: if [h, w, c] != list(self.graph_input_shape)[1:4]:
raise ValueError(f"Shape given should be (N, {h}, {w}, {c}) but got {self.graph_input_shape}")
raise ModelLoadingFail(f"Shape given should be (N, {h}, {w}, {c}) but got {self.graph_input_shape}")
return True return True
return False return False




+ 6
- 6
mindinsight/mindconverter/graph_based_converter/third_party_graph/pytorch_graph_parser.py View File

@@ -25,7 +25,9 @@ class PyTorchGraphParser(GraphParser):
"""Define pytorch graph parser.""" """Define pytorch graph parser."""


@classmethod @classmethod
@ModelNotSupport.check_except_pytorch("Error occurs in loading model, make sure model.pth correct.")
@ModelNotSupport.check_except(
"Error occurs in loading model, please check your model or runtime environment integrity."
)
def parse(cls, model_path: str, **kwargs): def parse(cls, model_path: str, **kwargs):
""" """
Parser pytorch graph. Parser pytorch graph.
@@ -50,11 +52,9 @@ class PyTorchGraphParser(GraphParser):
else: else:
model = torch.load(f=model_path, map_location="cpu") model = torch.load(f=model_path, map_location="cpu")
except ModuleNotFoundError: except ModuleNotFoundError:
error_msg = \
"Cannot find model scripts in system path, " \
"set `--project_path` to the path of model scripts folder correctly."
error_msg = "Cannot find model scripts in system path, " \
"set `--project_path` to the path of model scripts folder correctly."
error = ModuleNotFoundError(error_msg) error = ModuleNotFoundError(error_msg)
log.error(str(error))
raise error from None
raise error


return model return model

+ 4
- 4
mindinsight/mindconverter/graph_based_converter/third_party_graph/tf_graph_parser.py View File

@@ -25,7 +25,9 @@ class TFGraphParser(GraphParser):
"""Define TF graph parser.""" """Define TF graph parser."""


@classmethod @classmethod
@ModelNotSupport.check_except_tf("Error occurs in loading model, make sure model.pb correct.")
@ModelNotSupport.check_except(
"Error occurs in loading model, please check your model or runtime environment integrity."
)
def parse(cls, model_path: str, **kwargs): def parse(cls, model_path: str, **kwargs):
""" """
Parse TF Computational Graph File (.pb) Parse TF Computational Graph File (.pb)
@@ -36,7 +38,6 @@ class TFGraphParser(GraphParser):
Returns: Returns:
object, ONNX model. object, ONNX model.
""" """

onnx_utils = import_module( onnx_utils = import_module(
"mindinsight.mindconverter.graph_based_converter.third_party_graph.onnx_utils") "mindinsight.mindconverter.graph_based_converter.third_party_graph.onnx_utils")
convert_tf_graph_to_onnx = getattr(onnx_utils, "convert_tf_graph_to_onnx") convert_tf_graph_to_onnx = getattr(onnx_utils, "convert_tf_graph_to_onnx")
@@ -50,6 +51,5 @@ class TFGraphParser(GraphParser):


model = convert_tf_graph_to_onnx(model_path, model = convert_tf_graph_to_onnx(model_path,
model_inputs=tf_input_nodes, model_inputs=tf_input_nodes,
model_outputs=tf_output_nodes,
)
model_outputs=tf_output_nodes)
return model return model

+ 0
- 35
tests/st/func/mindconverter/test_converter.py View File

@@ -21,13 +21,10 @@ Usage:
""" """
import difflib import difflib
import os import os
import re
import sys import sys

import pytest import pytest


from mindinsight.mindconverter.converter import main from mindinsight.mindconverter.converter import main
from mindinsight.mindconverter.graph_based_converter.framework import main_graph_base_converter




@pytest.mark.usefixtures('create_output_dir') @pytest.mark.usefixtures('create_output_dir')
@@ -82,35 +79,3 @@ class TestConverter:


converted_ratio = 100 - (diff_lines * 100) / (len(expect_source)) converted_ratio = 100 - (diff_lines * 100) / (len(expect_source))
assert converted_ratio >= 80 assert converted_ratio >= 80

@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_single
def test_main_graph_based_converter(self, output):
"""Test main graph based converter."""
pytorch_filename = "resnet50.pth"
expected_model_filename = "resnet50.py"
expected_report_filename = "report_of_resnet50.txt"
file_config = {
'model_file': os.path.join(self.pytorch_dir, pytorch_filename),
'shape': (1, 3, 224, 224),
'outfile_dir': output,
'report_dir': output
}
with pytest.raises(ValueError) as e:
main_graph_base_converter(file_config=file_config)

assert os.path.isfile(os.path.join(output, expected_model_filename))
assert os.path.isfile(os.path.join(output, expected_report_filename))

with open(os.path.join(output, expected_report_filename)) as converted_r:
converted_report = converted_r.readlines()
converted_rate = re.findall(r".*(?:Converted Rate: )(.*)[.]", converted_report[-1])

assert converted_rate[0] == '100.00%'

exec_msg = e.value.args[0]
assert exec_msg == "torch.__spec__ is None"

Loading…
Cancel
Save