You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

_tensor_array.py 9.4 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285
  1. # Copyright 2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Operators for TensorArray."""
  16. import mindspore as ms
  17. from ..._checkparam import Validator as validator
  18. from ..._checkparam import Rel
  19. from ...common import dtype as mstype
  20. from ..primitive import prim_attr_register, PrimitiveWithInfer, Primitive
  21. class TensorArray(PrimitiveWithInfer):
  22. r"""
  23. TensorArrayCreate used to create a TensorArray and return an unique handle.
  24. Args:
  25. dtype (mindspore.dtype): the data type in the TensorArray.
  26. element_shape (List[int]): the shape of each tensor in a TensorArray.
  27. dynamic_size (bool): If true the TensorArray can increase the size. Default: True.
  28. size (int): The size of the TensorArray if dynamic_size = False.
  29. name (string): the name of this TensorArray. Default: "TA".
  30. Inputs:
  31. None.
  32. Outputs:
  33. - **output** (Tensor[mindspore.int64]) - an unique handle binded to the TensorArray.
  34. Supported Platforms:
  35. ``GPU``
  36. Examples:
  37. >>> import mindspore
  38. >>> import mindspore.ops as ops
  39. >>> create_op = ops.TensorArray(mindspore.int32, ())
  40. >>> handle = create_op()
  41. >>> print(handle)
  42. 0
  43. """
  44. @prim_attr_register
  45. def __init__(self, dtype, element_shape, dynamic_size=True, size=0, name="TA"):
  46. validator.check_type_name("dtype", dtype, mstype.number_type, self.name)
  47. validator.check_int(size, 0, Rel.GE, "size", self.name)
  48. self.add_prim_attr('dtype', dtype)
  49. self.add_prim_attr('element_shape', element_shape)
  50. self.add_prim_attr('dynamic_size', dynamic_size)
  51. self.add_prim_attr('size', size)
  52. self.add_prim_attr('side_effect_mem', True)
  53. self.add_prim_attr('name', name)
  54. def infer_shape(self):
  55. return ()
  56. def infer_dtype(self):
  57. return mstype.int64
  58. class TensorArrayWrite(PrimitiveWithInfer):
  59. r"""
  60. TensorArrayWrite used to write tensor into a created TensorArray.
  61. Inputs:
  62. - **index** (Tensor[int64]) - The position to write.
  63. - **value** (Tensor) - The value to add into the TensorArray.
  64. - **handle** (Tensor[int64]) - The handle pointed to the TensorArray.
  65. Outputs:
  66. None.
  67. Supported Platforms:
  68. ``GPU``
  69. Examples:
  70. >>> import mindspore
  71. >>> import mindspore.ops as ops
  72. >>> create_op = ops.TensorArray(mindspore.int32, ())
  73. >>> handle = create_op()
  74. >>> write_op = ops.TensorArrayWrite()
  75. >>> write_op.write(handle, 0, 1)
  76. """
  77. @prim_attr_register
  78. def __init__(self):
  79. self.add_prim_attr('side_effect_mem', True)
  80. def infer_shape(self, handle_shape, index_shape, value_shape):
  81. return ()
  82. def infer_dtype(self, handle_type, index_type, value_type):
  83. validator.check_type_name("handle", handle_type, (ms.int64), self.name)
  84. validator.check_type_name("index", index_type, (int, ms.int64), self.name)
  85. validator.check_type_name("value", value_type, mstype.number_type, self.name)
  86. return mstype.int64
  87. class TensorArrayRead(PrimitiveWithInfer):
  88. r"""
  89. TensorArrayRead used to read tensor from a created TensorArray by the given index.
  90. Args:
  91. dtype (mindspore.dtype): the data type in the TensorArray.
  92. element_shape (List[int]): the shape of each tensor in a TensorArray.
  93. Inputs:
  94. - **index** (Tensor[int64]) - The position to read.
  95. - **handle** (mindspore.int64) - The handle pointed to the TensorArray.
  96. Outputs:
  97. - **output** (Tensor) - the value in position index.
  98. Supported Platforms:
  99. ``GPU``
  100. Examples:
  101. >>> import mindspore
  102. >>> import mindspore.ops as ops
  103. >>> create_op = ops.TensorArray(mindspore.int32, ())
  104. >>> handle = create_op()
  105. >>> write_op = ops.TensorArrayWrite()
  106. >>> write_op.write(handle, 0, 1)
  107. >>> read_op = ops.TensorArrayRead(mindspore.int32, ())
  108. >>> ans = read_op(handle, 0)
  109. >>> print(ans)
  110. 1
  111. """
  112. @prim_attr_register
  113. def __init__(self, dtype, element_shape):
  114. validator.check_type_name("dtype", dtype, mstype.number_type, self.name)
  115. self.add_prim_attr('dtype', dtype)
  116. self.add_prim_attr('element_shape', element_shape)
  117. self.add_prim_attr('side_effect_mem', True)
  118. self.dtype = dtype
  119. self.shape = element_shape
  120. def infer_shape(self, handle_shape, index_shape):
  121. return self.shape
  122. def infer_dtype(self, handle_type, index_type):
  123. validator.check_type_name("handle", handle_type, (ms.int64), self.name)
  124. validator.check_type_name("index", index_type, (int, ms.int64), self.name)
  125. return self.dtype
  126. class TensorArrayClose(PrimitiveWithInfer):
  127. r"""
  128. TensorArrayClose used to close the created TensorArray. The resources in TensorArray will be deleted.
  129. Inputs:
  130. - **handle** (mindspore.int64) - The handle pointed to the TensorArray.
  131. Outputs:
  132. None.
  133. Supported Platforms:
  134. ``GPU``
  135. Examples:
  136. >>> import mindspore
  137. >>> import mindspore.ops as ops
  138. >>> create_op = ops.TensorArray(mindspore.int32, ())
  139. >>> handle = create_op()
  140. >>> close_op = ops.TensorArrayClose()
  141. >>> close_op(handle)
  142. """
  143. @prim_attr_register
  144. def __init__(self):
  145. self.add_prim_attr('side_effect_mem', True)
  146. def infer_shape(self, handle_shape):
  147. return ()
  148. def infer_dtype(self, handle_type):
  149. validator.check_type_name("handle", handle_type, (ms.int64), self.name)
  150. return mstype.int64
  151. class TensorArrayClear(PrimitiveWithInfer):
  152. r"""
  153. TensorArrayClear used to reset the created TensorArray. The instance of TensorArray is still aviliable.
  154. Inputs:
  155. - **handle** (mindspore.int64) - The handle pointed to the TensorArray.
  156. Outputs:
  157. None.
  158. Supported Platforms:
  159. ``GPU``
  160. Examples:
  161. >>> import mindspore
  162. >>> import mindspore.ops as ops
  163. >>> create_op = ops.TensorArray(mindspore.int32, ())
  164. >>> handle = create_op()
  165. >>> clear_op = ops.TensorArrayClear()
  166. >>> clear_op(handle)
  167. """
  168. @prim_attr_register
  169. def __init__(self):
  170. self.add_prim_attr('side_effect_mem', True)
  171. def infer_shape(self, handle_shape):
  172. return ()
  173. def infer_dtype(self, handle_type):
  174. validator.check_type_name("handle", handle_type, (ms.int64), self.name)
  175. return mstype.int64
  176. class TensorArrayStack(Primitive):
  177. r"""
  178. TensorArrayStack used to stack the tensors in a created TensorArray into one tensor.
  179. Args:
  180. dtype (mindspore.dtype): the data type in the TensorArray.
  181. element_shape (List[int]): the shape of each tensor in a TensorArray.
  182. Inputs:
  183. - **handle** (mindspore.int64) - The handle pointed to the TensorArray.
  184. Outputs:
  185. - **output** (Tensor) - the stacked value from the TensorArray.
  186. Supported Platforms:
  187. ``GPU``
  188. Examples:
  189. >>> import mindspore
  190. >>> import mindspore.ops as ops
  191. >>> create_op = ops.TensorArray(mindspore.int32, ())
  192. >>> handle = create_op()
  193. >>> write_op = ops.TensorArrayWrite()
  194. >>> write_op.write(handle, 0, 1)
  195. >>> write_op.write(handle, 1, 2)
  196. >>> stack_op = ops.TensorArrayStack(mindspore.int32, ())
  197. >>> ans = stack_op(handle)
  198. >>> print(ans)
  199. [1 2]
  200. """
  201. @prim_attr_register
  202. def __init__(self, dtype, element_shape):
  203. """Initialize TensorArrayStack"""
  204. self.init_prim_io_names(inputs=[''], outputs=['output'])
  205. self.add_prim_attr('dtype', dtype)
  206. self.add_prim_attr('element_shape', element_shape)
  207. self.add_prim_attr('is_dynamic_shape', True)
  208. self.add_prim_attr('side_effect_mem', True)
  209. class TensorArraySize(PrimitiveWithInfer):
  210. r"""
  211. TensorArraySize used to get the logical size of the created TensorArray.
  212. Inputs:
  213. - **handle** (mindspore.int64) - The handle pointed to the TensorArray.
  214. Outputs:
  215. - **output** (Tensor[mindspore.int64]) - the logical size of the TensorArray.
  216. Supported Platforms:
  217. ``GPU``
  218. Examples:
  219. >>> import mindspore
  220. >>> import mindspore.ops as ops
  221. >>> create_op = ops.TensorArray(mindspore.int32, ())
  222. >>> handle = create_op()
  223. >>> size_op = ops.TensorArraySize()
  224. >>> size = size_op(handle)
  225. """
  226. @prim_attr_register
  227. def __init__(self):
  228. self.add_prim_attr('side_effect_mem', True)
  229. def infer_shape(self, handle_shape):
  230. return ()
  231. def infer_dtype(self, handle_type):
  232. validator.check_type_name("handle", handle_type, (ms.int64), self.name)
  233. return mstype.int64