You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

standard_method.py 16 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609
  1. # This is the Python adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
  2. #
  3. # Copyright 2020-2021 Huawei Technologies Co., Ltd
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. # ============================================================================
  17. """standard_method"""
  18. from dataclasses import dataclass
  19. from mindspore import Tensor, Parameter
  20. from mindspore import dtype as mstype
  21. from ..._checkparam import Validator as validator
  22. from ...ops import functional as F
  23. from ...ops import operations as P
  24. from ...ops.composite import tail, core, MultitypeFuncGraph, env_get, hyper_add, \
  25. zeros_like, ones_like
  26. from ...ops.composite.base import _append
  27. from ...ops.primitive import constexpr
  28. __all__ = ['MultitypeFuncGraph', 'env_get', 'hyper_add', 'zeros_like', 'ones_like']
  29. shape_ = P.Shape()
  30. dtype_ = P.DType()
  31. abs_ = P.Abs()
  32. ndim_ = P.Rank()
  33. size_ = P.Size()
  34. itemsize_map = {mstype.bool_: 1, mstype.int8: 1, mstype.uint8: 1,
  35. mstype.float16: 2, mstype.int16: 2, mstype.uint16: 2,
  36. mstype.float32: 4, mstype.int32: 4, mstype.uint32: 4,
  37. mstype.float64: 8, mstype.int64: 8, mstype.uint64: 8}
  38. def mean(x, axis=(), keep_dims=False):
  39. """
  40. Reduces a dimension of a tensor by averaging all elements in the dimension.
  41. Args:
  42. axis (Union[None, int, tuple(int)]): Dimensions of reduction,
  43. when axis is None or empty tuple, reduce all dimensions.
  44. Default: (), reduce all dimensions.
  45. keep_dims (bool): Whether to keep the reduced dimensions.
  46. Default : False, don't keep these reduced dimensions.
  47. Returns:
  48. Tensor, has the same data type as x.
  49. """
  50. if axis is None:
  51. axis = ()
  52. reduce_mean = P.ReduceMean(keep_dims)
  53. return reduce_mean(x, axis)
  54. def all_(x, axis=(), keep_dims=False):
  55. """
  56. Check all array elements along a given axis evaluate to True.
  57. Args:
  58. x (Tensor): A Tensor to be reduced.
  59. axis (Union[None, int, tuple(int)): Dimensions of reduction.
  60. keep_dims (bool): Whether to keep the reduced dimensions.
  61. Returns:
  62. Tensor, has the same data type as x.
  63. """
  64. if axis is None:
  65. axis = ()
  66. reduce_all = P.ReduceAll(keep_dims)
  67. return reduce_all(x, axis)
  68. def any_(x, axis=(), keep_dims=False):
  69. """
  70. Check any array element along a given axis evaluate to True.
  71. Args:
  72. x (Tensor): A Tensor to be reduced.
  73. axis (Union[None, int, tuple(int)): Dimensions of reduction.
  74. keep_dims (bool): Whether to keep the reduced dimensions.
  75. Returns:
  76. Tensor, has the same data type as x.
  77. """
  78. if axis is None:
  79. axis = ()
  80. reduce_any = P.ReduceAny(keep_dims)
  81. return reduce_any(x, axis)
  82. def itemsize_(x):
  83. """
  84. Return length of one tensor element in bytes.
  85. Args:
  86. x (Tensor): Input tensor.
  87. Returns:
  88. itemsize(int).
  89. """
  90. return get_itemsize(x.dtype)
  91. def nbytes_(x):
  92. """
  93. Return total number of bytes taken by the tensor.
  94. Args:
  95. x (Tensor): Input tensor.
  96. Returns:
  97. nbytes(int).
  98. """
  99. return itemsize_(x) * F.shape_mul(shape_(x))
  100. def strides_(x):
  101. """
  102. Return the tuple of bytes to step in each dimension when traversing a tensor.
  103. Args:
  104. x (Tensor): Input tensor.
  105. Returns:
  106. strides (tuple[int]).
  107. """
  108. strides = ()
  109. ndim = P.Rank()(x)
  110. tensor_shape = shape_(x)
  111. for i in F.make_range(0, ndim):
  112. stride = itemsize_(x)
  113. for j in F.make_range(i + 1, ndim):
  114. stride *= tensor_shape[j]
  115. strides += (stride,)
  116. return strides
  117. def astype(x, dtype, copy=True):
  118. """Implementation of `astype`."""
  119. dtype = check_astype_dtype_const(dtype)
  120. if not copy and dtype == x.dtype:
  121. return x
  122. return F.cast(x, dtype)
  123. def transpose(x, *axis):
  124. """Implementation of `transpose`."""
  125. ndim = F.rank(x)
  126. perm = check_transpose_axis_const(axis, ndim)
  127. return F.transpose(x, perm)
  128. # `tensor.T` is used as a property in graph mode
  129. T_ = transpose
  130. def reshape(x, *shape):
  131. """Implementation of `reshape`."""
  132. new_shape = check_reshape_shp_const(shape)
  133. return F.reshape(x, new_shape)
  134. def ravel(x):
  135. """Implementation of `ravel`."""
  136. return reshape(x, (-1,))
  137. def flatten(x, order='C'):
  138. """
  139. Returns a copy of the array collapsed into one dimension.
  140. Args:
  141. order (str, optional): Can choose between `C` and `F`. `C` means to
  142. flatten in row-major (C-style) order. ‘F’ means to flatten in column-major
  143. (Fortran- style) order. Only `C` and `F` are supported.
  144. Returns:
  145. Tensor, has the same data type as x.
  146. """
  147. order = check_flatten_order_const(order)
  148. if order == 'C':
  149. return F.reshape(x, (-1,))
  150. perm = F.make_range(0, F.rank(x))
  151. new_order = F.tuple_reversed(perm)
  152. return F.reshape(F.transpose(x, new_order), (-1,))
  153. def swapaxes(x, axis1, axis2):
  154. """
  155. Interchanges two axes of a tensor.
  156. Args:
  157. axis1 (int): First axis.
  158. axis2 (int): Second axis.
  159. Returns:
  160. Transposed tensor, has the same data type as the original tensor x.
  161. """
  162. axis1, axis2 = check_swapaxes_axis_const((axis1, axis2), x.ndim)
  163. if axis1 == axis2:
  164. return x
  165. if axis1 > axis2:
  166. axis1, axis2 = axis2, axis1
  167. perm = F.make_range(0, x.ndim)
  168. new_perm = None
  169. if axis2 + 1 < x.ndim:
  170. new_perm = perm[0:axis1] + perm[axis2:axis2 + 1] + \
  171. perm[axis1 + 1:axis2] + perm[axis1:axis1 + 1] + perm[axis2 + 1:]
  172. else:
  173. new_perm = perm[0:axis1] + perm[axis2:axis2 + 1] + \
  174. perm[axis1 + 1:axis2] + perm[axis1:axis1 + 1]
  175. return F.transpose(x, new_perm)
  176. def squeeze(x, axis=None):
  177. """
  178. Removes single-dimensional entries from the shape of an tensor.
  179. Args:
  180. axis: Union[None, int, list(int), tuple(list)]. Default is None.
  181. Returns:
  182. Tensor, with all or a subset of the dimensions of length 1 removed.
  183. """
  184. shape = F.shape(x)
  185. if axis is None:
  186. return F.squeeze(x)
  187. # yield squeezed shape based on the axes
  188. new_shape = prepare_shape_for_squeeze_const(shape, axis)
  189. return F.reshape(x, new_shape)
  190. def getitem(data, item):
  191. """Implementation of `getitem`."""
  192. return data.__getitem__(item)
  193. def setitem(data, item, value):
  194. """Implementation of `setitem`."""
  195. return data.__setitem__(item, value)
  196. def ms_iter(xs):
  197. """Implementation of `iter`."""
  198. return xs.__ms_iter__()
  199. def ms_next(it):
  200. """Implementation of `next`."""
  201. return it.__ms_next__()
  202. def hasnext(it):
  203. """Implementation of `hasnext`."""
  204. return it.__ms_hasnext__()
  205. def ms_len(data):
  206. """Implementation of `len`."""
  207. return data.__len__()
  208. def floor(x):
  209. """Implementation of `floor`."""
  210. return x.__floor__()
  211. def trunc(x):
  212. """Implementation of `trunc`."""
  213. return x.__trunc__()
  214. def uadd(x):
  215. """Implementation of `uadd`."""
  216. return x.__pos__()
  217. def usub(x):
  218. """Implementation of `usub`."""
  219. return x.__neg__()
  220. def scalar_truediv(x, y):
  221. """Implementation of `scalar_truediv`."""
  222. return x.__truediv__(y)
  223. def scalar_floordiv(x, y):
  224. """Implementation of `scalar_floordiv`."""
  225. return x.__floordiv__(y)
  226. def bool_(x):
  227. """Implementation of `bool`."""
  228. return x.__bool__()
  229. def enumerate_(x, start=0):
  230. """Enumerate list or tuple or tensor."""
  231. x_type = F.typeof(x)
  232. ret = ()
  233. op_name = "enumerate"
  234. if check_is_tuple_or_list_or_tensor(x_type, op_name, "first input") and check_is_const_int(start, op_name, "start"):
  235. if check_is_tensor(x_type):
  236. for i in range(x.shape[0]):
  237. ret += ((start + i, x[i]),)
  238. else:
  239. ret = zip(range(start, start + len(x)), x)
  240. return ret
  241. def expand_tensor_as(x, y):
  242. """Expand tensor"""
  243. broadcast_to = P.BroadcastTo(shape_(y))
  244. return broadcast_to(x)
  245. def view(x, *shape):
  246. """Reshape tensor, if shape is -1, reshape tensor into one dimension"""
  247. shape = check_view_shape(shape)
  248. return F.reshape(x, shape)
  249. def isinstance_(x, base_type):
  250. """Determine whether x is an instance of base_type."""
  251. x_type = F.typeof(x)
  252. return check_type_same(x_type, base_type)
  253. def while_cond(x):
  254. """For while condition, if the condition is a tensor, the loop will not be unrolled"""
  255. if F.issubclass_(F.typeof(x), F.typeof(mstype.tensor)):
  256. is_cond = check_is_tensor_bool_cond(F.shape(x))
  257. if is_cond:
  258. return F.cast(x, mstype.bool_)
  259. return x
  260. @constexpr
  261. def check_type_same(x_type, base_type):
  262. """Check x_type is same as base_type."""
  263. pytype_to_mstype = {
  264. bool: mstype.Bool,
  265. int: mstype.Int,
  266. float: mstype.Float,
  267. str: mstype.String,
  268. list: mstype.List,
  269. tuple: mstype.Tuple,
  270. dict: mstype.Dict,
  271. Tensor: mstype.tensor_type,
  272. Parameter: mstype.ref_type
  273. }
  274. has_int = False
  275. has_tensor = False
  276. def to_target_type(origin_type):
  277. try:
  278. if isinstance(origin_type, type):
  279. ret_type = pytype_to_mstype[origin_type]
  280. if ret_type == mstype.Int:
  281. nonlocal has_int
  282. has_int = True
  283. if ret_type == mstype.tensor_type:
  284. nonlocal has_tensor
  285. has_tensor = True
  286. return (ret_type,)
  287. if isinstance(origin_type, tuple):
  288. return tuple(to_target_type(i) for i in origin_type)
  289. raise TypeError(f"The second arg of 'isinstance' must be a type or a tuple of types, "
  290. f"but got a {type(origin_type).__name__}")
  291. except KeyError:
  292. raise TypeError(f"The second arg of 'isinstance' should be bool, int, float, str, list, tuple, "
  293. f"Tensor, Parameter, or a tuple containing only these types, but got {origin_type}")
  294. target_type = to_target_type(base_type)
  295. if (isinstance(x_type, mstype.Bool) and has_int) or (isinstance(x_type, mstype.ref_type) and has_tensor):
  296. return True
  297. return isinstance(x_type, target_type)
  298. @constexpr
  299. def get_itemsize(x_type):
  300. """get itemsize from tensor's dtype."""
  301. return itemsize_map[x_type]
  302. @constexpr
  303. def check_is_tensor(x):
  304. """check whether x is tensor."""
  305. if isinstance(x, mstype.tensor_type):
  306. return True
  307. return False
  308. @constexpr
  309. def check_is_tuple_or_list_or_tensor(x, op_name, arg_name):
  310. """check whether x is list or tuple or tensor."""
  311. if isinstance(x, (mstype.List, mstype.Tuple, mstype.tensor_type)):
  312. return True
  313. raise TypeError(f"For '{op_name}', the '{arg_name}' should be tuple or list or tensor, but got {x}.")
  314. @constexpr
  315. def check_is_const_int(x, op_name, arg_name):
  316. """check whether x is const int."""
  317. if x is None:
  318. raise TypeError(f"For '{op_name}', the '{arg_name}' should be a const int number, but got not const.")
  319. if not isinstance(x, int):
  320. raise TypeError(f"For '{op_name}', the '{arg_name}' should be a const int number, but got {x}.")
  321. return True
  322. @constexpr
  323. def check_is_tensor_bool_cond(shp):
  324. """check if tensor is a bool condition"""
  325. if shp in ((), (1,)):
  326. return True
  327. raise ValueError("The truth value of an array with several elements is ambiguous.")
  328. @constexpr
  329. def const_tensor_to_bool(x):
  330. """convert bool tensor to bool condition"""
  331. if x is None:
  332. raise ValueError("Only constant tensor bool can be converted to bool")
  333. x = x.asnumpy()
  334. if x.shape == ():
  335. return bool(x)
  336. if x.shape == (1,):
  337. return bool(x[0])
  338. raise ValueError("The truth value of an array with several elements is ambiguous.")
  339. @constexpr
  340. def check_view_shape(x):
  341. """Check view function input shape"""
  342. if not x:
  343. raise ValueError("The shape variable should not be empty")
  344. if isinstance(x[0], tuple):
  345. if len(x) != 1:
  346. raise ValueError(f"Only one tuple is needed, but got {x}")
  347. x = x[0]
  348. return x
  349. # convert normal param_check functions to constexpr functions
  350. check_astype_dtype_const = constexpr(validator.check_astype_dtype)
  351. check_transpose_axis_const = constexpr(validator.check_transpose_axis)
  352. check_reshape_shp_const = constexpr(validator.check_reshape_shp)
  353. check_flatten_order_const = constexpr(validator.check_flatten_order)
  354. check_swapaxes_axis_const = constexpr(validator.check_swapaxes_axis)
  355. prepare_shape_for_squeeze_const = constexpr(validator.prepare_shape_for_squeeze)
  356. def tensor_bool(x):
  357. """tensor as condition, if is constant, return immediate bool value"""
  358. is_cond = check_is_tensor_bool_cond(F.shape(x))
  359. if is_cond and F.isconstant(x):
  360. return const_tensor_to_bool(x)
  361. return F.cast(x, mstype.bool_)
  362. def and_(x, y):
  363. """Implementation of `and` (`&`)."""
  364. return x.__and__(y)
  365. def or_(x, y):
  366. """Implementation of `or` (`|`)."""
  367. return x.__or__(y)
  368. def matmul(x, y):
  369. """Implementation of `matmul` (`@`)."""
  370. return x.__matmul__(y)
  371. def float_bool(x):
  372. """Implementation of `float_bool`."""
  373. return x != 0.0
  374. def int_bool(x):
  375. """Implementation of `int_bool`."""
  376. return x != 0
  377. def str_bool(x):
  378. """Implementation of `str_bool`."""
  379. if x == "":
  380. return False
  381. return True
  382. def list_bool(x):
  383. """Implementation of `tuple_bool`."""
  384. return len(x) != 0
  385. def tuple_bool(x):
  386. """Implementation of `tuple_bool`."""
  387. return len(x) != 0
  388. def dict_bool(x):
  389. """Implementation of `dict_bool`."""
  390. return len(x) != 0
  391. def none_bool(x):
  392. """Implementation of `none_bool`."""
  393. return False
  394. def func_bool(x):
  395. """Implementation of `func_bool`."""
  396. return True
  397. def float_floordiv(x, y):
  398. """Implementation of `float_floordiv`."""
  399. return floor(x / y)
  400. #############
  401. # Iteration #
  402. #############
  403. @dataclass(frozen=True)
  404. class SequenceIterator:
  405. """
  406. SequenceIterator is a util dataclass for iterating sequence object.
  407. Iterator to use for sequences like List, Array.
  408. """
  409. idx: int
  410. seq: list
  411. @core(ignore_values=True)
  412. def __ms_hasnext__(self):
  413. """Whether the index is past the length of the sequence."""
  414. return self.idx < ms_len(self.seq)
  415. @core(ignore_values=True)
  416. def __ms_next__(self):
  417. """Return the next element and a new iterator."""
  418. return self.seq[self.idx], SequenceIterator(self.idx + 1, self.seq)
  419. def list_iter(xs):
  420. """Iterator for List."""
  421. return SequenceIterator(0, xs)
  422. def array_iter(xs):
  423. """Iterator for Array."""
  424. return SequenceIterator(0, xs)
  425. def tuple_next(xs):
  426. """Next tuple."""
  427. return xs[0], tail(xs)
  428. def tuple_hasnext(xs):
  429. """Whether the tuple is empty or not."""
  430. return len(xs) > 0
  431. def list_next(xs):
  432. """Next list."""
  433. return xs[0], tail(xs)
  434. def list_hasnext(xs):
  435. """Whether the list is empty or not."""
  436. return len(xs) > 0
  437. def list_append(self_, item):
  438. return _append(self_, item)
  439. #################
  440. # Array methods #
  441. #################
  442. def to_array(x):
  443. """Implementation of `to_array`."""
  444. return x.__ms_to_array__()