You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

tensor.py 5.4 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Tensor implementation."""
  16. import numpy as np
  17. from .._c_expression import Tensor as Tensor_
  18. from .._c_expression import MetaTensor
  19. from .._checkparam import check_type, check_typename
  20. from . import dtype as mstype
  21. from ._register_for_tensor import tensor_operator_registry
  22. __all__ = ['Tensor', 'MetaTensor']
  23. class Tensor(Tensor_):
  24. """
  25. Tensor for data storage.
  26. Tensor inherits tensor object in C++ side, some functions are implemented
  27. in C++ side and some functions are implemented in Python layer.
  28. Args:
  29. input_data (Tensor, float, int, bool, tuple, list, numpy.ndarray): Input data of the tensor.
  30. dtype (:class:`mindspore.dtype`): Should be None, bool or numeric type defined in `mindspore.dtype`.
  31. The argument is used to define the data type of the output tensor. If it is None, the data type of the
  32. output tensor will be as same as the `input_data`. Default: None.
  33. Outputs:
  34. Tensor, with the same shape as `input_data`.
  35. Examples:
  36. >>> # init a tensor with input data
  37. >>> t1 = Tensor(np.zeros([1, 2, 3]), mindspore.float32)
  38. >>> assert isinstance(t1, Tensor)
  39. >>> assert t1.shape() == (1, 2, 3)
  40. >>> assert t1.dtype() == mindspore.float32
  41. >>>
  42. >>> # init a tensor with a float scalar
  43. >>> t2 = Tensor(0.1)
  44. >>> assert isinstance(t2, Tensor)
  45. >>> assert t2.dtype() == mindspore.float64
  46. """
  47. def __init__(self, input_data, dtype=None):
  48. # If input_data is tuple/list/numpy.ndarray, it's support in check_type method.
  49. check_type('tensor input_data', input_data, (Tensor_, float, int))
  50. if dtype is not None:
  51. check_typename('dtype', dtype, mstype.number_type + (mstype.bool_,))
  52. if isinstance(input_data, np.ndarray) and (not input_data.flags['FORC']):
  53. input_data = np.ascontiguousarray(input_data)
  54. if dtype is None:
  55. super(Tensor, self).__init__(input_data)
  56. else:
  57. super(Tensor, self).__init__(input_data, dtype)
  58. self._virtual_flag = False
  59. def __repr__(self):
  60. return str(self.__str__())
  61. def __add__(self, other):
  62. check_type('tensor input_data', other, (Tensor, float, int))
  63. out = tensor_operator_registry.get('__add__')(self, other)
  64. return out
  65. def __eq__(self, other):
  66. if not isinstance(other, Tensor):
  67. return False
  68. x = self.asnumpy()
  69. y = other.asnumpy()
  70. out = np.equal(x, y)
  71. return Tensor(np.array(out))
  72. def __hash__(self):
  73. return hash(id(self))
  74. def __mul__(self, other):
  75. check_type('tensor input_data', other, (Tensor, float, int))
  76. out = tensor_operator_registry.get('__mul__')(self, other)
  77. return out
  78. def __neg__(self):
  79. return Tensor(-self.asnumpy())
  80. def __iadd__(self, other):
  81. out = self.__add__(other)
  82. return out
  83. def __radd__(self, other):
  84. check_type('tensor operation input', other, (Tensor, float, int))
  85. out = tensor_operator_registry.get('__add__')(other, self)
  86. return out
  87. def __imul__(self, other):
  88. out = self.__mul__(other)
  89. return out
  90. def __rmul__(self, other):
  91. check_type('tensor operation input', other, (Tensor, float, int))
  92. out = tensor_operator_registry.get('__mul__')(other, self)
  93. return out
  94. def __truediv__(self, other):
  95. check_type('tensor operation input', other, (Tensor, float, int))
  96. out = tensor_operator_registry.get('__div__')(self, other)
  97. return out
  98. def __rtruediv__(self, other):
  99. check_type('tensor operation input', other, (Tensor, float, int))
  100. out = tensor_operator_registry.get('__div__')(other, self)
  101. return out
  102. def __sub__(self, other):
  103. check_type('tensor operation input', other, (Tensor, float, int))
  104. out = self.__add__(-other)
  105. return out
  106. def __isub__(self, other):
  107. out = self.__sub__(other)
  108. return out
  109. def __rsub__(self, other):
  110. check_type('tensor operation input', other, (Tensor, float, int))
  111. out = tensor_operator_registry.get('__add__')(other, Tensor(-self.asnumpy()))
  112. return out
  113. def __str__(self):
  114. if self.dtype() == mstype.type_none:
  115. return "Unknown Tensor type!"
  116. return str(self.asnumpy())
  117. @property
  118. def virtual_flag(self):
  119. """Mark tensor is virtual."""
  120. return self._virtual_flag
  121. @virtual_flag.setter
  122. def virtual_flag(self, value):
  123. """The setter of virtual_flag."""
  124. if not isinstance(value, bool):
  125. raise TypeError("virtual_flag must be bool.")
  126. self._virtual_flag = value