You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

handler.py 2.4 kB

4 years ago
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273
  1. from __future__ import absolute_import
  2. from __future__ import division
  3. from __future__ import print_function
  4. from __future__ import unicode_literals
  5. import collections
  6. import inspect
  7. from hetu.onnx import constants
  8. class hetu_op:
  9. _OPSETS = collections.OrderedDict()
  10. _MAPPING = None
  11. def __init__(self, name, onnx_op=None, domain=constants.ONNX_DOMAIN, **kwargs):
  12. if not isinstance(name, list):
  13. name = [name]
  14. self.name = name
  15. if not isinstance(onnx_op, list):
  16. onnx_op = [onnx_op]*len(name)
  17. self.onnx_op = onnx_op
  18. self.domain = domain
  19. self.kwargs = kwargs
  20. def __call__(self, func):
  21. opset = hetu_op._OPSETS.get(self.domain)
  22. if not opset:
  23. opset = []
  24. hetu_op._OPSETS[self.domain] = opset
  25. for k, v in inspect.getmembers(func, inspect.ismethod):
  26. if k.startswith("version_"):
  27. version = int(k.replace("version_", ""))
  28. while version >= len(opset):
  29. opset.append({})
  30. opset_dict = opset[version]
  31. for i, name in enumerate(self.name):
  32. opset_dict[name] = (v, self.onnx_op[i], self.kwargs)
  33. return func
  34. @staticmethod
  35. def get_opsets():
  36. return hetu_op._OPSETS
  37. @staticmethod
  38. def create_mapping(max_onnx_opset_version):
  39. mapping = {constants.ONNX_DOMAIN: max_onnx_opset_version}
  40. ops_mapping = {}
  41. for domain, opsets in hetu_op.get_opsets().items():
  42. for target_opset, op_map in enumerate(opsets):
  43. m = mapping.get(domain)
  44. if m:
  45. if target_opset <= m and op_map:
  46. ops_mapping.update(op_map)
  47. hetu_op._MAPPING = ops_mapping
  48. return ops_mapping
  49. @staticmethod
  50. def find_effective_op(name):
  51. """Find the effective version of an op create_mapping.
  52. This is used if we need to compose ops from other ops where we'd need to find the
  53. op that is doing to be used in the final graph, for example there is a custom op
  54. that overrides a onnx op ...
  55. :param name: The operator name.
  56. """
  57. map_info = hetu_op._MAPPING.get(name)
  58. if map_info is None:
  59. return None
  60. return map_info