You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

primitive.py 12 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """primitive"""
  16. import inspect
  17. import copy
  18. from mindspore.common.api import _wrap_func
  19. from .._c_expression import Primitive_, real_run_op, prim_type
  20. from .._c_expression import signature_rw as sig_rw
  21. from .._c_expression import signature_kind as sig_kind
  22. from .._c_expression import signature_dtype as sig_dtype
  23. class Primitive(Primitive_):
  24. """
  25. Primitive is base class for primitives in python.
  26. Args:
  27. name (str): Name for current Primitive.
  28. Examples:
  29. >>> add = Primitive('add')
  30. >>>
  31. >>> # or work with prim_attr_register:
  32. >>> # init a Primitive class with attr1 and attr2
  33. >>> class Add(Primitive):
  34. >>> @prim_attr_register
  35. >>> def __init__(self, attr1, attr2):
  36. >>> # check attr1 and attr2 or do some initializations
  37. >>> # init a Primitive obj with attr1=1 and attr2=2
  38. >>> add = Add(attr1=1, attr2=2)
  39. """
  40. def __init__(self, name):
  41. self.name = name
  42. self.attrs = {}
  43. self.init_attrs = {}
  44. Primitive_.__init__(self, name, self)
  45. if hasattr(self.__class__, '__mindspore_signature__'):
  46. sig = self._fill_signature(self.__class__.__mindspore_signature__)
  47. self.set_signatures(sig)
  48. def _fill_signature(self, signatures):
  49. """fills signature."""
  50. signatures_new = []
  51. for signature in signatures:
  52. if isinstance(signature, sig_dtype):
  53. signatures_new.append(("argument", sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD,
  54. sig_kind.KIND_EMPTY_DEFAULT_VALUE, signature))
  55. else:
  56. if len(signature) < 3:
  57. raise ValueError(f"[Internal Error]Signature for one parameter len must > 3, but {signature}")
  58. if len(signature) == 3:
  59. signature += (sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T_EMPTY_DEFAULT_VALUE)
  60. if len(signature) == 4:
  61. signature += (sig_dtype.T_EMPTY_DEFAULT_VALUE,)
  62. signatures_new.append(signature)
  63. return tuple(signatures_new)
  64. def _clone(self):
  65. """
  66. Deeply clones the primitive object.
  67. Calls the __init__() method with the same arguments. This method is called in parser if the
  68. flag self.__setattr_flag__ is True.
  69. """
  70. cloned = copy.deepcopy(self)
  71. init_params = inspect.getfullargspec(cloned.__init__.decorated_func).args[1:]
  72. init_args = {}
  73. for name in init_params:
  74. value = self.attrs[name]
  75. init_args[name] = value
  76. # __init__ should be called to construct cpp object.
  77. cloned.__init__(**init_args)
  78. for name in self.attrs:
  79. value = self.attrs[name]
  80. cloned.add_prim_attr(name, value)
  81. if hasattr(self, 'instance_name'):
  82. cloned.set_prim_instance_name(self.instance_name)
  83. return cloned
  84. def add_prim_attr(self, name, value):
  85. """
  86. Adds primitive attribute.
  87. Args:
  88. name (str): Attribute Name.
  89. value (Any): Attribute value.
  90. """
  91. self.__dict__[name] = value
  92. self.attrs[name] = value
  93. self.add_attr(name, value)
  94. return self
  95. def set_strategy(self, strategy):
  96. """
  97. Adds strategy to primitive attribute.
  98. Note:
  99. Valid only in semi auto parallel or auto parallel mode.
  100. In other parallel modes, strategies will be ignored if set.
  101. Args:
  102. strategy (tuple): Strategy describes the distributed parallel mode of the current primitive.
  103. """
  104. self.add_prim_attr("strategy", strategy)
  105. return self
  106. def set_prim_instance_name(self, instance_name):
  107. """
  108. Sets instance name to primitive operator.
  109. Note:
  110. Will be called by default when user defines primitive operator.
  111. Args:
  112. instance_name (str): Instance name of primitive operator set by user.
  113. """
  114. self.set_instance_name(instance_name)
  115. self.instance_name = instance_name
  116. return self
  117. def __getattr__(self, item):
  118. if item in super().get_attr_dict():
  119. return super().get_attr_dict()[item]
  120. if item in self.attrs:
  121. return self.attrs[item]
  122. raise AttributeError(item)
  123. def __call__(self, *args):
  124. output = _run_op(self, self.name, args)
  125. return output
  126. def __getstate__(self):
  127. return self.__dict__
  128. def __setstate__(self, d):
  129. self.__dict__.update(d)
  130. def init_prim_io_names(self, inputs, outputs):
  131. """
  132. Initializes inputs and outpus name of Tensor or attributes.
  133. Args:
  134. inputs (list[str]): list of inputs names.
  135. outputs (list[str]): list of outputs names.
  136. """
  137. # for checking para names with kernel implementation
  138. self.add_prim_attr("input_names", inputs)
  139. # for checking output number with kernel implementation
  140. self.add_prim_attr("output_names", outputs)
  141. class PrimitiveWithInfer(Primitive):
  142. """
  143. PrimitiveWithInfer is base class for primitives in python and defines functions for infer of tracks in python.
  144. There are four method can be overide to define the infer logic of the primitive: __infer__(), infer_shape(),
  145. infer_dtype(), and infer_value(). If __infer__() is defined in primitive, the __infer__() has highest priority
  146. to be called. If __infer__() is not defined, infer_shape() and infer_dtype() can be defined to describle shape
  147. and type infer logic. The infer_value() is used for constant propogation.
  148. Args:
  149. name (str): Name for current Primitive.
  150. Examples:
  151. >>> # init a Primitive class with infer
  152. >>> class Add(PrimitiveWithInfer):
  153. >>> @prim_attr_register
  154. >>> def __init__(self):
  155. >>> pass
  156. >>>
  157. >>> def infer_shape(self, x, y):
  158. >>> return x # output shape same as first input 'x'
  159. >>>
  160. >>> def infer_dtype(self, x, y):
  161. >>> return x # output type same as first input 'x'
  162. >>>
  163. >>> # init a Primitive obj
  164. >>> add = Add()
  165. """
  166. def __init__(self, name):
  167. Primitive.__init__(self, name)
  168. self.set_prim_type(prim_type.py_infer_shape)
  169. def _clone(self):
  170. """
  171. Deeply clones the primitive object.
  172. Calls the __init__() method with the same arguments. This method is called in parser if the
  173. flag self.__setattr_flag__ is True.
  174. """
  175. cloned_prim = Primitive._clone(self)
  176. return cloned_prim
  177. def infer_shape(self, *args):
  178. """
  179. Infer output shape based on input shape.
  180. Args:
  181. inputs (tuple(int)): dimensions of input tensors.
  182. outputs (tuple(int)): dimensions of output tensors.
  183. Note:
  184. The shape of scalar is an empty tuple.
  185. """
  186. return None
  187. def infer_dtype(self, *args):
  188. """
  189. Infer output dtype based on input dtype.
  190. Args:
  191. inputs (mstype): data type of inputs.
  192. outputs (mstype): data type of outputs.
  193. """
  194. return None
  195. def infer_value(self, *args):
  196. """
  197. Infer output value based on input value at compile time.
  198. Args:
  199. inputs (any): value of inputs.
  200. outputs (any): value of outputs.
  201. """
  202. return None
  203. def __infer__(self, *args):
  204. """Infer shape, type, and value at the same time by using dictionary as arguments."""
  205. tracks = ['dtype', 'shape', 'value']
  206. out = {}
  207. for track in tracks:
  208. fn = getattr(self, 'infer_' + track)
  209. # fn may return None
  210. out[track] = fn(*(x[track] for x in args))
  211. return out
  212. def prim_attr_register(fn):
  213. """
  214. Primitive attributes register.
  215. Registering the decorator of the built-in operator primitive __init__
  216. function will add all the parameters of __init__ as operator attributes.
  217. Args:
  218. fn (function): __init__ function of primitive.
  219. Returns:
  220. function, original function.
  221. """
  222. def deco(self, *args, **kwargs):
  223. if isinstance(self, PrimitiveWithInfer):
  224. PrimitiveWithInfer.__init__(self, self.__class__.__name__)
  225. else:
  226. Primitive.__init__(self, self.__class__.__name__)
  227. bound_args = inspect.signature(fn).bind(self, *args, **kwargs)
  228. bound_args.apply_defaults()
  229. arguments = bound_args.arguments
  230. del arguments['self']
  231. for name in arguments:
  232. value = arguments[name]
  233. self.add_prim_attr(name, value)
  234. self.init_attrs[name] = value
  235. fn(self, *args, **kwargs)
  236. deco.decorated_func = fn
  237. return deco
  238. def constexpr(fn=None, get_instance=True, name=None):
  239. """
  240. Makes a PrimitiveWithInfer operator, which infer the value while compiling. We can define a function
  241. to compute between constant variable and used in constructß.
  242. Args:
  243. fn (function): A `fn` use as the infer_value of the output operator.
  244. get_instance (bool): If true, returns the instance of operator, else returns the operator class.
  245. name (str): Defines the operator name. If `name` is None, use the function name as op name.
  246. Examples:
  247. >>> a = (1, 2)
  248. >>> # make a operator to calculate tuple len
  249. >>> @constexpr
  250. >>> def tuple_len(x):
  251. >>> return len(x)
  252. >>> assert tuple_len(a) == 2
  253. >>>
  254. >>> # make a operator class to calculate tuple len
  255. >>> @constexpr(get_instance=False, name="TupleLen")
  256. >>> def tuple_len_class(x):
  257. >>> return len(x)
  258. >>> assert tuple_len_class()(a) == 2
  259. """
  260. def deco(fn):
  261. class CompileOp(PrimitiveWithInfer):
  262. def __init__(self):
  263. op_name = name if name else fn.__name__
  264. PrimitiveWithInfer.__init__(self, op_name)
  265. def infer_value(self, *args):
  266. return fn(*args)
  267. if get_instance:
  268. return CompileOp()
  269. return CompileOp
  270. if fn is not None:
  271. return deco(fn)
  272. return deco
  273. @_wrap_func
  274. def _run_op(obj, op_name, args):
  275. """Single op execution function supported by ge in PyNative mode."""
  276. op_mask = [0] * len(args)
  277. op_inputs = []
  278. for i, arg in enumerate(args):
  279. if hasattr(arg, '__parameter__'):
  280. op_inputs.append(arg.default_input)
  281. op_mask[i] = 1
  282. elif isinstance(arg, tuple):
  283. convert = lambda x: x.default_input if hasattr(x, '__parameter__') else x
  284. args_ = tuple(convert(x) for x in arg)
  285. op_inputs.append(args_)
  286. else:
  287. op_inputs.append(arg)
  288. output = real_run_op(obj, op_name, tuple(op_inputs), tuple(op_mask))
  289. if not output:
  290. raise RuntimeError("Pynative run op %s failed!" % op_name)
  291. if len(output) == 1:
  292. output = output[0]
  293. return output