You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

tensor.py 16 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Tensor implementation."""
  16. import numpy as np
  17. from mindspore import log as logger
  18. from .._c_expression import Tensor as Tensor_
  19. from .._c_expression import MetaTensor as MetaTensor_
  20. from .._checkparam import check_type, check_typename
  21. from . import dtype as mstype
  22. from ._register_for_tensor import tensor_operator_registry
  23. __all__ = ['Tensor', 'MetaTensor', 'RowTensor', 'SparseTensor']
  24. np_types = (np.int8, np.int16, np.int32, np.int64,
  25. np.uint8, np.uint16, np.uint32, np.uint64, np.float16,
  26. np.float32, np.float64, np.bool_)
  27. class Tensor(Tensor_):
  28. """
  29. Tensor is used for data storage.
  30. Tensor inherits tensor object in C++.
  31. Some functions are implemented in C++ and some functions are implemented in Python.
  32. Args:
  33. input_data (Tensor, float, int, bool, tuple, list, numpy.ndarray): Input data of the tensor.
  34. dtype (:class:`mindspore.dtype`): Input data should be None, bool or numeric type defined in `mindspore.dtype`.
  35. The argument is used to define the data type of the output tensor. If it is None, the data type of the
  36. output tensor will be as same as the `input_data`. Default: None.
  37. Outputs:
  38. Tensor, with the same shape as `input_data`.
  39. Examples:
  40. >>> # initialize a tensor with input data
  41. >>> t1 = Tensor(np.zeros([1, 2, 3]), mindspore.float32)
  42. >>> assert isinstance(t1, Tensor)
  43. >>> assert t1.shape == (1, 2, 3)
  44. >>> assert t1.dtype == mindspore.float32
  45. >>>
  46. >>> # initialize a tensor with a float scalar
  47. >>> t2 = Tensor(0.1)
  48. >>> assert isinstance(t2, Tensor)
  49. >>> assert t2.dtype == mindspore.float64
  50. """
  51. def __init__(self, input_data, dtype=None):
  52. # If input data is numpy number, convert it to np array
  53. if isinstance(input_data, np_types):
  54. input_data = np.array(input_data)
  55. # If input_data is tuple/list/numpy.ndarray, it's support in check_type method.
  56. check_type('tensor input_data', input_data, (Tensor_, float, int))
  57. if dtype is not None:
  58. check_typename('dtype', dtype, mstype.number_type + (mstype.bool_,))
  59. if isinstance(input_data, np.ndarray) and (not input_data.flags['FORC']):
  60. input_data = np.ascontiguousarray(input_data)
  61. if dtype is None:
  62. Tensor_.__init__(self, input_data)
  63. else:
  64. Tensor_.__init__(self, input_data, dtype)
  65. self._virtual_flag = False
  66. def __repr__(self):
  67. return Tensor_.__repr__(self)
  68. def __add__(self, other):
  69. out = tensor_operator_registry.get('__add__')(self, other)
  70. return out
  71. def __eq__(self, other):
  72. if not isinstance(other, (int, float, Tensor)):
  73. return False
  74. # bool type is not supported for `Equal` operator in backend.
  75. if self.dtype == mstype.bool_ or (isinstance(other, Tensor) and other.dtype == mstype.bool_):
  76. if isinstance(other, Tensor):
  77. return Tensor(np.array(self.asnumpy() == other.asnumpy()))
  78. return Tensor(np.array(self.asnumpy() == other))
  79. return tensor_operator_registry.get('__eq__')(self, other)
  80. def __ne__(self, other):
  81. if not isinstance(other, (int, float, Tensor)):
  82. return True
  83. # bool type is not supported for `NotEqual` operator in backend.
  84. if self.dtype == mstype.bool_ or (isinstance(other, Tensor) and other.dtype == mstype.bool_):
  85. return Tensor(np.array(self.asnumpy() != other.asnumpy()))
  86. return tensor_operator_registry.get('__ne__')(self, other)
  87. def __hash__(self):
  88. return hash(id(self))
  89. def __mul__(self, other):
  90. out = tensor_operator_registry.get('__mul__')(self, other)
  91. return out
  92. def __neg__(self):
  93. out = tensor_operator_registry.get('__neg__')(self)
  94. return out
  95. def __bool__(self):
  96. data = self.asnumpy()
  97. if data.shape == ():
  98. return bool(data)
  99. if data.shape == (1,):
  100. return bool(data[0])
  101. raise ValueError("The truth value of an array with several elements is ambiguous.")
  102. def __pos__(self):
  103. return self
  104. def __iadd__(self, other):
  105. return self.__add__(other)
  106. def __radd__(self, other):
  107. out = tensor_operator_registry.get('__add__')(self, other)
  108. return out
  109. def __imul__(self, other):
  110. return self.__mul__(other)
  111. def __rmul__(self, other):
  112. out = tensor_operator_registry.get('__mul__')(self, other)
  113. return out
  114. def __truediv__(self, other):
  115. out = tensor_operator_registry.get('__truediv__')(self, other)
  116. return out
  117. def __rtruediv__(self, other):
  118. out = tensor_operator_registry.get('__truediv__')(other, self)
  119. return out
  120. def __sub__(self, other):
  121. out = tensor_operator_registry.get('__sub__')(self, other)
  122. return out
  123. def __isub__(self, other):
  124. return self.__sub__(other)
  125. def __rsub__(self, other):
  126. out = tensor_operator_registry.get('__sub__')(other, self)
  127. return out
  128. def __lt__(self, other):
  129. out = tensor_operator_registry.get('__lt__')(self, other)
  130. return out
  131. def __le__(self, other):
  132. out = tensor_operator_registry.get('__le__')(self, other)
  133. return out
  134. def __getitem__(self, index):
  135. out = tensor_operator_registry.get('__getitem__')(self, index)
  136. return out
  137. def __setitem__(self, index, value):
  138. out = tensor_operator_registry.get('__setitem__')(self, index, value)
  139. self.assign_value(out)
  140. return self
  141. def __gt__(self, other):
  142. out = tensor_operator_registry.get('__gt__')(self, other)
  143. return out
  144. def __ge__(self, other):
  145. out = tensor_operator_registry.get('__ge__')(self, other)
  146. return out
  147. def __len__(self):
  148. out = tensor_operator_registry.get('shape')(self)
  149. if not out:
  150. return 1
  151. return out[0]
  152. def __mod__(self, other):
  153. return tensor_operator_registry.get('__mod__')(self, other)
  154. def __imod__(self, other):
  155. return self.__mod__(other)
  156. def __rmod__(self, other):
  157. return tensor_operator_registry.get('__mod__')(other, self)
  158. def __pow__(self, other):
  159. return tensor_operator_registry.get('__pow__')(self, other)
  160. def __floordiv__(self, other):
  161. return tensor_operator_registry.get('__floordiv__')(self, other)
  162. def __ifloordiv__(self, other):
  163. return self.__floordiv__(other)
  164. def __rfloordiv__(self, other):
  165. return tensor_operator_registry.get('__floordiv__')(other, self)
  166. def __str__(self):
  167. if self.dtype == mstype.type_none:
  168. return "Unknown Tensor type!"
  169. return str(self.asnumpy())
  170. @property
  171. def shape(self):
  172. """The shape of tensor is a tuple."""
  173. return self._shape
  174. @property
  175. def dtype(self):
  176. """The dtype of tensor is a mindspore type."""
  177. return self._dtype
  178. @property
  179. def virtual_flag(self):
  180. """Mark tensor is virtual."""
  181. return self._virtual_flag
  182. @virtual_flag.setter
  183. def virtual_flag(self, value):
  184. """The setter of virtual_flag."""
  185. if not isinstance(value, bool):
  186. raise TypeError("virtual_flag must be bool.")
  187. self._virtual_flag = value
  188. @staticmethod
  189. def from_numpy(array):
  190. """Convert numpy array to Tensor without copy data."""
  191. return Tensor(Tensor_.from_numpy(array))
  192. def asnumpy(self):
  193. """Convert tensor to numpy array."""
  194. return Tensor_.asnumpy(self)
  195. def all(self, axis=(), keep_dims=False):
  196. """
  197. Check all array elements along a given axis evaluate to True.
  198. Args:
  199. axis (Union[None, int, tuple(int)): Dimensions of reduction,
  200. when axis is None or empty tuple, reduce all dimensions.
  201. Default: (), reduce all dimensions.
  202. keep_dims (bool): Whether to keep the reduced dimensions.
  203. Default : False, don't keep these reduced dimensions.
  204. Returns:
  205. Tensor, has the same data type as x.
  206. """
  207. if axis is None:
  208. axis = ()
  209. return tensor_operator_registry.get('all')(keep_dims)(self, axis)
  210. def any(self, axis=(), keep_dims=False):
  211. """
  212. Check any array element along a given axis evaluate to True.
  213. Args:
  214. axis (Union[None, int, tuple(int)): Dimensions of reduction,
  215. when axis is None or empty tuple, reduce all dimensions.
  216. Default: (), reduce all dimensions.
  217. keep_dims (bool): Whether to keep the reduced dimensions.
  218. Default : False, don't keep these reduced dimensions.
  219. Returns:
  220. Tensor, has the same data type as x.
  221. """
  222. if axis is None:
  223. axis = ()
  224. return tensor_operator_registry.get('any')(keep_dims)(self, axis)
  225. class RowTensor:
  226. """
  227. A sparse representation of a set of tensor slices at given indices.
  228. An RowTensor is typically used to represent a subset of a larger
  229. tensor dense of shape [L0, D1, .. , DN] where L0 >> D0.
  230. The values in indices are the indices in the first dimension of the slices
  231. that have been extracted from the larger tensor.
  232. The dense tensor dense represented by an RowTensor slices has
  233. `dense[slices.indices[i], :, :, :, ...] = slices.values[i, :, :, :, ...]`.
  234. RowTensor can only be used in the `Cell`'s construct method.
  235. It is not supported in pynative mode at the moment.
  236. Args:
  237. indices (Tensor): A 1-D integer Tensor of shape [D0].
  238. values (Tensor): A Tensor of any dtype of shape [D0, D1, ..., Dn].
  239. dense_shape (tuple): An integer tuple which contains the shape
  240. of the corresponding dense tensor.
  241. Returns:
  242. RowTensor, composed of `indices`, `values`, and `dense_shape`.
  243. Examples:
  244. >>> class Net(nn.Cell):
  245. >>> def __init__(self, dense_shape):
  246. >>> super(Net, self).__init__()
  247. >>> self.dense_shape = dense_shape
  248. >>> def construct(self, indices, values):
  249. >>> x = RowTensor(indices, values, self.dense_shape)
  250. >>> return x.values, x.indices, x.dense_shape
  251. >>>
  252. >>> indices = Tensor([0])
  253. >>> values = Tensor([[1, 2]], dtype=ms.float32)
  254. >>> Net((3, 2))(indices, values)
  255. """
  256. def __init__(self, indices, values, dense_shape):
  257. "Init RowTensor"
  258. self.__indices = indices
  259. self.__values = values
  260. self.__dense_shape = dense_shape
  261. @property
  262. def indices(self):
  263. return self.__indices
  264. @property
  265. def values(self):
  266. return self.__values
  267. @property
  268. def dense_shape(self):
  269. return self.__dense_shape
  270. class SparseTensor:
  271. """
  272. A sparse representation of a set of nonzero elememts from a tensor at given indices.
  273. SparseTensor can only be used in the `Cell`'s construct method.
  274. Pynative mode not supported at the moment.
  275. For a tensor dense, its SparseTensor(indices, values, dense_shape) has
  276. `dense[indices[i]] = values[i]`.
  277. Args:
  278. indices (Tensor): A 2-D integer Tensor of shape `[N, ndims]`,
  279. where N and ndims are the number of values and number of dimensions in
  280. the SparseTensor, respectively.
  281. values (Tensor): A 1-D tensor of any type and shape `[N]`, which
  282. supplies the values for each element in indices.
  283. dense_shape (tuple): A integer tuple of size `ndims`,
  284. which specifies the dense_shape of the sparse tensor.
  285. Returns:
  286. SparseTensor, composed of `indices`, `values`, and `dense_shape`.
  287. Examples:
  288. >>> class Net(nn.Cell):
  289. >>> def __init__(self, dense_shape):
  290. >>> super(Net, self).__init__()
  291. >>> self.dense_shape = dense_shape
  292. >>> def construct(self, indices, values):
  293. >>> x = SparseTensor(indices, values, self.dense_shape)
  294. >>> return x.values, x.indices, x.dense_shape
  295. >>>
  296. >>> indices = Tensor([[0, 1], [1, 2]])
  297. >>> values = Tensor([1, 2], dtype=ms.float32)
  298. >>> Net((3, 4))(indices, values)
  299. """
  300. def __init__(self, indices, values, dense_shape):
  301. "Init SparseTensor"
  302. self.__indices = indices
  303. self.__values = values
  304. self.__dense_shape = dense_shape
  305. @property
  306. def indices(self):
  307. return self.__indices
  308. @property
  309. def values(self):
  310. return self.__values
  311. @property
  312. def dense_shape(self):
  313. return self.__dense_shape
  314. class MetaTensor(MetaTensor_):
  315. """
  316. The base class of the MetaTensor.
  317. Initialization of tensor basic attributes and model weight values.
  318. Returns:
  319. Array, an array after being initialized.
  320. """
  321. def __init__(self, dtype, shape, init=None):
  322. #check param
  323. self.init = init
  324. MetaTensor_.__init__(self, dtype, shape)
  325. def to_tensor(self, slice_index=None, shape=None):
  326. """
  327. Get the tensor format data of this MetaTensor.
  328. Args:
  329. slice_index (int): Slice index of a parameter's slices.
  330. It is used when initialize a slice of a parameter, it guarantees that devices
  331. using the same slice can generate the same tensor.
  332. shape (list[int]): Shape of the slice, it is used when initialize a slice of the parameter.
  333. """
  334. if self.init is None:
  335. raise TypeError("to_dense must be set MetaTensor.init, init can't be None")
  336. if shape is None:
  337. shape = self.shape
  338. try:
  339. arr = np.ndarray(shape, dtype=mstype.dtype_to_nptype(self.dtype))
  340. except ValueError:
  341. msg = "Error shape={}".format(shape)
  342. logger.error(msg)
  343. raise ValueError(msg)
  344. class seed_context:
  345. '''set and restore seed'''
  346. def __init__(self, init):
  347. self.init = init
  348. from .seed import get_seed
  349. global_seed = get_seed()
  350. self._np_seed = np.random.get_state()[1][0]
  351. self.need_set_seed = ((slice_index is not None) and (global_seed is None))
  352. self.seed = self.init.seed
  353. def __enter__(self):
  354. if self.need_set_seed:
  355. np.random.seed(slice_index)
  356. self.init.seed = slice_index
  357. def __exit__(self, ptype, value, trace):
  358. if self.need_set_seed:
  359. np.random.seed(self._np_seed)
  360. self.init.seed = self.seed
  361. with seed_context(self.init):
  362. self.init(arr)
  363. return Tensor(arr, dtype=self.dtype)
  364. def _vm_compare(*args):
  365. """Implement `vm_compare` for tensor."""
  366. obj_str = args[-1]
  367. if obj_str == "shape":
  368. fn = getattr(args[0].asnumpy(), obj_str)
  369. return fn
  370. if len(args) == 2:
  371. fn = getattr(args[0].asnumpy(), obj_str)
  372. return Tensor(fn())
  373. if isinstance(args[0], Tensor):
  374. fn = getattr(args[0].asnumpy(), obj_str)
  375. y = args[1].asnumpy() if isinstance(args[1], Tensor) else args[1]
  376. else:
  377. obj_str = "__r" + obj_str[2:]
  378. fn = getattr(args[1].asnumpy(), obj_str)
  379. y = args[0]
  380. return Tensor(np.array(fn(y)))
  381. tensor_operator_registry.register('vm_compare', _vm_compare)