You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

graph_pattern.py 8.3 kB

4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Patterns for describing graphs"""
  16. from mindspore.ops import Primitive
  17. from mindspore.common.tensor import Tensor
  18. from mindspore._c_expression import Pattern, OneOf_, Prim_, Call_, NoneOf_, Any, NewTensor_, NewParameter_, Imm
  19. __all__ = [
  20. "OneOf",
  21. "Prim",
  22. "Call",
  23. "NoneOf",
  24. "Any",
  25. "NewTensor",
  26. "NewParameter",
  27. "Imm"
  28. ]
  29. class OneOf(OneOf_):
  30. r"""
  31. Express a pattern which allows a list of patterns.
  32. """
  33. def __init__(self, patterns=None):
  34. r"""
  35. Args:
  36. patterns(Union[:class:`mindspore.graph_utils.graph_pattern`,
  37. tuple[:class:`mindspore.graph_utils.graph_pattern`],
  38. list[:class:`mindspore.graph_utils.graph_pattern`]]): list of allowed patterns,
  39. each element should be one of the exposed Pattern instance.
  40. Raises:
  41. TypeError: raise type error for invalid inputs.
  42. """
  43. self.patterns = patterns
  44. if isinstance(patterns, Pattern):
  45. OneOf_.__init__(self, [patterns])
  46. elif isinstance(patterns, (tuple, list)) and all(isinstance(pattern, Pattern) for pattern in patterns):
  47. OneOf_.__init__(self, patterns)
  48. else:
  49. raise TypeError(f"Expect patterns to be a list of Patterns/Pattern, got : {patterns}")
  50. class Prim(Prim_):
  51. r"""
  52. Express a pattern of certain primitive type(s).
  53. NOTE:
  54. This pattern will match and only match the primitive value node. If matching primitive CNode is needed,
  55. please refer to CallWith pattern.
  56. """
  57. def __init__(self, types, name=None):
  58. r"""
  59. Args:
  60. types (Union[str, :class:`mindspore.ops.Primitive`, list[:class:`mindspore.ops.Primitive`],
  61. tuple[:class:`mindspore.ops.Primitive`]):
  62. Specify allowed types.
  63. If it is a string, the form could be
  64. 1) a single primitive type, e.g. 'Conv2D'
  65. 2) a set of primitive types separated by '|', e.g. 'MatMul|Conv2D'
  66. It can also be a Primitive or a list/tuple of Primitives, e.g. [ops.Conv2D(1, 6)]
  67. name (str): name of the pattern, optional. Default: None.
  68. Raises:
  69. TypeError: raise type error for invalid argument.
  70. """
  71. if name is not None and not isinstance(name, str):
  72. raise TypeError(f"Expect string, got : {name}")
  73. self.name = name
  74. if isinstance(types, str):
  75. if self.name is None:
  76. self.name = types
  77. self.types = types.split('|')
  78. elif isinstance(types, Primitive):
  79. if self.name is None:
  80. self.name = types.name
  81. self.types = [types]
  82. elif isinstance(types, (tuple, list)) and all(isinstance(tp, Primitive) for tp in types):
  83. if self.name is None:
  84. self.name = ""
  85. for prim in types:
  86. self.name += prim.name
  87. self.types = types
  88. else:
  89. raise TypeError(f"Expecting a primitive type string or a list of Primitives, got : {types}")
  90. Prim_.__init__(self, self.types, self.name)
  91. class Call(Call_):
  92. r"""
  93. Express a primitive CNode.
  94. """
  95. def __init__(self, prim_pattern, inputs=None):
  96. r"""
  97. Args:
  98. prim_pattern (Union[str, :class:`mindspore.graph_utils.graph_pattern.IsPrimTypeOf`,
  99. :class:`mindspore.ops.Primitive`]): Primitive ValueNode in the Primitive CNode.
  100. inputs (Union[list[:class:`mindspore.graph_utils.graph_pattern`],
  101. tuple[:class:`mindspore.graph_utils.graph_pattern`]]):
  102. Specify inputs pattern for the primitive(s), optional. If None, accepts any inputs; if specified, input
  103. patterns should be of right order and each element should be one of the exposed Pattern instance.
  104. Raises:
  105. TypeError: raise type error for invalid argument.
  106. """
  107. if not isinstance(prim_pattern, (Pattern, str, Primitive)):
  108. raise TypeError(f"Expect prim_pattern to be Pattern, Primitive or string, got : {prim_pattern}")
  109. self.prim_pattern = prim_pattern
  110. self.inputs = []
  111. if inputs is None:
  112. pass
  113. elif isinstance(inputs, (tuple, list)) and all(isinstance(input, Pattern) for input in inputs):
  114. self.inputs = inputs
  115. else:
  116. raise TypeError(f"Expect inputs to be a list of Patterns, got : {inputs}")
  117. Call_.__init__(self, self.prim_pattern, self.inputs)
  118. class NoneOf(NoneOf_):
  119. r"""
  120. Express a pattern which forbids a list of patterns.
  121. NOTE:
  122. NoneOf pattern should not be the root pattern.
  123. """
  124. def __init__(self, patterns=None):
  125. r"""
  126. Args:
  127. patterns(Union[list[:class:`mindspore.graph_utils.graph_pattern`]]: list of forbidden patterns, each
  128. element should be one of the exposed Pattern instance.
  129. Raises:
  130. TypeError: raise type error for invalid argument.
  131. """
  132. self.patterns = patterns
  133. if patterns is None:
  134. NoneOf_.__init__(self, ())
  135. elif isinstance(patterns, Pattern):
  136. NoneOf_.__init__(self, [patterns])
  137. elif isinstance(patterns, (tuple, list)) and all(isinstance(pattern, Pattern) for pattern in patterns):
  138. NoneOf_.__init__(self, patterns)
  139. else:
  140. raise TypeError(f"Expect list of Patterns/Pattern, got : {patterns}")
  141. class NewTensor(NewTensor_):
  142. r"""
  143. New Tensor to be used in the target.
  144. """
  145. def __init__(self, input_tensor):
  146. r"""
  147. Args:
  148. input_tensor(:class:`mindspore.common.tensor.Tensor`): new tensor to be used in the target.
  149. Raises:
  150. TypeError: raise type error for invalid argument.
  151. """
  152. self.input_tensor = input_tensor
  153. if isinstance(input_tensor, Tensor):
  154. NewTensor_.__init__(self, input_tensor)
  155. else:
  156. raise TypeError(f"Expect input_tensor to be a Tensor, got : {input_tensor}")
  157. class NewParameter(NewParameter_):
  158. r"""
  159. New Parameter to be used in the target.
  160. """
  161. def __init__(self, para_name, default_tensor, requires_grad=False, layerwise_parallel=False):
  162. r"""
  163. Args:
  164. para_name(str): name for the new Parameter.
  165. default_tensor(:class:`mindspore.common.tensor.Tensor`): default value for the new Parameter.
  166. requires_grad(bool): True if the parameter requires gradient. Default: True.
  167. layerwise_parallel(bool): switch for layerwise parallel mode. Default: False.
  168. Raises:
  169. TypeError: raise type error for invalid argument.
  170. """
  171. self.para_name = para_name
  172. self.default_tensor = default_tensor
  173. self.requires_grad = requires_grad
  174. self.layerwise_parallel = layerwise_parallel
  175. if isinstance(para_name, str) and isinstance(default_tensor, Tensor) and isinstance(requires_grad, bool) and\
  176. isinstance(layerwise_parallel, bool):
  177. NewParameter_.__init__(self, self.para_name, self.default_tensor, self.requires_grad,
  178. self.layerwise_parallel)
  179. else:
  180. raise TypeError(f"Expect para_name(str), default_tensor(Tensor), requires_grad(bool), \
  181. layerwise_parallel(bool), got : {para_name}, {default_tensor}, \
  182. {requires_grad}, {layerwise_parallel}")