You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

primitive.py 19 kB

5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """primitive"""
  16. import inspect
  17. import copy
  18. from mindspore.common.api import _wrap_func
  19. from mindspore import context
  20. from .._c_expression import Primitive_, real_run_op, prim_type
  21. from .._checkparam import Validator
  22. from . import signature as sig
  23. class Primitive(Primitive_):
  24. """
  25. Primitive is the base class of primitives in python.
  26. Args:
  27. name (str): Name for the current Primitive.
  28. Examples:
  29. >>> add = Primitive('add')
  30. >>>
  31. >>> # or work with prim_attr_register:
  32. >>> # init a Primitive class with attr1 and attr2
  33. >>> class Add(Primitive):
  34. ... @prim_attr_register
  35. ... def __init__(self, attr1, attr2):
  36. ... # check attr1 and attr2 or do some initializations
  37. >>> # init a Primitive obj with attr1=1 and attr2=2
  38. >>> add = Add(attr1=1, attr2=2)
  39. """
  40. _repr_ignore_list = ['input_names', 'output_names']
  41. def __init__(self, name):
  42. self.name = name
  43. self.attrs = {}
  44. self.init_attrs = {"name": name}
  45. self._update_parameter = False
  46. Primitive_.__init__(self, name, self)
  47. if hasattr(self.__class__, '__mindspore_signature__'):
  48. out = self._fill_signature(self.__class__.__mindspore_signature__)
  49. self.set_signatures(out)
  50. def _fill_signature(self, signatures):
  51. """fills signature."""
  52. signatures_new = []
  53. for signature in signatures:
  54. if isinstance(signature, sig.Signature):
  55. signatures_new.append(signature)
  56. elif isinstance(signature, sig.sig_dtype):
  57. signatures_new.append(sig.make_sig(dtype=signature))
  58. else:
  59. if len(signature) < 3:
  60. raise ValueError(f"[Internal Error]Signature for one parameter len must > 3, but {signature}")
  61. signatures_new.append(sig.make_sig(*signature))
  62. return tuple(signatures_new)
  63. def _clone(self):
  64. """
  65. Deeply clones the primitive object.
  66. Calls the __init__() method with the same arguments. This method is called in parser if the
  67. flag self.__setattr_flag__ is True.
  68. """
  69. cloned = copy.deepcopy(self)
  70. init_params = inspect.getfullargspec(cloned.__init__.decorated_func).args[1:]
  71. init_args = {}
  72. for name in init_params:
  73. value = self.attrs[name]
  74. init_args[name] = value
  75. # __init__ should be called to construct cpp object.
  76. cloned.__init__(**init_args)
  77. for name in self.attrs:
  78. value = self.attrs[name]
  79. cloned.add_prim_attr(name, value)
  80. if hasattr(self, 'instance_name'):
  81. cloned.set_prim_instance_name(self.instance_name)
  82. return cloned
  83. def add_prim_attr(self, name, value):
  84. """
  85. Adds primitive attribute.
  86. Args:
  87. name (str): Attribute Name.
  88. value (Any): Attribute value.
  89. """
  90. self.__dict__[name] = value
  91. self.attrs[name] = value
  92. self.add_attr(name, value)
  93. return self
  94. def del_prim_attr(self, name):
  95. """
  96. Del primitive attribute.
  97. Args:
  98. name (str): Attribute Name.
  99. """
  100. if name in self.__dict__ and name in self.attrs:
  101. del self.__dict__[name]
  102. del self.attrs[name]
  103. self.del_attr(name)
  104. return self
  105. def set_stage(self, stage):
  106. """
  107. Add stage id to primitive attribute.
  108. Note:
  109. It is valid only in semi auto parallel.
  110. In other parallel modes, please set it to be 0.
  111. Args:
  112. stage (int): The stage id for the current operation
  113. """
  114. self.add_prim_attr("stage", stage)
  115. return self
  116. def shard(self, strategy):
  117. """
  118. Add strategies to primitive attribute.
  119. Note:
  120. It is valid only in semi auto parallel or auto parallel mode.
  121. In other parallel modes, strategies set here will be ignored.
  122. Args:
  123. strategy (tuple): Strategy describes the distributed parallel mode of the current primitive.
  124. """
  125. self.add_prim_attr("strategy", strategy)
  126. return self
  127. def set_prim_instance_name(self, instance_name):
  128. """
  129. Set instance name to primitive operator.
  130. Note:
  131. It will be called by default when user defines primitive operator.
  132. Args:
  133. instance_name (str): Instance name of primitive operator set by user.
  134. """
  135. self.set_instance_name(instance_name)
  136. self.instance_name = instance_name
  137. return self
  138. def __getattr__(self, item):
  139. if item == 'infer_dynamic_shape':
  140. return None
  141. if item in super().get_attr_dict():
  142. return super().get_attr_dict()[item]
  143. if item in self.attrs:
  144. return self.attrs[item]
  145. raise AttributeError(item)
  146. def check_elim(self, *args):
  147. """
  148. Check if certain inputs should go to the backend. Subclass in need should override this method.
  149. Args:
  150. args(Primitive args): Same as arguments of current Primitive.
  151. Returns:
  152. A tuple consisting of two elements. The first element indicates whether we should filter out current
  153. arguments; the seconde element is the output if we need to filter out the arguments.
  154. """
  155. return (False, None)
  156. def __call__(self, *args):
  157. should_elim, output = self.check_elim(*args)
  158. if should_elim:
  159. return output
  160. return _run_op(self, self.name, args)
  161. def __getstate__(self):
  162. return self.__dict__
  163. def __setstate__(self, d):
  164. self.__dict__.update(d)
  165. def __deepcopy__(self, memo):
  166. return type(self)(**self.init_attrs)
  167. def __repr__(self):
  168. attr = ', '.join([f'{k}={self.attrs[k]}' for k in self.attrs if not k in Primitive._repr_ignore_list])
  169. info_str = f'Prim[{self.name}]'
  170. if attr:
  171. info_str += f'<{attr}>'
  172. return info_str
  173. def init_prim_io_names(self, inputs, outputs):
  174. """
  175. Initializes the name of inputs and outputs of Tensor or attributes.
  176. Args:
  177. inputs (list[str]): list of inputs names.
  178. outputs (list[str]): list of outputs names.
  179. """
  180. # for checking para names with kernel implementation
  181. self.add_prim_attr("input_names", inputs)
  182. # for checking output number with kernel implementation
  183. self.add_prim_attr("output_names", outputs)
  184. @property
  185. def update_parameter(self):
  186. """ Whether the primitive will update the value of parameter."""
  187. return self._update_parameter
  188. def recompute(self, mode=True):
  189. """
  190. Set the primitive recomputed. If a primitive set recomputed feeds into some backward nodes
  191. for computing gradient, rather than storing the intermediate activation computed in forward
  192. pass, we will recompute it in backward pass.
  193. Note:
  194. - If the computation involves something like randomization or global variable, the equivalence
  195. is not guaranteed currently.
  196. Args:
  197. mode (bool): Specifies whether the primitive is recomputed. Default: True.
  198. """
  199. if context.get_context("mode") == context.PYNATIVE_MODE:
  200. raise TypeError("Recompute is not supported in pynative mode currently.")
  201. Validator.check_bool(mode)
  202. self.add_prim_attr("recompute", mode)
  203. return self
  204. class PrimitiveWithCheck(Primitive):
  205. """
  206. PrimitiveWithCheck is the base class of primitives in python defines functions for checking operator input arguments
  207. but used the infer method registered in c++ source codes.
  208. There are three methods can be override to define the check logic of the primitive: __check__(), check_shape(),
  209. check_dtype(). If __check__() is defined in primitive, the __check__() has highest priority to be called.
  210. If __check__() is not defined, check_shape() and check_dtype() can be defined to describe the check logic of
  211. the shape and type. Method infer_value() can also be defined (such as PrimitiveWithInfer) for constant propagation.
  212. Args:
  213. name (str): Name of the current Primitive.
  214. Examples:
  215. >>> # init a Primitive class with check
  216. >>> class Flatten(PrimitiveWithCheck):
  217. >>> @prim_attr_register
  218. >>> def __init__(self):
  219. >>> pass
  220. >>> def check_shape(self, input_x):
  221. >>> validator.check_int(len(input_x), 1, Rel.GE, 'input_x rank', self.name)
  222. >>>
  223. >>> def check_dtype(self, input_x):
  224. >>> validator.check_subclass("input_x", input_x, mstype.tensor, self.name)
  225. >>>
  226. >>> # init a Primitive obj
  227. >>> add = Flatten()
  228. """
  229. def __init__(self, name):
  230. Primitive.__init__(self, name)
  231. self.set_prim_type(prim_type.py_infer_check)
  232. def _clone(self):
  233. """
  234. Deeply clones the primitive object.
  235. Calls the __init__() method with the same arguments. This method is called in parser if the
  236. flag self.__setattr_flag__ is True.
  237. """
  238. cloned_prim = Primitive._clone(self)
  239. return cloned_prim
  240. def check_shape(self, *args):
  241. """
  242. Check shapes of input args.
  243. Note:
  244. The shape of scalar is an empty tuple.
  245. Args:
  246. args (tuple(int)): shapes of input tensors.
  247. Return:
  248. None.
  249. """
  250. return None
  251. def check_dtype(self, *args):
  252. """
  253. Check data types of input args.
  254. Args:
  255. args (:class:`mindspore.dtype`): data type of inputs.
  256. Return:
  257. None.
  258. """
  259. return None
  260. def __check__(self, *args):
  261. """Check shape, type, and value at the same time by using dictionary as arguments."""
  262. tracks = ['dtype', 'shape']
  263. for track in tracks:
  264. fn = getattr(self, 'check_' + track)
  265. fn(*(x[track] for x in args))
  266. class PrimitiveWithInfer(Primitive):
  267. """
  268. PrimitiveWithInfer is the base class of primitives in python and defines functions for tracking inference in python.
  269. There are four method can be override to define the infer logic of the primitive: __infer__(), infer_shape(),
  270. infer_dtype(), and infer_value(). If __infer__() is defined in primitive, the __infer__() has highest priority
  271. to be called. If __infer__() is not defined, infer_shape() and infer_dtype() can be defined to describe the infer
  272. logic of the shape and type. The infer_value() is used for constant propagation.
  273. Args:
  274. name (str): Name of the current Primitive.
  275. Examples:
  276. >>> # init a Primitive class with infer
  277. >>> class Add(PrimitiveWithInfer):
  278. >>> @prim_attr_register
  279. >>> def __init__(self):
  280. >>> pass
  281. >>>
  282. >>> def infer_shape(self, x, y):
  283. >>> return x # output shape same as first input 'x'
  284. >>>
  285. >>> def infer_dtype(self, x, y):
  286. >>> return x # output type same as first input 'x'
  287. >>>
  288. >>> # init a Primitive obj
  289. >>> add = Add()
  290. """
  291. def __init__(self, name):
  292. Primitive.__init__(self, name)
  293. self.set_prim_type(prim_type.py_infer_shape)
  294. def _clone(self):
  295. """
  296. Deeply clones the primitive object.
  297. Calls the __init__() method with the same arguments. This method is called in parser if the
  298. flag self.__setattr_flag__ is True.
  299. """
  300. cloned_prim = Primitive._clone(self)
  301. return cloned_prim
  302. def infer_shape(self, *args):
  303. """
  304. Infer output shape based on input shape.
  305. Note:
  306. The shape of scalar is an empty tuple.
  307. Args:
  308. args (tuple(int)): shapes of input tensors.
  309. Return:
  310. `tuple(int)`, shapes of output tensors.
  311. """
  312. return None
  313. def infer_dtype(self, *args):
  314. """
  315. Infer output dtype based on input dtype.
  316. Args:
  317. args (:class:`mindspore.dtype`): data type of inputs.
  318. Return:
  319. :class:`mindspore.dtype`, data type of outputs.
  320. """
  321. return None
  322. def infer_value(self, *args):
  323. """
  324. Infer output value based on input value at compile time.
  325. Args:
  326. args (Any): value of inputs.
  327. Return:
  328. Value of outputs. Return `None`, the value can not be inferred at compile time in this case.
  329. """
  330. return None
  331. def __infer__(self, *args):
  332. """Infer shape, type, and value at the same time by using dictionary as arguments."""
  333. is_graph_mode = context.get_context("mode") == context.GRAPH_MODE
  334. fn_infer_dynamic_shape = getattr(self, 'infer_dynamic_shape', None)
  335. if is_graph_mode and fn_infer_dynamic_shape is not None:
  336. out = fn_infer_dynamic_shape(*args)
  337. tracks = ['dtype', 'value']
  338. for track in tracks:
  339. fn = getattr(self, 'infer_' + track)
  340. # fn may return None
  341. out[track] = fn(*(x[track] for x in args))
  342. return out
  343. tracks = ['dtype', 'shape', 'value']
  344. out = {}
  345. for track in tracks:
  346. fn = getattr(self, 'infer_' + track)
  347. # fn may return None
  348. out[track] = fn(*(x[track] for x in args))
  349. # in non-graph_mode, it is not necessary to infer min/max shape
  350. if not is_graph_mode:
  351. return out
  352. # output does not contain dynamic shape, no need to calculate min/max shape
  353. def has_dynamic_shape(shp):
  354. if isinstance(shp, int):
  355. return shp < 0
  356. if isinstance(shp, (list, tuple)):
  357. return any(has_dynamic_shape(e) for e in shp)
  358. return False
  359. if not has_dynamic_shape(out['shape']):
  360. return out
  361. # calculate min/max shape for output
  362. def get_specified_shape(elems, attr):
  363. has_specified_shape = False
  364. ret_vals = []
  365. for elem in elems:
  366. if attr in elem:
  367. has_specified_shape = True
  368. ret_vals.append(elem[attr])
  369. else:
  370. ret_vals.append(elem['shape'])
  371. return has_specified_shape, tuple(ret_vals)
  372. has_min_shape, min_shapes = get_specified_shape(args, 'min_shape')
  373. has_max_shape, max_shapes = get_specified_shape(args, 'max_shape')
  374. if not (has_min_shape or has_max_shape):
  375. return out
  376. if has_min_shape and has_max_shape:
  377. fn_infer_shape = getattr(self, 'infer_shape')
  378. out['min_shape'] = fn_infer_shape(*min_shapes)
  379. out['max_shape'] = fn_infer_shape(*max_shapes)
  380. return out
  381. raise ValueError('Input args has invalid dynamic shape, args info: {args}')
  382. def prim_attr_register(fn):
  383. """
  384. Primitive attributes register.
  385. Register the decorator of the built-in operator primitive '__init__'.
  386. The function will add all the parameters of '__init__' as operator attributes.
  387. Args:
  388. fn (function): __init__ function of primitive.
  389. Returns:
  390. function, original function.
  391. """
  392. def deco(self, *args, **kwargs):
  393. class_name = self.__class__.__name__
  394. if hasattr(self.__class__, "substitute_name"):
  395. class_name = self.__class__.substitute_name
  396. if isinstance(self, PrimitiveWithInfer):
  397. PrimitiveWithInfer.__init__(self, class_name)
  398. elif isinstance(self, PrimitiveWithCheck):
  399. PrimitiveWithCheck.__init__(self, class_name)
  400. else:
  401. Primitive.__init__(self, self.__class__.__name__)
  402. bound_args = inspect.signature(fn).bind(self, *args, **kwargs)
  403. bound_args.apply_defaults()
  404. arguments = bound_args.arguments
  405. del arguments['self']
  406. del self.init_attrs['name']
  407. for name in arguments:
  408. value = arguments[name]
  409. self.add_prim_attr(name, value)
  410. self.init_attrs[name] = value
  411. fn(self, *args, **kwargs)
  412. deco.decorated_func = fn
  413. return deco
  414. def constexpr(fn=None, get_instance=True, name=None):
  415. """
  416. Creates a PrimitiveWithInfer operator that can infer the value at compile time. We can use it to define a function
  417. to compute constant value using the constants in the constructor.
  418. Args:
  419. fn (function): A `fn` use as the infer_value of the output operator.
  420. get_instance (bool): If true, return the instance of operator, otherwise return the operator class.
  421. name (str): Defines the operator name. If `name` is None, use the function name as op name.
  422. Examples:
  423. >>> a = (1, 2)
  424. >>> # make an operator to calculate tuple len
  425. >>> @constexpr
  426. >>> def tuple_len(x):
  427. ... return len(x)
  428. >>> assert tuple_len(a) == 2
  429. ...
  430. >>> # make an operator class to calculate tuple len
  431. >>> @constexpr(get_instance=False, name="TupleLen")
  432. >>> def tuple_len_class(x):
  433. ... return len(x)
  434. >>> assert tuple_len_class()(a) == 2
  435. """
  436. def deco(fn):
  437. class CompileOp(PrimitiveWithInfer):
  438. def __init__(self):
  439. op_name = name if name else fn.__name__
  440. PrimitiveWithInfer.__init__(self, op_name)
  441. self.set_const_prim(True)
  442. def infer_value(self, *args):
  443. return fn(*args)
  444. if get_instance:
  445. return CompileOp()
  446. return CompileOp
  447. if fn is not None:
  448. return deco(fn)
  449. return deco
  450. @_wrap_func
  451. def _run_op(obj, op_name, args):
  452. """Single op execution function supported by ge in PyNative mode."""
  453. output = real_run_op(obj, op_name, args)
  454. return output