You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

op_selector.py 4.6 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """
  16. A factory class that create op selector instance to config switch on a class,
  17. which can be used to control the switch of op type: GraphKernel or Primitive.
  18. """
  19. import importlib
  20. import inspect
  21. from mindspore import context
  22. class _OpSelector:
  23. """
  24. A helper class, which can be used to choose different type of operator.
  25. When an instance of this class is called, we return the right operator
  26. according to the context['enable_graph_kernel'] and the name of the
  27. parameter. returned operator will be a GraphKernel op ora Primitive op.
  28. Args:
  29. op (class): an empty class has an operator name as its class name
  30. config_optype (str): operator type, which must be either 'GraphKernel'
  31. or 'Primitive'
  32. graph_kernel_pkg (str): real operator's package name
  33. primitive_pkg (str): graph kernel operator's package name
  34. Examples:
  35. >>> class A: pass
  36. >>> selected_op = _OpSelector(A, "GraphKernel",
  37. >>> "graph_kernel.ops.pkg", "primitive.ops.pkg")
  38. >>> # selected_op() will call graph_kernel.ops.pkg.A()
  39. """
  40. GRAPH_KERNEL = "GraphKernel"
  41. PRIMITIVE = "Primitive"
  42. DEFAULT_OP_TYPE = PRIMITIVE
  43. KW_STR = "op_type"
  44. def __init__(self, op, config_optype, primitive_pkg, graph_kernel_pkg):
  45. self.op_name = op.__name__
  46. self.config_optype = config_optype
  47. self.graph_kernel_pkg = graph_kernel_pkg
  48. self.primitive_pkg = primitive_pkg
  49. def __call__(self, *args, **kwargs):
  50. _op_type = _OpSelector.DEFAULT_OP_TYPE
  51. if context.get_context("enable_graph_kernel"):
  52. if _OpSelector.KW_STR in kwargs:
  53. _op_type = kwargs.get(_OpSelector.KW_STR)
  54. kwargs.pop(_OpSelector.KW_STR, None)
  55. elif self.config_optype is not None:
  56. _op_type = self.config_optype
  57. if _op_type == _OpSelector.GRAPH_KERNEL:
  58. pkg = self.graph_kernel_pkg
  59. else:
  60. pkg = self.primitive_pkg
  61. op = getattr(importlib.import_module(pkg, __package__), self.op_name)
  62. return op(*args, **kwargs)
  63. def new_ops_selector(primitive_pkg, graph_kernel_pkg):
  64. """
  65. A factory method to return an op selector
  66. When the GraphKernel switch is on:
  67. `context.get_context('enable_graph_kernel') == True`, we have 2 ways to control the op type:
  68. (1). call the real op with an extra parameter `op_type='Primitive'` or `op_type='GraphKernel'`
  69. (2). pass a parameter to the op selector, like `@op_selector('Primitive')` or
  70. `@op_selector('GraphKernel')`
  71. (3). default op type is PRIMITIVE
  72. The order of the highest priority to lowest priority is (1), (2), (3)
  73. If the GraphKernel switch is off, then op_type will always be PRIMITIVE.
  74. Args:
  75. primitive_pkg (str): primitive op's package name
  76. graph_kernel_pkg (str): graph kernel op's package name
  77. Returns:
  78. returns an op selector, which can control what operator should be actually called.
  79. Examples:
  80. >>> op_selector = new_ops_selector("primitive_pkg.some.path",
  81. >>> "graph_kernel_pkg.some.path")
  82. >>> @op_selector
  83. >>> class ReduceSum: pass
  84. """
  85. def op_selector(cls_or_optype):
  86. _primitive_pkg = primitive_pkg
  87. _graph_kernel_pkg = graph_kernel_pkg
  88. def direct_op_type():
  89. darg = None
  90. if cls_or_optype is None:
  91. pass
  92. elif not inspect.isclass(cls_or_optype):
  93. darg = cls_or_optype
  94. return darg
  95. if direct_op_type() is not None:
  96. def deco_cls(_real_cls):
  97. return _OpSelector(_real_cls, direct_op_type(), _primitive_pkg, _graph_kernel_pkg)
  98. return deco_cls
  99. return _OpSelector(cls_or_optype, direct_op_type(), _primitive_pkg, _graph_kernel_pkg)
  100. return op_selector