You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

graph_pattern.py 10 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Patterns for describing graphs"""
  16. from mindspore.ops import Primitive
  17. from mindspore.common.tensor import Tensor
  18. from mindspore._c_expression import Pattern, IsIn_, IsPrimTypeOf_, CallWith_, IsNot_, AnyPattern, NewTensor_,\
  19. NewParameter_, Imm
  20. __all__ = [
  21. "IsIn",
  22. "IsPrimTypeOf",
  23. "CallWith",
  24. "IsNot",
  25. "AnyPattern",
  26. "NewTensor",
  27. "NewParameter",
  28. "Imm"
  29. ]
  30. class IsIn(IsIn_):
  31. r"""
  32. Express a pattern which allows a list of patterns.
  33. """
  34. def __init__(self, patterns=None, should_replace=True):
  35. r"""
  36. Args:
  37. patterns(Union[tuple[:class:`mindspore.graph_utils.graph_pattern`],
  38. list[:class:`mindspore.graph_utils.graph_pattern`]]): list of allowed patterns,
  39. each element should be one of the exposed Pattern instance.
  40. should_replace(bool): added this for interface consistency. Should only set this in sub-patterns.
  41. Raises:
  42. ValueError: raise if should_replace is False
  43. TypeError: raise type error for invalid inputs.
  44. """
  45. if not should_replace:
  46. raise ValueError("IsIn pattern does not have its own should_replace attribute. Set should_replace in \
  47. its sub-pattern instead.")
  48. self.patterns = patterns
  49. if patterns is None:
  50. IsIn_.__init__(self, ())
  51. elif isinstance(patterns, Pattern):
  52. IsIn_.__init__(self, [patterns])
  53. elif isinstance(patterns, (tuple, list)) and all(isinstance(pattern, Pattern) for pattern in patterns):
  54. IsIn_.__init__(self, patterns)
  55. else:
  56. raise TypeError(f"Expect patterns to be a list of Patterns/Pattern, got : {patterns}")
  57. class IsPrimTypeOf(IsPrimTypeOf_):
  58. r"""
  59. Express a pattern of certain primitive type(s).
  60. NOTE:
  61. This pattern will match and only match the primitive value node. If matching primitive CNode is needed,
  62. please refer to CallWith pattern.
  63. """
  64. def __init__(self, types, name=None, should_replace=True):
  65. r"""
  66. Args:
  67. types (Union[str, :class:`mindspore.ops.Primitive`, list[:class:`mindspore.ops.Primitive`],
  68. tuple[:class:`mindspore.ops.Primitive`]):
  69. Specify allowed types.
  70. If it is a string, the form could be
  71. 1) a single primitive type, e.g. 'Conv2D'
  72. 2) a set of primitive types separated by '|', e.g. 'MatMul|Conv2D'
  73. It can also be a Primitive or a list/tuple of Primitives, e.g. [ops.Conv2D(1, 6)]
  74. name (str): name of the pattern, optional. Default: None.
  75. should_replace(bool): If pattern is part of the pass replacement target, this would set how this pattern is
  76. used when building the replacement target node. Use captured node if True, build from scratch otherwise.
  77. Default: True.
  78. Raises:
  79. TypeError: raise type error for invalid argument.
  80. """
  81. if name is not None and not isinstance(name, str):
  82. raise TypeError(f"Expect string, got : {name}")
  83. self.name = name
  84. if isinstance(types, str):
  85. if self.name is None:
  86. self.name = types
  87. self.types = types.split('|')
  88. elif isinstance(types, Primitive):
  89. if self.name is None:
  90. self.name = types.name
  91. self.types = [types]
  92. elif isinstance(types, (tuple, list)) and all(isinstance(tp, Primitive) for tp in types):
  93. if self.name is None:
  94. self.name = ""
  95. for prim in types:
  96. self.name += prim.name
  97. self.types = types
  98. else:
  99. raise TypeError(f"Expecting a primitive type string or a list of Primitives, got : {types}")
  100. IsPrimTypeOf_.__init__(self, self.types, self.name, should_replace)
  101. class CallWith(CallWith_):
  102. r"""
  103. Express a primitive CNode.
  104. """
  105. def __init__(self, prim_pattern, inputs=None, should_replace=True):
  106. r"""
  107. Args:
  108. prim_pattern (Union[str, :class:`mindspore.graph_utils.graph_pattern.IsPrimTypeOf`,
  109. :class:`mindspore.ops.Primitive`]): Primitive ValueNode in the Primitive CNode.
  110. inputs (Union[list[:class:`mindspore.graph_utils.graph_pattern`],
  111. tuple[:class:`mindspore.graph_utils.graph_pattern`]]):
  112. Specify inputs pattern for the primitive(s), optional. If None, accepts any inputs; if specified, input
  113. patterns should be of right order and each element should be one of the exposed Pattern instance.
  114. should_replace(bool): If pattern is part of the pass replacement target, this would set how this pattern is
  115. used when building the replacement target node. Use captured node if True, build from scratch otherwise.
  116. Default: True.
  117. Raises:
  118. TypeError: raise type error for invalid argument.
  119. """
  120. if not isinstance(prim_pattern, (Pattern, str, Primitive)):
  121. raise TypeError(f"Expect prim_pattern to be Pattern, Primitive or string, got : {prim_pattern}")
  122. self.prim_pattern = prim_pattern
  123. self.inputs = []
  124. if inputs is None:
  125. pass
  126. elif isinstance(inputs, (tuple, list)) and all(isinstance(input, Pattern) for input in inputs):
  127. self.inputs = inputs
  128. else:
  129. raise TypeError(f"Expect inputs to be a list of Patterns, got : {inputs}")
  130. CallWith_.__init__(self, self.prim_pattern, self.inputs, should_replace)
  131. class IsNot(IsNot_):
  132. r"""
  133. Express a pattern which forbids a list of patterns.
  134. NOTE:
  135. IsNot pattern should not be the root pattern.
  136. """
  137. def __init__(self, patterns=None, should_replace=True):
  138. r"""
  139. Args:
  140. patterns(Union[list[:class:`mindspore.graph_utils.graph_pattern`]]: list of forbiden patterns, each element
  141. should be one of the exposed Pattern instance.
  142. should_replace(bool): added this for interface consistency. Should only set this in sub-patterns.
  143. Raises:
  144. ValueError: raise if should_replace is False.
  145. TypeError: raise type error for invalid argument.
  146. """
  147. if not should_replace:
  148. raise ValueError("IsNot pattern does not have its own should_replace attribute. Set should_replace in \
  149. its sub-pattern instead.")
  150. self.patterns = patterns
  151. if patterns is None:
  152. IsNot_.__init__(self, ())
  153. elif isinstance(patterns, Pattern):
  154. IsNot_.__init__(self, [patterns])
  155. elif isinstance(patterns, (tuple, list)) and all(isinstance(pattern, Pattern) for pattern in patterns):
  156. IsNot_.__init__(self, patterns)
  157. else:
  158. raise TypeError(f"Expect list of Patterns/Pattern, got : {patterns}")
  159. class NewTensor(NewTensor_):
  160. r"""
  161. New Tensor to be used in the target.
  162. """
  163. def __init__(self, input_tensor, should_replace=False):
  164. r"""
  165. Args:
  166. input_tensor(:class:`mindspore.common.tensor.Tensor`): new tensor to be used in the target
  167. should_replace(bool): added this for interface consistency. NewTensor should only appear in the target.
  168. Raises:
  169. ValueError: raise if should_replace is True
  170. TypeError: raise type error for invalid argument.
  171. """
  172. if should_replace:
  173. raise ValueError("NewTensor should only appear in the target, thus should_replace can only be False.")
  174. self.input_tensor = input_tensor
  175. if isinstance(input_tensor, Tensor):
  176. NewTensor_.__init__(self, input_tensor)
  177. else:
  178. raise TypeError(f"Expect input_tensor to be a Tensor, got : {input_tensor}")
  179. class NewParameter(NewParameter_):
  180. r"""
  181. New Parameter to be used in the target.
  182. """
  183. def __init__(self, para_name, default_tensor, requires_grad=False, layerwise_parallel=False, should_replace=False):
  184. r"""
  185. Args:
  186. para_name(str): name for the new Parameter
  187. default_tensor(:class:`mindspore.common.tensor.Tensor`): default value for the new Parameter
  188. requires_grad(bool): True if the parameter requires gradient. Default: True
  189. layerwise_parallel(bool): switch for layerwise parallel mode. Default: False
  190. should_replace(bool): gen new parameter once and replace after if set to be true; otherwise build a new
  191. parameter everytime a pass target got built. Default: False
  192. Raises:
  193. TypeError: raise type error for invalid argument.
  194. """
  195. self.para_name = para_name
  196. self.default_tensor = default_tensor
  197. self.requires_grad = requires_grad
  198. self.layerwise_parallel = layerwise_parallel
  199. self.should_replace = should_replace
  200. if isinstance(para_name, str) and isinstance(default_tensor, Tensor) and isinstance(requires_grad, bool) and\
  201. isinstance(layerwise_parallel, bool) and isinstance(should_replace, bool):
  202. NewParameter_.__init__(self, self.para_name, self.default_tensor, self.requires_grad,
  203. self.layerwise_parallel, self.should_replace)
  204. else:
  205. raise TypeError(f"Expect para_name(str), default_tensor(Tensor), requires_grad(bool), \
  206. layerwise_parallel(bool) should_replace(bool), got : {para_name}, {default_tensor}, \
  207. {requires_grad}, {layerwise_parallel}, {should_replace}")