|
- from __future__ import absolute_import
- from __future__ import division
- from __future__ import print_function
- from __future__ import unicode_literals
-
- import collections
- import inspect
-
- from hetu.onnx import constants
-
-
- class hetu_op:
-
- _OPSETS = collections.OrderedDict()
- _MAPPING = None
-
- def __init__(self, name, onnx_op=None, domain=constants.ONNX_DOMAIN, **kwargs):
-
- if not isinstance(name, list):
- name = [name]
- self.name = name
- if not isinstance(onnx_op, list):
- onnx_op = [onnx_op]*len(name)
- self.onnx_op = onnx_op
- self.domain = domain
- self.kwargs = kwargs
-
- def __call__(self, func):
- opset = hetu_op._OPSETS.get(self.domain)
- if not opset:
- opset = []
- hetu_op._OPSETS[self.domain] = opset
- for k, v in inspect.getmembers(func, inspect.ismethod):
- if k.startswith("version_"):
- version = int(k.replace("version_", ""))
- while version >= len(opset):
- opset.append({})
- opset_dict = opset[version]
- for i, name in enumerate(self.name):
- opset_dict[name] = (v, self.onnx_op[i], self.kwargs)
- return func
-
- @staticmethod
- def get_opsets():
- return hetu_op._OPSETS
-
- @staticmethod
- def create_mapping(max_onnx_opset_version):
- mapping = {constants.ONNX_DOMAIN: max_onnx_opset_version}
- ops_mapping = {}
- for domain, opsets in hetu_op.get_opsets().items():
- for target_opset, op_map in enumerate(opsets):
- m = mapping.get(domain)
- if m:
- if target_opset <= m and op_map:
- ops_mapping.update(op_map)
-
- hetu_op._MAPPING = ops_mapping
- return ops_mapping
-
- @staticmethod
- def find_effective_op(name):
- """Find the effective version of an op create_mapping.
- This is used if we need to compose ops from other ops where we'd need to find the
- op that is doing to be used in the final graph, for example there is a custom op
- that overrides a onnx op ...
-
- :param name: The operator name.
- """
- map_info = hetu_op._MAPPING.get(name)
- if map_info is None:
- return None
- return map_info
|