# Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Mapper module.""" import abc import importlib import json import os from typing import Dict from mindinsight.mindconverter.common.log import logger as log CONFIG_JSON = "onnx_to_ms.json" OPERATION_TABLE = os.path.join( os.path.abspath(os.path.dirname(__file__)), CONFIG_JSON ) with open(OPERATION_TABLE) as file: # Load mapping table which key is operation name in ONNX and # value is corresponding module path. TABLE = json.load(file) # Define global func name. GET_OP_NAME = "_operation_name_in_ms" GET_OP_PARAMS = "_convert_params" GET_OP_WEIGHTS = "_convert_trained_weights" GET_OP_SETTINGS = "_convert_settings" class Mapper(metaclass=abc.ABCMeta): """Mapper between third-party-operation and MindSpore.""" @staticmethod @abc.abstractmethod def _operation_name_in_ms(*args, **kwargs): """Corresponding operation name in mindspore.""" @staticmethod @abc.abstractmethod def _convert_params(**kwargs): """Convert third party operation's param into MindSpore operation.""" @staticmethod @abc.abstractmethod def _convert_trained_weights(**kwargs): """Convert third party operation's weights into MindSpore operation.""" @staticmethod @abc.abstractmethod def _convert_settings(**kwargs): """Convert third party operation's params into MindSpore OP operator.""" @classmethod @abc.abstractmethod def convert(cls, op_name: str, params: Dict, weights: Dict = None): """Convert third party operation's param into MindSpore operation.""" class ONNXToMindSporeMapper(Mapper, abc.ABC): """ONNX operation to MindSpore.""" @classmethod def convert(cls, op_name: str, params: Dict, weights: Dict = None): """ Convert third party operation's param into MindSpore operation. Args: op_name (str): Operation name in ONNX. params (dict): Params in onnx. weights (dict): Weights in onnx. Returns: Tuple[str, dict, dict], operation name and params and settings. """ global TABLE module_name = TABLE.get(op_name) if not module_name: return None, dict(), dict() pos = module_name.rfind(".") try: converter = getattr(importlib.import_module(module_name[:pos]), module_name[pos + 1:]) op_name_converter = getattr(converter, GET_OP_NAME) params_converter = getattr(converter, GET_OP_PARAMS) weights_converter = getattr(converter, GET_OP_WEIGHTS) settings_converter = getattr(converter, GET_OP_SETTINGS) except (ModuleNotFoundError,) as e: # If mapper can not be found, then skip it. err_msg = f"Converting {op_name} failed, see {str(e)}" log.error(err_msg) return None, dict(), dict() try: converter_name = op_name_converter(params=params, weights=weights, op_name=op_name) converted_params = params_converter(params=params, weights=weights) converted_weights = weights_converter(weights=weights) if weights else dict() converted_params.update(converted_weights) converted_settings = settings_converter(params=params) except (AttributeError, KeyError, ValueError, TypeError, IndexError) as e: err_msg = f"Converting {op_name} failed, see {str(e)}" log.error(err_msg) return None, dict(), dict() return converter_name, converted_params, converted_settings @staticmethod def _operation_name_in_ms(*args, **kwargs): raise NotImplementedError @staticmethod def _convert_params(**kwargs): raise NotImplementedError @staticmethod def _convert_trained_weights(**kwargs): raise NotImplementedError @staticmethod def _convert_settings(**kwargs): raise NotImplementedError