You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

array_ops.py 62 kB

5 years ago
5 years ago
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846184718481849185018511852185318541855185618571858185918601861186218631864186518661867186818691870187118721873187418751876187718781879188018811882188318841885188618871888188918901891189218931894189518961897189818991900190119021903190419051906190719081909191019111912191319141915191619171918
  1. # Copyright 2020-2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """array operations, the function docs are adapted from Numpy API."""
  16. import operator
  17. from ..common import dtype as mstype
  18. from ..common import Tensor
  19. from ..ops import operations as P
  20. from ..ops import functional as F
  21. from ..ops import composite as C
  22. from ..ops.primitive import constexpr
  23. from ..nn import Cell
  24. from .utils import _convert_list_tensor_to_tuple_tensor, _expand, _broadcast_to_shape, \
  25. _check_input_tensor, _broadcast_to
  26. from .utils_const import _check_axes_range, _check_start_normalize, \
  27. _raise_type_error, _raise_value_error, _infer_out_shape, _empty, _promote, \
  28. _check_same_type, _check_axis_valid, _add_unit_axes, _broadcast_tuples, \
  29. _check_is_float, _check_axis_in_range, _check_axis_type, _canonicalize_axis, \
  30. _list_comprehensions, _check_element_int, _is_shape_empty, _type_convert, \
  31. _tuple_getitem, _expanded_shape
  32. # According to official numpy reference, the dimension of a numpy array must be less
  33. # than 32
  34. MAX_NUMPY_DIMS = 32
  35. @constexpr
  36. def _prepare_shape_for_expand_dims(shape, axes):
  37. """
  38. Creates the expanded new shape based on the shape and given axes
  39. Args:
  40. shape (tuple): the shape of the tensor
  41. axes Union(int, tuple(int), list(int)): the axes with dimensions expanded.
  42. Returns:
  43. new_shape(tuple): the shape with dimensions expanded.
  44. """
  45. new_shape = []
  46. shape_idx = 0
  47. new_shape_length = len(shape)
  48. # Convert to set
  49. if isinstance(axes, int):
  50. new_shape_length += 1
  51. if axes >= new_shape_length or axes < -new_shape_length:
  52. raise ValueError(f"axis {axes} is out of bounds for tensor of dimension {new_shape_length}")
  53. axes = {axes}
  54. elif isinstance(axes, (list, tuple)):
  55. new_shape_length += len(axes)
  56. for axis in axes:
  57. if axis >= new_shape_length or axis < -new_shape_length:
  58. raise ValueError(f"axis {axis} is out of bounds for tensor of dimension {new_shape_length}")
  59. axes = set(axes)
  60. else:
  61. raise TypeError(f"only int, tuple and list are allowed for axes, but got {type(axes)}")
  62. for new_shape_idx in range(new_shape_length):
  63. if new_shape_idx in axes or new_shape_idx - new_shape_length in axes:
  64. new_shape.append(1)
  65. else:
  66. new_shape.append(shape[shape_idx])
  67. shape_idx += 1
  68. return tuple(new_shape)
  69. def expand_dims(a, axis):
  70. """
  71. Expands the shape of a tensor.
  72. Inserts a new axis that will appear at the axis position in the expanded tensor shape.
  73. Args:
  74. a (Tensor): Input tensor array.
  75. axis (Union[int, list(int), tuple(int)]): Position in the expanded axes where
  76. the new axis is placed,
  77. Returns:
  78. View of `a` with the number of dimensions increased.
  79. Raises:
  80. TypeError: If input arguments have types not specified above.
  81. ValueError: If axis exceeds a.ndim.
  82. Supported Platforms:
  83. ``Ascend`` ``GPU`` ``CPU``
  84. Examples:
  85. >>> import mindspore.numpy as np
  86. >>> x = np.ones((2,2))
  87. >>> x = np.expand_dims(x,0)
  88. >>> print(x.shape)
  89. (1, 2, 2)
  90. """
  91. _check_input_tensor(a)
  92. shape = F.shape(a)
  93. # yield expanded shape based on the axes
  94. new_shape = _prepare_shape_for_expand_dims(shape, axis)
  95. return F.reshape(a, new_shape)
  96. def squeeze(a, axis=None):
  97. """
  98. Removes single-dimensional entries from the shape of an tensor.
  99. Args:
  100. a (Tensor): Input tensor array.
  101. axis (Union[None, int, list(int), tuple(list)]): Default is None.
  102. Returns:
  103. Tensor, with all or a subset of the dimensions of length :math:`1` removed.
  104. Raises:
  105. TypeError: If input arguments have types not specified above.
  106. ValueError: If specified axis has shape entry :math:`> 1`.
  107. Supported Platforms:
  108. ``Ascend`` ``GPU`` ``CPU``
  109. Examples:
  110. >>> import mindspore.numpy as np
  111. >>> x = np.ones((1,2,2,1))
  112. >>> x = np.squeeze(x)
  113. >>> print(x.shape)
  114. (2, 2)
  115. """
  116. _check_input_tensor(a)
  117. return a.squeeze(axis)
  118. def transpose(a, axes=None):
  119. """
  120. Reverses or permutes the axes of a tensor; returns the modified tensor.
  121. Args:
  122. a (Tensor): a tensor to be transposed
  123. axes (Union[None, tuple, list]): the axes order, if `axes` is `None`, transpose
  124. the entire tensor. Default is `None`.
  125. Returns:
  126. Tensor, the transposed tensor array.
  127. Raises:
  128. TypeError: If input arguments have types not specified above.
  129. ValueError: If the number of `axes` is not euqal to a.ndim.
  130. Supported Platforms:
  131. ``Ascend`` ``GPU`` ``CPU``
  132. Examples:
  133. >>> import mindspore.numpy as np
  134. >>> x = np.ones((1,2,3))
  135. >>> x = np.transpose(x)
  136. >>> print(x.shape)
  137. (3, 2, 1)
  138. """
  139. _check_input_tensor(a)
  140. return a.transpose(axes)
  141. def rollaxis(x, axis, start=0):
  142. """
  143. Rolls the specified axis backwards, until it lies in the given position.
  144. The positions of the other axes do not change relative to one another.
  145. Args:
  146. x (Tensor): A Tensor to be transposed.
  147. axis (int): The axis to be rolled.
  148. start (int): Default: 0.
  149. If :math:`start <= axis`, the axis is rolled back until it lies in this position (`start`).
  150. If :math:`start > axis`: the axis is rolled until it lies before this position (`start`).
  151. If :math:`start < 0`, the start will be normalized as shown in the table.
  152. (Please refer to the source code.)
  153. .. table
  154. +===========+=================+
  155. |start |Normalized start |
  156. +===========+=================+
  157. |-(x.ndim+1)| raise ValueError|
  158. +-----------+-----------------+
  159. |-x.ndim |0 |
  160. +-----------+-----------------+
  161. |... |... |
  162. +-----------+-----------------+
  163. |-1 |x.ndim-1 |
  164. +-----------+-----------------+
  165. |... |... |
  166. +-----------+-----------------+
  167. |x.ndim |x.ndim |
  168. +-----------+-----------------+
  169. |x.ndim+1 |raise ValueError |
  170. +===========+=================+
  171. ..
  172. Returns:
  173. Transposed Tensor. Has the same data type as the original tensor `x`.
  174. Supported Platforms:
  175. ``Ascend`` ``GPU`` ``CPU``
  176. Raises:
  177. TypeError: If `axis` or `start` is not integer, or `x` is not tensor.
  178. ValueError: If `axis` is not in the range of :math:`[-ndim, ndim-1]` or
  179. `start` is not in the range of :math:`[-ndim, ndim]`.
  180. Examples:
  181. >>> import mindspore.numpy as np
  182. >>> x = np.ones((2,3,4))
  183. >>> output = np.rollaxis(x, 0, 2)
  184. >>> print(output.shape)
  185. (3, 2, 4)
  186. """
  187. _check_input_tensor(x)
  188. if not isinstance(axis, int):
  189. _raise_type_error("integer argument expected, but got ", axis)
  190. if not isinstance(start, int):
  191. _raise_type_error("integer argument expected, but got ", start)
  192. shape = F.shape(x)
  193. ndim = F.tuple_len(shape)
  194. axis = _check_axes_range(axis, ndim)
  195. start = _check_start_normalize(start, ndim)
  196. if start - axis >= 0 and start - axis <= 1:
  197. return x
  198. perm = F.make_range(0, ndim)
  199. new_perm = None
  200. if start < axis:
  201. if axis + 1 < ndim:
  202. new_perm = perm[0:start] + perm[axis:axis+1] + \
  203. perm[start:axis] + perm[axis+1:]
  204. else:
  205. new_perm = perm[0:start] + perm[axis:axis+1] + perm[start:axis]
  206. if start > axis:
  207. if start < ndim:
  208. new_perm = perm[0:axis] + perm[axis+1:start] + \
  209. perm[axis:axis+1] + perm[start:]
  210. else:
  211. new_perm = perm[0:axis] + perm[axis+1:start] + \
  212. perm[axis:axis+1]
  213. return F.transpose(x, new_perm)
  214. def swapaxes(x, axis1, axis2):
  215. """
  216. Interchanges two axes of a tensor.
  217. Args:
  218. x (Tensor): A tensor to be transposed.
  219. axis1 (int): First axis.
  220. axis2 (int): Second axis.
  221. Returns:
  222. Transposed tensor, has the same data type as the original tensor `x`.
  223. Raises:
  224. TypeError: If `axis1` or `axis2` is not integer, or `x` is not tensor.
  225. ValueError: If `axis1` or `axis2` is not in the range of :math:`[-ndim, ndim-1]`.
  226. Supported Platforms:
  227. ``Ascend`` ``GPU`` ``CPU``
  228. Examples:
  229. >>> import mindspore.numpy as np
  230. >>> x = np.ones((2,3,4))
  231. >>> output = np.swapaxes(x, 0, 2)
  232. >>> print(output.shape)
  233. (4,3,2)
  234. """
  235. _check_input_tensor(x)
  236. return x.swapaxes(axis1, axis2)
  237. def reshape(x, new_shape):
  238. """
  239. Reshapes a tensor without changing its data.
  240. Args:
  241. x (Tensor): A tensor to be reshaped.
  242. new_shape (Union[int, list(int), tuple(int)]): The new shape should be
  243. compatible with the original shape. If the tuple has only one element,
  244. the result will be a 1-D tensor of that length. One shape dimension
  245. can be :math:`-1`. In this case, the value is inferred from the length of
  246. the tensor and remaining dimensions.
  247. Returns:
  248. Reshaped Tensor. Has the same data type as the original tensor `x`.
  249. Raises:
  250. TypeError: If new_shape is not integer, list or tuple, or `x` is not tensor.
  251. ValueError: If new_shape is not compatible with the original shape.
  252. Supported Platforms:
  253. ``Ascend`` ``GPU`` ``CPU``
  254. Examples:
  255. >>> import mindspore.numpy as np
  256. >>> x = np.asarray([[-0.1, 0.3, 3.6], [0.4, 0.5, -3.2]])
  257. >>> output = np.reshape(x, (3, 2))
  258. >>> print(output)
  259. [[-0.1 0.3]
  260. [ 3.6 0.4]
  261. [ 0.5 -3.2]]
  262. >>> output = np.reshape(x, (3, -1))
  263. >>> print(output)
  264. [[-0.1 0.3]
  265. [ 3.6 0.4]
  266. [ 0.5 -3.2]]
  267. >>> output = np.reshape(x, (6, ))
  268. >>> print(output)
  269. [-0.1 0.3 3.6 0.4 0.5 -3.2]
  270. """
  271. _check_input_tensor(x)
  272. return x.reshape(new_shape)
  273. def ravel(x):
  274. """
  275. Returns a contiguous flattened tensor.
  276. A 1-D tensor, containing the elements of the input, is returned.
  277. Args:
  278. x (Tensor): A tensor to be flattened.
  279. Returns:
  280. Flattened tensor, has the same data type as the original tensor `x`.
  281. Raises:
  282. TypeError: If `x` is not tensor.
  283. Supported Platforms:
  284. ``Ascend`` ``GPU`` ``CPU``
  285. Examples:
  286. >>> import mindspore.numpy as np
  287. >>> x = np.ones((2,3,4))
  288. >>> output = np.ravel(x)
  289. >>> print(output.shape)
  290. (24,)
  291. """
  292. _check_input_tensor(x)
  293. return x.ravel()
  294. @constexpr
  295. def _move_axes_for_concatenate(arr_shape, axis):
  296. """
  297. Moves axis 0 to the desiganated position, while keeps other axes' relative
  298. positions unchanged, only used if a single tensor is concatenated.
  299. """
  300. original_axes = tuple(range(len(arr_shape)))
  301. new_axes = original_axes[1:axis+1] + (0,) + original_axes[axis+1:]
  302. new_shape = arr_shape[1:axis+1] + (arr_shape[0] * arr_shape[axis+1],) + \
  303. arr_shape[axis+2:]
  304. return new_axes, new_shape
  305. def _promote_type_for_concatenate(tuple_of_tensors):
  306. """
  307. Checks dtype for all tensors in the tuple. If dtypes are not the same, promote
  308. them to the `highest` dtype in the tuple, so that they are ready for the concat
  309. operator.
  310. Args:
  311. tuple_of_tensors(tuple(tensor)): A tuple of tensors
  312. Returns:
  313. tuple of tensors, with each tensor promoted to ths same dtype.
  314. """
  315. need_cast = False
  316. final_type = tuple_of_tensors[0].dtype
  317. for tensor in tuple_of_tensors:
  318. if not _check_same_type(final_type, tensor.dtype):
  319. need_cast = True
  320. final_type = _promote(final_type, tensor.dtype)
  321. if not need_cast:
  322. return tuple_of_tensors
  323. tuple_of_casted_tensors = ()
  324. for tensor in tuple_of_tensors:
  325. tuple_of_casted_tensors += (tensor.astype(final_type, copy=False),)
  326. return tuple_of_casted_tensors
  327. def concatenate(arrays, axis=0):
  328. """
  329. Joins a sequence of tensors along an existing axis.
  330. Args:
  331. arrays (Union[Tensor, tuple(Tensor), list(Tensor)]): a tensor or a list
  332. of tensors to be concatenated.
  333. axis (Union[None, int], optional): The axis along which the tensors will be joined,
  334. if `axis` is :class:`None`, tensors are flattened before use. Default is 0.
  335. Returns:
  336. A tensor concatenated from a tensor or a list of tensors.
  337. Raises:
  338. TypeError: If input arguments have types not specified above.
  339. ValueError: If specified `axis` < 0, or exceeds tensor.ndim.
  340. Supported Platforms:
  341. ``Ascend`` ``GPU`` ``CPU``
  342. Examples:
  343. >>> import mindspore.numpy as np
  344. >>> x1 = np.ones((1,2,3))
  345. >>> x2 = np.ones((1,2,1))
  346. >>> x = np.concatenate((x1, x2), axis=-1)
  347. >>> print(x.shape)
  348. (1, 2, 4)
  349. """
  350. if isinstance(arrays, Tensor):
  351. # if only one tensor is provided, it is treated as a tuple along the
  352. # first dimension. For example, a tensor of shape (3,4,5) will be treated
  353. # as: tuple(tensor_1(4,5), tensor_2(4,5), tensor_3(4,5))
  354. if axis is None or axis >= MAX_NUMPY_DIMS:
  355. return ravel(arrays)
  356. arr_shape = F.shape(arrays)
  357. _check_axes_range((axis,), len(arr_shape))
  358. # move axis 0 to the disiganated position, while keep other axes' relative
  359. # positions unchanged
  360. new_axes, new_shape = _move_axes_for_concatenate(arr_shape, axis)
  361. arrays = transpose(arrays, new_axes)
  362. arrays = reshape(arrays, new_shape)
  363. return arrays
  364. flattened_arrays = ()
  365. if axis is None or axis >= MAX_NUMPY_DIMS:
  366. for arr in arrays:
  367. flattened_arrays += (ravel(arr),)
  368. axis = -1
  369. flattened_arrays = _promote_type_for_concatenate(flattened_arrays)
  370. return P.Concat(axis)(flattened_arrays)
  371. # convert a list of tensor to a tuple of tensor
  372. arrays = _convert_list_tensor_to_tuple_tensor(arrays)
  373. arr_shape = F.shape(arrays[0])
  374. _check_axes_range((axis,), len(arr_shape))
  375. # if only one tensor in the tuple/list, return the tensor itself
  376. if len(arrays) == 1:
  377. return arrays[0]
  378. arrays = _promote_type_for_concatenate(arrays)
  379. return P.Concat(axis)(arrays)
  380. def append(arr, values, axis=None):
  381. """
  382. Appends values to the end of a tensor.
  383. Args:
  384. arr (Tensor): Values are appended to a copy of this tensor.
  385. values (Tensor): These values are appended to a copy of `arr`. It must be of
  386. the correct shape (the same shape as `arr`, excluding `axis`). If `axis` is
  387. not specified, `values` can be any shape and will be flattened before use.
  388. axis (None, int, optional): The `axis` along which values are appended. If `axis` is not
  389. given, both `arr` and `values` are flattened before use, default is :class:`None`.
  390. Returns:
  391. Tensor, a copy of tensor with values appended to axis.
  392. Raises:
  393. TypeError: If input arguments have types not specified above.
  394. ValueError: If specified axis exceeds `arr.ndim`.
  395. Supported Platforms:
  396. ``Ascend`` ``GPU`` ``CPU``
  397. Examples:
  398. >>> import mindspore.numpy as np
  399. >>> a = np.ones((2, 3))
  400. >>> b = np.ones((2, 1))
  401. >>> print(np.append(a, b, axis=1).shape)
  402. >>> (2, 4)
  403. """
  404. _check_input_tensor(arr)
  405. _check_input_tensor(values)
  406. if axis is None:
  407. arr = arr.ravel()
  408. values = values.ravel()
  409. else:
  410. _check_axis_in_range(axis, arr.ndim)
  411. if F.rank(arr) != F.rank(values):
  412. _raise_value_error("all tensors must have same number of dimensions")
  413. return concatenate((arr, values), axis)
  414. def column_stack(tup):
  415. """
  416. Stacks 1-D tensors as columns into a 2-D tensor. 2-D tensors are stacked as-is,
  417. like np.hstack.
  418. Args:
  419. tup (Union[Tensor, tuple, list]): A sequence of 1-D or 2-D tensors. All
  420. of them must have the same shape except the axis to be concatenated.
  421. Returns:
  422. 2-D Tensor, formed by stacking the given tensors.
  423. Supported Platforms:
  424. ``Ascend`` ``GPU`` ``CPU``
  425. Raises:
  426. TypeError: If `tup` is not Tensor, list or tuple.
  427. ValueError: If `tup` is empty.
  428. Examples:
  429. >>> import mindspore.numpy as mnp
  430. >>> import numpy as onp
  431. >>> from mindspore import Tensor
  432. >>> x1 = Tensor(onp.array([1, 2, 3]).astype('int32'))
  433. >>> x2 = Tensor(onp.array([4, 5, 6]).astype('int32'))
  434. >>> output = mnp.column_stack((x1, x2))
  435. >>> print(output)
  436. [[1, 4],
  437. [2, 5],
  438. [3, 6]]
  439. """
  440. if isinstance(tup, Tensor):
  441. return tup
  442. if not isinstance(tup, (list, tuple)):
  443. _raise_type_error("Tensor or, list or tuple of tensors are required, but got ", tup)
  444. trans_tup = ()
  445. for tensor in tup:
  446. if tensor.ndim < 1:
  447. tensor = F.expand_dims(tensor, 0)
  448. if tensor.ndim == 1:
  449. tensor = F.expand_dims(tensor, 1)
  450. trans_tup += (tensor,)
  451. if not trans_tup:
  452. _raise_value_error("Need at least one tensor to concatenate.")
  453. return P.Concat(1)(trans_tup)
  454. def vstack(tup):
  455. """
  456. Stacks tensors in sequence vertically.
  457. This is equivalent to concatenation along the first axis. 1-D tensors should firstly be reshaped to `(1, N)`,
  458. and then be concatenated along the first axis.
  459. Args:
  460. tup (Union[Tensor, tuple, list]): A sequence of 1-D or 2-D tensors. The tensors must have the same shape
  461. along all but the first axis. 1-D tensors must have the same shape.
  462. Returns:
  463. Stacked Tensor, formed by stacking the given tensors.
  464. Supported Platforms:
  465. ``Ascend`` ``GPU`` ``CPU``
  466. Raises:
  467. TypeError: If `tup` is not Tensor, list or tuple.
  468. ValueError: If `tup` is empty.
  469. Examples:
  470. >>> import mindspore.numpy as mnp
  471. >>> import numpy as onp
  472. >>> from mindspore import Tensor
  473. >>> x1 = Tensor(onp.array([1, 2, 3]).astype('int32'))
  474. >>> x2 = Tensor(onp.array([4, 5, 6]).astype('int32'))
  475. >>> output = mnp.vstack((x1, x2))
  476. >>> print(output)
  477. [[1, 2, 3],
  478. [4, 5, 6]]
  479. """
  480. if isinstance(tup, Tensor):
  481. return tup
  482. if not isinstance(tup, (list, tuple)):
  483. _raise_type_error("Tensor or, list or tuple of tensors are required, but got", tup)
  484. trans_tup = ()
  485. for tensor in tup:
  486. if tensor.ndim <= 1:
  487. tensor = _expand(tensor, 2, 0)
  488. trans_tup += (tensor,)
  489. if not trans_tup:
  490. _raise_value_error("Need at least one tensor to concatenate.")
  491. return P.Concat(0)(trans_tup)
  492. def hstack(tup):
  493. """
  494. Stacks tensors in sequence horizontally.
  495. This is equivalent to concatenation along the second axis, except for 1-D tensors
  496. where it concatenates along the first axis.
  497. Args:
  498. tup (Union[Tensor, tuple, list]): A sequence of 1-D or 2-D tensors. The
  499. tensors must have the same shape along all but the second axis, except
  500. 1-D tensors which can be any length.
  501. Returns:
  502. Stacked Tensor, formed by stacking the given tensors.
  503. Supported Platforms:
  504. ``Ascend`` ``GPU`` ``CPU``
  505. Raises:
  506. TypeError: If `tup` is not Tensor, list or tuple.
  507. ValueError: If `tup` is empty.
  508. Examples:
  509. >>> import mindspore.numpy as np
  510. >>> x1 = np.array([1, 2, 3]).astype('float32')
  511. >>> x2 = np.array([4, 5, 6]).astype('float32')
  512. >>> output = np.hstack((x1, x2))
  513. >>> print(output)
  514. [1. 2. 3. 4. 5. 6.]
  515. """
  516. if isinstance(tup, Tensor):
  517. return tup
  518. if not isinstance(tup, (list, tuple)):
  519. _raise_type_error("Tensor or, list or tuple of tensors are required, but got", tup)
  520. tuple_of_tensor = ()
  521. for tensor in tup:
  522. if tensor.ndim < 1:
  523. tensor = F.expand_dims(tensor, 0)
  524. tuple_of_tensor += (tensor,)
  525. if not tuple_of_tensor:
  526. _raise_value_error("Need at least one tensor to concatenate.")
  527. if tuple_of_tensor[0].ndim <= 1:
  528. return P.Concat(0)(tuple_of_tensor)
  529. return P.Concat(1)(tuple_of_tensor)
  530. def dstack(tup):
  531. """
  532. Stacks tensors in sequence depth wise (along the third axis).
  533. This is equivalent to concatenation along the third axis. 1-D tensors :math:`(N,)` should be
  534. reshaped to :math:`(1,N,1)`.
  535. 2-D tensors :math:`(M,N)` should be reshaped to :math:`(M,N,1)` before concatenation.
  536. Args:
  537. tup (Union[Tensor, tuple, list]): A sequence of tensors. The tensors must have the same shape along all but
  538. the third axis. 1-D or 2-D tensors must have the same shape.
  539. Returns:
  540. Stacked Tensor, formed by stacking the given tensors.
  541. Supported Platforms:
  542. ``Ascend`` ``GPU`` ``CPU``
  543. Raises:
  544. TypeError: If `tup` is not Tensor, list or tuple.
  545. ValueError: If `tup` is empty.
  546. Examples:
  547. >>> import mindspore.numpy as np
  548. >>> x1 = np.array([1, 2, 3]).astype('float32')
  549. >>> x2 = np.array([4, 5, 6]).astype('float32')
  550. >>> output = np.dstack((x1, x2))
  551. >>> print(output)
  552. [[[1. 4.]
  553. [2. 5.]
  554. [3. 6.]]]
  555. """
  556. if isinstance(tup, Tensor):
  557. return tup
  558. if not isinstance(tup, (list, tuple)):
  559. _raise_type_error("Tensor or list or tuple of tensors are required, but got", tup)
  560. trans_tup = ()
  561. for tensor in tup:
  562. if tensor.ndim <= 1:
  563. tensor = _expand(tensor, 2, 0)
  564. if tensor.ndim == 2:
  565. tensor = F.expand_dims(tensor, 2)
  566. trans_tup += (tensor,)
  567. if not trans_tup:
  568. _raise_value_error("Need at least one tensor to concatenate.")
  569. return P.Concat(2)(trans_tup)
  570. def where(condition, x=None, y=None):
  571. """
  572. Returns elements chosen from `x` or `y` depending on `condition`.
  573. Note:
  574. As nonzero is not supported, neither `x` or `y` can be None.
  575. Args:
  576. condition (Tensor): where True, yield `x`, otherwise yield `y`.
  577. x (Tensor)
  578. y (Tensor): Values from which to choose. `x`, `y` and `condition` need
  579. to be broadcastable to some shape.
  580. Returns:
  581. Tensor or scalar, with elements from `x` where `condition` is True, and
  582. elements from `y` elsewhere.
  583. Raises:
  584. ValueError: if operands cannot be broadcast.
  585. Supported Platforms:
  586. ``Ascend`` ``GPU`` ``CPU``
  587. Examples:
  588. >>> import mindspore.numpy as np
  589. >>> condition = np.full((1, 1, 2), [False, True])
  590. >>> x = np.full((1, 3, 2), 5)
  591. >>> y = np.full((2, 1, 1), 7)
  592. >>> output = np.where(condition, x, y)
  593. >>> print(output)
  594. [[[7, 5],
  595. [7, 5],
  596. [7, 5]],
  597. [[7, 5],
  598. [7, 5],
  599. [7, 5]]]
  600. """
  601. # type promotes input tensors
  602. dtype1 = F.dtype(x)
  603. dtype2 = F.dtype(y)
  604. dtype = _promote(dtype1, dtype2)
  605. if not _check_same_type(dtype1, dtype):
  606. x = F.cast(x, dtype)
  607. if not _check_same_type(dtype2, dtype):
  608. y = F.cast(y, dtype)
  609. is_bool = _check_same_type(dtype1, mstype.bool_) and _check_same_type(
  610. dtype2, mstype.bool_)
  611. if is_bool:
  612. # select does not support bool type for x or y
  613. x = F.cast(x, mstype.float32)
  614. y = F.cast(y, mstype.float32)
  615. # broadcasts input tensors
  616. shape_out = _infer_out_shape(F.shape(condition),
  617. F.shape(x), F.shape(y))
  618. if not _check_same_type(F.dtype(condition), mstype.float32):
  619. # tiling with bool is not supported on GPU
  620. condition = F.cast(condition, mstype.float32)
  621. condition = _broadcast_to_shape(condition, shape_out)
  622. x = _broadcast_to_shape(x, shape_out)
  623. y = _broadcast_to_shape(y, shape_out)
  624. if not _check_same_type(F.dtype(condition), mstype.bool_):
  625. condition = F.cast(condition, mstype.bool_)
  626. res = F.select(condition, x, y)
  627. if is_bool:
  628. res = F.cast(res, mstype.bool_)
  629. return res
  630. def _atleast_xd(ndim, arys):
  631. """Returns arys with at least ndim."""
  632. _check_input_tensor(*arys)
  633. res = []
  634. for arr in arys:
  635. arr = _expand(arr, ndim)
  636. res.append(arr)
  637. if len(res) == 1:
  638. return res[0]
  639. return res
  640. def atleast_1d(*arys):
  641. """
  642. Converts inputs to arrays with at least one dimension.
  643. Scalar inputs are converted to 1-dimensional arrays, whilst
  644. higher-dimensional inputs are preserved.
  645. Note:
  646. In graph mode, returns a tuple of tensor instead of a list of
  647. tensors.
  648. Args:
  649. *arys (Tensor): one or more input tensors.
  650. Returns:
  651. Tensor, or list of tensors, each with ``a.ndim >= 1``.
  652. Raises:
  653. TypeError: if the input is not a tensor.
  654. Supported Platforms:
  655. ``Ascend`` ``GPU`` ``CPU``
  656. Examples:
  657. >>> import mindspore.numpy as np
  658. >>> a = np.ones((2, 3))
  659. >>> b = np.ones(())
  660. >>> c = np.ones(5)
  661. >>> output = np.atleast_1d(a, b, c)
  662. >>> print(output)
  663. (Tensor(shape=[2, 3], dtype=Float32, value=
  664. [[1.00000000e+000, 1.00000000e+000, 1.00000000e+000],
  665. [1.00000000e+000, 1.00000000e+000, 1.00000000e+000]]),
  666. Tensor(shape=[1], dtype=Float32, value= [1.00000000e+000]),
  667. Tensor(shape=[5], dtype=Float32,
  668. value= [1.00000000e+000, 1.00000000e+000, 1.00000000e+000,
  669. 1.00000000e+000, 1.00000000e+000]))
  670. """
  671. return _atleast_xd(1, arys)
  672. def atleast_2d(*arys):
  673. """
  674. Views inputs as arrays with at least two dimensions.
  675. Note:
  676. In graph mode, returns a tuple of tensor instead of a list of
  677. tensors.
  678. Args:
  679. *arys (Tensor): one or more input tensors.
  680. Returns:
  681. Tensor, or list of tensors, each with ``a.ndim >= 2``.
  682. Raises:
  683. TypeError: if the input is not a tensor.
  684. Supported Platforms:
  685. ``Ascend`` ``GPU`` ``CPU``
  686. Examples:
  687. >>> import mindspore.numpy as np
  688. >>> a = np.ones((2, 3))
  689. >>> b = np.ones(())
  690. >>> c = np.ones(5)
  691. >>> output = np.atleast_2d(a, b, c)
  692. >>> print(output)
  693. (Tensor(shape=[2, 3], dtype=Float32, value=
  694. [[1.00000000e+000, 1.00000000e+000, 1.00000000e+000],
  695. [1.00000000e+000, 1.00000000e+000, 1.00000000e+000]]),
  696. Tensor(shape=[1, 1], dtype=Float32, value= [[1.00000000e+000]]),
  697. Tensor(shape=[1, 5], dtype=Float32,
  698. value= [[1.00000000e+000, 1.00000000e+000, 1.00000000e+000,
  699. 1.00000000e+000, 1.00000000e+000]]))
  700. """
  701. return _atleast_xd(2, arys)
  702. def atleast_3d(*arys):
  703. """
  704. Views inputs as arrays with at least three dimensions.
  705. Note:
  706. In graph mode, returns a tuple of tensor instead of a list of
  707. tensors.
  708. Args:
  709. *arys (Tensor): one or more input tensors.
  710. Returns:
  711. Tensor, or list of tensors, each with ``a.ndim >= 3``. For example,
  712. a 1-D array of shape `(N,)` becomes a view of shape `(1, N, 1)`, and
  713. a 2-D array of shape `(M, N)` becomes a view of shape `(M, N, 1)`.
  714. Raises:
  715. TypeError: if the input is not a tensor.
  716. Supported Platforms:
  717. ``Ascend`` ``GPU`` ``CPU``
  718. Examples:
  719. >>> import mindspore.numpy as np
  720. >>> a = np.ones((2, 3))
  721. >>> b = np.ones(())
  722. >>> c = np.ones(5)
  723. >>> output = np.atleast_3d(a, b, c)
  724. >>> print(output)
  725. (Tensor(shape=[2, 3, 1], dtype=Float32, value=
  726. [[[1.00000000e+000], [1.00000000e+000], [1.00000000e+000]],
  727. [[1.00000000e+000], [1.00000000e+000], [1.00000000e+000]]]),
  728. Tensor(shape=[1, 1, 1], dtype=Float32, value= [[[1.00000000e+000]]]),
  729. Tensor(shape=[1, 5, 1], dtype=Float32,
  730. value= [[[1.00000000e+000], [1.00000000e+000], [1.00000000e+000],
  731. [1.00000000e+000], [1.00000000e+000]]]))
  732. """
  733. res = []
  734. for arr in arys:
  735. ndim = F.rank(arr)
  736. if ndim == 0:
  737. arr = F.reshape(arr, (1, 1, 1))
  738. elif ndim == 1:
  739. arr = F.reshape(arr, (1, F.size(arr), 1))
  740. elif ndim == 2:
  741. arr = F.reshape(arr, F.shape(arr) + (1,))
  742. res.append(arr)
  743. if len(res) == 1:
  744. return res[0]
  745. return res
  746. def stack(arrays, axis=0):
  747. """
  748. Joins a sequence of arrays along a new axis.
  749. The `axis` parameter specifies the index of the new axis in the
  750. dimensions of the result. For example, if ``axis=0`` it will be the
  751. first dimension and if ``axis=-1`` it will be the last dimension.
  752. Note:
  753. Numpy argument out is not supported.
  754. Args:
  755. arrays (sequence of Tensor): Each array must have the same shape.
  756. axis (int, optional): The axis in the result array along which the
  757. input arrays are stacked.
  758. Returns:
  759. Tensor, The stacked array has one more dimension than the input
  760. arrays.
  761. Raises:
  762. ValueError: if input is not Tensor, tuple, or list.
  763. Supported Platforms:
  764. ``Ascend`` ``GPU`` ``CPU``
  765. Examples:
  766. >>> import mindspore.numpy as np
  767. >>> arrays = [np.ones((3, 4)) for _ in range(10)]
  768. >>> output = np.stack(arrays, axis=0)
  769. >>> print(output.shape)
  770. (10, 3, 4)
  771. >>> output = np.stack(arrays, axis=1)
  772. >>> print(output.shape)
  773. (3, 10, 4)
  774. >>> output = np.stack(arrays, axis=2)
  775. >>> print(output.shape)
  776. (3, 4, 10)
  777. """
  778. if isinstance(arrays, Tensor):
  779. shape = F.shape(arrays)
  780. ndim = F.rank(arrays)
  781. axis = axis % ndim
  782. axes = F.make_range(ndim)
  783. perm = axes[1:axis+1] + (0,) + axes[axis+1:]
  784. if _is_shape_empty(shape):
  785. return _empty(mstype.float32, shape[1:axis+1] + (shape[0],) + shape[axis+1:])
  786. return transpose(arrays, perm)
  787. if isinstance(arrays, (list, tuple)):
  788. shape = (len(arrays),) + F.shape(arrays[0])
  789. ndim = len(shape)
  790. axis = axis % ndim
  791. if _is_shape_empty(shape):
  792. return _empty(mstype.float32, shape[1:axis+1] + (shape[0],) + shape[axis+1:])
  793. seq = ()
  794. for arr in arrays:
  795. seq += (F.expand_dims(arr, axis),)
  796. return concatenate(seq, axis)
  797. return _raise_value_error('input arrays must be Tensor, tuple, or list')
  798. class UniqueNet(Cell):
  799. """The operation is wrapped inside a model. """
  800. def __init__(self):
  801. super(UniqueNet, self).__init__()
  802. self.unique = P.Unique()
  803. def construct(self, x):
  804. return self.unique(x)
  805. def unique(x, return_inverse=False):
  806. """
  807. Finds the unique elements of a tensor. The input tensor will be flattened first
  808. when it has more than one dimension.
  809. Note:
  810. Numpy arguments `axis`, `return_index` and `return_counts` are not supported.
  811. On CPU, this operator must be executed in graph mode.
  812. Args:
  813. x (Tensor): The input tensor to be processed.
  814. return_inverse (bool): If `True`, also return the indices of the unique tensor.
  815. Default: `False`.
  816. Returns:
  817. Tensor or tuple of Tensors.
  818. - If `return_inverse` is `False`, just return the unique tensor.
  819. - If `return_inverse` is `True`, return tuple of tensors.
  820. Supported Platforms:
  821. ``Ascend`` ``GPU`` ``CPU``
  822. Raises:
  823. TypeError: If `x` is not tensor.
  824. Examples:
  825. >>> import mindspore.numpy as np
  826. >>> from mindspore import context
  827. >>> context.set_context(mode=context.GRAPH_MODE)
  828. >>> input_x = np.asarray([1, 2, 2, 2, 3, 4, 5]).astype('int32')
  829. >>> output_x = np.unique(input_x)
  830. >>> print(output_x)
  831. [1, 2, 3, 4, 5]
  832. >>> output_x = np.unique(input_x, return_inverse=True)
  833. >>> print(output_x)
  834. (Tensor(shape=[5], dtype=Int32, value= [ 1, 2, 3, 4, 5]), Tensor(shape=[7], dtype=Int32,
  835. value= [0, 1, 1, 1, 2, 3, 4]))
  836. """
  837. _check_input_tensor(x)
  838. if F.tuple_len(F.shape(x)) > 1:
  839. x = ravel(x)
  840. uniq = UniqueNet()
  841. res = uniq(x)
  842. if not return_inverse:
  843. return res[0]
  844. return res
  845. def roll_along_axis(a, shift, axis):
  846. """
  847. Rolls a tensor along a given axis. This is a helper function of np.roll.
  848. Args:
  849. a (Tensor): Input tensor.
  850. shift (int): The number of places the tensor is shifted.
  851. axis (int): The designated axis for shifting.
  852. Returns:
  853. Shifted tensor.
  854. """
  855. _check_axis_in_range(axis, a.ndim)
  856. _check_element_int((shift, axis))
  857. if axis < 0:
  858. axis += a.ndim
  859. shift = -(shift % a.shape[axis])
  860. # if shift is 0, we do not need to roll at all
  861. if shift == 0:
  862. return a
  863. begin1 = ()
  864. begin2 = ()
  865. end1 = ()
  866. end2 = ()
  867. stride = _list_comprehensions(a.ndim, 1, True)
  868. for i in F.make_range(a.ndim):
  869. if i != axis:
  870. begin1 += (0,)
  871. end1 += (a.shape[i],)
  872. begin2 += (0,)
  873. end2 += (a.shape[i],)
  874. else:
  875. begin1 += (shift,)
  876. end1 += (a.shape[i],)
  877. begin2 += (0,)
  878. end2 += (shift,)
  879. return append(F.strided_slice(a, begin1, end1, stride),
  880. F.strided_slice(a, begin2, end2, stride), axis=axis)
  881. def roll(a, shift, axis=None):
  882. """
  883. Rolls a tensor along given axes.
  884. Elements that rolls beyond the last position are re-introduced at the first.
  885. Args:
  886. a (Tensor): Input tensor.
  887. shift (Union[int, tuple(int)]: The number of places by which elements are
  888. shifted. If a tuple, then axis must be a tuple of the same size, and
  889. each of the given axes is shifted by the corresponding number. If shift
  890. is an int while axis is a tuple of ints, then the same value is used
  891. for all given axes.
  892. axis (Union[int, tuple(int)], optional): Axis or axes along which elements
  893. are shifted. By default, the array is flattened before shifting, after
  894. which the original shape is restored.
  895. Returns:
  896. Tensor, with the same shape as a.
  897. Raises:
  898. TypeError: If input arguments have types not specified above.
  899. ValueError: If axis exceeds `a.ndim`, or `shift` and `axis` cannot broadcast.
  900. Examples:
  901. >>> import mindspore.numpy as np
  902. >>> a = np.reshape(np.arange(12), (3, 4))
  903. >>> print(np.roll(a, [2,-3], [0,-1]))
  904. [[ 7 4 5 6]
  905. [11 8 9 10]
  906. [ 3 0 1 2]]
  907. """
  908. _check_input_tensor(a)
  909. original_shape = a.shape
  910. original_dtype = a.dtype
  911. restore_shape = False
  912. # F.strided_slice only supports float on cpu, this will change once more supports
  913. # are added.
  914. if not _check_is_float(original_dtype):
  915. a = a.astype(mstype.float32)
  916. if axis is None:
  917. restore_shape = True
  918. axis = 0
  919. a = a.ravel()
  920. # Broadcast shift and axis to the same length
  921. shift, axis = _broadcast_tuples(shift, axis)
  922. for shift_each, axis_each in zip(shift, axis):
  923. a = roll_along_axis(a, shift_each, axis_each)
  924. if restore_shape:
  925. a = a.reshape(original_shape)
  926. if not _check_is_float(original_dtype):
  927. a = a.astype(original_dtype)
  928. return a
  929. @constexpr
  930. def _get_moved_perm(ndim, source, destination):
  931. """
  932. Helper function for moveaxis, returns permutation after moving axes
  933. from source to destination.
  934. """
  935. dest_sorted_idx = [i for i, _ in sorted(enumerate(destination),
  936. key=operator.itemgetter(1))]
  937. axes_orig = [i for i in range(ndim) if i not in source]
  938. k = 0
  939. m = 0
  940. perm = []
  941. for i in dest_sorted_idx:
  942. # inserts an axis that has been moved, denoted by n, and axes that remain
  943. # in their original position, indexed from k to k + n - m, into index m in
  944. # the list of permuted axes
  945. n = destination[i]
  946. j = k + n - m
  947. perm += axes_orig[k:j]
  948. perm.append(source[i])
  949. k += n - m
  950. m = n + 1
  951. perm += axes_orig[k:]
  952. return tuple(perm)
  953. @constexpr
  954. def _get_moved_shape(shape, perm):
  955. """
  956. Helper function for moveaxis, returns the permuated shape after
  957. applying perm.
  958. """
  959. return tuple([shape[i] for i in perm])
  960. def moveaxis(a, source, destination):
  961. """
  962. Moves axes of an array to new positions.
  963. Other axes remain in their original order.
  964. Args:
  965. a (Tensor): The array whose axes should be reordered.
  966. source (int or sequence of ints): Original positions of the
  967. axes to move. These must be unique.
  968. destination (int or sequence of ints): Destination positions
  969. for each of the original axes. These must also be unique.
  970. Returns:
  971. Tensor, array with moved axes.
  972. Raises:
  973. ValueError: if axes are out of the range of ``[-a.ndim, a.ndim)``, or
  974. if the axes contain duplicates.
  975. Supported Platforms:
  976. ``Ascend`` ``GPU`` ``CPU``
  977. Examples:
  978. >>> x = np.zeros((3, 4, 5))
  979. >>> output = np.moveaxis(x, 0, -1)
  980. >>> print(output.shape)
  981. (4, 5, 3)
  982. >>> output = np.moveaxis(x, -1, 0)
  983. >>> print(output.shape)
  984. (5, 3, 4)
  985. >>> output = np.moveaxis(x, [0, 1, 2], [-1, -2, -3])
  986. >>> print(output.shape)
  987. (5, 4, 3)
  988. """
  989. ndim = F.rank(a)
  990. source = _check_axis_valid(source, ndim)
  991. destination = _check_axis_valid(destination, ndim)
  992. perm = _get_moved_perm(ndim, source, destination)
  993. shape = F.shape(a)
  994. if _is_shape_empty(shape):
  995. return _empty(F.dtype(a), _get_moved_shape(shape, perm))
  996. return F.transpose(a, perm)
  997. @constexpr
  998. def _seq_prod(seq1, seq2):
  999. """Returns the element-wise product of seq1 and seq2."""
  1000. return tuple(map(lambda x, y: x*y, seq1, seq2))
  1001. def tile(a, reps):
  1002. """
  1003. Constructs an array by repeating `a` the number of times given by `reps`.
  1004. If `reps` has length `d`, the result will have dimension of ``max(d, a.ndim)``.
  1005. If ``a.ndim < d``, `a` is promoted to be d-dimensional by prepending new axes.
  1006. So a shape (3,) array is promoted to (1, 3) for 2-D replication, or
  1007. shape (1, 1, 3) for 3-D replication. If this is not the desired behavior,
  1008. promote `a` to d-dimensions manually before calling this function.
  1009. If ``a.ndim > d``, `reps` is promoted to ``a.ndim`` by pre-pending 1’s to it. Thus
  1010. for an `a` of shape (2, 3, 4, 5), a `reps` of (2, 2) is treated as (1, 1, 2, 2).
  1011. Args:
  1012. a (Tensor): The input array.
  1013. reps (int or sequence of ints): The number of repetitions of `a` along
  1014. each axis.
  1015. Returns:
  1016. Tensor, the tiled output array.
  1017. Raises:
  1018. TypeError: if the input is not a tensor.
  1019. Supported Platforms:
  1020. ``Ascend`` ``GPU`` ``CPU``
  1021. Examples:
  1022. >>> a = np.array([0, 1, 2])
  1023. >>> output = np.tile(a, 2)
  1024. >>> print(output)
  1025. [0 1 2 0 1 2]
  1026. >>> output = np.tile(a, (2, 2))
  1027. >>> print(output)
  1028. [[0 1 2 0 1 2]
  1029. [0 1 2 0 1 2]]
  1030. >>> output = np.tile(a, (2, 1, 2))
  1031. >>> print(output)
  1032. [[[0 1 2 0 1 2]]
  1033. [[0 1 2 0 1 2]]]
  1034. """
  1035. _check_input_tensor(a)
  1036. ndim = F.rank(a)
  1037. shape = F.shape(a)
  1038. reps = _add_unit_axes(reps, ndim)
  1039. if _is_shape_empty(shape) or _is_shape_empty(reps):
  1040. shape = _add_unit_axes(shape, len(reps))
  1041. return _empty(F.dtype(a), _seq_prod(shape, reps))
  1042. return F.tile(a, reps)
  1043. @constexpr
  1044. def _check_can_broadcast_to(shape, target_shape):
  1045. """Determines if shape can be broadcast to target_shape."""
  1046. ndim = len(shape)
  1047. ndim_target = len(target_shape)
  1048. if ndim > ndim_target:
  1049. return False
  1050. for i, j in zip(reversed(shape), reversed(target_shape)):
  1051. if i not in (1, j):
  1052. return False
  1053. return True
  1054. def broadcast_to(array, shape):
  1055. """
  1056. Broadcasts an array to a new shape.
  1057. Args:
  1058. array (Tensor): The array to broadcast.
  1059. shape (tuple): The shape of the desired array.
  1060. Returns:
  1061. Tensor, original array broadcast to the given shape.
  1062. Raises:
  1063. ValueError: if array cannot be broadcast to shape.
  1064. Supported Platforms:
  1065. ``Ascend`` ``GPU`` ``CPU``
  1066. Example:
  1067. >>> x = np.array([1, 2, 3])
  1068. >>> output = np.broadcast_to(x, (3, 3))
  1069. >>> print(output)
  1070. [[1 2 3]
  1071. [1 2 3]
  1072. [1 2 3]]
  1073. """
  1074. shape_a = F.shape(array)
  1075. if not _check_can_broadcast_to(shape_a, shape):
  1076. return _raise_value_error('cannot broadcaast with {shape_a} {shape}')
  1077. return _broadcast_to_shape(array, shape)
  1078. def broadcast_arrays(*args):
  1079. """
  1080. Broadcasts any number of arrays against each other.
  1081. Note:
  1082. Numpy argument `subok` is not supported.
  1083. In graph mode, returns a tuple of Tensor instead of a list
  1084. of Tensor.
  1085. Args:
  1086. *args (Tensor): The arrays to broadcast.
  1087. Returns:
  1088. List of Tensor.
  1089. Raises:
  1090. ValueError: if arrays cannot be broadcast.
  1091. Supported Platforms:
  1092. ``Ascend`` ``GPU`` ``CPU``
  1093. Example:
  1094. >>> x = np.array([[1,2,3]])
  1095. >>> y = np.array([[4],[5]])
  1096. >>> output = np.broadcast_arrays(x, y)
  1097. >>> print(output)
  1098. [Tensor(shape=[2, 3], dtype=Int32, value=
  1099. [[1, 2, 3],
  1100. [1, 2, 3]]), Tensor(shape=[2, 3], dtype=Int32, value=
  1101. [[4, 4, 4],
  1102. [5, 5, 5]])]
  1103. """
  1104. shapes = map(F.shape, args)
  1105. out_shape = _infer_out_shape(*shapes)
  1106. res = []
  1107. for arr in args:
  1108. res.append(broadcast_to(arr, out_shape))
  1109. return res
  1110. def split(x, indices_or_sections, axis=0):
  1111. """
  1112. Splits a tensor into multiple sub-tensors along the given axis.
  1113. Args:
  1114. x (Tensor): A Tensor to be divided.
  1115. indices_or_sections (Union[int, tuple(int), list(int)]):
  1116. If integer, :math:`N`, the tensor will be divided into
  1117. :math:`N` equal tensors along axis.
  1118. If tuple(int), list(int) or of sorted integers,
  1119. the entries indicate where along axis the array is split.
  1120. For example, :math:`[2, 3]` would, for :math:`axis=0`, result in
  1121. three sub-tensors :math:`x[:2]`, :math:`x[2:3]`and :math:`x[3:]`.
  1122. If an index exceeds the dimension of the array along axis,
  1123. an empty sub-array is returned correspondingly.
  1124. axis (int): The axis along which to split. Default: 0.
  1125. Returns:
  1126. A list of sub-tensors.
  1127. Raises:
  1128. TypeError: If argument `indices_or_sections` is not integer,
  1129. tuple(int) or list(int) or argument `axis` is not integer.
  1130. ValueError: If argument `axis` is out of range of :math:`[-x.ndim, x.ndim)`.
  1131. Examples:
  1132. >>> import mindspore.numpy as np
  1133. >>> input_x = np.arange(9)
  1134. >>> output = np.split(input_x, 3)
  1135. >>> print(output)
  1136. (Tensor(shape=[3], dtype=Float32,
  1137. value= [ 0.00000000e+00, 1.00000000e+00, 2.00000000e+00]),
  1138. Tensor(shape=[3], dtype=Float32,
  1139. value= [ 3.00000000e+00, 4.00000000e+00, 5.00000000e+00]),
  1140. Tensor(shape=[3], dtype=Float32,
  1141. value= [ 6.00000000e+00, 7.00000000e+00, 8.00000000e+00]))
  1142. """
  1143. _ = _check_axis_type(axis, True, False, False)
  1144. axis = _canonicalize_axis(axis, x.ndim)
  1145. res = None
  1146. if isinstance(indices_or_sections, int):
  1147. _split = P.Split(axis, indices_or_sections)
  1148. res = _split(x)
  1149. elif isinstance(indices_or_sections, (list, tuple)) and _check_element_int(indices_or_sections):
  1150. res = _split_sub_tensors(x, indices_or_sections, axis)
  1151. else:
  1152. _raise_type_error("Argument `indices_or_sections` in `mindspore.numpy.split`\
  1153. should be integer, tuple(int) or list(int), but got", indices_or_sections)
  1154. return res
  1155. def _split_sub_tensors(x, indices, axis):
  1156. """
  1157. Splits the input tensor `x` into multiple sub-tensors
  1158. along the axis according to the given indices.
  1159. """
  1160. if indices[-1] < x.shape[axis]:
  1161. if isinstance(indices, list):
  1162. indices.append(x.shape[axis])
  1163. elif isinstance(indices, tuple):
  1164. indices += (x.shape[axis],)
  1165. sub_tensors = []
  1166. strides = _list_comprehensions(x.ndim, 1, True)
  1167. begin = _list_comprehensions(x.ndim, 0)
  1168. end = _list_comprehensions(x.shape)
  1169. for i, idx in enumerate(indices):
  1170. begin[axis] = 0 if i == 0 else indices[i-1]
  1171. end[axis] = idx
  1172. sliced_tensor = F.strided_slice(x, _type_convert(tuple, begin), _type_convert(tuple, end), strides)
  1173. sub_tensors.append(sliced_tensor)
  1174. return sub_tensors
  1175. def vsplit(x, indices_or_sections):
  1176. """
  1177. Splits a tensor into multiple sub-tensors vertically (row-wise).
  1178. It is equivalent to split with :math:`axis=0` (default), the array is always
  1179. split along the first axis regardless of the array dimension.
  1180. Args:
  1181. x (Tensor): A Tensor to be divided.
  1182. indices_or_sections (Union[int, tuple(int), list(int)]):
  1183. If integer, :math:`N`, the tensor will be divided into
  1184. :math:`N` equal tensors along axis.
  1185. If tuple(int), list(int) or of sorted integers,
  1186. the entries indicate where along axis the array is split.
  1187. For example, :math:`[2, 3]` would, for :math:`axis=0`, result in
  1188. three sub-tensors :math:`x[:2]`, :math:`x[2:3]`and :math:`x[3:]`.
  1189. If an index exceeds the dimension of the array along axis,
  1190. an empty sub-array is returned correspondingly.
  1191. Returns:
  1192. A list of sub-tensors.
  1193. Raises:
  1194. TypeError: If argument `indices_or_sections` is not integer.
  1195. Supported Platforms:
  1196. ``Ascend`` ``GPU`` ``CPU``
  1197. Examples:
  1198. >>> import mindspore.numpy as np
  1199. >>> input_x = np.arange(9).reshape((3, 3))
  1200. >>> output = np.vsplit(input_x, 3)
  1201. >>> print(output)
  1202. (Tensor(shape=[1, 3], dtype=Float32,
  1203. value=[[ 0.00000000e+00, 1.00000000e+00, 2.00000000e+00]]),
  1204. Tensor(shape=[1, 3], dtype=Float32,
  1205. value=[[ 3.00000000e+00, 4.00000000e+00, 5.00000000e+00]]),
  1206. Tensor(shape=[1, 3], dtype=Float32,
  1207. value=[[ 6.00000000e+00, 7.00000000e+00, 8.00000000e+00]]))
  1208. """
  1209. return split(x, indices_or_sections, 0)
  1210. def hsplit(x, indices_or_sections):
  1211. """
  1212. Splits a tensor into multiple sub-tensors horizontally (column-wise).
  1213. It is equivalent to split with :math:`axis=1` (default), the array is always
  1214. split along the second axis regardless of the array dimension.
  1215. Args:
  1216. x (Tensor): A Tensor to be divided.
  1217. indices_or_sections (Union[int, tuple(int), list(int)]):
  1218. If integer, :math:`N`, the tensor will be divided into
  1219. :math:`N` equal tensors along axis.
  1220. If tuple(int), list(int) or of sorted integers,
  1221. the entries indicate where along axis the array is split.
  1222. For example, :math:`[2, 3]` would, for :math:`axis=0`, result in
  1223. three sub-tensors :math:`x[:2]`, :math:`x[2:3]`and :math:`x[3:]`.
  1224. If an index exceeds the dimension of the array along axis,
  1225. an empty sub-array is returned correspondingly.
  1226. Returns:
  1227. A list of sub-tensors.
  1228. Raises:
  1229. TypeError: If argument `indices_or_sections` is not integer.
  1230. Supported Platforms:
  1231. ``Ascend`` ``GPU`` ``CPU``
  1232. Examples:
  1233. >>> import mindspore.numpy as np
  1234. >>> input_x = np.arange(6).reshape((2, 3))
  1235. >>> output = np.hsplit(input_x, 3)
  1236. >>> print(output)
  1237. (Tensor(shape=[2, 1], dtype=Float32,
  1238. value=[[ 0.00000000e+00],
  1239. [ 3.00000000e+00]]),
  1240. Tensor(shape=[2, 1], dtype=Float32,
  1241. value=[[ 1.00000000e+00],
  1242. [ 4.00000000e+00]]),
  1243. Tensor(shape=[2, 1], dtype=Float32,
  1244. value=[[ 2.00000000e+00],
  1245. [ 5.00000000e+00]]))
  1246. """
  1247. return split(x, indices_or_sections, 1)
  1248. def dsplit(x, indices_or_sections):
  1249. """
  1250. Splits a tensor into multiple sub-tensors along the 3rd axis (depth).
  1251. It is equivalent to split with :math:`axis=2` (default), the array is always
  1252. split along the third axis regardless of the array dimension.
  1253. Args:
  1254. x (Tensor): A Tensor to be divided.
  1255. indices_or_sections (Union[int, tuple(int), list(int)]):
  1256. If integer, :math:`N`, the tensor will be divided into
  1257. :math:`N` equal tensors along axis.
  1258. If tuple(int), list(int) or of sorted integers,
  1259. the entries indicate where along axis the array is split.
  1260. For example, :math:`[2, 3]` would, for :math:`axis=0`, result in
  1261. three sub-tensors :math:`x[:2]`, :math:`x[2:3]`and :math:`x[3:]`.
  1262. If an index exceeds the dimension of the array along axis,
  1263. an empty sub-array is returned correspondingly.
  1264. Returns:
  1265. A list of sub-tensors.
  1266. Raises:
  1267. TypeError: If argument `indices_or_sections` is not integer.
  1268. Supported Platforms:
  1269. ``Ascend`` ``GPU`` ``CPU``
  1270. Examples:
  1271. >>> import mindspore.numpy as np
  1272. >>> input_x = np.arange(6).reshape((1, 2, 3))
  1273. >>> output = np.dsplit(input_x, 3)
  1274. >>> print(output)
  1275. (Tensor(shape=[1, 2, 1], dtype=Float32,
  1276. value=[[[ 0.00000000e+00],
  1277. [ 3.00000000e+00]]]),
  1278. Tensor(shape=[1, 2, 1], dtype=Float32,
  1279. value=[[[ 1.00000000e+00],
  1280. [ 4.00000000e+00]]]),
  1281. Tensor(shape=[1, 2, 1], dtype=Float32,
  1282. value=[[[ 2.00000000e+00],
  1283. [ 5.00000000e+00]]]))
  1284. """
  1285. return split(x, indices_or_sections, 2)
  1286. @constexpr
  1287. def _get_flip_start(ndim, shape, axes):
  1288. return tuple([shape[i] - 1 if i in axes else 0 for i in range(ndim)])
  1289. @constexpr
  1290. def _get_flip_end(ndim, shape, axes):
  1291. return tuple([-shape[i] - 1 if i in axes else shape[i] + 1 for i in range(ndim)])
  1292. @constexpr
  1293. def _get_flip_strides(ndim, axes):
  1294. return tuple([-1 if i in axes else 1 for i in range(ndim)])
  1295. def flip(m, axis=None):
  1296. """
  1297. Reverses the order of elements in an array along the given axis.
  1298. The shape of the array is preserved, but the elements are reordered.
  1299. Note:
  1300. On CPU, the supported dtypes are np.float16, np.float32, and np.float64.
  1301. Args:
  1302. m (Tensor): Input array.
  1303. axis (None or int or tuple of ints, optional): Axis or axes along which
  1304. to flip over. The default, ``axis=None``, will flip over all of the axes
  1305. of the input array. If `axis` is negative it counts from the last to
  1306. the first axis. If `axis` is a tuple of ints, flipping is performed on
  1307. all of the axes specified in the tuple.
  1308. Returns:
  1309. Tensor, a view of `m` with the entries of `axis` reversed.
  1310. Raises:
  1311. TypeError: if the input is not a tensor.
  1312. Supported Platforms:
  1313. ``Ascend`` ``GPU`` ``CPU``
  1314. Example:
  1315. >>> A = np.arange(8.0).reshape((2,2,2))
  1316. >>> output = np.flip(A)
  1317. >>> print(output)
  1318. [[[7, 6],
  1319. [5, 4]],
  1320. [[3, 2],
  1321. [1, 0]]]
  1322. >>> output = np.flip(A, (0, 2))
  1323. >>> print(output)
  1324. [[[5, 4],
  1325. [7, 6]],
  1326. [[1, 0],
  1327. [3, 2]]]
  1328. """
  1329. _check_input_tensor(m)
  1330. ndim = F.rank(m)
  1331. axes = _check_axis_valid(axis, ndim)
  1332. shape = F.shape(m)
  1333. dtype = F.dtype(m)
  1334. if _is_shape_empty(shape):
  1335. return m
  1336. if not _check_is_float(dtype):
  1337. m = m.astype(mstype.float32)
  1338. start = _get_flip_start(ndim, shape, axes)
  1339. end = _get_flip_end(ndim, shape, axes)
  1340. strides = _get_flip_strides(ndim, axes)
  1341. res = F.strided_slice(m, start, end, strides)
  1342. if not _check_same_type(F.dtype(res), dtype):
  1343. res = F.cast(res, dtype)
  1344. return res
  1345. def flipud(m):
  1346. """
  1347. Flips the entries in each column in the up/down direction.
  1348. Rows are preserved, but appear in a different order than before.
  1349. Note:
  1350. On CPU, the supported dtypes are np.float16, np.float32, and np.float64.
  1351. Args:
  1352. m (Tensor): Input array.
  1353. Returns:
  1354. Tensor.
  1355. Raises:
  1356. TypeError: if the input is not a tensor.
  1357. Supported Platforms:
  1358. ``Ascend`` ``GPU`` ``CPU``
  1359. Example:
  1360. >>> A = np.arange(8.0).reshape((2,2,2))
  1361. >>> output = np.flipud(A)
  1362. >>> print(output)
  1363. [[[4., 5.],
  1364. [6., 7.]],
  1365. [[0., 1.],
  1366. [2., 3.]]]
  1367. """
  1368. return flip(m, 0)
  1369. def fliplr(m):
  1370. """
  1371. Flip the entries in each row in the left/right direction.
  1372. Columns are preserved, but appear in a different order than before.
  1373. Note:
  1374. On CPU, the supported dtypes are np.float16, np.float32, and np.float64.
  1375. Args:
  1376. m (Tensor): Input array.
  1377. Returns:
  1378. Tensor.
  1379. Raises:
  1380. TypeError: if the input is not a tensor.
  1381. Supported Platforms:
  1382. ``Ascend`` ``GPU`` ``CPU``
  1383. Example:
  1384. >>> A = np.arange(8.0).reshape((2,2,2))
  1385. >>> output = np.fliplr(A)
  1386. >>> print(output)
  1387. [[[2., 3.],
  1388. [0., 1.]],
  1389. [[6., 7.],
  1390. [4., 5.]]]
  1391. """
  1392. return flip(m, 1)
  1393. def take_along_axis(arr, indices, axis):
  1394. """
  1395. Takes values from the input array by matching 1d index and data slices.
  1396. This iterates over matching 1d slices oriented along the specified axis in the
  1397. index and data arrays, and uses the former to look up values in the latter.
  1398. These slices can be different lengths.
  1399. Args:
  1400. arr (Tensor): Source array with shape `(Ni…, M, Nk…)`.
  1401. indices (Tensor): Indices with shape `(Ni…, J, Nk…)` to take along each 1d
  1402. slice of `arr`. This must match the dimension of `arr`, but dimensions `Ni`
  1403. and `Nj` only need to broadcast against `arr`.
  1404. axis (int): The axis to take 1d slices along. If `axis` is None, the input
  1405. array is treated as if it had first been flattened to 1d.
  1406. Returns:
  1407. Tensor, the indexed result, with shape `(Ni…, J, Nk…)`.
  1408. Raises:
  1409. ValueError: if input array and indices have different number of dimensions.
  1410. TypeError: if the input is not a Tensor.
  1411. Supported Platforms:
  1412. ``Ascend`` ``GPU`` ``CPU``
  1413. Example:
  1414. >>> x = np.arange(12).reshape(3, 4)
  1415. >>> indices = np.arange(3).reshape(1, 3)
  1416. >>> output = np.take_along_axis(x, indices, 1)
  1417. >>> print(output)
  1418. [[ 0 1 2]
  1419. [ 4 5 6]
  1420. [ 8 9 10]]
  1421. """
  1422. _check_input_tensor(arr, indices)
  1423. if axis is None:
  1424. arr = ravel(arr)
  1425. axis = 0
  1426. ndim = F.rank(arr)
  1427. if ndim != F.rank(indices):
  1428. _raise_value_error('`indices` and `arr` must have the same number of dimensions')
  1429. _check_axis_in_range(axis, ndim)
  1430. axis = axis + ndim if axis < 0 else axis
  1431. shape_arr = F.shape(arr)
  1432. shape_indices = F.shape(indices)
  1433. # broadcasts indices against the shape of arr except at axis
  1434. indices = _broadcast_to(indices, _tuple_getitem(shape_indices, axis, False),
  1435. _tuple_getitem(shape_arr, axis, False), ndim)
  1436. indices = _broadcast_to(indices, _tuple_getitem(shape_arr, axis + 1, False) +
  1437. _tuple_getitem(shape_indices, axis + 1), shape_arr, ndim)
  1438. return F.gather_d(arr, axis, indices)
  1439. def _mod(x, y):
  1440. """Computes x mod y."""
  1441. quotient = F.tensor_floordiv(x, y)
  1442. prod = F.tensor_mul(y, quotient)
  1443. return F.tensor_sub(x, prod)
  1444. def _check_indices(size, indices, mode):
  1445. """Checks whether indices are out of bounds."""
  1446. shape = F.shape(indices)
  1447. dtype = F.dtype(indices)
  1448. lowerbounds = F.fill(dtype, shape, -size)
  1449. upperbounds = F.fill(dtype, shape, size - 1)
  1450. out_of_lowerbounds = F.tensor_lt(indices, lowerbounds)
  1451. out_of_upperbounds = F.tensor_gt(indices, upperbounds)
  1452. if mode == 'raise':
  1453. # For mode raise, index-out-of-bounds checking is performed at backend since
  1454. # evaluation of a boolean scalar Tensor always returns true in graph mode
  1455. # regardless of the truth value contained
  1456. return indices
  1457. if mode == 'wrap':
  1458. return _mod(indices, F.fill(dtype, shape, size))
  1459. zeros = F.fill(dtype, shape, 0)
  1460. clipped = F.select(out_of_lowerbounds, zeros, indices)
  1461. clipped = F.select(out_of_upperbounds, upperbounds, clipped)
  1462. return clipped
  1463. def take(a, indices, axis=None, mode='raise'):
  1464. """
  1465. Takes elements from an array along an axis.
  1466. When axis is not None, this function does the same thing as “fancy” indexing
  1467. (indexing arrays using arrays); however, it can be easier to use if you need
  1468. elements along a given axis. A call such as ``np.take(arr, indices, axis=3)`` is
  1469. equivalent to ``arr[:,:,:,indices,...]``.
  1470. Note:
  1471. Numpy argument out is not supported.
  1472. Args:
  1473. a (Tensor): Source array with shape `(Ni…, M, Nk…)`.
  1474. indices (Tensor): The indices with shape `(Nj...)` of the values to extract.
  1475. axis (int, optional): The axis over which to select values. By default,
  1476. the flattened input array is used.
  1477. mode (‘raise’, ‘wrap’, ‘clip’, optional): Specifies how out-of-bounds
  1478. indices will behave.
  1479. ‘raise’ – raise an error (default);
  1480. ‘wrap’ – wrap around;
  1481. ‘clip’ – clip to the range. ‘clip’ mode means that all indices that are
  1482. too large are replaced by the index that addresses the last element
  1483. along that axis. Note that this disables indexing with negative numbers.
  1484. Returns:
  1485. Tensor, the indexed result.
  1486. Raises:
  1487. ValueError: if axis is out of range.
  1488. TypeError: if the input is not a Tensor.
  1489. Supported Platforms:
  1490. ``Ascend`` ``GPU`` ``CPU``
  1491. Examples:
  1492. >>> a = np.array([4, 3, 5, 7, 6, 8])
  1493. >>> indices = np.array([0, 1, 4])
  1494. >>> output = np.take(a, indices)
  1495. >>> print(output)
  1496. [4 3 6]
  1497. >>> indices = np.array([[0, 1], [2, 3]])
  1498. >>> output = np.take(a, indices)
  1499. >>> print(output)
  1500. [[4 3]
  1501. [5 7]]
  1502. """
  1503. _check_input_tensor(a, indices)
  1504. if axis is None:
  1505. a = ravel(a)
  1506. axis = 0
  1507. ndim = F.rank(a)
  1508. _check_axis_in_range(axis, ndim)
  1509. axis = axis + ndim if axis < 0 else axis
  1510. shape_a = F.shape(a)
  1511. shape_indices = F.shape(indices)
  1512. size_indices = indices.size
  1513. indices = _check_indices(shape_a[axis], indices, mode)
  1514. # reshapes indices to shape (Ni..., Nj..., Nk)
  1515. shape_ni = _tuple_getitem(shape_a, axis, False)
  1516. shape_nk = _tuple_getitem(shape_a, axis + 1)
  1517. shape_out = shape_ni + shape_indices + shape_nk
  1518. shape_indices = _expanded_shape(ndim, size_indices, axis)
  1519. indices = F.reshape(indices, shape_indices)
  1520. shape_indices = shape_ni + (indices.size,) + shape_nk
  1521. indices = _broadcast_to_shape(indices, shape_indices)
  1522. res = F.gather_d(a, axis, indices)
  1523. return F.reshape(res, shape_out)
  1524. def repeat(a, repeats, axis=None):
  1525. """
  1526. Repeats elements of an array.
  1527. Args:
  1528. a (Tensor): Input array.
  1529. repeats (int or sequence of ints): The number of repetitions for each element.
  1530. `repeats` is broadcasted to fit the shape of the given axis.
  1531. axis (int, optional): The axis along which to repeat values. By default,
  1532. use the flattened input array, and return a flat output array.
  1533. Returns:
  1534. Tensor, output array which has the same shape as `a`, except along the given
  1535. axis.
  1536. Raises:
  1537. ValueError: if axis is out of range.
  1538. TypeError: if input `a` is not a Tensor.
  1539. Supported Platforms:
  1540. ``Ascend`` ``GPU`` ``CPU``
  1541. Examples:
  1542. >>> output = np.repeat(np.array(3), 4)
  1543. >>> print(output)
  1544. [3 3 3 3]
  1545. >>> x = np.array([[1,2],[3,4]])
  1546. >>> output = np.repeat(x, 2)
  1547. >>> print(output)
  1548. [1 1 2 2 3 3 4 4]
  1549. >>> output = np.repeat(x, 3, axis=1)
  1550. >>> print(output)
  1551. [[1 1 1 2 2 2]
  1552. [3 3 3 4 4 4]]
  1553. >>> output = np.repeat(x, [1, 2], axis=0)
  1554. >>> print(output)
  1555. [[1 2]
  1556. [3 4]
  1557. [3 4]]
  1558. """
  1559. _check_input_tensor(a)
  1560. if axis is None:
  1561. a = ravel(a)
  1562. axis = 0
  1563. ndim = F.rank(a)
  1564. _check_axis_in_range(axis, ndim)
  1565. axis = axis + ndim if axis < 0 else axis
  1566. if isinstance(repeats, (tuple, list)) and len(repeats) == 1:
  1567. repeats = repeats[0]
  1568. if isinstance(repeats, int):
  1569. if repeats == 0:
  1570. return _empty(F.dtype(a), (0,))
  1571. return C.repeat_elements(a, repeats, axis)
  1572. shape = F.shape(a)
  1573. size = shape[axis]
  1574. if len(repeats) != size:
  1575. _raise_value_error('operands could not be broadcast together')
  1576. subs = split(a, size, axis)
  1577. repeated_subs = []
  1578. for sub, rep in zip(subs, repeats):
  1579. if rep != 0:
  1580. repeated_subs.append(C.repeat_elements(sub, rep, axis))
  1581. return concatenate(repeated_subs, axis)