You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

utils_const.py 15 kB

5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520
  1. # Copyright 2020-2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """internal graph-compatible utility functions"""
  16. import math
  17. from itertools import zip_longest, accumulate
  18. from collections import deque
  19. import operator
  20. import mindspore.context as context
  21. from ..ops import functional as F
  22. from ..ops.primitive import constexpr
  23. from ..common import dtype as mstype
  24. from ..common import Tensor
  25. from .._c_expression import Tensor as Tensor_
  26. from .._c_expression import typing
  27. from .dtypes import promotion_rule, dtype_tuple, all_types, dtype_map, rule_for_trigonometric
  28. @constexpr
  29. def _check_shape(shape):
  30. """check the shape param to match the numpy style"""
  31. if not isinstance(shape, (int, tuple, list, typing.Tuple, typing.List)):
  32. raise TypeError(f"only int, tuple and list are allowed for shape, but got {type(shape)}")
  33. if isinstance(shape, int):
  34. shape = (shape,)
  35. if isinstance(shape, (list, typing.List)):
  36. shape = tuple(shape)
  37. for s in shape:
  38. if not isinstance(s, int):
  39. raise TypeError("each entry in shape should be int.")
  40. if s < 0:
  41. raise ValueError("each entry in shape should no less than 0.")
  42. return shape
  43. @constexpr
  44. def _check_dtype(dtype):
  45. """check the input dtype and make conversions"""
  46. # convert the string dtype to mstype.dtype
  47. if isinstance(dtype, str):
  48. dtype = dtype.lower()
  49. dtype = dtype_map[dtype]
  50. elif isinstance(dtype, type):
  51. if dtype is int:
  52. dtype = mstype.int32
  53. elif dtype is float:
  54. dtype = mstype.float32
  55. else:
  56. dtype = mstype.pytype_to_dtype(dtype)
  57. if dtype not in dtype_tuple:
  58. raise TypeError(f"only {all_types} are allowed for dtype, but got {type(dtype)}")
  59. return dtype
  60. @constexpr
  61. def _is_shape_empty(shp):
  62. """Check whether shape contains zero"""
  63. if isinstance(shp, int):
  64. return shp == 0
  65. return F.shape_mul(shp) == 0
  66. @constexpr
  67. def _check_start_normalize(start, ndim):
  68. """check and normalize start argument for rollaxis."""
  69. if start < -ndim or start > ndim:
  70. raise ValueError(f"For rollaxis, start {start} is out of bounds. Ranging from {-ndim} to {ndim} is allowed.")
  71. if start < 0:
  72. start = start + ndim
  73. return start
  74. @constexpr
  75. def _check_axes_range(axes, ndim):
  76. """
  77. Check axes type and normalize the negative axes.
  78. Args:
  79. axes: Axes of the tensor.
  80. ndim (int): The number of dimensions of the tensor.
  81. Return:
  82. Axes (Union[int, tuple(int)]). If input is integer, return integer, else tuple.
  83. Raises:
  84. TypeError: If the axes are not integer, tuple(int) or list(int).
  85. ValueError: If duplicate axes exists or some axis is out of bounds.
  86. """
  87. _check_axis_type(axes, True, True, True)
  88. if isinstance(axes, (list, tuple)):
  89. _check_element_int(axes)
  90. axes = _canonicalize_axis(axes, ndim)
  91. return axes
  92. @constexpr
  93. def _get_device():
  94. """Get the current device (`GPU`, `CPU`, `Ascend`)"""
  95. return context.get_context('device_target')
  96. @constexpr
  97. def _infer_out_shape(*shapes):
  98. """
  99. Returns shape of output after broadcasting. Raises ValueError if shapes cannot be broadcast.
  100. """
  101. shape_out = deque()
  102. reversed_shapes = map(reversed, shapes)
  103. for items in zip_longest(*reversed_shapes, fillvalue=1):
  104. max_size = 0 if 0 in items else max(items)
  105. if any(item not in (1, max_size) for item in items):
  106. raise ValueError(f'operands could not be broadcast together with shapes {*shapes,}')
  107. shape_out.appendleft(max_size)
  108. return tuple(shape_out)
  109. @constexpr
  110. def _can_broadcast(*shapes):
  111. """
  112. Returns Ture if shapes can broadcast, False if they cannot.
  113. """
  114. try:
  115. _infer_out_shape(*shapes)
  116. except ValueError:
  117. return False
  118. finally:
  119. pass
  120. return True
  121. @constexpr
  122. def _check_axis_in_range(axis, ndim):
  123. """Checks axes are with the bounds of ndim"""
  124. if not isinstance(axis, int):
  125. raise TypeError(f'axes should be integers, not {type(axis)}')
  126. if not -ndim <= axis < ndim:
  127. raise ValueError(f'axis {axis} is out of bounds for array of dimension {ndim}')
  128. return axis % ndim
  129. @constexpr
  130. def _check_axis_valid(axes, ndim):
  131. """
  132. Checks axes are valid given ndim, and returns axes that can be passed
  133. to the built-in operator (non-negative, int or tuple)
  134. """
  135. if axes is None:
  136. axes = F.make_range(ndim)
  137. return axes
  138. if isinstance(axes, (tuple, list)):
  139. axes = tuple(map(lambda x: _check_axis_in_range(x, ndim), axes))
  140. if any(axes.count(el) > 1 for el in axes):
  141. raise ValueError('duplicate value in "axis"')
  142. return axes
  143. return (_check_axis_in_range(axes, ndim),)
  144. @constexpr
  145. def _check_shape_aligned(shape1, shape2):
  146. """Checks shape1 and shape2 are valid shapes to perform inner product"""
  147. if shape1[-1] != shape2[-1]:
  148. raise ValueError(f'shapes {shape1} {shape2} not aligned: {shape1[-1]} (dim 0) != {shape2[-1]} (dim 0)')
  149. @constexpr
  150. def _tile_size(shape, out_shape, ndim):
  151. """Returns tile_size such that shape*tile_size = out_shape"""
  152. size = [1]*ndim
  153. for idx, (i, j) in enumerate(zip(shape, out_shape)):
  154. if i != j:
  155. size[idx] = j
  156. return tuple(size)
  157. @constexpr
  158. def _raise_type_error(info, param=None):
  159. """
  160. Raise TypeError in both graph/pynative mode
  161. Args:
  162. info(str): info string to display
  163. param(python obj): any object that can be recognized by graph mode. If is
  164. not None, then param's type information will be extracted and displayed.
  165. Default is None.
  166. """
  167. if param is None:
  168. raise TypeError(info)
  169. raise TypeError(info + f"{type(param)}")
  170. @constexpr
  171. def _raise_value_error(info, param=None):
  172. """
  173. Raise TypeError in both graph/pynative mode
  174. Args:
  175. info(str): info string to display
  176. param(python obj): any object that can be recognized by graph mode. If is
  177. not None, then param's value information will be extracted and displayed.
  178. Default is None.
  179. """
  180. if param is None:
  181. raise ValueError(info)
  182. raise ValueError(info + f"{param}")
  183. @constexpr
  184. def _raise_runtime_error(info, param=None):
  185. """
  186. Raise RuntimeError in both graph/pynative mode
  187. Args:
  188. info(str): info string to display
  189. param(python obj): any object that can be recognized by graph mode. If is
  190. not None, then param's value information will be extracted and displayed.
  191. Default is None.
  192. """
  193. if param is None:
  194. raise RuntimeError(info)
  195. raise RuntimeError(info + f"{param}")
  196. @constexpr
  197. def _raise_unimplemented_error(info, param=None):
  198. """
  199. Raise NotImplementedError in both graph/pynative mode
  200. Args:
  201. info(str): info string to display
  202. param(python obj): any object that can be recognized by graph mode. If is
  203. not None, then param's value information will be extracted and displayed.
  204. Default is None.
  205. """
  206. if param is None:
  207. raise NotImplementedError(info)
  208. raise NotImplementedError(info + f"{param}")
  209. @constexpr
  210. def _empty(dtype, shape):
  211. """Returns an uninitialized array with dtype and shape."""
  212. return Tensor_(dtype, shape)
  213. @constexpr
  214. def _promote(dtype1, dtype2):
  215. if dtype1 == dtype2:
  216. return dtype1
  217. if (dtype1, dtype2) in promotion_rule:
  218. return promotion_rule[dtype1, dtype2]
  219. return promotion_rule[dtype2, dtype1]
  220. @constexpr
  221. def _promote_for_trigonometric(dtype):
  222. return rule_for_trigonometric[dtype]
  223. @constexpr
  224. def _max(*args):
  225. """Returns the maximum value."""
  226. return max(*args)
  227. @constexpr
  228. def _min(*args):
  229. """"Returns the minimum value."""
  230. return min(*args)
  231. @constexpr
  232. def _abs(arg):
  233. """Returns the absolute value."""
  234. return abs(arg)
  235. @constexpr
  236. def _check_same_type(dtype1, dtype2):
  237. return dtype1 == dtype2
  238. @constexpr
  239. def _check_is_float(dtype):
  240. """Returns whether dtype is float16 or float32."""
  241. return dtype in (mstype.float16, mstype.float32)
  242. @constexpr
  243. def _check_is_int(dtype):
  244. return isinstance(dtype, typing.Int)
  245. @constexpr
  246. def _check_axis_type(axis, type_int=True, type_tuple=True, type_list=True):
  247. """Check axis argument type."""
  248. if type_int and isinstance(axis, int):
  249. return True
  250. if (type_tuple and isinstance(axis, tuple)) or (type_list and isinstance(axis, list)):
  251. for ax in axis:
  252. if not isinstance(ax, int):
  253. raise TypeError(f"Each axis should be integer, but got {type(ax)} in {axis}.")
  254. return True
  255. type_str = ""
  256. if type_int:
  257. type_str += "int, "
  258. if type_tuple:
  259. type_str += "tuple, "
  260. if type_list:
  261. type_str += "list, "
  262. raise TypeError(f"Axis should be {type_str}but got {type(axis)}.")
  263. @constexpr
  264. def _canonicalize_axis(axis, ndim):
  265. """
  266. Check axes are within the number of dimensions of tensor x and normalize the negative axes.
  267. Args:
  268. axis (Union[int, tuple(int), list(int)]): Axes of the tensor.
  269. ndim (int): The number of dimensions of the tensor.
  270. Return:
  271. Axis (Union[int, tuple(int)]). If input is integer, return integer, else tuple.
  272. """
  273. if isinstance(axis, int):
  274. axis = [axis]
  275. for ax in axis:
  276. _check_axis_in_range(ax, ndim)
  277. def canonicalizer(ax):
  278. return ax + ndim if ax < 0 else ax
  279. axis = tuple([canonicalizer(axis) for axis in axis])
  280. if all(axis.count(el) <= 1 for el in axis):
  281. return tuple(sorted(axis)) if len(axis) > 1 else axis[0]
  282. raise ValueError(f"duplicate axes in {axis}.")
  283. @constexpr
  284. def _broadcast_tuples(tup1, tup2):
  285. """
  286. Broadcast two 1D tuples to the same length, if inputs are ints, convert to
  287. tuples first.
  288. """
  289. tup1 = (tup1,) if isinstance(tup1, int) else tup1
  290. tup2 = (tup2,) if isinstance(tup2, int) else tup2
  291. if not isinstance(tup1, (tuple, list)) or not isinstance(tup2, (tuple, list)):
  292. raise TypeError("input shift and axis must be tuple or list or int.")
  293. if len(tup1) == len(tup2):
  294. return tup1, tup2
  295. if len(tup1) == 1:
  296. tup1 *= len(tup2)
  297. elif len(tup2) == 1:
  298. tup2 *= len(tup1)
  299. else:
  300. raise ValueError("shape mismatch: objects cannot be broadcast to a single shape")
  301. return tup1, tup2
  302. @constexpr
  303. def _expanded_shape(ndim, axis_size, axis):
  304. """
  305. Returns a shape with size = 1 for all dimensions
  306. except at axis.
  307. """
  308. return tuple([axis_size if i == axis else 1 for i in range(ndim)])
  309. @constexpr
  310. def _add_unit_axes(shape, ndim, append=False):
  311. """
  312. Prepends shape with 1s so that it has the number of dimensions ndim.
  313. If append is set to True, returns shape appended with 1s instead.
  314. """
  315. if isinstance(shape, int):
  316. shape = (shape,)
  317. ndim_diff = ndim - len(shape)
  318. if ndim_diff > 0:
  319. if append:
  320. shape = [i for i in shape] + [1]*ndim_diff
  321. else:
  322. shape = [1]*ndim_diff + [i for i in shape]
  323. return tuple(shape)
  324. @constexpr
  325. def _check_element_int(lst):
  326. """
  327. Check whether each element in `lst` is an integer.
  328. """
  329. for item in lst:
  330. if not isinstance(item, int):
  331. raise TypeError(f"Each element in {lst} should be integer, but got {type(item)}.")
  332. return True
  333. @constexpr
  334. def _type_convert(force, obj):
  335. """
  336. Convert type of `obj` to `force`.
  337. """
  338. return force(obj)
  339. @constexpr
  340. def _list_comprehensions(obj, item=None, return_tuple=False, make_none=False):
  341. """
  342. Generates a new list/tuple by list comprehension.
  343. Args:
  344. obj (Union[int, list, tuple]):
  345. If integer, it will be the length of the returned tuple/list.
  346. item: The value to be filled. Default: None.
  347. If None, the values in the new list/tuple are the same as obj
  348. or range(obj) when obj is integer.
  349. return_tuple(bool): If true, returns tuple, else returns list.
  350. Returns:
  351. List or tuple.
  352. """
  353. res = []
  354. lst = obj
  355. if isinstance(obj, int):
  356. lst = range(obj)
  357. if make_none:
  358. res = [None for _ in lst]
  359. elif item is None:
  360. res = [i for i in lst]
  361. else:
  362. res = [item for i in lst]
  363. if return_tuple:
  364. return tuple(res)
  365. return res
  366. @constexpr
  367. def _tuple_setitem(tup, idx, value):
  368. """
  369. Returns a tuple with specified `idx` set to `value`.
  370. """
  371. tup = list(tup)
  372. tup[idx] = value
  373. return tuple(tup)
  374. @constexpr
  375. def _iota(dtype, num, increasing=True):
  376. """Creates a 1-D tensor with value: [0,1,...num-1] and dtype."""
  377. # Change to P.Linspace when the kernel is implemented on CPU.
  378. if num <= 0:
  379. raise ValueError("zero shape Tensor is not currently supported.")
  380. if increasing:
  381. return Tensor(list(range(int(num))), dtype)
  382. return Tensor(list(range(int(num)-1, -1, -1)), dtype)
  383. @constexpr
  384. def _ceil(number):
  385. """Ceils the number in graph mode."""
  386. return math.ceil(number)
  387. @constexpr
  388. def _seq_prod(seq1, seq2):
  389. """Returns the element-wise product of seq1 and seq2."""
  390. return tuple(map(lambda x, y: x*y, seq1, seq2))
  391. @constexpr
  392. def _make_tensor(val, dtype):
  393. """Returns the tensor with value `val` and dtype `dtype`."""
  394. return Tensor(val, dtype)
  395. @constexpr
  396. def _tuple_slice(tup, start, end):
  397. """get sliced tuple from start and end."""
  398. return tup[start:end]
  399. @constexpr
  400. def _isscalar(x):
  401. """Returns True if x is a scalar type"""
  402. return isinstance(x, (typing.Number, typing.Int, typing.UInt, typing.Float,
  403. typing.Bool, typing.String))
  404. @constexpr
  405. def _cumprod(x):
  406. return tuple(accumulate(x, operator.mul))
  407. @constexpr
  408. def _in(x, y):
  409. return x in y
  410. @constexpr
  411. def _callable_const(x):
  412. """Returns true if x is a function in graph mode."""
  413. return isinstance(x, typing.Function)
  414. @constexpr
  415. def _check_is_inf(x, negative=False):
  416. if not negative:
  417. return x == float('inf')
  418. return x == float('-inf')