You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

tensor.py 7.6 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Tensor implementation."""
  16. import numpy as np
  17. from .._c_expression import Tensor as Tensor_
  18. from .._c_expression import MetaTensor
  19. from .._checkparam import check_type, check_typename
  20. from . import dtype as mstype
  21. from ._register_for_tensor import tensor_operator_registry
  22. __all__ = ['Tensor', 'MetaTensor', 'IndexedSlices']
  23. np_types = (np.int8, np.int16, np.int32, np.int64,
  24. np.uint8, np.uint16, np.uint32, np.uint64, np.float16,
  25. np.float32, np.float64, np.bool_)
  26. class Tensor(Tensor_):
  27. """
  28. Tensor for data storage.
  29. Tensor inherits tensor object in C++ side, some functions are implemented
  30. in C++ side and some functions are implemented in Python layer.
  31. Args:
  32. input_data (Tensor, float, int, bool, tuple, list, numpy.ndarray): Input data of the tensor.
  33. dtype (:class:`mindspore.dtype`): Should be None, bool or numeric type defined in `mindspore.dtype`.
  34. The argument is used to define the data type of the output tensor. If it is None, the data type of the
  35. output tensor will be as same as the `input_data`. Default: None.
  36. Outputs:
  37. Tensor, with the same shape as `input_data`.
  38. Examples:
  39. >>> # init a tensor with input data
  40. >>> t1 = Tensor(np.zeros([1, 2, 3]), mindspore.float32)
  41. >>> assert isinstance(t1, Tensor)
  42. >>> assert t1.shape == (1, 2, 3)
  43. >>> assert t1.dtype == mindspore.float32
  44. >>>
  45. >>> # init a tensor with a float scalar
  46. >>> t2 = Tensor(0.1)
  47. >>> assert isinstance(t2, Tensor)
  48. >>> assert t2.dtype == mindspore.float64
  49. """
  50. def __init__(self, input_data, dtype=None):
  51. # If input data is numpy number, convert it to np array
  52. if isinstance(input_data, np_types):
  53. input_data = np.array(input_data)
  54. # If input_data is tuple/list/numpy.ndarray, it's support in check_type method.
  55. check_type('tensor input_data', input_data, (Tensor_, float, int))
  56. if dtype is not None:
  57. check_typename('dtype', dtype, mstype.number_type + (mstype.bool_,))
  58. if isinstance(input_data, np.ndarray) and (not input_data.flags['FORC']):
  59. input_data = np.ascontiguousarray(input_data)
  60. if dtype is None:
  61. Tensor_.__init__(self, input_data)
  62. else:
  63. Tensor_.__init__(self, input_data, dtype)
  64. self._virtual_flag = False
  65. self._init_flag = False
  66. def __repr__(self):
  67. return str(self.__str__())
  68. def __add__(self, other):
  69. out = tensor_operator_registry.get('__add__')(self, other)
  70. return out
  71. def __eq__(self, other):
  72. if not isinstance(other, (int, float, Tensor)):
  73. return False
  74. # bool type is not supported for `Equal` operator in backend.
  75. if self.dtype == mstype.bool_ or (isinstance(other, Tensor) and other.dtype == mstype.bool_):
  76. return Tensor(np.array(self.asnumpy() == other.asnumpy()))
  77. return tensor_operator_registry.get('__eq__')(self, other)
  78. def __ne__(self, other):
  79. if not isinstance(other, (int, float, Tensor)):
  80. return True
  81. # bool type is not supported for `NotEqual` operator in backend.
  82. if self.dtype == mstype.bool_ or (isinstance(other, Tensor) and other.dtype == mstype.bool_):
  83. return Tensor(np.array(self.asnumpy() != other.asnumpy()))
  84. return tensor_operator_registry.get('__ne__')(self, other)
  85. def __hash__(self):
  86. return hash(id(self))
  87. def __mul__(self, other):
  88. out = tensor_operator_registry.get('__mul__')(self, other)
  89. return out
  90. def __neg__(self):
  91. out = tensor_operator_registry.get('__neg__')(self)
  92. return out
  93. def __pos__(self):
  94. return self
  95. def __iadd__(self, other):
  96. return self.__add__(other)
  97. def __radd__(self, other):
  98. out = tensor_operator_registry.get('__add__')(self, other)
  99. return out
  100. def __imul__(self, other):
  101. return self.__mul__(other)
  102. def __rmul__(self, other):
  103. out = tensor_operator_registry.get('__mul__')(self, other)
  104. return out
  105. def __truediv__(self, other):
  106. out = tensor_operator_registry.get('__truediv__')(self, other)
  107. return out
  108. def __rtruediv__(self, other):
  109. out = tensor_operator_registry.get('__truediv__')(other, self)
  110. return out
  111. def __sub__(self, other):
  112. out = tensor_operator_registry.get('__sub__')(self, other)
  113. return out
  114. def __isub__(self, other):
  115. return self.__sub__(other)
  116. def __rsub__(self, other):
  117. out = tensor_operator_registry.get('__sub__')(other, self)
  118. return out
  119. def __lt__(self, other):
  120. out = tensor_operator_registry.get('__lt__')(self, other)
  121. return out
  122. def __le__(self, other):
  123. out = tensor_operator_registry.get('__le__')(self, other)
  124. return out
  125. def __getitem__(self, index):
  126. out = tensor_operator_registry.get('__getitem__')(self, index)
  127. return out
  128. def __setitem__(self, index, value):
  129. out = tensor_operator_registry.get('__setitem__')(self, index, value)
  130. self.assign_value(out)
  131. return self
  132. def __gt__(self, other):
  133. out = tensor_operator_registry.get('__gt__')(self, other)
  134. return out
  135. def __ge__(self, other):
  136. out = tensor_operator_registry.get('__ge__')(self, other)
  137. return out
  138. def __len__(self):
  139. out = tensor_operator_registry.get('shape')(self)
  140. if not out:
  141. return 1
  142. return out[0]
  143. def __mod__(self, other):
  144. return tensor_operator_registry.get('__mod__')(self, other)
  145. def __imod__(self, other):
  146. return self.__mod__(other)
  147. def __floordiv__(self, other):
  148. return tensor_operator_registry.get('__floordiv__')(self, other)
  149. def __ifloordiv__(self, other):
  150. return self.__floordiv__(other)
  151. def __str__(self):
  152. if self.dtype == mstype.type_none:
  153. return "Unknown Tensor type!"
  154. return str(self.asnumpy())
  155. @property
  156. def virtual_flag(self):
  157. """Mark tensor is virtual."""
  158. return self._virtual_flag
  159. @virtual_flag.setter
  160. def virtual_flag(self, value):
  161. """The setter of virtual_flag."""
  162. if not isinstance(value, bool):
  163. raise TypeError("virtual_flag must be bool.")
  164. self._virtual_flag = value
  165. @property
  166. def init_flag(self):
  167. """whether the tensor is init."""
  168. return self._init_flag
  169. @init_flag.setter
  170. def init_flag(self, value):
  171. """Set the tensor is init_flag."""
  172. if not isinstance(value, bool):
  173. raise TypeError("init_flag must be bool.")
  174. self.set_init_flag(value)
  175. self._init_flag = value
  176. class IndexedSlices:
  177. def __init__(self, indices, values, dense_shape):
  178. raise NotImplementedError