|
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285 |
- # Copyright 2021 Huawei Technologies Co., Ltd
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # ============================================================================
-
- """Operators for TensorArray."""
-
- import mindspore as ms
- from ..._checkparam import Validator as validator
- from ..._checkparam import Rel
- from ...common import dtype as mstype
- from ..primitive import prim_attr_register, PrimitiveWithInfer, Primitive
-
- class TensorArray(PrimitiveWithInfer):
- r"""
- TensorArrayCreate used to create a TensorArray and return an unique handle.
-
- Args:
- dtype (mindspore.dtype): the data type in the TensorArray.
- element_shape (List[int]): the shape of each tensor in a TensorArray.
- dynamic_size (bool): If true the TensorArray can increase the size. Default: True.
- size (int): The size of the TensorArray if dynamic_size = False.
- name (string): the name of this TensorArray. Default: "TA".
-
- Inputs:
- None.
-
- Outputs:
- - **output** (Tensor[mindspore.int64]) - an unique handle binded to the TensorArray.
-
- Supported Platforms:
- ``GPU``
-
- Examples:
- >>> import mindspore
- >>> import mindspore.ops as ops
- >>> create_op = ops.TensorArray(mindspore.int32, ())
- >>> handle = create_op()
- >>> print(handle)
- 0
- """
- @prim_attr_register
- def __init__(self, dtype, element_shape, dynamic_size=True, size=0, name="TA"):
- validator.check_type_name("dtype", dtype, mstype.number_type, self.name)
- validator.check_int(size, 0, Rel.GE, "size", self.name)
- self.add_prim_attr('dtype', dtype)
- self.add_prim_attr('element_shape', element_shape)
- self.add_prim_attr('dynamic_size', dynamic_size)
- self.add_prim_attr('size', size)
- self.add_prim_attr('side_effect_mem', True)
- self.add_prim_attr('name', name)
-
- def infer_shape(self):
- return ()
-
- def infer_dtype(self):
- return mstype.int64
-
- class TensorArrayWrite(PrimitiveWithInfer):
- r"""
- TensorArrayWrite used to write tensor into a created TensorArray.
-
- Inputs:
- - **index** (Tensor[int64]) - The position to write.
- - **value** (Tensor) - The value to add into the TensorArray.
- - **handle** (Tensor[int64]) - The handle pointed to the TensorArray.
-
- Outputs:
- None.
-
- Supported Platforms:
- ``GPU``
-
- Examples:
- >>> import mindspore
- >>> import mindspore.ops as ops
- >>> create_op = ops.TensorArray(mindspore.int32, ())
- >>> handle = create_op()
- >>> write_op = ops.TensorArrayWrite()
- >>> write_op.write(handle, 0, 1)
- """
- @prim_attr_register
- def __init__(self):
- self.add_prim_attr('side_effect_mem', True)
-
- def infer_shape(self, handle_shape, index_shape, value_shape):
- return ()
-
- def infer_dtype(self, handle_type, index_type, value_type):
- validator.check_type_name("handle", handle_type, (ms.int64), self.name)
- validator.check_type_name("index", index_type, (int, ms.int64), self.name)
- validator.check_type_name("value", value_type, mstype.number_type, self.name)
- return mstype.int64
-
- class TensorArrayRead(PrimitiveWithInfer):
- r"""
- TensorArrayRead used to read tensor from a created TensorArray by the given index.
-
- Args:
- dtype (mindspore.dtype): the data type in the TensorArray.
- element_shape (List[int]): the shape of each tensor in a TensorArray.
-
- Inputs:
- - **index** (Tensor[int64]) - The position to read.
- - **handle** (mindspore.int64) - The handle pointed to the TensorArray.
-
- Outputs:
- - **output** (Tensor) - the value in position index.
-
- Supported Platforms:
- ``GPU``
-
- Examples:
- >>> import mindspore
- >>> import mindspore.ops as ops
- >>> create_op = ops.TensorArray(mindspore.int32, ())
- >>> handle = create_op()
- >>> write_op = ops.TensorArrayWrite()
- >>> write_op.write(handle, 0, 1)
- >>> read_op = ops.TensorArrayRead(mindspore.int32, ())
- >>> ans = read_op(handle, 0)
- >>> print(ans)
- 1
- """
- @prim_attr_register
- def __init__(self, dtype, element_shape):
- validator.check_type_name("dtype", dtype, mstype.number_type, self.name)
- self.add_prim_attr('dtype', dtype)
- self.add_prim_attr('element_shape', element_shape)
- self.add_prim_attr('side_effect_mem', True)
- self.dtype = dtype
- self.shape = element_shape
-
- def infer_shape(self, handle_shape, index_shape):
- return self.shape
-
- def infer_dtype(self, handle_type, index_type):
- validator.check_type_name("handle", handle_type, (ms.int64), self.name)
- validator.check_type_name("index", index_type, (int, ms.int64), self.name)
- return self.dtype
-
- class TensorArrayClose(PrimitiveWithInfer):
- r"""
- TensorArrayClose used to close the created TensorArray. The resources in TensorArray will be deleted.
-
- Inputs:
- - **handle** (mindspore.int64) - The handle pointed to the TensorArray.
-
- Outputs:
- None.
-
- Supported Platforms:
- ``GPU``
-
- Examples:
- >>> import mindspore
- >>> import mindspore.ops as ops
- >>> create_op = ops.TensorArray(mindspore.int32, ())
- >>> handle = create_op()
- >>> close_op = ops.TensorArrayClose()
- >>> close_op(handle)
- """
- @prim_attr_register
- def __init__(self):
- self.add_prim_attr('side_effect_mem', True)
-
- def infer_shape(self, handle_shape):
- return ()
-
- def infer_dtype(self, handle_type):
- validator.check_type_name("handle", handle_type, (ms.int64), self.name)
- return mstype.int64
-
- class TensorArrayClear(PrimitiveWithInfer):
- r"""
- TensorArrayClear used to reset the created TensorArray. The instance of TensorArray is still aviliable.
-
- Inputs:
- - **handle** (mindspore.int64) - The handle pointed to the TensorArray.
-
- Outputs:
- None.
-
- Supported Platforms:
- ``GPU``
-
- Examples:
- >>> import mindspore
- >>> import mindspore.ops as ops
- >>> create_op = ops.TensorArray(mindspore.int32, ())
- >>> handle = create_op()
- >>> clear_op = ops.TensorArrayClear()
- >>> clear_op(handle)
- """
- @prim_attr_register
- def __init__(self):
- self.add_prim_attr('side_effect_mem', True)
-
- def infer_shape(self, handle_shape):
- return ()
-
- def infer_dtype(self, handle_type):
- validator.check_type_name("handle", handle_type, (ms.int64), self.name)
- return mstype.int64
-
- class TensorArrayStack(Primitive):
- r"""
- TensorArrayStack used to stack the tensors in a created TensorArray into one tensor.
-
- Args:
- dtype (mindspore.dtype): the data type in the TensorArray.
- element_shape (List[int]): the shape of each tensor in a TensorArray.
-
- Inputs:
- - **handle** (mindspore.int64) - The handle pointed to the TensorArray.
-
- Outputs:
- - **output** (Tensor) - the stacked value from the TensorArray.
-
- Supported Platforms:
- ``GPU``
-
- Examples:
- >>> import mindspore
- >>> import mindspore.ops as ops
- >>> create_op = ops.TensorArray(mindspore.int32, ())
- >>> handle = create_op()
- >>> write_op = ops.TensorArrayWrite()
- >>> write_op.write(handle, 0, 1)
- >>> write_op.write(handle, 1, 2)
- >>> stack_op = ops.TensorArrayStack(mindspore.int32, ())
- >>> ans = stack_op(handle)
- >>> print(ans)
- [1 2]
- """
- @prim_attr_register
- def __init__(self, dtype, element_shape):
- """Initialize TensorArrayStack"""
- self.init_prim_io_names(inputs=[''], outputs=['output'])
- self.add_prim_attr('dtype', dtype)
- self.add_prim_attr('element_shape', element_shape)
- self.add_prim_attr('is_dynamic_shape', True)
- self.add_prim_attr('side_effect_mem', True)
-
- class TensorArraySize(PrimitiveWithInfer):
- r"""
- TensorArraySize used to get the logical size of the created TensorArray.
-
- Inputs:
- - **handle** (mindspore.int64) - The handle pointed to the TensorArray.
-
- Outputs:
- - **output** (Tensor[mindspore.int64]) - the logical size of the TensorArray.
-
- Supported Platforms:
- ``GPU``
-
- Examples:
- >>> import mindspore
- >>> import mindspore.ops as ops
- >>> create_op = ops.TensorArray(mindspore.int32, ())
- >>> handle = create_op()
- >>> size_op = ops.TensorArraySize()
- >>> size = size_op(handle)
- """
- @prim_attr_register
- def __init__(self):
- self.add_prim_attr('side_effect_mem', True)
-
- def infer_shape(self, handle_shape):
- return ()
-
- def infer_dtype(self, handle_type):
- validator.check_type_name("handle", handle_type, (ms.int64), self.name)
- return mstype.int64
|