Browse Source

add timedistributed input type check

tags/v1.2.0-rc1
dinglongwei 4 years ago
parent
commit
465eee4865
1 changed files with 6 additions and 2 deletions
  1. +6
    -2
      mindspore/nn/layer/timedistributed.py

+ 6
- 2
mindspore/nn/layer/timedistributed.py View File

@@ -17,6 +17,7 @@
from mindspore.ops.primitive import constexpr, Primitive from mindspore.ops.primitive import constexpr, Primitive
from mindspore.ops import Reshape, Transpose, Pack, Unpack from mindspore.ops import Reshape, Transpose, Pack, Unpack
from mindspore.common import Tensor from mindspore.common import Tensor
from mindspore._checkparam import Validator
from ..cell import Cell from ..cell import Cell


__all__ = ['TimeDistributed'] __all__ = ['TimeDistributed']
@@ -69,13 +70,13 @@ class TimeDistributed(Cell):
Args: Args:
layer(Union[Cell, Primitive]): The Cell or Primitive which will be wrapped. layer(Union[Cell, Primitive]): The Cell or Primitive which will be wrapped.
time_axis(int): The axis of time_step. time_axis(int): The axis of time_step.
reshape_with_axis(int): The axis which time_axis will be reshaped with. Default: 'None'.
reshape_with_axis(int): The axis which time_axis will be reshaped with. Default: None.


Inputs: Inputs:
- **input** (Tensor) - Tensor of shape :math:`(N, T, *)`. - **input** (Tensor) - Tensor of shape :math:`(N, T, *)`.


Outputs: Outputs:
Tensor of shape: math:'(N, T, *)'
Tensor of shape :math:`(N, T, *)`


Supported Platforms: Supported Platforms:
``Ascend`` ``GPU`` ``CPU`` ``Ascend`` ``GPU`` ``CPU``
@@ -97,6 +98,9 @@ class TimeDistributed(Cell):
raise TypeError("Please initialize TimeDistributed with mindspore.nn.Cell or " raise TypeError("Please initialize TimeDistributed with mindspore.nn.Cell or "
"mindspore.ops.Primitive instance. You passed: {input}".format(input=layer)) "mindspore.ops.Primitive instance. You passed: {input}".format(input=layer))
super(TimeDistributed, self).__init__() super(TimeDistributed, self).__init__()
Validator.check_is_int(time_axis)
if reshape_with_axis is not None:
Validator.check_is_int(reshape_with_axis)
self.layer = layer self.layer = layer
self.time_axis = time_axis self.time_axis = time_axis
self.reshape_with_axis = reshape_with_axis self.reshape_with_axis = reshape_with_axis


Loading…
Cancel
Save