You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

tensor_array.py 4.8 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142
  1. # Copyright 2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """
  16. TensorArray
  17. """
  18. from mindspore.nn.cell import Cell
  19. from mindspore.ops.operations import _tensor_array as ta
  20. from mindspore._checkparam import Rel, Validator
  21. from mindspore.common import dtype as mstype
  22. class TensorArray(Cell):
  23. r"""TensorArray: a dynamic array to store tensors.
  24. .. warning::
  25. This is an experiential prototype that is subject to change and/or deletion.
  26. Args:
  27. dtype (mindspore.dtype): the data type in the TensorArray.
  28. element_shape (List[int]): the shape of each tensor in a TensorArray.
  29. dynamic_size (bool): if true, the size of TensorArray can be increased. Default: True.
  30. size (int): if dynamic_size=False, `size` means the max_size of the TensorArray.
  31. name (string): the name of this TensorArray. Default: "TA".
  32. Supported Platforms:
  33. ``GPU``
  34. Examples:
  35. >>> import mindspore
  36. >>> import mindspore.nn as nn
  37. >>> ta = nn.TensorArray(mindspore.int64, ())
  38. >>> ta.write(0, 1)
  39. >>> ta.write(1, 2)
  40. >>> ans = ta.read(1)
  41. >>> print(ans)
  42. 2
  43. >>> s = ta.stack()
  44. >>> print(s)
  45. [1 2]
  46. >>> ta.clear()
  47. >>> ta.write(0, 3)
  48. >>> ans = ta.read(0)
  49. >>> print(ans)
  50. 3
  51. >>> ta.close()
  52. """
  53. def __init__(self, dtype, element_shape, dynamic_size=True, size=0, name="TA"):
  54. """Initialize TensorArray"""
  55. super(TensorArray, self).__init__()
  56. Validator.check_subclass("dtype", dtype, mstype.number_type, self.cls_name)
  57. Validator.check_int(size, 0, Rel.GE, "size", self.cls_name)
  58. self.handle_ = ta.TensorArray(dtype, element_shape, dynamic_size, size, name)()
  59. self.tensor_array_write = ta.TensorArrayWrite()
  60. self.tensor_array_read = ta.TensorArrayRead(dtype, element_shape)
  61. self.tensor_array_close = ta.TensorArrayClose()
  62. self.tensor_array_clear = ta.TensorArrayClear()
  63. self.tensor_array_stack = ta.TensorArrayStack(dtype, element_shape)
  64. self.tensor_array_size = ta.TensorArraySize()
  65. def write(self, index, value):
  66. """
  67. Write value(Tensor) to TensorArray in position index.
  68. Args:
  69. index ([int, mindspore.int64]): The position to write.
  70. value (Tensor): The value to add into the TensorArray.
  71. Returns:
  72. Bool, true.
  73. """
  74. self.tensor_array_write(self.handle_, index, value)
  75. return True
  76. def read(self, index):
  77. """
  78. Read tensor form the TensorArray by the given position index.
  79. Args:
  80. index ([int, mindspore.int64]): The given index to get the tensor.
  81. Returns:
  82. Tensor, the value in position index.
  83. """
  84. value = self.tensor_array_read(self.handle_, index)
  85. return value
  86. def close(self):
  87. """
  88. Close the created TensorArray.
  89. .. warning::
  90. Once close the TensorArray, every functions belong to this TensorArray will be disaviliable.
  91. Every resources created in TensorArray will be removed. If this TensorArray will be used in next step
  92. or somewhere, eg: next loop, please use `clear` instead.
  93. Returns:
  94. Bool, true.
  95. """
  96. self.tensor_array_close(self.handle_)
  97. return True
  98. def clear(self):
  99. """
  100. Clear the created TensorArray. Only reset the TensorArray, clear the data and reset the size
  101. in TensorArray and keep the instance of this TensorArray.
  102. Returns:
  103. Bool, true.
  104. """
  105. self.tensor_array_clear(self.handle_)
  106. return True
  107. def stack(self):
  108. """
  109. Stack the values in TensorArray into a stacked Tensor.
  110. Returns:
  111. Tensor, all the values will be stacked into one tensor.
  112. """
  113. ans = self.tensor_array_stack(self.handle_)
  114. return ans
  115. def size(self):
  116. """
  117. The logical size of TensorArray.
  118. Returns:
  119. Tensor, the size of TensorArray.
  120. """
  121. size = self.tensor_array_size(self.handle_)
  122. return size