From: @moran3 Reviewed-by: Signed-off-by:tags/v1.2.0-rc1
| @@ -1,4 +1,4 @@ | |||||
| # Copyright 2019 Huawei Technologies Co., Ltd | |||||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd | |||||
| # | # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | # you may not use this file except in compliance with the License. | ||||
| @@ -301,6 +301,8 @@ class SourceFilesSaveError(MindConverterException): | |||||
| NODE_INPUT_TYPE_NOT_SUPPORT = 1 | NODE_INPUT_TYPE_NOT_SUPPORT = 1 | ||||
| SCRIPT_GENERATE_FAIL = 2 | SCRIPT_GENERATE_FAIL = 2 | ||||
| REPORT_GENERATE_FAIL = 3 | REPORT_GENERATE_FAIL = 3 | ||||
| CKPT_GENERATE_FAIL = 4 | |||||
| MAP_GENERATE_FAIL = 5 | |||||
| BASE_ERROR_CODE = ConverterErrors.SOURCE_FILES_SAVE_FAIL.value | BASE_ERROR_CODE = ConverterErrors.SOURCE_FILES_SAVE_FAIL.value | ||||
| ERROR_CODE = ErrCode.UNKNOWN_ERROR.value | ERROR_CODE = ErrCode.UNKNOWN_ERROR.value | ||||
| @@ -315,6 +317,8 @@ class SourceFilesSaveError(MindConverterException): | |||||
| except_source = (NodeInputTypeNotSupportError, | except_source = (NodeInputTypeNotSupportError, | ||||
| ScriptGenerationError, | ScriptGenerationError, | ||||
| ReportGenerationError, | ReportGenerationError, | ||||
| CheckPointGenerationError, | |||||
| WeightMapGenerationError, | |||||
| IOError, cls) | IOError, cls) | ||||
| return except_source | return except_source | ||||
| @@ -437,6 +441,32 @@ class ReportGenerationError(SourceFilesSaveError): | |||||
| return ZeroDivisionError, cls | return ZeroDivisionError, cls | ||||
| class CheckPointGenerationError(SourceFilesSaveError): | |||||
| """The checkpoint generate fail error.""" | |||||
| ERROR_CODE = SourceFilesSaveError.ErrCode.CKPT_GENERATE_FAIL.value | |||||
| def __init__(self, msg): | |||||
| super(CheckPointGenerationError, self).__init__(msg=msg) | |||||
| @classmethod | |||||
| def raise_from(cls): | |||||
| """Raise from exceptions below.""" | |||||
| return cls | |||||
| class WeightMapGenerationError(SourceFilesSaveError): | |||||
| """The weight names map generate fail error.""" | |||||
| ERROR_CODE = SourceFilesSaveError.ErrCode.MAP_GENERATE_FAIL.value | |||||
| def __init__(self, msg): | |||||
| super(WeightMapGenerationError, self).__init__(msg=msg) | |||||
| @classmethod | |||||
| def raise_from(cls): | |||||
| """Raise from exception below.""" | |||||
| return cls | |||||
| class SubGraphSearchingError(MindConverterException): | class SubGraphSearchingError(MindConverterException): | ||||
| """Sub-graph searching exception.""" | """Sub-graph searching exception.""" | ||||
| @@ -1,4 +1,4 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. | |||||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd.All Rights Reserved. | |||||
| # | # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | # you may not use this file except in compliance with the License. | ||||
| @@ -15,5 +15,5 @@ | |||||
| """Graph based scripts converter definition.""" | """Graph based scripts converter definition.""" | ||||
| __all__ = ["graph_based_converter_pytorch_to_ms", "graph_based_converter_tf_to_ms"] | __all__ = ["graph_based_converter_pytorch_to_ms", "graph_based_converter_tf_to_ms"] | ||||
| from .framework import graph_based_converter_pytorch_to_ms | |||||
| from .framework import graph_based_converter_tf_to_ms | |||||
| from mindinsight.mindconverter.graph_based_converter.framework import graph_based_converter_pytorch_to_ms | |||||
| from mindinsight.mindconverter.graph_based_converter.framework import graph_based_converter_tf_to_ms | |||||
| @@ -191,18 +191,23 @@ class CodeFragment(Fragment): | |||||
| """ | """ | ||||
| def __init__(self, operation, actual_args, settings, input_shape, output_shape, | def __init__(self, operation, actual_args, settings, input_shape, output_shape, | ||||
| trainable_params=None): | |||||
| trainable_params=None, trainable_weights=None): | |||||
| super(CodeFragment, self).__init__(operation=operation, actual_args=actual_args, | super(CodeFragment, self).__init__(operation=operation, actual_args=actual_args, | ||||
| input_shape=input_shape, output_shape=output_shape, | input_shape=input_shape, output_shape=output_shape, | ||||
| settings=settings) | settings=settings) | ||||
| self._trainable_params = dict() # External weights, like Matmul. | self._trainable_params = dict() # External weights, like Matmul. | ||||
| self._init_trainable_params = trainable_params # Can put into operation init method, like Conv2d. | self._init_trainable_params = trainable_params # Can put into operation init method, like Conv2d. | ||||
| self._trainable_weights = trainable_weights | |||||
| @property | @property | ||||
| def trainable_params(self): | def trainable_params(self): | ||||
| """Return the trainable parameters.""" | """Return the trainable parameters.""" | ||||
| return self._trainable_params | return self._trainable_params | ||||
| @property | |||||
| def trainable_weights(self): | |||||
| return self._trainable_weights | |||||
| class ModuleFragment(Fragment): | class ModuleFragment(Fragment): | ||||
| """Manage module type code variables.""" | """Manage module type code variables.""" | ||||
| @@ -14,7 +14,7 @@ | |||||
| # ============================================================================== | # ============================================================================== | ||||
| """Define GlobalContext class to save required resources during whole conversion procedure.""" | """Define GlobalContext class to save required resources during whole conversion procedure.""" | ||||
| from collections import OrderedDict | from collections import OrderedDict | ||||
| from .outputs import OutputStorage | |||||
| from mindinsight.mindconverter.graph_based_converter.common.outputs import OutputStorage | |||||
| class Singleton(type): | class Singleton(type): | ||||
| @@ -13,16 +13,21 @@ | |||||
| # limitations under the License. | # limitations under the License. | ||||
| # ============================================================================ | # ============================================================================ | ||||
| """Define common utils.""" | """Define common utils.""" | ||||
| import json | |||||
| import os | import os | ||||
| import stat | import stat | ||||
| from importlib import import_module | from importlib import import_module | ||||
| from importlib.util import find_spec | |||||
| from typing import List, Tuple, Mapping | from typing import List, Tuple, Mapping | ||||
| from mindinsight.mindconverter.common.exceptions import ScriptGenerationError, ReportGenerationError, UnknownModelError | |||||
| from mindinsight.mindconverter.common.exceptions import ScriptGenerationError, ReportGenerationError, \ | |||||
| UnknownModelError, CheckPointGenerationError, WeightMapGenerationError | |||||
| from mindinsight.mindconverter.common.log import logger as log | from mindinsight.mindconverter.common.log import logger as log | ||||
| from mindinsight.mindconverter.graph_based_converter.constant import SEPARATOR_IN_ONNX_OP, BINARY_HEADER_PYTORCH_BITS, \ | from mindinsight.mindconverter.graph_based_converter.constant import SEPARATOR_IN_ONNX_OP, BINARY_HEADER_PYTORCH_BITS, \ | ||||
| FrameworkType, BINARY_HEADER_PYTORCH_FILE, TENSORFLOW_MODEL_SUFFIX | FrameworkType, BINARY_HEADER_PYTORCH_FILE, TENSORFLOW_MODEL_SUFFIX | ||||
| from mindspore.train.serialization import save_checkpoint | |||||
| def is_converted(operation: str): | def is_converted(operation: str): | ||||
| """ | """ | ||||
| @@ -96,7 +101,6 @@ def save_code_file_and_report(model_name: str, code_lines: Mapping[str, Tuple], | |||||
| code_lines (dict): Code lines. | code_lines (dict): Code lines. | ||||
| out_folder (str): Output folder. | out_folder (str): Output folder. | ||||
| report_folder (str): Report output folder. | report_folder (str): Report output folder. | ||||
| """ | """ | ||||
| flags = os.O_WRONLY | os.O_CREAT | os.O_EXCL | flags = os.O_WRONLY | os.O_CREAT | os.O_EXCL | ||||
| modes = stat.S_IRUSR | stat.S_IWUSR | modes = stat.S_IRUSR | stat.S_IWUSR | ||||
| @@ -114,7 +118,7 @@ def save_code_file_and_report(model_name: str, code_lines: Mapping[str, Tuple], | |||||
| os.makedirs(report_folder, modes_usr) | os.makedirs(report_folder, modes_usr) | ||||
| for file_name in code_lines: | for file_name in code_lines: | ||||
| code, report = code_lines[file_name] | |||||
| code, report, trainable_weights, weight_map = code_lines[file_name] | |||||
| code_file_path = os.path.realpath(os.path.join(out_folder, f"{model_name}.py")) | code_file_path = os.path.realpath(os.path.join(out_folder, f"{model_name}.py")) | ||||
| report_file_path = os.path.realpath(os.path.join(report_folder, f"report_of_{model_name}.txt")) | report_file_path = os.path.realpath(os.path.join(report_folder, f"report_of_{model_name}.txt")) | ||||
| try: | try: | ||||
| @@ -133,6 +137,31 @@ def save_code_file_and_report(model_name: str, code_lines: Mapping[str, Tuple], | |||||
| except (IOError, FileExistsError) as error: | except (IOError, FileExistsError) as error: | ||||
| raise ReportGenerationError(str(error)) | raise ReportGenerationError(str(error)) | ||||
| ckpt_file_path = os.path.realpath(os.path.join(out_folder, f"{model_name}.ckpt")) | |||||
| try: | |||||
| if os.path.exists(ckpt_file_path): | |||||
| raise CheckPointGenerationError("Checkpoint file with the same name already exists.") | |||||
| save_checkpoint(trainable_weights, ckpt_file_path) | |||||
| except TypeError as error: | |||||
| raise CheckPointGenerationError(str(error)) | |||||
| weight_map_path = os.path.realpath(os.path.join(out_folder, f"weight_map_of_{model_name}.json")) | |||||
| try: | |||||
| if os.path.exists(weight_map_path): | |||||
| raise WeightMapGenerationError("Weight map file with the same name already exists.") | |||||
| with os.fdopen(os.open(weight_map_path, flags, stat.S_IRUSR), 'w') as map_f: | |||||
| weight_map_json = {f"{model_name}": weight_map} | |||||
| json.dump(weight_map_json, map_f) | |||||
| except (IOError, FileExistsError) as error: | |||||
| raise WeightMapGenerationError(str(error)) | |||||
| def onnx_satisfied(): | |||||
| """Validate ONNX , ONNXRUNTIME, ONNXOPTIMIZER installation.""" | |||||
| if not find_spec("onnx") or not find_spec("onnxruntime") or not find_spec("onnxoptimizer"): | |||||
| return False | |||||
| return True | |||||
| def lib_version_satisfied(current_ver: str, mini_ver_limited: str, | def lib_version_satisfied(current_ver: str, mini_ver_limited: str, | ||||
| newest_ver_limited: str = ""): | newest_ver_limited: str = ""): | ||||
| @@ -220,6 +249,7 @@ def reset_init_or_construct(template, variable_slot, new_data, scope): | |||||
| template[variable_slot][scope] += new_data | template[variable_slot][scope] += new_data | ||||
| return template | return template | ||||
| def replace_string_in_list(str_list: list, original_str: str, target_str: str): | def replace_string_in_list(str_list: list, original_str: str, target_str: str): | ||||
| """ | """ | ||||
| Replace a string in a list by provided string. | Replace a string in a list by provided string. | ||||
| @@ -41,6 +41,7 @@ UNKNOWN_DIM_VAL = "unk__001" | |||||
| ONNX_MIN_VER = "1.8.0" | ONNX_MIN_VER = "1.8.0" | ||||
| TF2ONNX_MIN_VER = "1.7.1" | TF2ONNX_MIN_VER = "1.7.1" | ||||
| ONNXRUNTIME_MIN_VER = "1.5.2" | ONNXRUNTIME_MIN_VER = "1.5.2" | ||||
| ONNXOPTIMIZER_MIN_VER = "0.1.2" | |||||
| @unique | @unique | ||||
| @@ -21,10 +21,10 @@ from importlib.util import find_spec | |||||
| import mindinsight | import mindinsight | ||||
| from mindinsight.mindconverter.graph_based_converter.common.global_context import GlobalContext | from mindinsight.mindconverter.graph_based_converter.common.global_context import GlobalContext | ||||
| from mindinsight.mindconverter.graph_based_converter.common.utils import lib_version_satisfied, \ | |||||
| from mindinsight.mindconverter.graph_based_converter.common.utils import lib_version_satisfied, onnx_satisfied, \ | |||||
| save_code_file_and_report, get_framework_type | save_code_file_and_report, get_framework_type | ||||
| from mindinsight.mindconverter.graph_based_converter.constant import FrameworkType, \ | from mindinsight.mindconverter.graph_based_converter.constant import FrameworkType, \ | ||||
| ONNX_MIN_VER, TF2ONNX_MIN_VER, ONNXRUNTIME_MIN_VER | |||||
| ONNX_MIN_VER, TF2ONNX_MIN_VER, ONNXRUNTIME_MIN_VER, ONNXOPTIMIZER_MIN_VER | |||||
| from mindinsight.mindconverter.graph_based_converter.generator import batch_add_nodes | from mindinsight.mindconverter.graph_based_converter.generator import batch_add_nodes | ||||
| from mindinsight.mindconverter.graph_based_converter.mapper import ONNXToMindSporeMapper | from mindinsight.mindconverter.graph_based_converter.mapper import ONNXToMindSporeMapper | ||||
| from mindinsight.mindconverter.common.log import logger as log, logger_console as log_console | from mindinsight.mindconverter.common.log import logger as log, logger_console as log_console | ||||
| @@ -53,6 +53,18 @@ parser.add_argument("--report", type=str, required=False, | |||||
| help="Generated reports output folder path.") | help="Generated reports output folder path.") | ||||
| def onnx_lib_version_satisfied(): | |||||
| """Check onnx libs version whether is satisfied.""" | |||||
| onnx = import_module("onnx") | |||||
| ort = import_module("onnxruntime") | |||||
| optimizer = import_module("onnxoptimizer.version") | |||||
| if not lib_version_satisfied(getattr(onnx, "__version__"), ONNX_MIN_VER) \ | |||||
| or not lib_version_satisfied(getattr(ort, "__version__"), ONNXRUNTIME_MIN_VER) \ | |||||
| or not lib_version_satisfied(getattr(optimizer, "version"), ONNXOPTIMIZER_MIN_VER): | |||||
| return False | |||||
| return True | |||||
| def torch_installation_validation(func): | def torch_installation_validation(func): | ||||
| """ | """ | ||||
| Validate args of func. | Validate args of func. | ||||
| @@ -68,26 +80,33 @@ def torch_installation_validation(func): | |||||
| input_nodes: str, output_nodes: str, | input_nodes: str, output_nodes: str, | ||||
| output_folder: str, report_folder: str = None): | output_folder: str, report_folder: str = None): | ||||
| # Check whether pytorch is installed. | # Check whether pytorch is installed. | ||||
| if not find_spec("torch") or not find_spec("onnx") or not find_spec("onnxruntime"): | |||||
| error = RuntimeIntegrityError(f"PyTorch, onnx(>={ONNX_MIN_VER}) and " | |||||
| f"onnxruntime(>={ONNXRUNTIME_MIN_VER}) " | |||||
| f"are required when using graph based " | |||||
| f"scripts converter, and PyTorch version must " | |||||
| f"be consisted with model generation runtime.") | |||||
| error_info = None | |||||
| if graph_path.endswith('.onnx'): | |||||
| if not onnx_satisfied(): | |||||
| error_info = f"onnx(>={ONNX_MIN_VER}, onnxruntime(>={ONNXRUNTIME_MIN_VER}) and " \ | |||||
| f"onnxoptimizer(>={ONNXOPTIMIZER_MIN_VER}) " \ | |||||
| f"are required when using graph based scripts converter." | |||||
| else: | |||||
| if not find_spec("torch") or not onnx_satisfied(): | |||||
| error_info = f"PyTorch, onnx(>={ONNX_MIN_VER}), " \ | |||||
| f"onnxruntime(>={ONNXRUNTIME_MIN_VER}) and " \ | |||||
| f"onnxoptimizer(>={ONNXOPTIMIZER_MIN_VER}) " \ | |||||
| f"are required when using graph based " \ | |||||
| f"scripts converter, and PyTorch version must " \ | |||||
| f"be consisted with model generation runtime." | |||||
| if error_info: | |||||
| error = RuntimeIntegrityError(error_info) | |||||
| log.error(error) | log.error(error) | ||||
| log_console.error("\n") | log_console.error("\n") | ||||
| log_console.error(str(error)) | log_console.error(str(error)) | ||||
| log_console.error("\n") | log_console.error("\n") | ||||
| sys.exit(0) | sys.exit(0) | ||||
| onnx = import_module("onnx") | |||||
| ort = import_module("onnxruntime") | |||||
| if not lib_version_satisfied(getattr(onnx, "__version__"), ONNX_MIN_VER) \ | |||||
| or not lib_version_satisfied(getattr(ort, "__version__"), ONNXRUNTIME_MIN_VER): | |||||
| if not onnx_lib_version_satisfied(): | |||||
| error = RuntimeIntegrityError( | error = RuntimeIntegrityError( | ||||
| f"onnx(>={ONNX_MIN_VER}) and " | |||||
| f"onnxruntime(>={ONNXRUNTIME_MIN_VER}) are required when using graph " | |||||
| f"onnx(>={ONNX_MIN_VER}), " | |||||
| f"onnxruntime(>={ONNXRUNTIME_MIN_VER}) and " | |||||
| f"onnxoptimizer(>={ONNXOPTIMIZER_MIN_VER}) are required when using graph " | |||||
| f"based scripts converter for Pytorch conversion." | f"based scripts converter for Pytorch conversion." | ||||
| ) | ) | ||||
| log.error(error) | log.error(error) | ||||
| @@ -128,11 +147,11 @@ def tf_installation_validation(func): | |||||
| output_folder: str, report_folder: str = None, | output_folder: str, report_folder: str = None, | ||||
| input_nodes: str = None, output_nodes: str = None): | input_nodes: str = None, output_nodes: str = None): | ||||
| # Check whether tensorflow is installed. | # Check whether tensorflow is installed. | ||||
| if not _check_tf_installation() or not find_spec("tf2onnx") \ | |||||
| or not find_spec("onnx") or not find_spec("onnxruntime"): | |||||
| if not _check_tf_installation() or not onnx_satisfied(): | |||||
| error = RuntimeIntegrityError( | error = RuntimeIntegrityError( | ||||
| f"TensorFlow, tf2onnx(>={TF2ONNX_MIN_VER}), onnx(>={ONNX_MIN_VER}) and " | |||||
| f"onnxruntime(>={ONNXRUNTIME_MIN_VER}) are required when using graph " | |||||
| f"TensorFlow, tf2onnx(>={TF2ONNX_MIN_VER}), onnx(>={ONNX_MIN_VER}), " | |||||
| f"onnxruntime(>={ONNXRUNTIME_MIN_VER}) and onnxoptimizer(>={ONNXOPTIMIZER_MIN_VER}) " | |||||
| f"are required when using graph " | |||||
| f"based scripts converter for TensorFlow conversion." | f"based scripts converter for TensorFlow conversion." | ||||
| ) | ) | ||||
| log.error(error) | log.error(error) | ||||
| @@ -141,15 +160,14 @@ def tf_installation_validation(func): | |||||
| log_console.error("\n") | log_console.error("\n") | ||||
| sys.exit(0) | sys.exit(0) | ||||
| onnx, tf2onnx = import_module("onnx"), import_module("tf2onnx") | |||||
| ort = import_module("onnxruntime") | |||||
| tf2onnx = import_module("tf2onnx") | |||||
| if not lib_version_satisfied(getattr(onnx, "__version__"), ONNX_MIN_VER) \ | |||||
| or not lib_version_satisfied(getattr(ort, "__version__"), ONNXRUNTIME_MIN_VER) \ | |||||
| or not lib_version_satisfied(getattr(tf2onnx, "__version__"), TF2ONNX_MIN_VER): | |||||
| if not lib_version_satisfied(getattr(tf2onnx, "__version__"), TF2ONNX_MIN_VER) \ | |||||
| or not onnx_lib_version_satisfied(): | |||||
| error = RuntimeIntegrityError( | error = RuntimeIntegrityError( | ||||
| f"TensorFlow, tf2onnx(>={TF2ONNX_MIN_VER}), onnx(>={ONNX_MIN_VER}) and " | |||||
| f"onnxruntime(>={ONNXRUNTIME_MIN_VER}) are required when using graph " | |||||
| f"TensorFlow, tf2onnx(>={TF2ONNX_MIN_VER}), onnx(>={ONNX_MIN_VER}), " | |||||
| f"onnxruntime(>={ONNXRUNTIME_MIN_VER}) and onnxoptimizer(>={ONNXOPTIMIZER_MIN_VER}) " | |||||
| f"are required when using graph " | |||||
| f"based scripts converter for TensorFlow conversion." | f"based scripts converter for TensorFlow conversion." | ||||
| ) | ) | ||||
| log.error(error) | log.error(error) | ||||
| @@ -258,12 +276,12 @@ def main_graph_base_converter(file_config): | |||||
| raise ParamMissingError("Param missing, `--shape` is required when using graph mode.") | raise ParamMissingError("Param missing, `--shape` is required when using graph mode.") | ||||
| if frame_type == FrameworkType.PYTORCH.value: | if frame_type == FrameworkType.PYTORCH.value: | ||||
| check_params = ['input_nodes', 'output_nodes'] | |||||
| check_params_exist(check_params, file_config) | |||||
| graph_based_converter_pytorch_to_ms(graph_path=graph_path, | graph_based_converter_pytorch_to_ms(graph_path=graph_path, | ||||
| sample_shape=file_config['shape'], | sample_shape=file_config['shape'], | ||||
| input_nodes=file_config['input_nodes'], | |||||
| output_nodes=file_config['output_nodes'], | |||||
| input_nodes=file_config['input_nodes'] if file_config['input_nodes'] | |||||
| else 'input.1', | |||||
| output_nodes=file_config['output_nodes'] if file_config['output_nodes'] | |||||
| else '', | |||||
| output_folder=file_config['outfile_dir'], | output_folder=file_config['outfile_dir'], | ||||
| report_folder=file_config['report_dir']) | report_folder=file_config['report_dir']) | ||||
| elif frame_type == FrameworkType.TENSORFLOW.value: | elif frame_type == FrameworkType.TENSORFLOW.value: | ||||
| @@ -18,10 +18,10 @@ __all__ = ["batch_add_nodes"] | |||||
| import re | import re | ||||
| import copy | import copy | ||||
| from .generator import Generator, CodeStruct | |||||
| from ..common.code_fragment import CodeFragment, NewFragment | |||||
| from ..common.outputs import NodeOutputManager | |||||
| from ..constant import ExchangeMessageKeywords | |||||
| from mindinsight.mindconverter.graph_based_converter.generator.generator import Generator, CodeStruct | |||||
| from mindinsight.mindconverter.graph_based_converter.common.code_fragment import NewFragment | |||||
| from mindinsight.mindconverter.graph_based_converter.common.outputs import NodeOutputManager | |||||
| from mindinsight.mindconverter.graph_based_converter.constant import ExchangeMessageKeywords | |||||
| def _tf_model_node_name_reformat(node, node_name): | def _tf_model_node_name_reformat(node, node_name): | ||||
| @@ -16,6 +16,7 @@ | |||||
| import copy | import copy | ||||
| from collections import OrderedDict | from collections import OrderedDict | ||||
| from mindspore import Tensor | |||||
| from yapf.yapflib.yapf_api import FormatCode | from yapf.yapflib.yapf_api import FormatCode | ||||
| from mindinsight.mindconverter.common.exceptions import GeneratorError | from mindinsight.mindconverter.common.exceptions import GeneratorError | ||||
| @@ -28,7 +29,7 @@ from mindinsight.mindconverter.graph_based_converter.common.outputs import BaseO | |||||
| from mindinsight.mindconverter.graph_based_converter.common.yapf_config import mindspore_yapf_config | from mindinsight.mindconverter.graph_based_converter.common.yapf_config import mindspore_yapf_config | ||||
| from mindinsight.mindconverter.graph_based_converter.common.name_mgr import GlobalVarNameMgr | from mindinsight.mindconverter.graph_based_converter.common.name_mgr import GlobalVarNameMgr | ||||
| from mindinsight.mindconverter.graph_based_converter.constant import NEW_LINE, SECOND_LEVEL_INDENT, \ | from mindinsight.mindconverter.graph_based_converter.constant import NEW_LINE, SECOND_LEVEL_INDENT, \ | ||||
| FIRST_LEVEL_INDENT, get_imported_module | |||||
| FIRST_LEVEL_INDENT, get_imported_module, SEPARATOR_BTW_NAME_AND_ID | |||||
| from mindinsight.mindconverter.graph_based_converter.report_generator import ReportGenerator | from mindinsight.mindconverter.graph_based_converter.report_generator import ReportGenerator | ||||
| from mindinsight.mindconverter.graph_based_converter.common.utils import replace_string_in_list | from mindinsight.mindconverter.graph_based_converter.common.utils import replace_string_in_list | ||||
| @@ -469,6 +470,74 @@ class Generator: | |||||
| """Return all ModuleStructs in this model.""" | """Return all ModuleStructs in this model.""" | ||||
| return self._module_struct_collections | return self._module_struct_collections | ||||
| def generate_weight_scope_name(self, node): | |||||
| """Generate weight scope name for checkpoint.""" | |||||
| replaced_module_dict = self.node_structs[node].global_context_mgr.known_module_name | |||||
| scope_list = self.node_structs[node].scope.scope_list | |||||
| ms_var_name = self.node_structs[node].ms_var_name | |||||
| weight_scope_name = None | |||||
| for scope in scope_list[1:]: | |||||
| replaced_module = replaced_module_dict.get(scope.split(SEPARATOR_BTW_NAME_AND_ID)[0]) | |||||
| if replaced_module: | |||||
| scope = scope.replace(scope.split(SEPARATOR_BTW_NAME_AND_ID)[0], replaced_module) | |||||
| if not weight_scope_name: | |||||
| weight_scope_name = scope | |||||
| else: | |||||
| weight_scope_name = '.'.join((weight_scope_name, scope)) | |||||
| if not weight_scope_name: | |||||
| weight_scope_name = ms_var_name | |||||
| else: | |||||
| weight_scope_name = '.'.join((weight_scope_name, ms_var_name)) | |||||
| return weight_scope_name.lower() | |||||
| def generate_checkpoint(self): | |||||
| """Generate checkpoint.""" | |||||
| trainable_weights_dict = dict() | |||||
| weight_map = list() | |||||
| for node_name, node_inst in self.node_structs.items(): | |||||
| if node_inst.fragment.exchange_msg['var_0']['trainable_params']: | |||||
| weights_scope_name = self.generate_weight_scope_name(node_name) | |||||
| onnx_weight_inst = node_inst.fragment.exchange_msg['var_0']['weights'] | |||||
| for idx, (weight_key, weight_value) in \ | |||||
| enumerate(node_inst.fragment.exchange_msg['var_0']['trainable_params'].items()): | |||||
| weight_name = '.'.join((weights_scope_name, weight_key)) | |||||
| weight_shape = Tensor(weight_value).shape | |||||
| data_type = Tensor(weight_value).dtype | |||||
| trainable_weights_dict[weight_name] = weight_value | |||||
| onnx_weight_name = onnx_weight_inst[idx].name | |||||
| onnx_weight_shape = onnx_weight_inst[idx].value.shape | |||||
| onnx_data_type = onnx_weight_inst[idx].value.dtype | |||||
| weight_map.append( | |||||
| { | |||||
| 'converted_weight': { | |||||
| 'name': weight_name, | |||||
| 'shape': weight_shape, | |||||
| 'data_type': str(data_type) | |||||
| }, | |||||
| 'source_weight': { | |||||
| 'name': onnx_weight_name, | |||||
| 'shape': onnx_weight_shape, | |||||
| 'data_type': str(onnx_data_type) | |||||
| } | |||||
| } | |||||
| ) | |||||
| save_obj = list() | |||||
| for weight_name, weight_value in trainable_weights_dict.items(): | |||||
| obj = { | |||||
| 'name': weight_name, | |||||
| 'data': Tensor(weight_value) | |||||
| } | |||||
| save_obj.append(obj) | |||||
| return save_obj, weight_map | |||||
| @GeneratorError.check_except("Generator occurs an error when generating code statements.") | @GeneratorError.check_except("Generator occurs an error when generating code statements.") | ||||
| def generate(self): | def generate(self): | ||||
| """ | """ | ||||
| @@ -479,6 +548,9 @@ class Generator: | |||||
| """ | """ | ||||
| self._form_bottom_submodule() | self._form_bottom_submodule() | ||||
| self._recursive_form_module() | self._recursive_form_module() | ||||
| ckpt_data_list, weight_map = self.generate_checkpoint() | |||||
| CodeStruct(self.module_structs.get('[]'), self._repeated_submodules) | CodeStruct(self.module_structs.get('[]'), self._repeated_submodules) | ||||
| outputs = [get_imported_module()] | outputs = [get_imported_module()] | ||||
| @@ -494,7 +566,7 @@ class Generator: | |||||
| report = report_generator.gen_report(formatted_code) | report = report_generator.gen_report(formatted_code) | ||||
| del self._global_context | del self._global_context | ||||
| return {"model": (formatted_code, report)} | |||||
| return {"model": (formatted_code, report, ckpt_data_list, weight_map)} | |||||
| def get_node_struct(self, node_identifier): | def get_node_struct(self, node_identifier): | ||||
| """ | """ | ||||
| @@ -17,13 +17,13 @@ | |||||
| import copy | import copy | ||||
| from collections import OrderedDict | from collections import OrderedDict | ||||
| from .node_struct import NodeStruct | |||||
| from .scope_utils import Scope | |||||
| from ..common.utils import get_dict_key_by_value | |||||
| from .args_translator import ArgsTranslation | |||||
| from ..common.code_fragment import ModuleFragment | |||||
| from ..common.global_context import GlobalContext | |||||
| from ..common.name_mgr import LocalVarNameMgr | |||||
| from mindinsight.mindconverter.graph_based_converter.generator.node_struct import NodeStruct | |||||
| from mindinsight.mindconverter.graph_based_converter.generator.scope_utils import Scope | |||||
| from mindinsight.mindconverter.graph_based_converter.common.utils import get_dict_key_by_value | |||||
| from mindinsight.mindconverter.graph_based_converter.generator.args_translator import ArgsTranslation | |||||
| from mindinsight.mindconverter.graph_based_converter.common.code_fragment import ModuleFragment | |||||
| from mindinsight.mindconverter.graph_based_converter.common.global_context import GlobalContext | |||||
| from mindinsight.mindconverter.graph_based_converter.common.name_mgr import LocalVarNameMgr | |||||
| class ModuleStruct: | class ModuleStruct: | ||||
| @@ -17,11 +17,11 @@ from collections import OrderedDict | |||||
| from mindinsight.mindconverter.graph_based_converter.common.code_fragment import NewFragment | from mindinsight.mindconverter.graph_based_converter.common.code_fragment import NewFragment | ||||
| from mindinsight.mindconverter.graph_based_converter.generator.fragment_utils import FragmentHandler | from mindinsight.mindconverter.graph_based_converter.generator.fragment_utils import FragmentHandler | ||||
| from .scope_utils import Scope | |||||
| from .args_translator import ArgsTranslation | |||||
| from ..third_party_graph.onnx_graph_node import OnnxGraphNode | |||||
| from ..common.global_context import GlobalContext | |||||
| from ...common.exceptions import GeneratorError | |||||
| from mindinsight.mindconverter.graph_based_converter.generator.scope_utils import Scope | |||||
| from mindinsight.mindconverter.graph_based_converter.generator.args_translator import ArgsTranslation | |||||
| from mindinsight.mindconverter.graph_based_converter.third_party_graph.onnx_graph_node import OnnxGraphNode | |||||
| from mindinsight.mindconverter.graph_based_converter.common.global_context import GlobalContext | |||||
| from mindinsight.mindconverter.common.exceptions import GeneratorError | |||||
| class NodeStruct: | class NodeStruct: | ||||
| @@ -1,4 +1,4 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. | |||||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd.All Rights Reserved. | |||||
| # | # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | # you may not use this file except in compliance with the License. | ||||
| @@ -16,4 +16,4 @@ | |||||
| __all__ = ["ONNXToMindSporeMapper"] | __all__ = ["ONNXToMindSporeMapper"] | ||||
| from .base import ONNXToMindSporeMapper | |||||
| from mindinsight.mindconverter.graph_based_converter.mapper.base import ONNXToMindSporeMapper | |||||
| @@ -108,18 +108,21 @@ class ONNXToMindSporeMapper(Mapper, abc.ABC): | |||||
| try: | try: | ||||
| converter_name = op_name_converter(params=params, weights=weights, op_name=op_name) | converter_name = op_name_converter(params=params, weights=weights, op_name=op_name) | ||||
| converted_params = params_converter(params=params, weights=weights) | converted_params = params_converter(params=params, weights=weights) | ||||
| if "input_shape" in converted_params: | if "input_shape" in converted_params: | ||||
| converted_params.pop("input_shape") | converted_params.pop("input_shape") | ||||
| if "output_shape" in converted_params: | if "output_shape" in converted_params: | ||||
| converted_params.pop("output_shape") | converted_params.pop("output_shape") | ||||
| # set to converted_weights to enable weight migration | # set to converted_weights to enable weight migration | ||||
| _ = weights_converter(weights=weights) if weights else dict() | |||||
| converted_weights = weights_converter(weights=weights) if weights else dict() | |||||
| code_template, exchange_msg, outputs_list, outputs_mapping = template_generator( | code_template, exchange_msg, outputs_list, outputs_mapping = template_generator( | ||||
| operation=converter_name, | operation=converter_name, | ||||
| converted_params=converted_params, | converted_params=converted_params, | ||||
| raw_params=params, | raw_params=params, | ||||
| weights=weights | |||||
| weights=weights, | |||||
| trainable_params=converted_weights | |||||
| ) | ) | ||||
| except (AttributeError, KeyError, ValueError, TypeError, IndexError) as e: | except (AttributeError, KeyError, ValueError, TypeError, IndexError) as e: | ||||
| err_msg = f"Converting {op_name} failed, see {str(e)}" | err_msg = f"Converting {op_name} failed, see {str(e)}" | ||||
| log.error(err_msg) | log.error(err_msg) | ||||
| @@ -148,6 +151,7 @@ class ONNXToMindSporeMapper(Mapper, abc.ABC): | |||||
| op = kwargs.get("operation") | op = kwargs.get("operation") | ||||
| args = kwargs.get("converted_params", dict()) | args = kwargs.get("converted_params", dict()) | ||||
| weights = kwargs.get("weights") | weights = kwargs.get("weights") | ||||
| trainable_params = kwargs.get("trainable_params", dict()) | |||||
| if not op: | if not op: | ||||
| raise ValueError("Can not get MindSpore operation name.") | raise ValueError("Can not get MindSpore operation name.") | ||||
| variable_slot = "var_0" | variable_slot = "var_0" | ||||
| @@ -169,7 +173,7 @@ class ONNXToMindSporeMapper(Mapper, abc.ABC): | |||||
| ExchangeMessageKeywords.VariableScope.value.INPUTS.value: [], | ExchangeMessageKeywords.VariableScope.value.INPUTS.value: [], | ||||
| ExchangeMessageKeywords.VariableScope.value.ARGS.value: args, | ExchangeMessageKeywords.VariableScope.value.ARGS.value: args, | ||||
| ExchangeMessageKeywords.VariableScope.value.WEIGHTS.value: weights, | ExchangeMessageKeywords.VariableScope.value.WEIGHTS.value: weights, | ||||
| ExchangeMessageKeywords.VariableScope.value.TRAINABLE_PARAMS.value: {} | |||||
| ExchangeMessageKeywords.VariableScope.value.TRAINABLE_PARAMS.value: trainable_params | |||||
| } | } | ||||
| } | } | ||||
| outputs_list = [f"opt_{{{variable_slot}}}"] | outputs_list = [f"opt_{{{variable_slot}}}"] | ||||
| @@ -177,11 +181,14 @@ class ONNXToMindSporeMapper(Mapper, abc.ABC): | |||||
| return template, exchange_msg, outputs_list, outputs_mapping | return template, exchange_msg, outputs_list, outputs_mapping | ||||
| @staticmethod | @staticmethod | ||||
| def _find_val_by_index(loc_index, values_dict): | |||||
| """Find value by location index of values_dict.""" | |||||
| def _find_val_by_index(loc_index, weights_list): | |||||
| """Find value by location index of weights_list.""" | |||||
| result = None | result = None | ||||
| for idx, dict_val in enumerate(values_dict.values()): | |||||
| if loc_index < 0: | |||||
| return weights_list[loc_index].value | |||||
| for idx, weight in enumerate(weights_list): | |||||
| if idx == loc_index: | if idx == loc_index: | ||||
| result = dict_val | |||||
| result = weight.value | |||||
| break | break | ||||
| return result | return result | ||||
| @@ -14,7 +14,6 @@ | |||||
| # ============================================================================== | # ============================================================================== | ||||
| """Mapper module.""" | """Mapper module.""" | ||||
| from mindinsight.mindconverter.graph_based_converter.mapper.base import ONNXToMindSporeMapper | from mindinsight.mindconverter.graph_based_converter.mapper.base import ONNXToMindSporeMapper | ||||
| from mindinsight.mindconverter.graph_based_converter.mapper.gen_setting import Setting | |||||
| class BatchNormMapper(ONNXToMindSporeMapper): | class BatchNormMapper(ONNXToMindSporeMapper): | ||||
| @@ -36,8 +35,14 @@ class BatchNormMapper(ONNXToMindSporeMapper): | |||||
| @staticmethod | @staticmethod | ||||
| def _convert_trained_weights(**kwargs): | def _convert_trained_weights(**kwargs): | ||||
| return dict() | |||||
| @staticmethod | |||||
| def _convert_settings(**kwargs): | |||||
| return Setting() | |||||
| weights = kwargs['weights'] | |||||
| gamma = BatchNormMapper._find_val_by_index(0, weights) | |||||
| beta = BatchNormMapper._find_val_by_index(1, weights) | |||||
| moving_mean = BatchNormMapper._find_val_by_index(2, weights) | |||||
| moving_variance = BatchNormMapper._find_val_by_index(3, weights) | |||||
| return { | |||||
| 'gamma': gamma, | |||||
| 'beta': beta, | |||||
| 'moving_mean': moving_mean, | |||||
| 'moving_variance': moving_variance | |||||
| } | |||||
| @@ -14,7 +14,6 @@ | |||||
| # ============================================================================== | # ============================================================================== | ||||
| """Mapper module.""" | """Mapper module.""" | ||||
| import numpy as np | import numpy as np | ||||
| from mindinsight.mindconverter.graph_based_converter.mapper.gen_setting import Setting | |||||
| from mindinsight.mindconverter.graph_based_converter.mapper.base import ONNXToMindSporeMapper | from mindinsight.mindconverter.graph_based_converter.mapper.base import ONNXToMindSporeMapper | ||||
| from mindinsight.mindconverter.graph_based_converter.common.utils import convert_bytes_string_to_string | from mindinsight.mindconverter.graph_based_converter.common.utils import convert_bytes_string_to_string | ||||
| @@ -42,7 +41,7 @@ class ConvMapper(ONNXToMindSporeMapper): | |||||
| """Convert params from PyTorch to MindSpore""" | """Convert params from PyTorch to MindSpore""" | ||||
| weights = kwargs['weights'] | weights = kwargs['weights'] | ||||
| params = kwargs['params'] | params = kwargs['params'] | ||||
| weight = weights['weight'] | |||||
| weight = ConvMapper._find_val_by_index(0, weights) | |||||
| weight = np.transpose(weight, list(range(2, weight.ndim)) + [1, 0]) | weight = np.transpose(weight, list(range(2, weight.ndim)) + [1, 0]) | ||||
| if isinstance(params['dilations'], list): | if isinstance(params['dilations'], list): | ||||
| dilation = tuple(params['dilations']) | dilation = tuple(params['dilations']) | ||||
| @@ -76,11 +75,13 @@ class ConvMapper(ONNXToMindSporeMapper): | |||||
| """Convert params from Tensorflow to MindSpore""" | """Convert params from Tensorflow to MindSpore""" | ||||
| weights = kwargs['weights'] | weights = kwargs['weights'] | ||||
| params = kwargs['params'] | params = kwargs['params'] | ||||
| # regex to find Conv weight | |||||
| weight = list(weights.values())[0] | |||||
| weight = ConvMapper._find_val_by_index(0, weights) | |||||
| bias = ConvMapper._find_val_by_index(1, weights) | |||||
| if weight is None: | if weight is None: | ||||
| raise ValueError("Conv. Mapper cannot get the weight.") | raise ValueError("Conv. Mapper cannot get the weight.") | ||||
| has_bias = isinstance(bias, np.ndarray) | |||||
| auto_pad = None | auto_pad = None | ||||
| if params.get("auto_pad") is not None: | if params.get("auto_pad") is not None: | ||||
| auto_pad = convert_bytes_string_to_string(params.get("auto_pad")) | auto_pad = convert_bytes_string_to_string(params.get("auto_pad")) | ||||
| @@ -119,18 +120,14 @@ class ConvMapper(ONNXToMindSporeMapper): | |||||
| 'padding': padding, | 'padding': padding, | ||||
| 'pad_mode': pad_mode, | 'pad_mode': pad_mode, | ||||
| 'dilation': dilation, | 'dilation': dilation, | ||||
| 'group': params.get('group', 1)} | |||||
| 'group': params.get('group', 1), | |||||
| 'has_bias': has_bias | |||||
| } | |||||
| @staticmethod | @staticmethod | ||||
| def _operation_name_in_ms(*args, **kwargs): | def _operation_name_in_ms(*args, **kwargs): | ||||
| weight = kwargs['weights'].get('weight', 'empty') | |||||
| if weight == 'empty': # is from tf | |||||
| kernel_size = kwargs['params'].get('kernel_shape') | |||||
| dim = len(kernel_size) | |||||
| return f"nn.Conv{dim}d" | |||||
| dim = weight.ndim - 2 | |||||
| kernel_size = kwargs['params'].get('kernel_shape') | |||||
| dim = len(kernel_size) | |||||
| return f"nn.Conv{dim}d" | return f"nn.Conv{dim}d" | ||||
| @staticmethod | @staticmethod | ||||
| @@ -138,14 +135,16 @@ class ConvMapper(ONNXToMindSporeMapper): | |||||
| weights = kwargs['weights'] | weights = kwargs['weights'] | ||||
| params = kwargs['params'] | params = kwargs['params'] | ||||
| if weights.get('weight', 'empty') == 'empty': # is from tf | |||||
| return ConvMapper.convert_params_tf(params=params, weights=weights) | |||||
| return ConvMapper.convert_params_torch(params=params, weights=weights) | |||||
| return ConvMapper.convert_params_tf(params=params, weights=weights) | |||||
| @staticmethod | @staticmethod | ||||
| def _convert_trained_weights(**kwargs): | def _convert_trained_weights(**kwargs): | ||||
| return dict() | |||||
| weights = kwargs['weights'] | |||||
| weight = ConvMapper._find_val_by_index(0, weights) | |||||
| bias = ConvMapper._find_val_by_index(1, weights) | |||||
| @staticmethod | |||||
| def _convert_settings(**kwargs): | |||||
| return Setting() | |||||
| converted_weights = {'weight': weight} | |||||
| if isinstance(bias, np.ndarray): | |||||
| converted_weights['bias'] = bias | |||||
| return converted_weights | |||||
| @@ -15,7 +15,6 @@ | |||||
| """Mapper module.""" | """Mapper module.""" | ||||
| import numpy as np | import numpy as np | ||||
| from mindinsight.mindconverter.graph_based_converter.mapper.base import ONNXToMindSporeMapper | from mindinsight.mindconverter.graph_based_converter.mapper.base import ONNXToMindSporeMapper | ||||
| from mindinsight.mindconverter.graph_based_converter.mapper.gen_setting import Setting | |||||
| class DenseMapper(ONNXToMindSporeMapper): | class DenseMapper(ONNXToMindSporeMapper): | ||||
| @@ -42,8 +41,10 @@ class DenseMapper(ONNXToMindSporeMapper): | |||||
| @staticmethod | @staticmethod | ||||
| def _convert_trained_weights(**kwargs): | def _convert_trained_weights(**kwargs): | ||||
| return dict() | |||||
| @staticmethod | |||||
| def _convert_settings(**kwargs): | |||||
| return Setting() | |||||
| weights = kwargs['weights'] | |||||
| weight = DenseMapper._find_val_by_index(0, weights) | |||||
| bias = DenseMapper._find_val_by_index(1, weights) | |||||
| return { | |||||
| 'weight': weight, | |||||
| 'bias': bias | |||||
| } | |||||
| @@ -30,7 +30,9 @@ class MatMulMapper(ONNXToMindSporeMapper): | |||||
| @staticmethod | @staticmethod | ||||
| def _convert_trained_weights(**kwargs): | def _convert_trained_weights(**kwargs): | ||||
| return dict() | |||||
| weights = kwargs['weights'] | |||||
| weight = MatMulMapper._find_val_by_index(0, weights) | |||||
| return {'weight': weight} | |||||
| @staticmethod | @staticmethod | ||||
| def _generate_snippet_template(**kwargs): | def _generate_snippet_template(**kwargs): | ||||
| @@ -44,8 +46,7 @@ class MatMulMapper(ONNXToMindSporeMapper): | |||||
| if not weights: | if not weights: | ||||
| return template, exchange_msg, outputs_list, outputs_mapping | return template, exchange_msg, outputs_list, outputs_mapping | ||||
| weight = list(weights.items())[0] | |||||
| _, tensor = weight | |||||
| tensor = MatMulMapper._find_val_by_index(0, weights) | |||||
| variable_slot = "var_0" | variable_slot = "var_0" | ||||
| init_template = f"self.{{{variable_slot}}} = {op}({', '.join(['%s={%s}' % (p, p) for p in args])})" | init_template = f"self.{{{variable_slot}}} = {op}({', '.join(['%s={%s}' % (p, p) for p in args])})" | ||||
| @@ -15,7 +15,6 @@ | |||||
| """Mapper module.""" | """Mapper module.""" | ||||
| from mindinsight.mindconverter.graph_based_converter.common.utils import convert_bytes_string_to_string | from mindinsight.mindconverter.graph_based_converter.common.utils import convert_bytes_string_to_string | ||||
| from mindinsight.mindconverter.graph_based_converter.mapper.base import ONNXToMindSporeMapper | from mindinsight.mindconverter.graph_based_converter.mapper.base import ONNXToMindSporeMapper | ||||
| from mindinsight.mindconverter.graph_based_converter.mapper.gen_setting import Setting | |||||
| def _padding_format_convert(padding: list): | def _padding_format_convert(padding: list): | ||||
| @@ -49,7 +48,7 @@ class PadMapper(ONNXToMindSporeMapper): | |||||
| weights = kwargs.get("weights") | weights = kwargs.get("weights") | ||||
| params = kwargs.get("params") | params = kwargs.get("params") | ||||
| mode = convert_bytes_string_to_string(params.get('mode', 'constant')) | mode = convert_bytes_string_to_string(params.get('mode', 'constant')) | ||||
| pads_onnx = params.get("pads") if params.get("pads") else list(weights.values())[0].tolist() | |||||
| pads_onnx = params.get("pads") if params.get("pads") else PadMapper._find_val_by_index(0, weights).tolist() | |||||
| if mode == 'constant' and params.get('value') is None: | if mode == 'constant' and params.get('value') is None: | ||||
| if params.get('pads') or weights: | if params.get('pads') or weights: | ||||
| if isinstance(pads_onnx, list): | if isinstance(pads_onnx, list): | ||||
| @@ -76,7 +75,3 @@ class PadMapper(ONNXToMindSporeMapper): | |||||
| @staticmethod | @staticmethod | ||||
| def _convert_trained_weights(**kwargs): | def _convert_trained_weights(**kwargs): | ||||
| return dict() | return dict() | ||||
| @staticmethod | |||||
| def _convert_settings(**kwargs): | |||||
| return Setting() | |||||
| @@ -13,12 +13,16 @@ | |||||
| # limitations under the License. | # limitations under the License. | ||||
| # ============================================================================== | # ============================================================================== | ||||
| """Mapper module.""" | """Mapper module.""" | ||||
| import math | |||||
| import numpy as np | |||||
| from mindinsight.mindconverter.graph_based_converter.mapper.base import ONNXToMindSporeMapper | from mindinsight.mindconverter.graph_based_converter.mapper.base import ONNXToMindSporeMapper | ||||
| from mindinsight.mindconverter.graph_based_converter.mapper.gen_setting import Setting | |||||
| from mindinsight.mindconverter.graph_based_converter.constant import ExchangeMessageKeywords, TemplateKeywords | |||||
| class PoolMapper(ONNXToMindSporeMapper): | class PoolMapper(ONNXToMindSporeMapper): | ||||
| """MaxPool mapper.""" | |||||
| """Pool mapper.""" | |||||
| @staticmethod | @staticmethod | ||||
| def _operation_name_in_ms(*args, **kwargs): | def _operation_name_in_ms(*args, **kwargs): | ||||
| @@ -35,12 +39,6 @@ class PoolMapper(ONNXToMindSporeMapper): | |||||
| transformed_params = dict() | transformed_params = dict() | ||||
| transformed_params["kernel_size"] = tuple(params['kernel_shape']) | transformed_params["kernel_size"] = tuple(params['kernel_shape']) | ||||
| transformed_params["stride"] = tuple(params['strides']) | transformed_params["stride"] = tuple(params['strides']) | ||||
| if "pads" in params: | |||||
| if sum(params['pads']) == 0 and not params.get('ceil_mode', None): | |||||
| pad_mode = '\"valid\"' | |||||
| else: | |||||
| pad_mode = '\"same\"' | |||||
| transformed_params["pad_mode"] = pad_mode | |||||
| return transformed_params | return transformed_params | ||||
| @@ -49,5 +47,100 @@ class PoolMapper(ONNXToMindSporeMapper): | |||||
| return dict() | return dict() | ||||
| @staticmethod | @staticmethod | ||||
| def _convert_settings(**kwargs): | |||||
| return Setting() | |||||
| def _get_ms_opt_shape(**kwargs): | |||||
| """Get output shape in MindSpore.""" | |||||
| params = kwargs['raw_params'] | |||||
| input_shape = params['input_shape'] | |||||
| kernel_shape = params['kernel_shape'] | |||||
| strides = params['strides'] | |||||
| dilations = params.get('dilations', (1, 1)) | |||||
| # For mindspore, | |||||
| # output_shape[i] = ceil((input_shape[i] - ((kernel_shape[i] - 1) * dilations[i] + 1) + 1) / strides[i]) | |||||
| ms_opt_shape = np.true_divide(np.subtract(np.array(input_shape[-len(kernel_shape):], dtype=np.float32), | |||||
| ((np.array(kernel_shape, dtype=np.float32) - 1) * | |||||
| np.array(dilations, dtype=np.float32) + 1)) + 1, | |||||
| np.array(strides, dtype=np.float32)).tolist() | |||||
| ms_opt_shape_ceil = tuple(math.ceil(ms_opt_shape_axis) for ms_opt_shape_axis in ms_opt_shape) | |||||
| return ms_opt_shape_ceil | |||||
| @staticmethod | |||||
| def _generate_snippet_template(**kwargs): | |||||
| template, exchange_msg, outputs_list, outputs_mapping = ONNXToMindSporeMapper._generate_snippet_template( | |||||
| **kwargs) | |||||
| op = kwargs.get("operation") | |||||
| args = kwargs.get("converted_params", dict()) | |||||
| ms_opt_shape = PoolMapper._get_ms_opt_shape(**kwargs) | |||||
| tensor_opt_shape = kwargs['raw_params']['output_shape'] | |||||
| tensor_ipt_shape = kwargs['raw_params']['input_shape'] | |||||
| kernel_shape = kwargs['raw_params']['kernel_shape'] | |||||
| dilations = kwargs['raw_params'].get('dilations', (1, 1)) | |||||
| strides = kwargs['raw_params']['strides'] | |||||
| onnx_opt_shape = tensor_opt_shape[-len(ms_opt_shape):] | |||||
| if np.all(np.array(ms_opt_shape) == np.array(onnx_opt_shape)): | |||||
| return template, exchange_msg, outputs_list, outputs_mapping | |||||
| variable_slot = "var_0" | |||||
| init_template = f"self.{{{variable_slot}}} = {op}({', '.join(['%s={%s}' % (p, p) for p in args])})" | |||||
| construct_template = f"opt_{{{variable_slot}}} = self.{{{variable_slot}}}(opt_{{{variable_slot}}})" | |||||
| init_template_pad, construct_template_pad, paddings = \ | |||||
| PoolMapper._generate_pad_init_and_construct(tensor_opt_shape, tensor_ipt_shape, | |||||
| ms_opt_shape, variable_slot, | |||||
| kernel_shape, dilations, strides) | |||||
| template = { | |||||
| variable_slot: { | |||||
| TemplateKeywords.INIT.value: [init_template_pad, init_template], | |||||
| TemplateKeywords.CONSTRUCT.value: [construct_template_pad, construct_template] | |||||
| } | |||||
| } | |||||
| args['paddings'] = paddings | |||||
| exchange_msg = { | |||||
| variable_slot: { | |||||
| ExchangeMessageKeywords.VariableScope.value.OPERATION.value: op, | |||||
| ExchangeMessageKeywords.VariableScope.value.VARIABLE_NAME.value: None, | |||||
| ExchangeMessageKeywords.VariableScope.value.OUTPUT_TYPE.value: | |||||
| ExchangeMessageKeywords.VariableScope.value.TSR_TYPE.value, | |||||
| ExchangeMessageKeywords.VariableScope.value.INPUTS.value: [], | |||||
| ExchangeMessageKeywords.VariableScope.value.ARGS.value: args, | |||||
| ExchangeMessageKeywords.VariableScope.value.WEIGHTS.value: dict(), | |||||
| ExchangeMessageKeywords.VariableScope.value.TRAINABLE_PARAMS.value: dict() | |||||
| } | |||||
| } | |||||
| return template, exchange_msg, outputs_list, outputs_mapping | |||||
| @staticmethod | |||||
| def _generate_pad_init_and_construct(tensor_opt_shape, tensor_ipt_shape, | |||||
| ms_opt_shape, variable_slot, kernel_shape, dilations, strides): | |||||
| """Generate pad code in init and construct.""" | |||||
| onnx_opt_shape = tensor_opt_shape[-len(ms_opt_shape):] | |||||
| onnx_ipt_shape = tensor_ipt_shape[-len(ms_opt_shape):] | |||||
| if np.any(np.array(ms_opt_shape) > np.array(onnx_opt_shape)): | |||||
| raise ValueError(f"ms_opt_shape[{ms_opt_shape}] should be no larger than onnx_opt_shape[{onnx_opt_shape}].") | |||||
| # shape_diff[i] = (onnx_opt_shape[i] - 1)*strides[i] - | |||||
| # (onnx_ipt_shape[i] - ((kernel_shape[i] - 1)*dilations[i] + 1)) | |||||
| shape_diff = np.subtract((np.array(onnx_opt_shape) - 1)*np.array(strides), | |||||
| np.subtract(np.array(onnx_ipt_shape), | |||||
| (np.array(kernel_shape) - 1)*np.array(dilations) + 1)).tolist() | |||||
| zero_pad_single = (0, 0) | |||||
| paddings = [zero_pad_single] | |||||
| num_zero_pads = len(tensor_opt_shape) - len(ms_opt_shape) | |||||
| for _ in range(num_zero_pads - 1): | |||||
| paddings.append(zero_pad_single) | |||||
| for axis_diff in shape_diff: | |||||
| paddings.append((int(axis_diff//2), int(axis_diff//2 + axis_diff % 2))) | |||||
| init_template_pad = f"self.pad_{{{variable_slot}}} = nn.Pad(paddings={{paddings}})" | |||||
| construct_template_pad = f"opt_{{{variable_slot}}} = self.pad_{{{variable_slot}}}" \ | |||||
| f"({{{ExchangeMessageKeywords.VariableScope.value.INPUTS.value}}})" | |||||
| return init_template_pad, construct_template_pad, tuple(paddings) | |||||
| @@ -44,8 +44,7 @@ class AddMapper(ONNXToMindSporeMapper): | |||||
| if not weights: | if not weights: | ||||
| return template, exchange_msg, outputs_list, outputs_mapping | return template, exchange_msg, outputs_list, outputs_mapping | ||||
| bias = list(weights.items())[0] | |||||
| _, tensor = bias | |||||
| tensor = AddMapper._find_val_by_index(0, weights) | |||||
| variable_slot = "var_0" | variable_slot = "var_0" | ||||
| init_template = f"self.{{{variable_slot}}} = {op}({', '.join(['%s={%s}' % (p, p) for p in args])})" | init_template = f"self.{{{variable_slot}}} = {op}({', '.join(['%s={%s}' % (p, p) for p in args])})" | ||||
| @@ -53,8 +53,7 @@ class MulMapper(ONNXToMindSporeMapper): | |||||
| if not weights: | if not weights: | ||||
| return template, exchange_msg, outputs_list, outputs_mapping | return template, exchange_msg, outputs_list, outputs_mapping | ||||
| weight = list(weights.items())[0] | |||||
| _, tensor = weight | |||||
| tensor = MulMapper._find_val_by_index(0, weights) | |||||
| variable_slot = "var_0" | variable_slot = "var_0" | ||||
| init_template = f"self.{{{variable_slot}}} = {op}({', '.join(['%s={%s}' % (p, p) for p in args])})" | init_template = f"self.{{{variable_slot}}} = {op}({', '.join(['%s={%s}' % (p, p) for p in args])})" | ||||
| @@ -60,7 +60,7 @@ class ResizeMapper(ONNXToMindSporeMapper): | |||||
| align_corners = True | align_corners = True | ||||
| # Get requested size for resize | # Get requested size for resize | ||||
| size = list(weights.values())[-1][-2:].tolist() | |||||
| size = ResizeMapper._find_val_by_index(-1, weights)[-2:].tolist() | |||||
| return {"size": tuple(size), | return {"size": tuple(size), | ||||
| "align_corners": align_corners} | "align_corners": align_corners} | ||||
| @@ -48,7 +48,7 @@ class SliceMapper(ONNXToMindSporeMapper): | |||||
| def _generate_snippet_template(**kwargs): | def _generate_snippet_template(**kwargs): | ||||
| template, exchange_msg, outputs_list, outputs_mapping = ONNXToMindSporeMapper._generate_snippet_template( | template, exchange_msg, outputs_list, outputs_mapping = ONNXToMindSporeMapper._generate_snippet_template( | ||||
| **kwargs) | **kwargs) | ||||
| weights = list(kwargs.get("weights").values()) # start, end, axis | |||||
| weights = [weight.value for weight in kwargs.get('weights')] # start, end, axis | |||||
| opt_shape = kwargs["raw_params"]["output_shape"] | opt_shape = kwargs["raw_params"]["output_shape"] | ||||
| if not weights: | if not weights: | ||||
| raise ValueError("Cannot get required params from slice.") | raise ValueError("Cannot get required params from slice.") | ||||
| @@ -1,4 +1,4 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. | |||||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd.All Rights Reserved. | |||||
| # | # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | # you may not use this file except in compliance with the License. | ||||
| @@ -15,4 +15,4 @@ | |||||
| """Searcher of scope name.""" | """Searcher of scope name.""" | ||||
| __all__ = ["generate_scope_name"] | __all__ = ["generate_scope_name"] | ||||
| from .searcher import generate_scope_name | |||||
| from mindinsight.mindconverter.graph_based_converter.sub_graph_searcher.searcher import generate_scope_name | |||||
| @@ -1,4 +1,4 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. | |||||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd.All Rights Reserved. | |||||
| # | # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | # you may not use this file except in compliance with the License. | ||||
| @@ -16,8 +16,8 @@ | |||||
| __all__ = ["BUILT_IN_PATTERN", "register_pattern", "is_built_in_pattern"] | __all__ = ["BUILT_IN_PATTERN", "register_pattern", "is_built_in_pattern"] | ||||
| from .common import cal_matching_score | |||||
| from .pattern import Pattern | |||||
| from mindinsight.mindconverter.graph_based_converter.sub_graph_searcher.common import cal_matching_score | |||||
| from mindinsight.mindconverter.graph_based_converter.sub_graph_searcher.pattern import Pattern | |||||
| BUILT_IN_PATTERN = dict() | BUILT_IN_PATTERN = dict() | ||||
| @@ -1,4 +1,4 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. | |||||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd.All Rights Reserved. | |||||
| # | # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | # you may not use this file except in compliance with the License. | ||||
| @@ -17,12 +17,15 @@ import copy | |||||
| import uuid | import uuid | ||||
| from typing import Dict, List, Callable, Union | from typing import Dict, List, Callable, Union | ||||
| from collections import OrderedDict | from collections import OrderedDict | ||||
| from .common import context, gen_hash_key, DagGraph, MAX_OUT_DEGREE, cal_matching_score | |||||
| from .known_module_name import BUILT_IN_MODULE_NAME | |||||
| from .pattern import Pattern, scope_name_mapping | |||||
| from .built_in_pattern import BUILT_IN_PATTERN, is_built_in_pattern | |||||
| from .pattern_fuzzy_matching import pattern_fuzzy_matching | |||||
| from ..third_party_graph.onnx_utils import OnnxNode, BaseNode | |||||
| from mindinsight.mindconverter.graph_based_converter.sub_graph_searcher.common import context, gen_hash_key, DagGraph, \ | |||||
| MAX_OUT_DEGREE, cal_matching_score | |||||
| from mindinsight.mindconverter.graph_based_converter.sub_graph_searcher.known_module_name import BUILT_IN_MODULE_NAME | |||||
| from mindinsight.mindconverter.graph_based_converter.sub_graph_searcher.pattern import Pattern, scope_name_mapping | |||||
| from mindinsight.mindconverter.graph_based_converter.sub_graph_searcher.built_in_pattern import BUILT_IN_PATTERN, \ | |||||
| is_built_in_pattern | |||||
| from mindinsight.mindconverter.graph_based_converter.sub_graph_searcher.pattern_fuzzy_matching import \ | |||||
| pattern_fuzzy_matching | |||||
| from mindinsight.mindconverter.graph_based_converter.third_party_graph.onnx_utils import OnnxNode, BaseNode | |||||
| module_name_to_src = {} | module_name_to_src = {} | ||||
| used_module_name = dict() | used_module_name = dict() | ||||
| @@ -16,12 +16,15 @@ | |||||
| from queue import PriorityQueue | from queue import PriorityQueue | ||||
| from typing import Dict, List | from typing import Dict, List | ||||
| from .common import context, DagGraph, gen_hash_key, ACCEPTABLE_RESULT_COUNT | |||||
| from .common import MINI_FREQUENCY, MAX_ITERATION_DEPTH, SATISFIED_SCORE | |||||
| from ..common.global_context import GlobalContext | |||||
| from ..third_party_graph.onnx_utils import BaseNode | |||||
| from .search_path import SearchPath, Pattern, generate_pattern, find_built_in_pattern | |||||
| from ...common.exceptions import SubGraphSearchingError | |||||
| from mindinsight.mindconverter.graph_based_converter.sub_graph_searcher.common import context, DagGraph, gen_hash_key, \ | |||||
| ACCEPTABLE_RESULT_COUNT | |||||
| from mindinsight.mindconverter.graph_based_converter.sub_graph_searcher.common import MINI_FREQUENCY, \ | |||||
| MAX_ITERATION_DEPTH, SATISFIED_SCORE | |||||
| from mindinsight.mindconverter.graph_based_converter.common.global_context import GlobalContext | |||||
| from mindinsight.mindconverter.graph_based_converter.third_party_graph.onnx_utils import BaseNode | |||||
| from mindinsight.mindconverter.graph_based_converter.sub_graph_searcher.search_path import SearchPath, Pattern, \ | |||||
| generate_pattern, find_built_in_pattern | |||||
| from mindinsight.mindconverter.common.exceptions import SubGraphSearchingError | |||||
| def _is_satisfied(path): | def _is_satisfied(path): | ||||
| @@ -17,7 +17,7 @@ | |||||
| __all__ = ["GraphFactory"] | __all__ = ["GraphFactory"] | ||||
| from importlib import import_module | from importlib import import_module | ||||
| from .base import Graph | |||||
| from mindinsight.mindconverter.graph_based_converter.third_party_graph.base import Graph | |||||
| class GraphFactory: | class GraphFactory: | ||||
| @@ -1,4 +1,4 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. | |||||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd.All Rights Reserved. | |||||
| # | # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | # you may not use this file except in compliance with the License. | ||||
| @@ -15,8 +15,8 @@ | |||||
| """Define PyTorch graph node.""" | """Define PyTorch graph node.""" | ||||
| import os | import os | ||||
| from .base import GraphNode | |||||
| from ..constant import SEPARATOR_IN_SCOPE, NodeType | |||||
| from mindinsight.mindconverter.graph_based_converter.third_party_graph.base import GraphNode | |||||
| from mindinsight.mindconverter.graph_based_converter.constant import SEPARATOR_IN_SCOPE, NodeType | |||||
| class InputNode(GraphNode): | class InputNode(GraphNode): | ||||
| @@ -25,7 +25,6 @@ class InputNode(GraphNode): | |||||
| Args: | Args: | ||||
| input_shape: Input shape of module. | input_shape: Input shape of module. | ||||
| """ | """ | ||||
| def _get_arg_name(self, arg, variable_name): | def _get_arg_name(self, arg, variable_name): | ||||
| @@ -17,12 +17,12 @@ from importlib import import_module | |||||
| from typing import Dict, NoReturn | from typing import Dict, NoReturn | ||||
| from mindinsight.mindconverter.common.log import logger as log | from mindinsight.mindconverter.common.log import logger as log | ||||
| from .base import Graph | |||||
| from .input_node import InputNode | |||||
| from .onnx_graph_node import OnnxGraphNode | |||||
| from .pytorch_graph_parser import PyTorchGraphParser | |||||
| from .tf_graph_parser import TFGraphParser | |||||
| from .onnx_utils import OnnxDataLoader | |||||
| from mindinsight.mindconverter.graph_based_converter.third_party_graph.base import Graph | |||||
| from mindinsight.mindconverter.graph_based_converter.third_party_graph.input_node import InputNode | |||||
| from mindinsight.mindconverter.graph_based_converter.third_party_graph.onnx_graph_node import OnnxGraphNode | |||||
| from mindinsight.mindconverter.graph_based_converter.third_party_graph.pytorch_graph_parser import PyTorchGraphParser | |||||
| from mindinsight.mindconverter.graph_based_converter.third_party_graph.tf_graph_parser import TFGraphParser | |||||
| from mindinsight.mindconverter.graph_based_converter.third_party_graph.onnx_utils import OnnxDataLoader, NodeWeight | |||||
| NONE_SCOPE_OP = { | NONE_SCOPE_OP = { | ||||
| "onnx::Add": "Add", | "onnx::Add": "Add", | ||||
| @@ -126,7 +126,7 @@ class OnnxGraph(Graph): | |||||
| self._shape_dict = model_data.node_output_shape_dict | self._shape_dict = model_data.node_output_shape_dict | ||||
| for ind, (node_name, node) in enumerate(model_data.nodes_dict.items()): | for ind, (node_name, node) in enumerate(model_data.nodes_dict.items()): | ||||
| node_weight = {} | |||||
| node_weights = list() | |||||
| node.scope_name = scope_name_list[ind] | node.scope_name = scope_name_list[ind] | ||||
| inputs = node.input_name_list | inputs = node.input_name_list | ||||
| # check each input from node or tensors | # check each input from node or tensors | ||||
| @@ -135,8 +135,8 @@ class OnnxGraph(Graph): | |||||
| tensor = model_data.tensors_dict[i] | tensor = model_data.tensors_dict[i] | ||||
| t_name = tensor.name | t_name = tensor.name | ||||
| t_value = tensor.to_array() | t_value = tensor.to_array() | ||||
| node_weight[t_name] = t_value | |||||
| self._nodes_collection[node_name] = OnnxGraphNode(node, node_weight) | |||||
| node_weights.append(NodeWeight(t_name, t_value)) | |||||
| self._nodes_collection[node_name] = OnnxGraphNode(node, node_weights) | |||||
| self._nodes_record[node_name] = node_name | self._nodes_record[node_name] = node_name | ||||
| for nd_ipt_name in node.precursor_onnx_node_dict: | for nd_ipt_name in node.precursor_onnx_node_dict: | ||||
| @@ -1,4 +1,4 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. | |||||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd.All Rights Reserved. | |||||
| # | # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | # you may not use this file except in compliance with the License. | ||||
| @@ -15,11 +15,11 @@ | |||||
| """Define ONNX graph node.""" | """Define ONNX graph node.""" | ||||
| from importlib import import_module | from importlib import import_module | ||||
| from .base import GraphNode | |||||
| from ..common.utils import is_converted | |||||
| from mindinsight.mindconverter.graph_based_converter.third_party_graph.base import GraphNode | |||||
| from mindinsight.mindconverter.graph_based_converter.common.utils import is_converted | |||||
| from ..constant import NodeType, SEPARATOR_IN_SCOPE, SEPARATOR_BTW_NAME_AND_ID, LEFT_BUCKET, RIGHT_BUCKET, \ | |||||
| SEPARATOR_IN_ONNX_OP | |||||
| from mindinsight.mindconverter.graph_based_converter.constant import NodeType, SEPARATOR_IN_SCOPE, \ | |||||
| SEPARATOR_BTW_NAME_AND_ID, LEFT_BUCKET, RIGHT_BUCKET, SEPARATOR_IN_ONNX_OP | |||||
| class OnnxGraphNode(GraphNode): | class OnnxGraphNode(GraphNode): | ||||
| @@ -28,7 +28,7 @@ class OnnxGraphNode(GraphNode): | |||||
| Args: | Args: | ||||
| node (OnnxNode): OnnxNode Object. | node (OnnxNode): OnnxNode Object. | ||||
| weight (dict): Dictionary records weight and bias. | |||||
| weight (list): List of recording node weights. | |||||
| """ | """ | ||||
| _type_frozen = False | _type_frozen = False | ||||
| _module_name_frozen = False | _module_name_frozen = False | ||||
| @@ -227,7 +227,6 @@ class OnnxGraphNode(GraphNode): | |||||
| Args: | Args: | ||||
| src_arg (str): Original arg name. | src_arg (str): Original arg name. | ||||
| tgt_arg (str): Target arg name. | tgt_arg (str): Target arg name. | ||||
| """ | """ | ||||
| self._args_in_code[src_arg] = tgt_arg | self._args_in_code[src_arg] = tgt_arg | ||||
| @@ -22,13 +22,13 @@ from typing import Union | |||||
| import numpy as np | import numpy as np | ||||
| from mindinsight.mindconverter.common.log import logger as log | from mindinsight.mindconverter.common.log import logger as log | ||||
| from ..common.utils import fetch_output_from_onnx_model | |||||
| from ..common.global_context import GlobalContext | |||||
| from .optimizer import OnnxSimplify | |||||
| from mindinsight.mindconverter.graph_based_converter.common.utils import fetch_output_from_onnx_model | |||||
| from mindinsight.mindconverter.graph_based_converter.common.global_context import GlobalContext | |||||
| from mindinsight.mindconverter.graph_based_converter.third_party_graph.optimizer import OnnxSimplify | |||||
| from ..constant import ONNX_TYPE_INT, ONNX_TYPE_INTS, ONNX_TYPE_STRING, \ | |||||
| from mindinsight.mindconverter.graph_based_converter.constant import ONNX_TYPE_INT, ONNX_TYPE_INTS, ONNX_TYPE_STRING, \ | |||||
| ONNX_TYPE_FLOATS, ONNX_TYPE_FLOAT, SCALAR_WITHOUT_SHAPE, DYNAMIC_SHAPE, UNKNOWN_DIM_VAL | ONNX_TYPE_FLOATS, ONNX_TYPE_FLOAT, SCALAR_WITHOUT_SHAPE, DYNAMIC_SHAPE, UNKNOWN_DIM_VAL | ||||
| from ...common.exceptions import GraphInitError, ModelLoadingError | |||||
| from mindinsight.mindconverter.common.exceptions import GraphInitError, ModelLoadingError | |||||
| def convert_tf_graph_to_onnx(model_path, model_inputs, model_outputs, opset=12): | def convert_tf_graph_to_onnx(model_path, model_inputs, model_outputs, opset=12): | ||||
| @@ -128,7 +128,6 @@ class ParamsAttribute: | |||||
| raw_attribute (onnx.AttributeProto): onnx.AttributeProto instance. | raw_attribute (onnx.AttributeProto): onnx.AttributeProto instance. | ||||
| node (onnx.NodeProto): Must pass the onnx.NodeProto instance | node (onnx.NodeProto): Must pass the onnx.NodeProto instance | ||||
| containing the same AttributeProto. | containing the same AttributeProto. | ||||
| """ | """ | ||||
| def __init__(self, raw_attribute, node): | def __init__(self, raw_attribute, node): | ||||
| @@ -148,7 +147,6 @@ class ParamsAttribute: | |||||
| Args: | Args: | ||||
| attrs (onnx.AttributeProto): onnx.AttributeProto instance. | attrs (onnx.AttributeProto): onnx.AttributeProto instance. | ||||
| """ | """ | ||||
| if not attrs: | if not attrs: | ||||
| return | return | ||||
| @@ -604,3 +602,18 @@ class OnnxDataLoader: | |||||
| eliminated_nodes = _traceback_precursor_nodes_until_shape_op(to_shape) | eliminated_nodes = _traceback_precursor_nodes_until_shape_op(to_shape) | ||||
| self.dynamic_resize_node.append(nd_name) | self.dynamic_resize_node.append(nd_name) | ||||
| self.eliminated_nodes += eliminated_nodes | self.eliminated_nodes += eliminated_nodes | ||||
| class NodeWeight: | |||||
| """Node weight struct.""" | |||||
| def __init__(self, weight_name, weight_value): | |||||
| self._weight_name = weight_name | |||||
| self._weight_value = weight_value | |||||
| @property | |||||
| def name(self): | |||||
| return self._weight_name | |||||
| @property | |||||
| def value(self): | |||||
| return self._weight_value | |||||
| @@ -18,7 +18,7 @@ from importlib import import_module | |||||
| import numpy as np | import numpy as np | ||||
| from ..common.utils import fetch_output_from_onnx_model | |||||
| from mindinsight.mindconverter.graph_based_converter.common.utils import fetch_output_from_onnx_model | |||||
| class OnnxSimplify: | class OnnxSimplify: | ||||
| @@ -20,6 +20,7 @@ from mindinsight.mindconverter.common.log import logger as log | |||||
| from mindinsight.mindconverter.graph_based_converter.third_party_graph.base import GraphParser | from mindinsight.mindconverter.graph_based_converter.third_party_graph.base import GraphParser | ||||
| from mindinsight.mindconverter.common.exceptions import ModelNotSupportError | from mindinsight.mindconverter.common.exceptions import ModelNotSupportError | ||||
| class PyTorchGraphParser(GraphParser): | class PyTorchGraphParser(GraphParser): | ||||
| """Define pytorch graph parser.""" | """Define pytorch graph parser.""" | ||||
| @@ -1,4 +1,4 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. | |||||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd.All Rights Reserved. | |||||
| # | # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | # you may not use this file except in compliance with the License. | ||||
| @@ -18,8 +18,8 @@ import re | |||||
| from importlib import import_module | from importlib import import_module | ||||
| from mindinsight.mindconverter.common.log import logger as log | from mindinsight.mindconverter.common.log import logger as log | ||||
| from .base import GraphParser | |||||
| from ...common.exceptions import ModelNotSupportError | |||||
| from mindinsight.mindconverter.graph_based_converter.third_party_graph.base import GraphParser | |||||
| from mindinsight.mindconverter.common.exceptions import ModelNotSupportError | |||||
| class TFGraphParser(GraphParser): | class TFGraphParser(GraphParser): | ||||
| @@ -89,7 +89,9 @@ class TestMappers: | |||||
| 'input': {'op_name': 'onnx::MaxPool', | 'input': {'op_name': 'onnx::MaxPool', | ||||
| 'params': {'kernel_shape': [3, 3], | 'params': {'kernel_shape': [3, 3], | ||||
| 'pads': [1, 1, 1, 1], | 'pads': [1, 1, 1, 1], | ||||
| 'strides': [2, 2]}, | |||||
| 'strides': [2, 2], | |||||
| 'input_shape': (1, 3, 224, 224), | |||||
| 'output_shape': (1, 3, 112, 112)}, | |||||
| 'weights': dict()}, | 'weights': dict()}, | ||||
| 'expected_output': {'converter_name': 'nn.MaxPool2d', | 'expected_output': {'converter_name': 'nn.MaxPool2d', | ||||
| 'converted_params': {'kernel_size': (3, 3), | 'converted_params': {'kernel_size': (3, 3), | ||||
| @@ -100,7 +102,9 @@ class TestMappers: | |||||
| 'input': {'op_name': 'onnx::AveragePool', | 'input': {'op_name': 'onnx::AveragePool', | ||||
| 'params': {'kernel_shape': [3, 3], | 'params': {'kernel_shape': [3, 3], | ||||
| 'pads': [1, 1, 1, 1], | 'pads': [1, 1, 1, 1], | ||||
| 'strides': [2, 2]}, | |||||
| 'strides': [2, 2], | |||||
| 'input_shape': (1, 3, 224, 224), | |||||
| 'output_shape': (1, 3, 112, 112)}, | |||||
| 'weights': dict()}, | 'weights': dict()}, | ||||
| 'expected_output': {'converter_name': 'nn.AvgPool2d', | 'expected_output': {'converter_name': 'nn.AvgPool2d', | ||||
| 'converted_params': {'kernel_size': (3, 3), | 'converted_params': {'kernel_size': (3, 3), | ||||
| @@ -0,0 +1,26 @@ | |||||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """Mock the MindSpore mindspore/train/serialization.py.""" | |||||
| def save_checkpoint(trainable_weights, ckpt_file_name): | |||||
| """ | |||||
| Mock save_checkpoint. | |||||
| Args: | |||||
| trainable_weights (list): List of weights. | |||||
| ckpt_file_name (str): Path to save checkpoint file. | |||||
| """ | |||||
| return len(trainable_weights), ckpt_file_name | |||||