Browse Source

add timedistributed input type check

tags/v1.2.0-rc1
dinglongwei 4 years ago
parent
commit
465eee4865
1 changed files with 6 additions and 2 deletions
  1. +6
    -2
      mindspore/nn/layer/timedistributed.py

+ 6
- 2
mindspore/nn/layer/timedistributed.py View File

@@ -17,6 +17,7 @@
from mindspore.ops.primitive import constexpr, Primitive
from mindspore.ops import Reshape, Transpose, Pack, Unpack
from mindspore.common import Tensor
from mindspore._checkparam import Validator
from ..cell import Cell

__all__ = ['TimeDistributed']
@@ -69,13 +70,13 @@ class TimeDistributed(Cell):
Args:
layer(Union[Cell, Primitive]): The Cell or Primitive which will be wrapped.
time_axis(int): The axis of time_step.
reshape_with_axis(int): The axis which time_axis will be reshaped with. Default: 'None'.
reshape_with_axis(int): The axis which time_axis will be reshaped with. Default: None.

Inputs:
- **input** (Tensor) - Tensor of shape :math:`(N, T, *)`.

Outputs:
Tensor of shape: math:'(N, T, *)'
Tensor of shape :math:`(N, T, *)`

Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
@@ -97,6 +98,9 @@ class TimeDistributed(Cell):
raise TypeError("Please initialize TimeDistributed with mindspore.nn.Cell or "
"mindspore.ops.Primitive instance. You passed: {input}".format(input=layer))
super(TimeDistributed, self).__init__()
Validator.check_is_int(time_axis)
if reshape_with_axis is not None:
Validator.check_is_int(reshape_with_axis)
self.layer = layer
self.time_axis = time_axis
self.reshape_with_axis = reshape_with_axis


Loading…
Cancel
Save