diff --git a/mindinsight/mindconverter/common/exceptions.py b/mindinsight/mindconverter/common/exceptions.py index 92ac6101..a33e2d1c 100644 --- a/mindinsight/mindconverter/common/exceptions.py +++ b/mindinsight/mindconverter/common/exceptions.py @@ -1,4 +1,4 @@ -# Copyright 2019 Huawei Technologies Co., Ltd +# Copyright 2020-2021 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -301,6 +301,8 @@ class SourceFilesSaveError(MindConverterException): NODE_INPUT_TYPE_NOT_SUPPORT = 1 SCRIPT_GENERATE_FAIL = 2 REPORT_GENERATE_FAIL = 3 + CKPT_GENERATE_FAIL = 4 + MAP_GENERATE_FAIL = 5 BASE_ERROR_CODE = ConverterErrors.SOURCE_FILES_SAVE_FAIL.value ERROR_CODE = ErrCode.UNKNOWN_ERROR.value @@ -315,6 +317,8 @@ class SourceFilesSaveError(MindConverterException): except_source = (NodeInputTypeNotSupportError, ScriptGenerationError, ReportGenerationError, + CheckPointGenerationError, + WeightMapGenerationError, IOError, cls) return except_source @@ -437,6 +441,32 @@ class ReportGenerationError(SourceFilesSaveError): return ZeroDivisionError, cls +class CheckPointGenerationError(SourceFilesSaveError): + """The checkpoint generate fail error.""" + ERROR_CODE = SourceFilesSaveError.ErrCode.CKPT_GENERATE_FAIL.value + + def __init__(self, msg): + super(CheckPointGenerationError, self).__init__(msg=msg) + + @classmethod + def raise_from(cls): + """Raise from exceptions below.""" + return cls + + +class WeightMapGenerationError(SourceFilesSaveError): + """The weight names map generate fail error.""" + ERROR_CODE = SourceFilesSaveError.ErrCode.MAP_GENERATE_FAIL.value + + def __init__(self, msg): + super(WeightMapGenerationError, self).__init__(msg=msg) + + @classmethod + def raise_from(cls): + """Raise from exception below.""" + return cls + + class SubGraphSearchingError(MindConverterException): """Sub-graph searching exception.""" diff --git a/mindinsight/mindconverter/graph_based_converter/__init__.py b/mindinsight/mindconverter/graph_based_converter/__init__.py index 888b034f..c25f490f 100644 --- a/mindinsight/mindconverter/graph_based_converter/__init__.py +++ b/mindinsight/mindconverter/graph_based_converter/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. +# Copyright 2020-2021 Huawei Technologies Co., Ltd.All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -15,5 +15,5 @@ """Graph based scripts converter definition.""" __all__ = ["graph_based_converter_pytorch_to_ms", "graph_based_converter_tf_to_ms"] -from .framework import graph_based_converter_pytorch_to_ms -from .framework import graph_based_converter_tf_to_ms +from mindinsight.mindconverter.graph_based_converter.framework import graph_based_converter_pytorch_to_ms +from mindinsight.mindconverter.graph_based_converter.framework import graph_based_converter_tf_to_ms diff --git a/mindinsight/mindconverter/graph_based_converter/common/code_fragment.py b/mindinsight/mindconverter/graph_based_converter/common/code_fragment.py index 108d543e..9d2fbd0c 100644 --- a/mindinsight/mindconverter/graph_based_converter/common/code_fragment.py +++ b/mindinsight/mindconverter/graph_based_converter/common/code_fragment.py @@ -191,18 +191,23 @@ class CodeFragment(Fragment): """ def __init__(self, operation, actual_args, settings, input_shape, output_shape, - trainable_params=None): + trainable_params=None, trainable_weights=None): super(CodeFragment, self).__init__(operation=operation, actual_args=actual_args, input_shape=input_shape, output_shape=output_shape, settings=settings) self._trainable_params = dict() # External weights, like Matmul. self._init_trainable_params = trainable_params # Can put into operation init method, like Conv2d. + self._trainable_weights = trainable_weights @property def trainable_params(self): """Return the trainable parameters.""" return self._trainable_params + @property + def trainable_weights(self): + return self._trainable_weights + class ModuleFragment(Fragment): """Manage module type code variables.""" diff --git a/mindinsight/mindconverter/graph_based_converter/common/global_context.py b/mindinsight/mindconverter/graph_based_converter/common/global_context.py index 0ab74cf4..212c72d5 100644 --- a/mindinsight/mindconverter/graph_based_converter/common/global_context.py +++ b/mindinsight/mindconverter/graph_based_converter/common/global_context.py @@ -14,7 +14,7 @@ # ============================================================================== """Define GlobalContext class to save required resources during whole conversion procedure.""" from collections import OrderedDict -from .outputs import OutputStorage +from mindinsight.mindconverter.graph_based_converter.common.outputs import OutputStorage class Singleton(type): diff --git a/mindinsight/mindconverter/graph_based_converter/common/utils.py b/mindinsight/mindconverter/graph_based_converter/common/utils.py index 1a4d7a6a..49665b6b 100644 --- a/mindinsight/mindconverter/graph_based_converter/common/utils.py +++ b/mindinsight/mindconverter/graph_based_converter/common/utils.py @@ -13,16 +13,21 @@ # limitations under the License. # ============================================================================ """Define common utils.""" +import json import os import stat from importlib import import_module +from importlib.util import find_spec from typing import List, Tuple, Mapping -from mindinsight.mindconverter.common.exceptions import ScriptGenerationError, ReportGenerationError, UnknownModelError +from mindinsight.mindconverter.common.exceptions import ScriptGenerationError, ReportGenerationError, \ + UnknownModelError, CheckPointGenerationError, WeightMapGenerationError from mindinsight.mindconverter.common.log import logger as log from mindinsight.mindconverter.graph_based_converter.constant import SEPARATOR_IN_ONNX_OP, BINARY_HEADER_PYTORCH_BITS, \ FrameworkType, BINARY_HEADER_PYTORCH_FILE, TENSORFLOW_MODEL_SUFFIX +from mindspore.train.serialization import save_checkpoint + def is_converted(operation: str): """ @@ -96,7 +101,6 @@ def save_code_file_and_report(model_name: str, code_lines: Mapping[str, Tuple], code_lines (dict): Code lines. out_folder (str): Output folder. report_folder (str): Report output folder. - """ flags = os.O_WRONLY | os.O_CREAT | os.O_EXCL modes = stat.S_IRUSR | stat.S_IWUSR @@ -114,7 +118,7 @@ def save_code_file_and_report(model_name: str, code_lines: Mapping[str, Tuple], os.makedirs(report_folder, modes_usr) for file_name in code_lines: - code, report = code_lines[file_name] + code, report, trainable_weights, weight_map = code_lines[file_name] code_file_path = os.path.realpath(os.path.join(out_folder, f"{model_name}.py")) report_file_path = os.path.realpath(os.path.join(report_folder, f"report_of_{model_name}.txt")) try: @@ -133,6 +137,31 @@ def save_code_file_and_report(model_name: str, code_lines: Mapping[str, Tuple], except (IOError, FileExistsError) as error: raise ReportGenerationError(str(error)) + ckpt_file_path = os.path.realpath(os.path.join(out_folder, f"{model_name}.ckpt")) + try: + if os.path.exists(ckpt_file_path): + raise CheckPointGenerationError("Checkpoint file with the same name already exists.") + save_checkpoint(trainable_weights, ckpt_file_path) + except TypeError as error: + raise CheckPointGenerationError(str(error)) + + weight_map_path = os.path.realpath(os.path.join(out_folder, f"weight_map_of_{model_name}.json")) + try: + if os.path.exists(weight_map_path): + raise WeightMapGenerationError("Weight map file with the same name already exists.") + with os.fdopen(os.open(weight_map_path, flags, stat.S_IRUSR), 'w') as map_f: + weight_map_json = {f"{model_name}": weight_map} + json.dump(weight_map_json, map_f) + except (IOError, FileExistsError) as error: + raise WeightMapGenerationError(str(error)) + + +def onnx_satisfied(): + """Validate ONNX , ONNXRUNTIME, ONNXOPTIMIZER installation.""" + if not find_spec("onnx") or not find_spec("onnxruntime") or not find_spec("onnxoptimizer"): + return False + return True + def lib_version_satisfied(current_ver: str, mini_ver_limited: str, newest_ver_limited: str = ""): @@ -220,6 +249,7 @@ def reset_init_or_construct(template, variable_slot, new_data, scope): template[variable_slot][scope] += new_data return template + def replace_string_in_list(str_list: list, original_str: str, target_str: str): """ Replace a string in a list by provided string. diff --git a/mindinsight/mindconverter/graph_based_converter/constant.py b/mindinsight/mindconverter/graph_based_converter/constant.py index 7f138b2e..330ef767 100644 --- a/mindinsight/mindconverter/graph_based_converter/constant.py +++ b/mindinsight/mindconverter/graph_based_converter/constant.py @@ -41,6 +41,7 @@ UNKNOWN_DIM_VAL = "unk__001" ONNX_MIN_VER = "1.8.0" TF2ONNX_MIN_VER = "1.7.1" ONNXRUNTIME_MIN_VER = "1.5.2" +ONNXOPTIMIZER_MIN_VER = "0.1.2" @unique diff --git a/mindinsight/mindconverter/graph_based_converter/framework.py b/mindinsight/mindconverter/graph_based_converter/framework.py index aaeb0d99..804fdd35 100644 --- a/mindinsight/mindconverter/graph_based_converter/framework.py +++ b/mindinsight/mindconverter/graph_based_converter/framework.py @@ -21,10 +21,10 @@ from importlib.util import find_spec import mindinsight from mindinsight.mindconverter.graph_based_converter.common.global_context import GlobalContext -from mindinsight.mindconverter.graph_based_converter.common.utils import lib_version_satisfied, \ +from mindinsight.mindconverter.graph_based_converter.common.utils import lib_version_satisfied, onnx_satisfied, \ save_code_file_and_report, get_framework_type from mindinsight.mindconverter.graph_based_converter.constant import FrameworkType, \ - ONNX_MIN_VER, TF2ONNX_MIN_VER, ONNXRUNTIME_MIN_VER + ONNX_MIN_VER, TF2ONNX_MIN_VER, ONNXRUNTIME_MIN_VER, ONNXOPTIMIZER_MIN_VER from mindinsight.mindconverter.graph_based_converter.generator import batch_add_nodes from mindinsight.mindconverter.graph_based_converter.mapper import ONNXToMindSporeMapper from mindinsight.mindconverter.common.log import logger as log, logger_console as log_console @@ -53,6 +53,18 @@ parser.add_argument("--report", type=str, required=False, help="Generated reports output folder path.") +def onnx_lib_version_satisfied(): + """Check onnx libs version whether is satisfied.""" + onnx = import_module("onnx") + ort = import_module("onnxruntime") + optimizer = import_module("onnxoptimizer.version") + if not lib_version_satisfied(getattr(onnx, "__version__"), ONNX_MIN_VER) \ + or not lib_version_satisfied(getattr(ort, "__version__"), ONNXRUNTIME_MIN_VER) \ + or not lib_version_satisfied(getattr(optimizer, "version"), ONNXOPTIMIZER_MIN_VER): + return False + return True + + def torch_installation_validation(func): """ Validate args of func. @@ -68,26 +80,33 @@ def torch_installation_validation(func): input_nodes: str, output_nodes: str, output_folder: str, report_folder: str = None): # Check whether pytorch is installed. - if not find_spec("torch") or not find_spec("onnx") or not find_spec("onnxruntime"): - error = RuntimeIntegrityError(f"PyTorch, onnx(>={ONNX_MIN_VER}) and " - f"onnxruntime(>={ONNXRUNTIME_MIN_VER}) " - f"are required when using graph based " - f"scripts converter, and PyTorch version must " - f"be consisted with model generation runtime.") + error_info = None + if graph_path.endswith('.onnx'): + if not onnx_satisfied(): + error_info = f"onnx(>={ONNX_MIN_VER}, onnxruntime(>={ONNXRUNTIME_MIN_VER}) and " \ + f"onnxoptimizer(>={ONNXOPTIMIZER_MIN_VER}) " \ + f"are required when using graph based scripts converter." + else: + if not find_spec("torch") or not onnx_satisfied(): + error_info = f"PyTorch, onnx(>={ONNX_MIN_VER}), " \ + f"onnxruntime(>={ONNXRUNTIME_MIN_VER}) and " \ + f"onnxoptimizer(>={ONNXOPTIMIZER_MIN_VER}) " \ + f"are required when using graph based " \ + f"scripts converter, and PyTorch version must " \ + f"be consisted with model generation runtime." + if error_info: + error = RuntimeIntegrityError(error_info) log.error(error) log_console.error("\n") log_console.error(str(error)) log_console.error("\n") sys.exit(0) - onnx = import_module("onnx") - ort = import_module("onnxruntime") - - if not lib_version_satisfied(getattr(onnx, "__version__"), ONNX_MIN_VER) \ - or not lib_version_satisfied(getattr(ort, "__version__"), ONNXRUNTIME_MIN_VER): + if not onnx_lib_version_satisfied(): error = RuntimeIntegrityError( - f"onnx(>={ONNX_MIN_VER}) and " - f"onnxruntime(>={ONNXRUNTIME_MIN_VER}) are required when using graph " + f"onnx(>={ONNX_MIN_VER}), " + f"onnxruntime(>={ONNXRUNTIME_MIN_VER}) and " + f"onnxoptimizer(>={ONNXOPTIMIZER_MIN_VER}) are required when using graph " f"based scripts converter for Pytorch conversion." ) log.error(error) @@ -128,11 +147,11 @@ def tf_installation_validation(func): output_folder: str, report_folder: str = None, input_nodes: str = None, output_nodes: str = None): # Check whether tensorflow is installed. - if not _check_tf_installation() or not find_spec("tf2onnx") \ - or not find_spec("onnx") or not find_spec("onnxruntime"): + if not _check_tf_installation() or not onnx_satisfied(): error = RuntimeIntegrityError( - f"TensorFlow, tf2onnx(>={TF2ONNX_MIN_VER}), onnx(>={ONNX_MIN_VER}) and " - f"onnxruntime(>={ONNXRUNTIME_MIN_VER}) are required when using graph " + f"TensorFlow, tf2onnx(>={TF2ONNX_MIN_VER}), onnx(>={ONNX_MIN_VER}), " + f"onnxruntime(>={ONNXRUNTIME_MIN_VER}) and onnxoptimizer(>={ONNXOPTIMIZER_MIN_VER}) " + f"are required when using graph " f"based scripts converter for TensorFlow conversion." ) log.error(error) @@ -141,15 +160,14 @@ def tf_installation_validation(func): log_console.error("\n") sys.exit(0) - onnx, tf2onnx = import_module("onnx"), import_module("tf2onnx") - ort = import_module("onnxruntime") + tf2onnx = import_module("tf2onnx") - if not lib_version_satisfied(getattr(onnx, "__version__"), ONNX_MIN_VER) \ - or not lib_version_satisfied(getattr(ort, "__version__"), ONNXRUNTIME_MIN_VER) \ - or not lib_version_satisfied(getattr(tf2onnx, "__version__"), TF2ONNX_MIN_VER): + if not lib_version_satisfied(getattr(tf2onnx, "__version__"), TF2ONNX_MIN_VER) \ + or not onnx_lib_version_satisfied(): error = RuntimeIntegrityError( - f"TensorFlow, tf2onnx(>={TF2ONNX_MIN_VER}), onnx(>={ONNX_MIN_VER}) and " - f"onnxruntime(>={ONNXRUNTIME_MIN_VER}) are required when using graph " + f"TensorFlow, tf2onnx(>={TF2ONNX_MIN_VER}), onnx(>={ONNX_MIN_VER}), " + f"onnxruntime(>={ONNXRUNTIME_MIN_VER}) and onnxoptimizer(>={ONNXOPTIMIZER_MIN_VER}) " + f"are required when using graph " f"based scripts converter for TensorFlow conversion." ) log.error(error) @@ -258,12 +276,12 @@ def main_graph_base_converter(file_config): raise ParamMissingError("Param missing, `--shape` is required when using graph mode.") if frame_type == FrameworkType.PYTORCH.value: - check_params = ['input_nodes', 'output_nodes'] - check_params_exist(check_params, file_config) graph_based_converter_pytorch_to_ms(graph_path=graph_path, sample_shape=file_config['shape'], - input_nodes=file_config['input_nodes'], - output_nodes=file_config['output_nodes'], + input_nodes=file_config['input_nodes'] if file_config['input_nodes'] + else 'input.1', + output_nodes=file_config['output_nodes'] if file_config['output_nodes'] + else '', output_folder=file_config['outfile_dir'], report_folder=file_config['report_dir']) elif frame_type == FrameworkType.TENSORFLOW.value: diff --git a/mindinsight/mindconverter/graph_based_converter/generator/__init__.py b/mindinsight/mindconverter/graph_based_converter/generator/__init__.py index dafdfc96..849b4f3e 100644 --- a/mindinsight/mindconverter/graph_based_converter/generator/__init__.py +++ b/mindinsight/mindconverter/graph_based_converter/generator/__init__.py @@ -18,10 +18,10 @@ __all__ = ["batch_add_nodes"] import re import copy -from .generator import Generator, CodeStruct -from ..common.code_fragment import CodeFragment, NewFragment -from ..common.outputs import NodeOutputManager -from ..constant import ExchangeMessageKeywords +from mindinsight.mindconverter.graph_based_converter.generator.generator import Generator, CodeStruct +from mindinsight.mindconverter.graph_based_converter.common.code_fragment import NewFragment +from mindinsight.mindconverter.graph_based_converter.common.outputs import NodeOutputManager +from mindinsight.mindconverter.graph_based_converter.constant import ExchangeMessageKeywords def _tf_model_node_name_reformat(node, node_name): diff --git a/mindinsight/mindconverter/graph_based_converter/generator/generator.py b/mindinsight/mindconverter/graph_based_converter/generator/generator.py index 6bdb7bb2..154f6860 100644 --- a/mindinsight/mindconverter/graph_based_converter/generator/generator.py +++ b/mindinsight/mindconverter/graph_based_converter/generator/generator.py @@ -16,6 +16,7 @@ import copy from collections import OrderedDict +from mindspore import Tensor from yapf.yapflib.yapf_api import FormatCode from mindinsight.mindconverter.common.exceptions import GeneratorError @@ -28,7 +29,7 @@ from mindinsight.mindconverter.graph_based_converter.common.outputs import BaseO from mindinsight.mindconverter.graph_based_converter.common.yapf_config import mindspore_yapf_config from mindinsight.mindconverter.graph_based_converter.common.name_mgr import GlobalVarNameMgr from mindinsight.mindconverter.graph_based_converter.constant import NEW_LINE, SECOND_LEVEL_INDENT, \ - FIRST_LEVEL_INDENT, get_imported_module + FIRST_LEVEL_INDENT, get_imported_module, SEPARATOR_BTW_NAME_AND_ID from mindinsight.mindconverter.graph_based_converter.report_generator import ReportGenerator from mindinsight.mindconverter.graph_based_converter.common.utils import replace_string_in_list @@ -469,6 +470,74 @@ class Generator: """Return all ModuleStructs in this model.""" return self._module_struct_collections + def generate_weight_scope_name(self, node): + """Generate weight scope name for checkpoint.""" + replaced_module_dict = self.node_structs[node].global_context_mgr.known_module_name + scope_list = self.node_structs[node].scope.scope_list + ms_var_name = self.node_structs[node].ms_var_name + + weight_scope_name = None + for scope in scope_list[1:]: + replaced_module = replaced_module_dict.get(scope.split(SEPARATOR_BTW_NAME_AND_ID)[0]) + if replaced_module: + scope = scope.replace(scope.split(SEPARATOR_BTW_NAME_AND_ID)[0], replaced_module) + if not weight_scope_name: + weight_scope_name = scope + else: + weight_scope_name = '.'.join((weight_scope_name, scope)) + + if not weight_scope_name: + weight_scope_name = ms_var_name + else: + weight_scope_name = '.'.join((weight_scope_name, ms_var_name)) + + return weight_scope_name.lower() + + def generate_checkpoint(self): + """Generate checkpoint.""" + + trainable_weights_dict = dict() + weight_map = list() + for node_name, node_inst in self.node_structs.items(): + if node_inst.fragment.exchange_msg['var_0']['trainable_params']: + weights_scope_name = self.generate_weight_scope_name(node_name) + onnx_weight_inst = node_inst.fragment.exchange_msg['var_0']['weights'] + for idx, (weight_key, weight_value) in \ + enumerate(node_inst.fragment.exchange_msg['var_0']['trainable_params'].items()): + weight_name = '.'.join((weights_scope_name, weight_key)) + weight_shape = Tensor(weight_value).shape + data_type = Tensor(weight_value).dtype + trainable_weights_dict[weight_name] = weight_value + + onnx_weight_name = onnx_weight_inst[idx].name + onnx_weight_shape = onnx_weight_inst[idx].value.shape + onnx_data_type = onnx_weight_inst[idx].value.dtype + + weight_map.append( + { + 'converted_weight': { + 'name': weight_name, + 'shape': weight_shape, + 'data_type': str(data_type) + }, + 'source_weight': { + 'name': onnx_weight_name, + 'shape': onnx_weight_shape, + 'data_type': str(onnx_data_type) + } + } + ) + + save_obj = list() + for weight_name, weight_value in trainable_weights_dict.items(): + obj = { + 'name': weight_name, + 'data': Tensor(weight_value) + } + save_obj.append(obj) + + return save_obj, weight_map + @GeneratorError.check_except("Generator occurs an error when generating code statements.") def generate(self): """ @@ -479,6 +548,9 @@ class Generator: """ self._form_bottom_submodule() self._recursive_form_module() + + ckpt_data_list, weight_map = self.generate_checkpoint() + CodeStruct(self.module_structs.get('[]'), self._repeated_submodules) outputs = [get_imported_module()] @@ -494,7 +566,7 @@ class Generator: report = report_generator.gen_report(formatted_code) del self._global_context - return {"model": (formatted_code, report)} + return {"model": (formatted_code, report, ckpt_data_list, weight_map)} def get_node_struct(self, node_identifier): """ diff --git a/mindinsight/mindconverter/graph_based_converter/generator/module_struct.py b/mindinsight/mindconverter/graph_based_converter/generator/module_struct.py index 22ae05ed..3b366f08 100644 --- a/mindinsight/mindconverter/graph_based_converter/generator/module_struct.py +++ b/mindinsight/mindconverter/graph_based_converter/generator/module_struct.py @@ -17,13 +17,13 @@ import copy from collections import OrderedDict -from .node_struct import NodeStruct -from .scope_utils import Scope -from ..common.utils import get_dict_key_by_value -from .args_translator import ArgsTranslation -from ..common.code_fragment import ModuleFragment -from ..common.global_context import GlobalContext -from ..common.name_mgr import LocalVarNameMgr +from mindinsight.mindconverter.graph_based_converter.generator.node_struct import NodeStruct +from mindinsight.mindconverter.graph_based_converter.generator.scope_utils import Scope +from mindinsight.mindconverter.graph_based_converter.common.utils import get_dict_key_by_value +from mindinsight.mindconverter.graph_based_converter.generator.args_translator import ArgsTranslation +from mindinsight.mindconverter.graph_based_converter.common.code_fragment import ModuleFragment +from mindinsight.mindconverter.graph_based_converter.common.global_context import GlobalContext +from mindinsight.mindconverter.graph_based_converter.common.name_mgr import LocalVarNameMgr class ModuleStruct: diff --git a/mindinsight/mindconverter/graph_based_converter/generator/node_struct.py b/mindinsight/mindconverter/graph_based_converter/generator/node_struct.py index 1b3c0e01..d9707d86 100644 --- a/mindinsight/mindconverter/graph_based_converter/generator/node_struct.py +++ b/mindinsight/mindconverter/graph_based_converter/generator/node_struct.py @@ -17,11 +17,11 @@ from collections import OrderedDict from mindinsight.mindconverter.graph_based_converter.common.code_fragment import NewFragment from mindinsight.mindconverter.graph_based_converter.generator.fragment_utils import FragmentHandler -from .scope_utils import Scope -from .args_translator import ArgsTranslation -from ..third_party_graph.onnx_graph_node import OnnxGraphNode -from ..common.global_context import GlobalContext -from ...common.exceptions import GeneratorError +from mindinsight.mindconverter.graph_based_converter.generator.scope_utils import Scope +from mindinsight.mindconverter.graph_based_converter.generator.args_translator import ArgsTranslation +from mindinsight.mindconverter.graph_based_converter.third_party_graph.onnx_graph_node import OnnxGraphNode +from mindinsight.mindconverter.graph_based_converter.common.global_context import GlobalContext +from mindinsight.mindconverter.common.exceptions import GeneratorError class NodeStruct: diff --git a/mindinsight/mindconverter/graph_based_converter/mapper/__init__.py b/mindinsight/mindconverter/graph_based_converter/mapper/__init__.py index dadf4ad2..f29b23b4 100644 --- a/mindinsight/mindconverter/graph_based_converter/mapper/__init__.py +++ b/mindinsight/mindconverter/graph_based_converter/mapper/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. +# Copyright 2020-2021 Huawei Technologies Co., Ltd.All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -16,4 +16,4 @@ __all__ = ["ONNXToMindSporeMapper"] -from .base import ONNXToMindSporeMapper +from mindinsight.mindconverter.graph_based_converter.mapper.base import ONNXToMindSporeMapper diff --git a/mindinsight/mindconverter/graph_based_converter/mapper/base.py b/mindinsight/mindconverter/graph_based_converter/mapper/base.py index 609546cb..97b190a0 100644 --- a/mindinsight/mindconverter/graph_based_converter/mapper/base.py +++ b/mindinsight/mindconverter/graph_based_converter/mapper/base.py @@ -108,18 +108,21 @@ class ONNXToMindSporeMapper(Mapper, abc.ABC): try: converter_name = op_name_converter(params=params, weights=weights, op_name=op_name) converted_params = params_converter(params=params, weights=weights) + if "input_shape" in converted_params: converted_params.pop("input_shape") if "output_shape" in converted_params: converted_params.pop("output_shape") # set to converted_weights to enable weight migration - _ = weights_converter(weights=weights) if weights else dict() + converted_weights = weights_converter(weights=weights) if weights else dict() code_template, exchange_msg, outputs_list, outputs_mapping = template_generator( operation=converter_name, converted_params=converted_params, raw_params=params, - weights=weights + weights=weights, + trainable_params=converted_weights ) + except (AttributeError, KeyError, ValueError, TypeError, IndexError) as e: err_msg = f"Converting {op_name} failed, see {str(e)}" log.error(err_msg) @@ -148,6 +151,7 @@ class ONNXToMindSporeMapper(Mapper, abc.ABC): op = kwargs.get("operation") args = kwargs.get("converted_params", dict()) weights = kwargs.get("weights") + trainable_params = kwargs.get("trainable_params", dict()) if not op: raise ValueError("Can not get MindSpore operation name.") variable_slot = "var_0" @@ -169,7 +173,7 @@ class ONNXToMindSporeMapper(Mapper, abc.ABC): ExchangeMessageKeywords.VariableScope.value.INPUTS.value: [], ExchangeMessageKeywords.VariableScope.value.ARGS.value: args, ExchangeMessageKeywords.VariableScope.value.WEIGHTS.value: weights, - ExchangeMessageKeywords.VariableScope.value.TRAINABLE_PARAMS.value: {} + ExchangeMessageKeywords.VariableScope.value.TRAINABLE_PARAMS.value: trainable_params } } outputs_list = [f"opt_{{{variable_slot}}}"] @@ -177,11 +181,14 @@ class ONNXToMindSporeMapper(Mapper, abc.ABC): return template, exchange_msg, outputs_list, outputs_mapping @staticmethod - def _find_val_by_index(loc_index, values_dict): - """Find value by location index of values_dict.""" + def _find_val_by_index(loc_index, weights_list): + """Find value by location index of weights_list.""" result = None - for idx, dict_val in enumerate(values_dict.values()): + if loc_index < 0: + return weights_list[loc_index].value + + for idx, weight in enumerate(weights_list): if idx == loc_index: - result = dict_val + result = weight.value break return result diff --git a/mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/batch_norm_mapper.py b/mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/batch_norm_mapper.py index e7c51548..5a4591df 100644 --- a/mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/batch_norm_mapper.py +++ b/mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/batch_norm_mapper.py @@ -14,7 +14,6 @@ # ============================================================================== """Mapper module.""" from mindinsight.mindconverter.graph_based_converter.mapper.base import ONNXToMindSporeMapper -from mindinsight.mindconverter.graph_based_converter.mapper.gen_setting import Setting class BatchNormMapper(ONNXToMindSporeMapper): @@ -36,8 +35,14 @@ class BatchNormMapper(ONNXToMindSporeMapper): @staticmethod def _convert_trained_weights(**kwargs): - return dict() - - @staticmethod - def _convert_settings(**kwargs): - return Setting() + weights = kwargs['weights'] + gamma = BatchNormMapper._find_val_by_index(0, weights) + beta = BatchNormMapper._find_val_by_index(1, weights) + moving_mean = BatchNormMapper._find_val_by_index(2, weights) + moving_variance = BatchNormMapper._find_val_by_index(3, weights) + return { + 'gamma': gamma, + 'beta': beta, + 'moving_mean': moving_mean, + 'moving_variance': moving_variance + } diff --git a/mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/conv_mapper.py b/mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/conv_mapper.py index 3100f555..cdad6167 100644 --- a/mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/conv_mapper.py +++ b/mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/conv_mapper.py @@ -14,7 +14,6 @@ # ============================================================================== """Mapper module.""" import numpy as np -from mindinsight.mindconverter.graph_based_converter.mapper.gen_setting import Setting from mindinsight.mindconverter.graph_based_converter.mapper.base import ONNXToMindSporeMapper from mindinsight.mindconverter.graph_based_converter.common.utils import convert_bytes_string_to_string @@ -42,7 +41,7 @@ class ConvMapper(ONNXToMindSporeMapper): """Convert params from PyTorch to MindSpore""" weights = kwargs['weights'] params = kwargs['params'] - weight = weights['weight'] + weight = ConvMapper._find_val_by_index(0, weights) weight = np.transpose(weight, list(range(2, weight.ndim)) + [1, 0]) if isinstance(params['dilations'], list): dilation = tuple(params['dilations']) @@ -76,11 +75,13 @@ class ConvMapper(ONNXToMindSporeMapper): """Convert params from Tensorflow to MindSpore""" weights = kwargs['weights'] params = kwargs['params'] - # regex to find Conv weight - weight = list(weights.values())[0] + weight = ConvMapper._find_val_by_index(0, weights) + bias = ConvMapper._find_val_by_index(1, weights) if weight is None: raise ValueError("Conv. Mapper cannot get the weight.") + has_bias = isinstance(bias, np.ndarray) + auto_pad = None if params.get("auto_pad") is not None: auto_pad = convert_bytes_string_to_string(params.get("auto_pad")) @@ -119,18 +120,14 @@ class ConvMapper(ONNXToMindSporeMapper): 'padding': padding, 'pad_mode': pad_mode, 'dilation': dilation, - 'group': params.get('group', 1)} + 'group': params.get('group', 1), + 'has_bias': has_bias + } @staticmethod def _operation_name_in_ms(*args, **kwargs): - weight = kwargs['weights'].get('weight', 'empty') - - if weight == 'empty': # is from tf - kernel_size = kwargs['params'].get('kernel_shape') - dim = len(kernel_size) - return f"nn.Conv{dim}d" - - dim = weight.ndim - 2 + kernel_size = kwargs['params'].get('kernel_shape') + dim = len(kernel_size) return f"nn.Conv{dim}d" @staticmethod @@ -138,14 +135,16 @@ class ConvMapper(ONNXToMindSporeMapper): weights = kwargs['weights'] params = kwargs['params'] - if weights.get('weight', 'empty') == 'empty': # is from tf - return ConvMapper.convert_params_tf(params=params, weights=weights) - return ConvMapper.convert_params_torch(params=params, weights=weights) + return ConvMapper.convert_params_tf(params=params, weights=weights) @staticmethod def _convert_trained_weights(**kwargs): - return dict() + weights = kwargs['weights'] + weight = ConvMapper._find_val_by_index(0, weights) + bias = ConvMapper._find_val_by_index(1, weights) - @staticmethod - def _convert_settings(**kwargs): - return Setting() + converted_weights = {'weight': weight} + if isinstance(bias, np.ndarray): + converted_weights['bias'] = bias + + return converted_weights diff --git a/mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/dense_mapper.py b/mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/dense_mapper.py index 2c1cb219..f65f5a06 100644 --- a/mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/dense_mapper.py +++ b/mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/dense_mapper.py @@ -15,7 +15,6 @@ """Mapper module.""" import numpy as np from mindinsight.mindconverter.graph_based_converter.mapper.base import ONNXToMindSporeMapper -from mindinsight.mindconverter.graph_based_converter.mapper.gen_setting import Setting class DenseMapper(ONNXToMindSporeMapper): @@ -42,8 +41,10 @@ class DenseMapper(ONNXToMindSporeMapper): @staticmethod def _convert_trained_weights(**kwargs): - return dict() - - @staticmethod - def _convert_settings(**kwargs): - return Setting() + weights = kwargs['weights'] + weight = DenseMapper._find_val_by_index(0, weights) + bias = DenseMapper._find_val_by_index(1, weights) + return { + 'weight': weight, + 'bias': bias + } diff --git a/mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/mat_mul_mapper.py b/mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/mat_mul_mapper.py index 2cf809d7..805fe9d7 100644 --- a/mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/mat_mul_mapper.py +++ b/mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/mat_mul_mapper.py @@ -30,7 +30,9 @@ class MatMulMapper(ONNXToMindSporeMapper): @staticmethod def _convert_trained_weights(**kwargs): - return dict() + weights = kwargs['weights'] + weight = MatMulMapper._find_val_by_index(0, weights) + return {'weight': weight} @staticmethod def _generate_snippet_template(**kwargs): @@ -44,8 +46,7 @@ class MatMulMapper(ONNXToMindSporeMapper): if not weights: return template, exchange_msg, outputs_list, outputs_mapping - weight = list(weights.items())[0] - _, tensor = weight + tensor = MatMulMapper._find_val_by_index(0, weights) variable_slot = "var_0" init_template = f"self.{{{variable_slot}}} = {op}({', '.join(['%s={%s}' % (p, p) for p in args])})" diff --git a/mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/pad_mapper.py b/mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/pad_mapper.py index 6dedee16..027fdedb 100644 --- a/mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/pad_mapper.py +++ b/mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/pad_mapper.py @@ -15,7 +15,6 @@ """Mapper module.""" from mindinsight.mindconverter.graph_based_converter.common.utils import convert_bytes_string_to_string from mindinsight.mindconverter.graph_based_converter.mapper.base import ONNXToMindSporeMapper -from mindinsight.mindconverter.graph_based_converter.mapper.gen_setting import Setting def _padding_format_convert(padding: list): @@ -49,7 +48,7 @@ class PadMapper(ONNXToMindSporeMapper): weights = kwargs.get("weights") params = kwargs.get("params") mode = convert_bytes_string_to_string(params.get('mode', 'constant')) - pads_onnx = params.get("pads") if params.get("pads") else list(weights.values())[0].tolist() + pads_onnx = params.get("pads") if params.get("pads") else PadMapper._find_val_by_index(0, weights).tolist() if mode == 'constant' and params.get('value') is None: if params.get('pads') or weights: if isinstance(pads_onnx, list): @@ -76,7 +75,3 @@ class PadMapper(ONNXToMindSporeMapper): @staticmethod def _convert_trained_weights(**kwargs): return dict() - - @staticmethod - def _convert_settings(**kwargs): - return Setting() diff --git a/mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/pool_mapper.py b/mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/pool_mapper.py index feae93d1..b024f096 100644 --- a/mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/pool_mapper.py +++ b/mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/pool_mapper.py @@ -13,12 +13,16 @@ # limitations under the License. # ============================================================================== """Mapper module.""" +import math + +import numpy as np + from mindinsight.mindconverter.graph_based_converter.mapper.base import ONNXToMindSporeMapper -from mindinsight.mindconverter.graph_based_converter.mapper.gen_setting import Setting +from mindinsight.mindconverter.graph_based_converter.constant import ExchangeMessageKeywords, TemplateKeywords class PoolMapper(ONNXToMindSporeMapper): - """MaxPool mapper.""" + """Pool mapper.""" @staticmethod def _operation_name_in_ms(*args, **kwargs): @@ -35,12 +39,6 @@ class PoolMapper(ONNXToMindSporeMapper): transformed_params = dict() transformed_params["kernel_size"] = tuple(params['kernel_shape']) transformed_params["stride"] = tuple(params['strides']) - if "pads" in params: - if sum(params['pads']) == 0 and not params.get('ceil_mode', None): - pad_mode = '\"valid\"' - else: - pad_mode = '\"same\"' - transformed_params["pad_mode"] = pad_mode return transformed_params @@ -49,5 +47,100 @@ class PoolMapper(ONNXToMindSporeMapper): return dict() @staticmethod - def _convert_settings(**kwargs): - return Setting() + def _get_ms_opt_shape(**kwargs): + """Get output shape in MindSpore.""" + params = kwargs['raw_params'] + input_shape = params['input_shape'] + kernel_shape = params['kernel_shape'] + strides = params['strides'] + dilations = params.get('dilations', (1, 1)) + # For mindspore, + # output_shape[i] = ceil((input_shape[i] - ((kernel_shape[i] - 1) * dilations[i] + 1) + 1) / strides[i]) + ms_opt_shape = np.true_divide(np.subtract(np.array(input_shape[-len(kernel_shape):], dtype=np.float32), + ((np.array(kernel_shape, dtype=np.float32) - 1) * + np.array(dilations, dtype=np.float32) + 1)) + 1, + np.array(strides, dtype=np.float32)).tolist() + ms_opt_shape_ceil = tuple(math.ceil(ms_opt_shape_axis) for ms_opt_shape_axis in ms_opt_shape) + return ms_opt_shape_ceil + + @staticmethod + def _generate_snippet_template(**kwargs): + template, exchange_msg, outputs_list, outputs_mapping = ONNXToMindSporeMapper._generate_snippet_template( + **kwargs) + op = kwargs.get("operation") + args = kwargs.get("converted_params", dict()) + + ms_opt_shape = PoolMapper._get_ms_opt_shape(**kwargs) + tensor_opt_shape = kwargs['raw_params']['output_shape'] + tensor_ipt_shape = kwargs['raw_params']['input_shape'] + kernel_shape = kwargs['raw_params']['kernel_shape'] + dilations = kwargs['raw_params'].get('dilations', (1, 1)) + strides = kwargs['raw_params']['strides'] + onnx_opt_shape = tensor_opt_shape[-len(ms_opt_shape):] + + if np.all(np.array(ms_opt_shape) == np.array(onnx_opt_shape)): + return template, exchange_msg, outputs_list, outputs_mapping + + variable_slot = "var_0" + init_template = f"self.{{{variable_slot}}} = {op}({', '.join(['%s={%s}' % (p, p) for p in args])})" + construct_template = f"opt_{{{variable_slot}}} = self.{{{variable_slot}}}(opt_{{{variable_slot}}})" + + init_template_pad, construct_template_pad, paddings = \ + PoolMapper._generate_pad_init_and_construct(tensor_opt_shape, tensor_ipt_shape, + ms_opt_shape, variable_slot, + kernel_shape, dilations, strides) + + template = { + variable_slot: { + TemplateKeywords.INIT.value: [init_template_pad, init_template], + TemplateKeywords.CONSTRUCT.value: [construct_template_pad, construct_template] + } + } + + args['paddings'] = paddings + + exchange_msg = { + variable_slot: { + ExchangeMessageKeywords.VariableScope.value.OPERATION.value: op, + ExchangeMessageKeywords.VariableScope.value.VARIABLE_NAME.value: None, + ExchangeMessageKeywords.VariableScope.value.OUTPUT_TYPE.value: + ExchangeMessageKeywords.VariableScope.value.TSR_TYPE.value, + ExchangeMessageKeywords.VariableScope.value.INPUTS.value: [], + ExchangeMessageKeywords.VariableScope.value.ARGS.value: args, + ExchangeMessageKeywords.VariableScope.value.WEIGHTS.value: dict(), + ExchangeMessageKeywords.VariableScope.value.TRAINABLE_PARAMS.value: dict() + } + } + + return template, exchange_msg, outputs_list, outputs_mapping + + @staticmethod + def _generate_pad_init_and_construct(tensor_opt_shape, tensor_ipt_shape, + ms_opt_shape, variable_slot, kernel_shape, dilations, strides): + """Generate pad code in init and construct.""" + onnx_opt_shape = tensor_opt_shape[-len(ms_opt_shape):] + onnx_ipt_shape = tensor_ipt_shape[-len(ms_opt_shape):] + + if np.any(np.array(ms_opt_shape) > np.array(onnx_opt_shape)): + raise ValueError(f"ms_opt_shape[{ms_opt_shape}] should be no larger than onnx_opt_shape[{onnx_opt_shape}].") + + # shape_diff[i] = (onnx_opt_shape[i] - 1)*strides[i] - + # (onnx_ipt_shape[i] - ((kernel_shape[i] - 1)*dilations[i] + 1)) + shape_diff = np.subtract((np.array(onnx_opt_shape) - 1)*np.array(strides), + np.subtract(np.array(onnx_ipt_shape), + (np.array(kernel_shape) - 1)*np.array(dilations) + 1)).tolist() + + zero_pad_single = (0, 0) + paddings = [zero_pad_single] + num_zero_pads = len(tensor_opt_shape) - len(ms_opt_shape) + for _ in range(num_zero_pads - 1): + paddings.append(zero_pad_single) + + for axis_diff in shape_diff: + paddings.append((int(axis_diff//2), int(axis_diff//2 + axis_diff % 2))) + + init_template_pad = f"self.pad_{{{variable_slot}}} = nn.Pad(paddings={{paddings}})" + construct_template_pad = f"opt_{{{variable_slot}}} = self.pad_{{{variable_slot}}}" \ + f"({{{ExchangeMessageKeywords.VariableScope.value.INPUTS.value}}})" + + return init_template_pad, construct_template_pad, tuple(paddings) diff --git a/mindinsight/mindconverter/graph_based_converter/mapper/impl/ops/add_mapper.py b/mindinsight/mindconverter/graph_based_converter/mapper/impl/ops/add_mapper.py index bac2973d..db2413d1 100644 --- a/mindinsight/mindconverter/graph_based_converter/mapper/impl/ops/add_mapper.py +++ b/mindinsight/mindconverter/graph_based_converter/mapper/impl/ops/add_mapper.py @@ -44,8 +44,7 @@ class AddMapper(ONNXToMindSporeMapper): if not weights: return template, exchange_msg, outputs_list, outputs_mapping - bias = list(weights.items())[0] - _, tensor = bias + tensor = AddMapper._find_val_by_index(0, weights) variable_slot = "var_0" init_template = f"self.{{{variable_slot}}} = {op}({', '.join(['%s={%s}' % (p, p) for p in args])})" diff --git a/mindinsight/mindconverter/graph_based_converter/mapper/impl/ops/mul_mapper.py b/mindinsight/mindconverter/graph_based_converter/mapper/impl/ops/mul_mapper.py index 5b3c5d1b..f47b5e7a 100644 --- a/mindinsight/mindconverter/graph_based_converter/mapper/impl/ops/mul_mapper.py +++ b/mindinsight/mindconverter/graph_based_converter/mapper/impl/ops/mul_mapper.py @@ -53,8 +53,7 @@ class MulMapper(ONNXToMindSporeMapper): if not weights: return template, exchange_msg, outputs_list, outputs_mapping - weight = list(weights.items())[0] - _, tensor = weight + tensor = MulMapper._find_val_by_index(0, weights) variable_slot = "var_0" init_template = f"self.{{{variable_slot}}} = {op}({', '.join(['%s={%s}' % (p, p) for p in args])})" diff --git a/mindinsight/mindconverter/graph_based_converter/mapper/impl/ops/resize_mapper.py b/mindinsight/mindconverter/graph_based_converter/mapper/impl/ops/resize_mapper.py index 9aa7a3cf..81b4938e 100644 --- a/mindinsight/mindconverter/graph_based_converter/mapper/impl/ops/resize_mapper.py +++ b/mindinsight/mindconverter/graph_based_converter/mapper/impl/ops/resize_mapper.py @@ -60,7 +60,7 @@ class ResizeMapper(ONNXToMindSporeMapper): align_corners = True # Get requested size for resize - size = list(weights.values())[-1][-2:].tolist() + size = ResizeMapper._find_val_by_index(-1, weights)[-2:].tolist() return {"size": tuple(size), "align_corners": align_corners} diff --git a/mindinsight/mindconverter/graph_based_converter/mapper/impl/ops/slice_mapper.py b/mindinsight/mindconverter/graph_based_converter/mapper/impl/ops/slice_mapper.py index 24b27943..481ea272 100644 --- a/mindinsight/mindconverter/graph_based_converter/mapper/impl/ops/slice_mapper.py +++ b/mindinsight/mindconverter/graph_based_converter/mapper/impl/ops/slice_mapper.py @@ -48,7 +48,7 @@ class SliceMapper(ONNXToMindSporeMapper): def _generate_snippet_template(**kwargs): template, exchange_msg, outputs_list, outputs_mapping = ONNXToMindSporeMapper._generate_snippet_template( **kwargs) - weights = list(kwargs.get("weights").values()) # start, end, axis + weights = [weight.value for weight in kwargs.get('weights')] # start, end, axis opt_shape = kwargs["raw_params"]["output_shape"] if not weights: raise ValueError("Cannot get required params from slice.") diff --git a/mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/__init__.py b/mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/__init__.py index c4893f16..5ebea156 100644 --- a/mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/__init__.py +++ b/mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. +# Copyright 2020-2021 Huawei Technologies Co., Ltd.All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -15,4 +15,4 @@ """Searcher of scope name.""" __all__ = ["generate_scope_name"] -from .searcher import generate_scope_name +from mindinsight.mindconverter.graph_based_converter.sub_graph_searcher.searcher import generate_scope_name diff --git a/mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/built_in_pattern.py b/mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/built_in_pattern.py index 7049b74d..aafe7055 100644 --- a/mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/built_in_pattern.py +++ b/mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/built_in_pattern.py @@ -1,4 +1,4 @@ -# Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. +# Copyright 2020-2021 Huawei Technologies Co., Ltd.All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -16,8 +16,8 @@ __all__ = ["BUILT_IN_PATTERN", "register_pattern", "is_built_in_pattern"] -from .common import cal_matching_score -from .pattern import Pattern +from mindinsight.mindconverter.graph_based_converter.sub_graph_searcher.common import cal_matching_score +from mindinsight.mindconverter.graph_based_converter.sub_graph_searcher.pattern import Pattern BUILT_IN_PATTERN = dict() diff --git a/mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/search_path.py b/mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/search_path.py index bc573ae8..ac706c8c 100644 --- a/mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/search_path.py +++ b/mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/search_path.py @@ -1,4 +1,4 @@ -# Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. +# Copyright 2020-2021 Huawei Technologies Co., Ltd.All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -17,12 +17,15 @@ import copy import uuid from typing import Dict, List, Callable, Union from collections import OrderedDict -from .common import context, gen_hash_key, DagGraph, MAX_OUT_DEGREE, cal_matching_score -from .known_module_name import BUILT_IN_MODULE_NAME -from .pattern import Pattern, scope_name_mapping -from .built_in_pattern import BUILT_IN_PATTERN, is_built_in_pattern -from .pattern_fuzzy_matching import pattern_fuzzy_matching -from ..third_party_graph.onnx_utils import OnnxNode, BaseNode +from mindinsight.mindconverter.graph_based_converter.sub_graph_searcher.common import context, gen_hash_key, DagGraph, \ + MAX_OUT_DEGREE, cal_matching_score +from mindinsight.mindconverter.graph_based_converter.sub_graph_searcher.known_module_name import BUILT_IN_MODULE_NAME +from mindinsight.mindconverter.graph_based_converter.sub_graph_searcher.pattern import Pattern, scope_name_mapping +from mindinsight.mindconverter.graph_based_converter.sub_graph_searcher.built_in_pattern import BUILT_IN_PATTERN, \ + is_built_in_pattern +from mindinsight.mindconverter.graph_based_converter.sub_graph_searcher.pattern_fuzzy_matching import \ + pattern_fuzzy_matching +from mindinsight.mindconverter.graph_based_converter.third_party_graph.onnx_utils import OnnxNode, BaseNode module_name_to_src = {} used_module_name = dict() diff --git a/mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/searcher.py b/mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/searcher.py index d61c8df0..ea8ff3c3 100644 --- a/mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/searcher.py +++ b/mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/searcher.py @@ -16,12 +16,15 @@ from queue import PriorityQueue from typing import Dict, List -from .common import context, DagGraph, gen_hash_key, ACCEPTABLE_RESULT_COUNT -from .common import MINI_FREQUENCY, MAX_ITERATION_DEPTH, SATISFIED_SCORE -from ..common.global_context import GlobalContext -from ..third_party_graph.onnx_utils import BaseNode -from .search_path import SearchPath, Pattern, generate_pattern, find_built_in_pattern -from ...common.exceptions import SubGraphSearchingError +from mindinsight.mindconverter.graph_based_converter.sub_graph_searcher.common import context, DagGraph, gen_hash_key, \ + ACCEPTABLE_RESULT_COUNT +from mindinsight.mindconverter.graph_based_converter.sub_graph_searcher.common import MINI_FREQUENCY, \ + MAX_ITERATION_DEPTH, SATISFIED_SCORE +from mindinsight.mindconverter.graph_based_converter.common.global_context import GlobalContext +from mindinsight.mindconverter.graph_based_converter.third_party_graph.onnx_utils import BaseNode +from mindinsight.mindconverter.graph_based_converter.sub_graph_searcher.search_path import SearchPath, Pattern, \ + generate_pattern, find_built_in_pattern +from mindinsight.mindconverter.common.exceptions import SubGraphSearchingError def _is_satisfied(path): diff --git a/mindinsight/mindconverter/graph_based_converter/third_party_graph/__init__.py b/mindinsight/mindconverter/graph_based_converter/third_party_graph/__init__.py index 55424443..424c0245 100644 --- a/mindinsight/mindconverter/graph_based_converter/third_party_graph/__init__.py +++ b/mindinsight/mindconverter/graph_based_converter/third_party_graph/__init__.py @@ -17,7 +17,7 @@ __all__ = ["GraphFactory"] from importlib import import_module -from .base import Graph +from mindinsight.mindconverter.graph_based_converter.third_party_graph.base import Graph class GraphFactory: diff --git a/mindinsight/mindconverter/graph_based_converter/third_party_graph/input_node.py b/mindinsight/mindconverter/graph_based_converter/third_party_graph/input_node.py index 8f75451e..a7090078 100644 --- a/mindinsight/mindconverter/graph_based_converter/third_party_graph/input_node.py +++ b/mindinsight/mindconverter/graph_based_converter/third_party_graph/input_node.py @@ -1,4 +1,4 @@ -# Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. +# Copyright 2020-2021 Huawei Technologies Co., Ltd.All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -15,8 +15,8 @@ """Define PyTorch graph node.""" import os -from .base import GraphNode -from ..constant import SEPARATOR_IN_SCOPE, NodeType +from mindinsight.mindconverter.graph_based_converter.third_party_graph.base import GraphNode +from mindinsight.mindconverter.graph_based_converter.constant import SEPARATOR_IN_SCOPE, NodeType class InputNode(GraphNode): @@ -25,7 +25,6 @@ class InputNode(GraphNode): Args: input_shape: Input shape of module. - """ def _get_arg_name(self, arg, variable_name): diff --git a/mindinsight/mindconverter/graph_based_converter/third_party_graph/onnx_graph.py b/mindinsight/mindconverter/graph_based_converter/third_party_graph/onnx_graph.py index 74b6d4fc..3c3060c6 100644 --- a/mindinsight/mindconverter/graph_based_converter/third_party_graph/onnx_graph.py +++ b/mindinsight/mindconverter/graph_based_converter/third_party_graph/onnx_graph.py @@ -17,12 +17,12 @@ from importlib import import_module from typing import Dict, NoReturn from mindinsight.mindconverter.common.log import logger as log -from .base import Graph -from .input_node import InputNode -from .onnx_graph_node import OnnxGraphNode -from .pytorch_graph_parser import PyTorchGraphParser -from .tf_graph_parser import TFGraphParser -from .onnx_utils import OnnxDataLoader +from mindinsight.mindconverter.graph_based_converter.third_party_graph.base import Graph +from mindinsight.mindconverter.graph_based_converter.third_party_graph.input_node import InputNode +from mindinsight.mindconverter.graph_based_converter.third_party_graph.onnx_graph_node import OnnxGraphNode +from mindinsight.mindconverter.graph_based_converter.third_party_graph.pytorch_graph_parser import PyTorchGraphParser +from mindinsight.mindconverter.graph_based_converter.third_party_graph.tf_graph_parser import TFGraphParser +from mindinsight.mindconverter.graph_based_converter.third_party_graph.onnx_utils import OnnxDataLoader, NodeWeight NONE_SCOPE_OP = { "onnx::Add": "Add", @@ -126,7 +126,7 @@ class OnnxGraph(Graph): self._shape_dict = model_data.node_output_shape_dict for ind, (node_name, node) in enumerate(model_data.nodes_dict.items()): - node_weight = {} + node_weights = list() node.scope_name = scope_name_list[ind] inputs = node.input_name_list # check each input from node or tensors @@ -135,8 +135,8 @@ class OnnxGraph(Graph): tensor = model_data.tensors_dict[i] t_name = tensor.name t_value = tensor.to_array() - node_weight[t_name] = t_value - self._nodes_collection[node_name] = OnnxGraphNode(node, node_weight) + node_weights.append(NodeWeight(t_name, t_value)) + self._nodes_collection[node_name] = OnnxGraphNode(node, node_weights) self._nodes_record[node_name] = node_name for nd_ipt_name in node.precursor_onnx_node_dict: diff --git a/mindinsight/mindconverter/graph_based_converter/third_party_graph/onnx_graph_node.py b/mindinsight/mindconverter/graph_based_converter/third_party_graph/onnx_graph_node.py index 3158c082..913c02c1 100644 --- a/mindinsight/mindconverter/graph_based_converter/third_party_graph/onnx_graph_node.py +++ b/mindinsight/mindconverter/graph_based_converter/third_party_graph/onnx_graph_node.py @@ -1,4 +1,4 @@ -# Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. +# Copyright 2020-2021 Huawei Technologies Co., Ltd.All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -15,11 +15,11 @@ """Define ONNX graph node.""" from importlib import import_module -from .base import GraphNode -from ..common.utils import is_converted +from mindinsight.mindconverter.graph_based_converter.third_party_graph.base import GraphNode +from mindinsight.mindconverter.graph_based_converter.common.utils import is_converted -from ..constant import NodeType, SEPARATOR_IN_SCOPE, SEPARATOR_BTW_NAME_AND_ID, LEFT_BUCKET, RIGHT_BUCKET, \ - SEPARATOR_IN_ONNX_OP +from mindinsight.mindconverter.graph_based_converter.constant import NodeType, SEPARATOR_IN_SCOPE, \ + SEPARATOR_BTW_NAME_AND_ID, LEFT_BUCKET, RIGHT_BUCKET, SEPARATOR_IN_ONNX_OP class OnnxGraphNode(GraphNode): @@ -28,7 +28,7 @@ class OnnxGraphNode(GraphNode): Args: node (OnnxNode): OnnxNode Object. - weight (dict): Dictionary records weight and bias. + weight (list): List of recording node weights. """ _type_frozen = False _module_name_frozen = False @@ -227,7 +227,6 @@ class OnnxGraphNode(GraphNode): Args: src_arg (str): Original arg name. tgt_arg (str): Target arg name. - """ self._args_in_code[src_arg] = tgt_arg diff --git a/mindinsight/mindconverter/graph_based_converter/third_party_graph/onnx_utils.py b/mindinsight/mindconverter/graph_based_converter/third_party_graph/onnx_utils.py index 59efa333..dc9b3295 100644 --- a/mindinsight/mindconverter/graph_based_converter/third_party_graph/onnx_utils.py +++ b/mindinsight/mindconverter/graph_based_converter/third_party_graph/onnx_utils.py @@ -22,13 +22,13 @@ from typing import Union import numpy as np from mindinsight.mindconverter.common.log import logger as log -from ..common.utils import fetch_output_from_onnx_model -from ..common.global_context import GlobalContext -from .optimizer import OnnxSimplify +from mindinsight.mindconverter.graph_based_converter.common.utils import fetch_output_from_onnx_model +from mindinsight.mindconverter.graph_based_converter.common.global_context import GlobalContext +from mindinsight.mindconverter.graph_based_converter.third_party_graph.optimizer import OnnxSimplify -from ..constant import ONNX_TYPE_INT, ONNX_TYPE_INTS, ONNX_TYPE_STRING, \ +from mindinsight.mindconverter.graph_based_converter.constant import ONNX_TYPE_INT, ONNX_TYPE_INTS, ONNX_TYPE_STRING, \ ONNX_TYPE_FLOATS, ONNX_TYPE_FLOAT, SCALAR_WITHOUT_SHAPE, DYNAMIC_SHAPE, UNKNOWN_DIM_VAL -from ...common.exceptions import GraphInitError, ModelLoadingError +from mindinsight.mindconverter.common.exceptions import GraphInitError, ModelLoadingError def convert_tf_graph_to_onnx(model_path, model_inputs, model_outputs, opset=12): @@ -128,7 +128,6 @@ class ParamsAttribute: raw_attribute (onnx.AttributeProto): onnx.AttributeProto instance. node (onnx.NodeProto): Must pass the onnx.NodeProto instance containing the same AttributeProto. - """ def __init__(self, raw_attribute, node): @@ -148,7 +147,6 @@ class ParamsAttribute: Args: attrs (onnx.AttributeProto): onnx.AttributeProto instance. - """ if not attrs: return @@ -604,3 +602,18 @@ class OnnxDataLoader: eliminated_nodes = _traceback_precursor_nodes_until_shape_op(to_shape) self.dynamic_resize_node.append(nd_name) self.eliminated_nodes += eliminated_nodes + + +class NodeWeight: + """Node weight struct.""" + def __init__(self, weight_name, weight_value): + self._weight_name = weight_name + self._weight_value = weight_value + + @property + def name(self): + return self._weight_name + + @property + def value(self): + return self._weight_value diff --git a/mindinsight/mindconverter/graph_based_converter/third_party_graph/optimizer.py b/mindinsight/mindconverter/graph_based_converter/third_party_graph/optimizer.py index 63c0476f..179c220e 100644 --- a/mindinsight/mindconverter/graph_based_converter/third_party_graph/optimizer.py +++ b/mindinsight/mindconverter/graph_based_converter/third_party_graph/optimizer.py @@ -18,7 +18,7 @@ from importlib import import_module import numpy as np -from ..common.utils import fetch_output_from_onnx_model +from mindinsight.mindconverter.graph_based_converter.common.utils import fetch_output_from_onnx_model class OnnxSimplify: diff --git a/mindinsight/mindconverter/graph_based_converter/third_party_graph/pytorch_graph_parser.py b/mindinsight/mindconverter/graph_based_converter/third_party_graph/pytorch_graph_parser.py index 9dd20995..597e96e6 100644 --- a/mindinsight/mindconverter/graph_based_converter/third_party_graph/pytorch_graph_parser.py +++ b/mindinsight/mindconverter/graph_based_converter/third_party_graph/pytorch_graph_parser.py @@ -20,6 +20,7 @@ from mindinsight.mindconverter.common.log import logger as log from mindinsight.mindconverter.graph_based_converter.third_party_graph.base import GraphParser from mindinsight.mindconverter.common.exceptions import ModelNotSupportError + class PyTorchGraphParser(GraphParser): """Define pytorch graph parser.""" diff --git a/mindinsight/mindconverter/graph_based_converter/third_party_graph/tf_graph_parser.py b/mindinsight/mindconverter/graph_based_converter/third_party_graph/tf_graph_parser.py index 952506f9..0f1c1ba0 100644 --- a/mindinsight/mindconverter/graph_based_converter/third_party_graph/tf_graph_parser.py +++ b/mindinsight/mindconverter/graph_based_converter/third_party_graph/tf_graph_parser.py @@ -1,4 +1,4 @@ -# Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. +# Copyright 2020-2021 Huawei Technologies Co., Ltd.All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -18,8 +18,8 @@ import re from importlib import import_module from mindinsight.mindconverter.common.log import logger as log -from .base import GraphParser -from ...common.exceptions import ModelNotSupportError +from mindinsight.mindconverter.graph_based_converter.third_party_graph.base import GraphParser +from mindinsight.mindconverter.common.exceptions import ModelNotSupportError class TFGraphParser(GraphParser): diff --git a/tests/ut/mindconverter/graph_based_converter/mapper/test_mapper.py b/tests/ut/mindconverter/graph_based_converter/mapper/test_mapper.py index cd31fdc3..6ebdfce9 100644 --- a/tests/ut/mindconverter/graph_based_converter/mapper/test_mapper.py +++ b/tests/ut/mindconverter/graph_based_converter/mapper/test_mapper.py @@ -89,7 +89,9 @@ class TestMappers: 'input': {'op_name': 'onnx::MaxPool', 'params': {'kernel_shape': [3, 3], 'pads': [1, 1, 1, 1], - 'strides': [2, 2]}, + 'strides': [2, 2], + 'input_shape': (1, 3, 224, 224), + 'output_shape': (1, 3, 112, 112)}, 'weights': dict()}, 'expected_output': {'converter_name': 'nn.MaxPool2d', 'converted_params': {'kernel_size': (3, 3), @@ -100,7 +102,9 @@ class TestMappers: 'input': {'op_name': 'onnx::AveragePool', 'params': {'kernel_shape': [3, 3], 'pads': [1, 1, 1, 1], - 'strides': [2, 2]}, + 'strides': [2, 2], + 'input_shape': (1, 3, 224, 224), + 'output_shape': (1, 3, 112, 112)}, 'weights': dict()}, 'expected_output': {'converter_name': 'nn.AvgPool2d', 'converted_params': {'kernel_size': (3, 3), diff --git a/tests/utils/mindspore/train/serialization.py b/tests/utils/mindspore/train/serialization.py new file mode 100644 index 00000000..6f2278a4 --- /dev/null +++ b/tests/utils/mindspore/train/serialization.py @@ -0,0 +1,26 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Mock the MindSpore mindspore/train/serialization.py.""" + + +def save_checkpoint(trainable_weights, ckpt_file_name): + """ + Mock save_checkpoint. + + Args: + trainable_weights (list): List of weights. + ckpt_file_name (str): Path to save checkpoint file. + """ + return len(trainable_weights), ckpt_file_name