| @@ -17,6 +17,7 @@ | |||||
| from mindspore.ops.primitive import constexpr, Primitive | from mindspore.ops.primitive import constexpr, Primitive | ||||
| from mindspore.ops import Reshape, Transpose, Pack, Unpack | from mindspore.ops import Reshape, Transpose, Pack, Unpack | ||||
| from mindspore.common import Tensor | from mindspore.common import Tensor | ||||
| from mindspore._checkparam import Validator | |||||
| from ..cell import Cell | from ..cell import Cell | ||||
| __all__ = ['TimeDistributed'] | __all__ = ['TimeDistributed'] | ||||
| @@ -69,13 +70,13 @@ class TimeDistributed(Cell): | |||||
| Args: | Args: | ||||
| layer(Union[Cell, Primitive]): The Cell or Primitive which will be wrapped. | layer(Union[Cell, Primitive]): The Cell or Primitive which will be wrapped. | ||||
| time_axis(int): The axis of time_step. | time_axis(int): The axis of time_step. | ||||
| reshape_with_axis(int): The axis which time_axis will be reshaped with. Default: 'None'. | |||||
| reshape_with_axis(int): The axis which time_axis will be reshaped with. Default: None. | |||||
| Inputs: | Inputs: | ||||
| - **input** (Tensor) - Tensor of shape :math:`(N, T, *)`. | - **input** (Tensor) - Tensor of shape :math:`(N, T, *)`. | ||||
| Outputs: | Outputs: | ||||
| Tensor of shape: math:'(N, T, *)' | |||||
| Tensor of shape :math:`(N, T, *)` | |||||
| Supported Platforms: | Supported Platforms: | ||||
| ``Ascend`` ``GPU`` ``CPU`` | ``Ascend`` ``GPU`` ``CPU`` | ||||
| @@ -97,6 +98,9 @@ class TimeDistributed(Cell): | |||||
| raise TypeError("Please initialize TimeDistributed with mindspore.nn.Cell or " | raise TypeError("Please initialize TimeDistributed with mindspore.nn.Cell or " | ||||
| "mindspore.ops.Primitive instance. You passed: {input}".format(input=layer)) | "mindspore.ops.Primitive instance. You passed: {input}".format(input=layer)) | ||||
| super(TimeDistributed, self).__init__() | super(TimeDistributed, self).__init__() | ||||
| Validator.check_is_int(time_axis) | |||||
| if reshape_with_axis is not None: | |||||
| Validator.check_is_int(reshape_with_axis) | |||||
| self.layer = layer | self.layer = layer | ||||
| self.time_axis = time_axis | self.time_axis = time_axis | ||||
| self.reshape_with_axis = reshape_with_axis | self.reshape_with_axis = reshape_with_axis | ||||