You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

base.py 4.7 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133
  1. # Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Mapper module."""
  16. import abc
  17. import importlib
  18. import json
  19. import os
  20. from typing import Dict
  21. from mindinsight.mindconverter.common.log import logger as log
  22. CONFIG_JSON = "onnx_to_ms.json"
  23. OPERATION_TABLE = os.path.join(
  24. os.path.abspath(os.path.dirname(__file__)),
  25. CONFIG_JSON
  26. )
  27. with open(OPERATION_TABLE) as file:
  28. # Load mapping table which key is operation name in ONNX and
  29. # value is corresponding module path.
  30. TABLE = json.load(file)
  31. # Define global func name.
  32. GET_OP_NAME = "_operation_name_in_ms"
  33. GET_OP_PARAMS = "_convert_params"
  34. GET_OP_WEIGHTS = "_convert_trained_weights"
  35. GET_OP_SETTINGS = "_convert_settings"
  36. class Mapper(metaclass=abc.ABCMeta):
  37. """Mapper between third-party-operation and MindSpore."""
  38. @staticmethod
  39. @abc.abstractmethod
  40. def _operation_name_in_ms(*args, **kwargs):
  41. """Corresponding operation name in mindspore."""
  42. @staticmethod
  43. @abc.abstractmethod
  44. def _convert_params(**kwargs):
  45. """Convert third party operation's param into MindSpore operation."""
  46. @staticmethod
  47. @abc.abstractmethod
  48. def _convert_trained_weights(**kwargs):
  49. """Convert third party operation's weights into MindSpore operation."""
  50. @staticmethod
  51. @abc.abstractmethod
  52. def _convert_settings(**kwargs):
  53. """Convert third party operation's params into MindSpore OP operator."""
  54. @classmethod
  55. @abc.abstractmethod
  56. def convert(cls, op_name: str, params: Dict, weights: Dict = None):
  57. """Convert third party operation's param into MindSpore operation."""
  58. class ONNXToMindSporeMapper(Mapper, abc.ABC):
  59. """ONNX operation to MindSpore."""
  60. @classmethod
  61. def convert(cls, op_name: str, params: Dict, weights: Dict = None):
  62. """
  63. Convert third party operation's param into MindSpore operation.
  64. Args:
  65. op_name (str): Operation name in ONNX.
  66. params (dict): Params in onnx.
  67. weights (dict): Weights in onnx.
  68. Returns:
  69. Tuple[str, dict, dict], operation name and params and settings.
  70. """
  71. global TABLE
  72. module_name = TABLE.get(op_name)
  73. if not module_name:
  74. return None, dict(), dict()
  75. pos = module_name.rfind(".")
  76. try:
  77. converter = getattr(importlib.import_module(module_name[:pos]),
  78. module_name[pos + 1:])
  79. op_name_converter = getattr(converter, GET_OP_NAME)
  80. params_converter = getattr(converter, GET_OP_PARAMS)
  81. weights_converter = getattr(converter, GET_OP_WEIGHTS)
  82. settings_converter = getattr(converter, GET_OP_SETTINGS)
  83. except (ModuleNotFoundError,) as e:
  84. # If mapper can not be found, then skip it.
  85. err_msg = f"Converting {op_name} failed, see {str(e)}"
  86. log.error(err_msg)
  87. return None, dict(), dict()
  88. try:
  89. converter_name = op_name_converter(params=params, weights=weights, op_name=op_name)
  90. converted_params = params_converter(params=params, weights=weights)
  91. converted_weights = weights_converter(weights=weights) if weights else dict()
  92. converted_params.update(converted_weights)
  93. converted_settings = settings_converter(params=params)
  94. except (AttributeError, KeyError, ValueError, TypeError, IndexError) as e:
  95. err_msg = f"Converting {op_name} failed, see {str(e)}"
  96. log.error(err_msg)
  97. return None, dict(), dict()
  98. return converter_name, converted_params, converted_settings
  99. @staticmethod
  100. def _operation_name_in_ms(*args, **kwargs):
  101. raise NotImplementedError
  102. @staticmethod
  103. def _convert_params(**kwargs):
  104. raise NotImplementedError
  105. @staticmethod
  106. def _convert_trained_weights(**kwargs):
  107. raise NotImplementedError
  108. @staticmethod
  109. def _convert_settings(**kwargs):
  110. raise NotImplementedError