diff --git a/mindspore/nn/layer/timedistributed.py b/mindspore/nn/layer/timedistributed.py index c62a3d11fc..b4846c78aa 100644 --- a/mindspore/nn/layer/timedistributed.py +++ b/mindspore/nn/layer/timedistributed.py @@ -17,6 +17,7 @@ from mindspore.ops.primitive import constexpr, Primitive from mindspore.ops import Reshape, Transpose, Pack, Unpack from mindspore.common import Tensor +from mindspore._checkparam import Validator from ..cell import Cell __all__ = ['TimeDistributed'] @@ -69,13 +70,13 @@ class TimeDistributed(Cell): Args: layer(Union[Cell, Primitive]): The Cell or Primitive which will be wrapped. time_axis(int): The axis of time_step. - reshape_with_axis(int): The axis which time_axis will be reshaped with. Default: 'None'. + reshape_with_axis(int): The axis which time_axis will be reshaped with. Default: None. Inputs: - **input** (Tensor) - Tensor of shape :math:`(N, T, *)`. Outputs: - Tensor of shape: math:'(N, T, *)' + Tensor of shape :math:`(N, T, *)` Supported Platforms: ``Ascend`` ``GPU`` ``CPU`` @@ -97,6 +98,9 @@ class TimeDistributed(Cell): raise TypeError("Please initialize TimeDistributed with mindspore.nn.Cell or " "mindspore.ops.Primitive instance. You passed: {input}".format(input=layer)) super(TimeDistributed, self).__init__() + Validator.check_is_int(time_axis) + if reshape_with_axis is not None: + Validator.check_is_int(reshape_with_axis) self.layer = layer self.time_axis = time_axis self.reshape_with_axis = reshape_with_axis