Browse Source

!11845 check wrapper layer timedistributed input type

From: @dinglongwei
Reviewed-by: @c_34,@liangchenghui
Signed-off-by: @c_34
tags/v1.2.0-rc1
mindspore-ci-bot Gitee 4 years ago
parent
commit
35d0634291
1 changed files with 6 additions and 2 deletions
  1. +6
    -2
      mindspore/nn/layer/timedistributed.py

+ 6
- 2
mindspore/nn/layer/timedistributed.py View File

@@ -17,6 +17,7 @@
from mindspore.ops.primitive import constexpr, Primitive from mindspore.ops.primitive import constexpr, Primitive
from mindspore.ops import Reshape, Transpose, Pack, Unpack from mindspore.ops import Reshape, Transpose, Pack, Unpack
from mindspore.common import Tensor from mindspore.common import Tensor
from mindspore._checkparam import Validator
from ..cell import Cell from ..cell import Cell


__all__ = ['TimeDistributed'] __all__ = ['TimeDistributed']
@@ -69,13 +70,13 @@ class TimeDistributed(Cell):
Args: Args:
layer(Union[Cell, Primitive]): The Cell or Primitive which will be wrapped. layer(Union[Cell, Primitive]): The Cell or Primitive which will be wrapped.
time_axis(int): The axis of time_step. time_axis(int): The axis of time_step.
reshape_with_axis(int): The axis which time_axis will be reshaped with. Default: 'None'.
reshape_with_axis(int): The axis which time_axis will be reshaped with. Default: None.


Inputs: Inputs:
- **input** (Tensor) - Tensor of shape :math:`(N, T, *)`. - **input** (Tensor) - Tensor of shape :math:`(N, T, *)`.


Outputs: Outputs:
Tensor of shape: math:'(N, T, *)'
Tensor of shape :math:`(N, T, *)`


Supported Platforms: Supported Platforms:
``Ascend`` ``GPU`` ``CPU`` ``Ascend`` ``GPU`` ``CPU``
@@ -97,6 +98,9 @@ class TimeDistributed(Cell):
raise TypeError("Please initialize TimeDistributed with mindspore.nn.Cell or " raise TypeError("Please initialize TimeDistributed with mindspore.nn.Cell or "
"mindspore.ops.Primitive instance. You passed: {input}".format(input=layer)) "mindspore.ops.Primitive instance. You passed: {input}".format(input=layer))
super(TimeDistributed, self).__init__() super(TimeDistributed, self).__init__()
Validator.check_is_int(time_axis)
if reshape_with_axis is not None:
Validator.check_is_int(reshape_with_axis)
self.layer = layer self.layer = layer
self.time_axis = time_axis self.time_axis = time_axis
self.reshape_with_axis = reshape_with_axis self.reshape_with_axis = reshape_with_axis


Loading…
Cancel
Save