You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

tensor.py 29 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813
  1. # Copyright 2020-2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Tensor implementation."""
  16. import numbers
  17. import numpy as np
  18. from mindspore import log as logger
  19. from mindspore.communication.management import get_rank, get_group_size
  20. from . import dtype as mstype
  21. from ._register_for_tensor import tensor_operator_registry
  22. from .._c_expression import Tensor as Tensor_
  23. from .._checkparam import Validator as validator
  24. __all__ = ['Tensor', 'RowTensor', 'SparseTensor']
  25. np_types = (np.int8, np.int16, np.int32, np.int64,
  26. np.uint8, np.uint16, np.uint32, np.uint64, np.float16,
  27. np.float32, np.float64, np.bool_)
  28. class Tensor(Tensor_):
  29. """
  30. Tensor is used for data storage.
  31. Tensor inherits tensor object in C++.
  32. Some functions are implemented in C++ and some functions are implemented in Python.
  33. Args:
  34. input_data (Tensor, float, int, bool, tuple, list, numpy.ndarray): Input data of the tensor.
  35. dtype (:class:`mindspore.dtype`): Input data should be None, bool or numeric type defined in `mindspore.dtype`.
  36. The argument is used to define the data type of the output tensor. If it is None, the data type of the
  37. output tensor will be as same as the `input_data`. Default: None.
  38. shape (Union[tuple, list, int]): A list of integers, a tuple of integers or an integer as the shape of
  39. output. Default: None.
  40. init (class:'Initializer'): the information of init data.
  41. 'init' is used for delayed initialization in parallel mode. Usually, it is not recommended to
  42. use 'init' interface to initialize parameters in other conditions. If 'init' interface is used
  43. to initialize parameters, the `init_data` API need to be called to convert `Tensor` to the actual data.
  44. Outputs:
  45. Tensor, with the same shape as `input_data`.
  46. Examples:
  47. >>> import mindspore as ms
  48. >>> import mindspore.nn as nn
  49. >>> # initialize a tensor with input data
  50. >>> t1 = Tensor(np.zeros([1, 2, 3]), mindspore.float32)
  51. >>> assert isinstance(t1, Tensor)
  52. >>> assert t1.shape == (1, 2, 3)
  53. >>> assert t1.dtype == mindspore.float32
  54. ...
  55. >>> # initialize a tensor with a float scalar
  56. >>> t2 = Tensor(0.1)
  57. >>> assert isinstance(t2, Tensor)
  58. >>> assert t2.dtype == mindspore.float64
  59. """
  60. def __init__(self, input_data=None, dtype=None, shape=None, init=None):
  61. # If input data is numpy number, convert it to np array
  62. if isinstance(input_data, np_types):
  63. input_data = np.array(input_data)
  64. if input_data is not None and shape is not None and input_data.shape != shape:
  65. raise ValueError("input_data.shape and shape should be same.")
  66. if init is not None and (shape is None or dtype is None):
  67. raise ValueError("init, dtype and shape must have values at the same time.")
  68. if ((input_data is not None and init is None) or (input_data is None and init is not None)) is False:
  69. raise TypeError("input_data and init can not be None at the same time.")
  70. if isinstance(shape, numbers.Number):
  71. shape = (shape,)
  72. # If input_data is tuple/list/numpy.ndarray, it's support in check_type method.
  73. if init is None:
  74. validator.check_value_type('input_data', input_data, (Tensor_, np.ndarray, list, tuple, float, int, bool),
  75. 'Tensor')
  76. valid_dtypes = (np.int8, np.int16, np.int32, np.int64, np.uint8, np.uint16, np.uint32, np.uint64,
  77. np.float16, np.float32, np.float64, np.bool_)
  78. if isinstance(input_data, np.ndarray) and input_data.dtype not in valid_dtypes:
  79. raise TypeError(f"For Tensor, the input_data is a numpy array, "
  80. f"but it's data type is not in supported list:\
  81. {list(i.__name__ for i in valid_dtypes)}.")
  82. if isinstance(input_data, (tuple, list)):
  83. if np.array(input_data).dtype not in valid_dtypes:
  84. raise TypeError(f"For Tensor, the input_data is {input_data} that contain unsupported element.")
  85. if dtype is not None:
  86. validator.check_type_name('dtype', dtype, mstype.number_type + (mstype.bool_,), "Tensor")
  87. if isinstance(input_data, np.ndarray) and (not input_data.flags['FORC']):
  88. input_data = np.ascontiguousarray(input_data)
  89. if dtype is None:
  90. Tensor_.__init__(self, input_data)
  91. else:
  92. Tensor_.__init__(self, input_data, dtype)
  93. else:
  94. Tensor_.__init__(self, dtype, shape)
  95. self._virtual_flag = False
  96. self.init = init
  97. def __deepcopy__(self, memodict):
  98. new_obj = Tensor(self)
  99. new_obj.init = self.init
  100. new_obj._virtual_flag = self._virtual_flag # pylint:disable=w0212
  101. return new_obj
  102. def __repr__(self):
  103. Tensor_.data_sync(self, False)
  104. return Tensor_.__repr__(self)
  105. def __add__(self, other):
  106. out = tensor_operator_registry.get('__add__')(self, other)
  107. return out
  108. def __eq__(self, other):
  109. if not isinstance(other, (int, float, Tensor)):
  110. return False
  111. # bool type is not supported for `Equal` operator in backend.
  112. if self.dtype == mstype.bool_ or (isinstance(other, Tensor) and other.dtype == mstype.bool_):
  113. if isinstance(other, Tensor):
  114. return Tensor(np.array(self.asnumpy() == other.asnumpy()))
  115. return Tensor(np.array(self.asnumpy() == other))
  116. return tensor_operator_registry.get('__eq__')(self, other)
  117. def __ne__(self, other):
  118. if not isinstance(other, (int, float, Tensor)):
  119. return True
  120. # bool type is not supported for `NotEqual` operator in backend.
  121. if self.dtype == mstype.bool_ or (isinstance(other, Tensor) and other.dtype == mstype.bool_):
  122. return Tensor(np.array(self.asnumpy() != other.asnumpy()))
  123. return tensor_operator_registry.get('__ne__')(self, other)
  124. def __hash__(self):
  125. return hash(id(self))
  126. def __mul__(self, other):
  127. out = tensor_operator_registry.get('__mul__')(self, other)
  128. return out
  129. def __neg__(self):
  130. out = tensor_operator_registry.get('__neg__')(self)
  131. return out
  132. def __bool__(self):
  133. data = self.asnumpy()
  134. if data.shape == ():
  135. return bool(data)
  136. if data.shape == (1,):
  137. return bool(data[0])
  138. raise ValueError("The truth value of an array with several elements is ambiguous.")
  139. def __index__(self):
  140. data = self.asnumpy()
  141. if not (data.dtype == "int8"
  142. or data.dtype == "int16"
  143. or data.dtype == "int32"
  144. or data.dtype == "int64"
  145. or data.dtype == "bool"):
  146. raise ValueError("Only integer tensors of a single element can be converted to an index.")
  147. if data.shape == ():
  148. return int(data)
  149. if data.shape == (1,):
  150. return int(data[0])
  151. raise ValueError("Only integer tensors of a single element can be converted to an index.")
  152. def __pos__(self):
  153. return self
  154. def __iadd__(self, other):
  155. return self.__add__(other)
  156. def __radd__(self, other):
  157. out = tensor_operator_registry.get('__add__')(self, other)
  158. return out
  159. def __imul__(self, other):
  160. return self.__mul__(other)
  161. def __rmul__(self, other):
  162. out = tensor_operator_registry.get('__mul__')(self, other)
  163. return out
  164. def __truediv__(self, other):
  165. out = tensor_operator_registry.get('__truediv__')(self, other)
  166. return out
  167. def __rtruediv__(self, other):
  168. out = tensor_operator_registry.get('__truediv__')(other, self)
  169. return out
  170. def __sub__(self, other):
  171. out = tensor_operator_registry.get('__sub__')(self, other)
  172. return out
  173. def __isub__(self, other):
  174. return self.__sub__(other)
  175. def __rsub__(self, other):
  176. out = tensor_operator_registry.get('__sub__')(other, self)
  177. return out
  178. def __lt__(self, other):
  179. out = tensor_operator_registry.get('__lt__')(self, other)
  180. return out
  181. def __le__(self, other):
  182. out = tensor_operator_registry.get('__le__')(self, other)
  183. return out
  184. def __getitem__(self, index):
  185. if isinstance(index, int) and not isinstance(index, bool) and self.shape and index >= self.shape[0]:
  186. raise IndexError("index {} is out of bounds for axis 0 with size {}".format(index, self.shape[0]))
  187. out = tensor_operator_registry.get('__getitem__')(self, index)
  188. return out
  189. def __setitem__(self, index, value):
  190. out = tensor_operator_registry.get('__setitem__')(self, index, value)
  191. self.assign_value(out)
  192. return self
  193. def __gt__(self, other):
  194. out = tensor_operator_registry.get('__gt__')(self, other)
  195. return out
  196. def __ge__(self, other):
  197. out = tensor_operator_registry.get('__ge__')(self, other)
  198. return out
  199. def __len__(self):
  200. out = tensor_operator_registry.get('shape')(self)
  201. if out:
  202. return out[0]
  203. raise TypeError("Not support len of a 0-D tensor")
  204. def __mod__(self, other):
  205. return tensor_operator_registry.get('__mod__')(self, other)
  206. def __imod__(self, other):
  207. return self.__mod__(other)
  208. def __rmod__(self, other):
  209. return tensor_operator_registry.get('__mod__')(other, self)
  210. def __pow__(self, other):
  211. return tensor_operator_registry.get('__pow__')(self, other)
  212. def __floordiv__(self, other):
  213. return tensor_operator_registry.get('__floordiv__')(self, other)
  214. def __ifloordiv__(self, other):
  215. return self.__floordiv__(other)
  216. def __rfloordiv__(self, other):
  217. return tensor_operator_registry.get('__floordiv__')(other, self)
  218. def __str__(self):
  219. if self.dtype == mstype.type_none:
  220. return "Unknown Tensor type!"
  221. return str(self.asnumpy())
  222. @property
  223. def shape(self):
  224. """Returns the shape of the tensor as a tuple."""
  225. return self._shape
  226. @property
  227. def dtype(self):
  228. """Returns the dtype of the tensor (:class:`mindspore.dtype`)."""
  229. return self._dtype
  230. @property
  231. def size(self):
  232. """Returns the total number of elements in tensor."""
  233. return self._size
  234. @property
  235. def ndim(self):
  236. """Returns the number of tensor dimensions."""
  237. return len(self._shape)
  238. @property
  239. def has_init(self):
  240. """tensor is inited."""
  241. return self.init is not None
  242. @property
  243. def itemsize(self):
  244. """Returns the length of one tensor element in bytes."""
  245. return self._itemsize
  246. @property
  247. def strides(self):
  248. """Returns the tuple of bytes to step in each dimension when traversing a tensor."""
  249. return self._strides
  250. @property
  251. def nbytes(self):
  252. """Returns the total number of bytes taken by the tensor."""
  253. return self._nbytes
  254. @property
  255. def T(self):
  256. """Returns the transposed tensor."""
  257. return self.transpose()
  258. @property
  259. def virtual_flag(self):
  260. """Mark tensor is virtual."""
  261. return self._virtual_flag
  262. @virtual_flag.setter
  263. def virtual_flag(self, value):
  264. """The setter of virtual_flag."""
  265. if not isinstance(value, bool):
  266. raise TypeError("virtual_flag must be bool.")
  267. self._virtual_flag = value
  268. @staticmethod
  269. def from_numpy(array):
  270. """Convert numpy array to Tensor without copy data."""
  271. return Tensor(Tensor_.from_numpy(array))
  272. def asnumpy(self):
  273. """Convert tensor to numpy array."""
  274. self.init_check()
  275. return Tensor_.asnumpy(self)
  276. def _flush_from_cache(self):
  277. """Flush cache data to host if tensor is cache enable."""
  278. self.init_check()
  279. Tensor_._flush_from_cache(self)
  280. def all(self, axis=(), keep_dims=False):
  281. """
  282. Check all array elements along a given axis evaluate to True.
  283. Args:
  284. axis (Union[None, int, tuple(int)): Dimensions of reduction,
  285. when axis is None or empty tuple, reduce all dimensions.
  286. Default: (), reduce all dimensions.
  287. keep_dims (bool): Whether to keep the reduced dimensions.
  288. Default : False, don't keep these reduced dimensions.
  289. Returns:
  290. Tensor, has the same data type as x.
  291. """
  292. self.init_check()
  293. if axis is None:
  294. axis = ()
  295. return tensor_operator_registry.get('all')(keep_dims)(self, axis)
  296. def any(self, axis=(), keep_dims=False):
  297. """
  298. Check any array element along a given axis evaluate to True.
  299. Args:
  300. axis (Union[None, int, tuple(int)): Dimensions of reduction,
  301. when axis is None or empty tuple, reduce all dimensions.
  302. Default: (), reduce all dimensions.
  303. keep_dims (bool): Whether to keep the reduced dimensions.
  304. Default : False, don't keep these reduced dimensions.
  305. Returns:
  306. Tensor, has the same data type as x.
  307. """
  308. self.init_check()
  309. if axis is None:
  310. axis = ()
  311. return tensor_operator_registry.get('any')(keep_dims)(self, axis)
  312. def view(self, *shape):
  313. r"""
  314. Reshape the tensor according to the input shape.
  315. Args:
  316. shape (Union(tuple[int], \*int)): Dimension of the output tensor.
  317. Returns:
  318. Tensor, has the same dimension as the input shape.
  319. """
  320. self.init_check()
  321. if not shape:
  322. raise ValueError("The shape variable should not be empty")
  323. if isinstance(shape[0], tuple):
  324. if len(shape) != 1:
  325. raise ValueError(f"Only one tuple is needed, but got {shape}")
  326. shape = shape[0]
  327. return tensor_operator_registry.get('reshape')()(self, shape)
  328. def expand_as(self, x):
  329. """
  330. Expand the dimension of target tensor to the dimension of input tensor.
  331. Args:
  332. shape (Tensor): The input tensor. The shape of input tensor must obey
  333. the broadcasting rule.
  334. Returns:
  335. Tensor, has the same dimension as input tensor.
  336. """
  337. self.init_check()
  338. return tensor_operator_registry.get('broadcast_to')(x.shape)(self)
  339. def abs(self):
  340. """
  341. Return absolute value element-wisely.
  342. Returns:
  343. Tensor, has the same data type as x.
  344. """
  345. self.init_check()
  346. return tensor_operator_registry.get('abs')()(self)
  347. def mean(self, axis=(), keep_dims=False):
  348. """
  349. Reduces a dimension of a tensor by averaging all elements in the dimension.
  350. Args:
  351. axis (Union[None, int, tuple(int), list(int)]): Dimensions of reduction,
  352. when axis is None or empty tuple, reduce all dimensions.
  353. Default: (), reduce all dimensions.
  354. keep_dims (bool): Whether to keep the reduced dimensions.
  355. Default : False, don't keep these reduced dimensions.
  356. Returns:
  357. Tensor, has the same data type as x.
  358. """
  359. self.init_check()
  360. if axis is None:
  361. axis = ()
  362. return tensor_operator_registry.get('mean')(keep_dims)(self, axis)
  363. def transpose(self, *axes):
  364. r"""
  365. Returns a view of the tensor with axes transposed.
  366. For a 1-D tensor this has no effect, as a transposed vector is simply the
  367. same vector. For a 2-D tensor, this is a standard matrix transpose. For a
  368. n-D tensor, if axes are given, their order indicates how the axes are permuted.
  369. If axes are not provided and tensor.shape = (i[0], i[1],...i[n-2], i[n-1]),
  370. then tensor.transpose().shape = (i[n-1], i[n-2], ... i[1], i[0]).
  371. Args:
  372. axes(Union[None, tuple(int), list(int), \*int], optional): If axes is None or
  373. blank, tensor.transpose() will reverse the order of the axes. If axes is tuple(int)
  374. or list(int), tensor.transpose() will transpose the tensor to the new axes order.
  375. If axes is \*int, this form is simply intended as a convenience alternative to the
  376. tuple/list form.
  377. Returns:
  378. Tensor, has the same dimension as input tensor, with axes suitably permuted.
  379. """
  380. self.init_check()
  381. perm = validator.check_transpose_axis(axes, self.ndim)
  382. return tensor_operator_registry.get('transpose')()(self, perm)
  383. def reshape(self, *shape):
  384. """
  385. Gives a new shape to a tensor without changing its data.
  386. Args:
  387. shape(Union[int, tuple(int), list(int)]): The new shape should be compatible
  388. with the original shape. If an integer, then the result will be a 1-D
  389. array of that length. One shape dimension can be -1. In this case, the
  390. value is inferred from the length of the array and remaining dimensions.
  391. Returns:
  392. Tensor, with new specified shape.
  393. """
  394. self.init_check()
  395. new_shape = validator.check_reshape_shp(shape)
  396. return tensor_operator_registry.get('reshape')()(self, new_shape)
  397. def ravel(self):
  398. """
  399. Returns a contiguous flattened tensor.
  400. Returns:
  401. Tensor, a 1-D tensor, containing the same elements of the input.
  402. """
  403. self.init_check()
  404. reshape_op = tensor_operator_registry.get('reshape')()
  405. return reshape_op(self, (-1,))
  406. def flatten(self, order='C'):
  407. """
  408. Returns a copy of the tensor collapsed into one dimension.
  409. Args:
  410. order (str, optional): Can choose between \'C\' and \'F\'. \'C\' means to
  411. flatten in row-major (C-style) order. \'F\' means to flatten in column-major
  412. (Fortran- style) order. Only \'C\' and \'F\' are supported.
  413. Returns:
  414. Tensor, has the same data type as input.
  415. """
  416. self.init_check()
  417. reshape_op = tensor_operator_registry.get('reshape')()
  418. trans_op = tensor_operator_registry.get('transpose')()
  419. order = validator.check_flatten_order(order)
  420. if order == 'C':
  421. return reshape_op(self, (-1,))
  422. perm = tuple(range(self.ndim-1, -1, -1))
  423. return reshape_op(trans_op(self, perm), (-1,))
  424. def swapaxes(self, axis1, axis2):
  425. """
  426. Interchanges two axes of a tensor.
  427. Args:
  428. axis1 (int): First axis.
  429. axis2 (int): Second axis.
  430. Returns:
  431. Transposed tensor, has the same data type as the input.
  432. """
  433. self.init_check()
  434. axis1, axis2 = validator.check_swapaxes_axis((axis1, axis2), self.ndim)
  435. if axis1 == axis2:
  436. return self
  437. if axis1 > axis2:
  438. axis1, axis2 = axis2, axis1
  439. perm = tuple(range(0, self.ndim))
  440. new_perm = None
  441. if axis2 + 1 < self.ndim:
  442. new_perm = perm[0:axis1] + perm[axis2:axis2+1] + \
  443. perm[axis1+1:axis2] + perm[axis1:axis1+1] + perm[axis2+1:]
  444. else:
  445. new_perm = perm[0:axis1] + perm[axis2:axis2+1] + \
  446. perm[axis1+1:axis2] + perm[axis1:axis1+1]
  447. return tensor_operator_registry.get('transpose')()(self, new_perm)
  448. def squeeze(self, axis=None):
  449. """
  450. Removes single-dimensional entries from the shape of a tensor.
  451. Args:
  452. axis (Union[None, int, list(int), tuple(list)], optional): Default is None.
  453. Returns:
  454. Tensor, with all or a subset of the dimensions of length 1 removed.
  455. """
  456. self.init_check()
  457. if axis is None:
  458. return tensor_operator_registry.get('squeeze')(self)
  459. new_shape = validator.prepare_shape_for_squeeze(self.shape, axis)
  460. return tensor_operator_registry.get('reshape')()(self, new_shape)
  461. def astype(self, dtype, copy=True):
  462. """
  463. Returns a copy of the tensor, casted to a specified type.
  464. Args:
  465. dtype (Union[:class:`mindspore.dtype`, str]): Designated tensor dtype, can be in format
  466. of :class:`mindspore.dtype.float32` or \'float32\'. Default is :class:`mindspore.dtype.float32`
  467. copy (bool, optional): By default, astype always returns a newly allocated
  468. tensor. If this is set to false, the input tensor is returned instead
  469. of a copy if possible.
  470. Returns:
  471. Tensor, with the designated dtype.
  472. """
  473. self.init_check()
  474. dtype = validator.check_astype_dtype(dtype)
  475. if not copy and dtype == self.dtype:
  476. return self
  477. return tensor_operator_registry.get('cast')(self, dtype)
  478. def init_check(self):
  479. if self.has_init:
  480. self.init_data()
  481. return self
  482. def init_data(self, slice_index=None, shape=None, opt_shard_group=None):
  483. """
  484. Get the tensor format data of this Tensor.
  485. The init_data function can be called once for the same tensor.
  486. Args:
  487. slice_index (int): Slice index of a parameter's slices.
  488. It is used when initialize a slice of a parameter, it guarantees that devices
  489. using the same slice can generate the same tensor.
  490. shape (list[int]): Shape of the slice, it is used when initialize a slice of the parameter.
  491. opt_shard_group(str): Optimizer shard group which is used in auto or semi auto parallel mode
  492. to get one shard of a parameter's slice.
  493. """
  494. if self.init is None:
  495. raise TypeError("init_data must be set Tensor.init, init can't be None")
  496. if shape is None:
  497. shape = self.shape
  498. try:
  499. arr = np.ndarray(shape, dtype=mstype.dtype_to_nptype(self.dtype))
  500. except ValueError:
  501. msg = "Error shape={}".format(shape)
  502. logger.error(msg)
  503. raise ValueError(msg)
  504. class seed_context:
  505. '''set and restore seed'''
  506. def __init__(self, init):
  507. self.init = init
  508. from .seed import get_seed
  509. global_seed = get_seed()
  510. self._np_seed = np.random.get_state()[1][0]
  511. self.need_set_seed = ((slice_index is not None) and (global_seed is None))
  512. def __enter__(self):
  513. if self.need_set_seed:
  514. self.seed = self.init.seed
  515. np.random.seed(slice_index)
  516. self.init.seed = slice_index
  517. def __exit__(self, ptype, value, trace):
  518. if self.need_set_seed:
  519. np.random.seed(self._np_seed)
  520. self.init.seed, _ = self.seed
  521. with seed_context(self.init):
  522. self.init(arr)
  523. data = np.array(arr)
  524. if opt_shard_group:
  525. rank = get_rank(opt_shard_group)
  526. size = get_group_size(opt_shard_group)
  527. data = np.split(data, size)[rank]
  528. self.init = None
  529. self.assign_value(Tensor(data, dtype=self.dtype))
  530. return self
  531. def to_tensor(self, slice_index=None, shape=None, opt_shard_group=None):
  532. """Return init_data()."""
  533. logger.warning("WARN_DEPRECATED: The usage of to_tensor is deprecated."
  534. " Please use init_data")
  535. return self.init_data(slice_index, shape, opt_shard_group)
  536. class RowTensor:
  537. """
  538. A sparse representation of a set of tensor slices at given indices.
  539. An RowTensor is typically used to represent a subset of a larger
  540. tensor dense of shape [L0, D1, .. , DN] where L0 >> D0.
  541. The values in indices are the indices in the first dimension of the slices
  542. that have been extracted from the larger tensor.
  543. The dense tensor dense represented by an RowTensor slices has
  544. `dense[slices.indices[i], :, :, :, ...] = slices.values[i, :, :, :, ...]`.
  545. RowTensor can only be used in the `Cell`'s construct method.
  546. It is not supported in pynative mode at the moment.
  547. Args:
  548. indices (Tensor): A 1-D integer Tensor of shape [D0].
  549. values (Tensor): A Tensor of any dtype of shape [D0, D1, ..., Dn].
  550. dense_shape (tuple): An integer tuple which contains the shape
  551. of the corresponding dense tensor.
  552. Returns:
  553. RowTensor, composed of `indices`, `values`, and `dense_shape`.
  554. Examples:
  555. >>> import mindspore as ms
  556. >>> import mindspore.nn as nn
  557. >>> class Net(nn.Cell):
  558. ... def __init__(self, dense_shape):
  559. ... super(Net, self).__init__()
  560. ... self.dense_shape = dense_shape
  561. ... def construct(self, indices, values):
  562. ... x = RowTensor(indices, values, self.dense_shape)
  563. ... return x.values, x.indices, x.dense_shape
  564. >>>
  565. >>> indices = Tensor([0])
  566. >>> values = Tensor([[1, 2]], dtype=ms.float32)
  567. >>> out = Net((3, 2))(indices, values)
  568. >>> print(out[0])
  569. [[1. 2.]]
  570. >>> print(out[1])
  571. [0]
  572. >>> print(out[2])
  573. (3, 2)
  574. """
  575. def __init__(self, indices, values, dense_shape):
  576. "Init RowTensor"
  577. self.__indices = indices
  578. self.__values = values
  579. self.__dense_shape = dense_shape
  580. @property
  581. def indices(self):
  582. return self.__indices
  583. @property
  584. def values(self):
  585. return self.__values
  586. @property
  587. def dense_shape(self):
  588. return self.__dense_shape
  589. class SparseTensor:
  590. """
  591. A sparse representation of a set of nonzero elememts from a tensor at given indices.
  592. SparseTensor can only be used in the `Cell`'s construct method.
  593. Pynative mode not supported at the moment.
  594. For a tensor dense, its SparseTensor(indices, values, dense_shape) has
  595. `dense[indices[i]] = values[i]`.
  596. Args:
  597. indices (Tensor): A 2-D integer Tensor of shape `[N, ndims]`,
  598. where N and ndims are the number of values and number of dimensions in
  599. the SparseTensor, respectively.
  600. values (Tensor): A 1-D tensor of any type and shape `[N]`, which
  601. supplies the values for each element in indices.
  602. dense_shape (tuple): A integer tuple of size `ndims`,
  603. which specifies the dense_shape of the sparse tensor.
  604. Returns:
  605. SparseTensor, composed of `indices`, `values`, and `dense_shape`.
  606. Examples:
  607. >>> import mindspore as ms
  608. >>> import mindspore.nn as nn
  609. >>> class Net(nn.Cell):
  610. ... def __init__(self, dense_shape):
  611. ... super(Net, self).__init__()
  612. ... self.dense_shape = dense_shape
  613. ... def construct(self, indices, values):
  614. ... x = SparseTensor(indices, values, self.dense_shape)
  615. ... return x.values, x.indices, x.dense_shape
  616. >>>
  617. >>> indices = Tensor([[0, 1], [1, 2]])
  618. >>> values = Tensor([1, 2], dtype=ms.float32)
  619. >>> out = Net((3, 4))(indices, values)
  620. >>> print(out[0])
  621. [1. 2.]
  622. >>> print(out[1])
  623. [[0 1]
  624. [1 2]]
  625. >>> print(out[2])
  626. (3, 4)
  627. """
  628. def __init__(self, indices, values, dense_shape):
  629. "Init SparseTensor"
  630. self.__indices = indices
  631. self.__values = values
  632. self.__dense_shape = dense_shape
  633. @property
  634. def indices(self):
  635. return self.__indices
  636. @property
  637. def values(self):
  638. return self.__values
  639. @property
  640. def dense_shape(self):
  641. return self.__dense_shape
  642. def _vm_compare(*args):
  643. """Implement `vm_compare` for tensor."""
  644. obj_str = args[-1]
  645. if obj_str == "shape":
  646. fn = getattr(args[0].asnumpy(), obj_str)
  647. return fn
  648. if len(args) == 2:
  649. fn = getattr(args[0].asnumpy(), obj_str)
  650. return Tensor(fn())
  651. if isinstance(args[0], Tensor):
  652. fn = getattr(args[0].asnumpy(), obj_str)
  653. y = args[1].asnumpy() if isinstance(args[1], Tensor) else args[1]
  654. else:
  655. obj_str = "__r" + obj_str[2:]
  656. fn = getattr(args[1].asnumpy(), obj_str)
  657. y = args[0]
  658. return Tensor(np.array(fn(y)))
  659. tensor_operator_registry.register('vm_compare', _vm_compare)