|
|
|
@@ -17,6 +17,7 @@ |
|
|
|
from mindspore.ops.primitive import constexpr, Primitive |
|
|
|
from mindspore.ops import Reshape, Transpose, Pack, Unpack |
|
|
|
from mindspore.common import Tensor |
|
|
|
from mindspore._checkparam import Validator |
|
|
|
from ..cell import Cell |
|
|
|
|
|
|
|
__all__ = ['TimeDistributed'] |
|
|
|
@@ -69,13 +70,13 @@ class TimeDistributed(Cell): |
|
|
|
Args: |
|
|
|
layer(Union[Cell, Primitive]): The Cell or Primitive which will be wrapped. |
|
|
|
time_axis(int): The axis of time_step. |
|
|
|
reshape_with_axis(int): The axis which time_axis will be reshaped with. Default: 'None'. |
|
|
|
reshape_with_axis(int): The axis which time_axis will be reshaped with. Default: None. |
|
|
|
|
|
|
|
Inputs: |
|
|
|
- **input** (Tensor) - Tensor of shape :math:`(N, T, *)`. |
|
|
|
|
|
|
|
Outputs: |
|
|
|
Tensor of shape: math:'(N, T, *)' |
|
|
|
Tensor of shape :math:`(N, T, *)` |
|
|
|
|
|
|
|
Supported Platforms: |
|
|
|
``Ascend`` ``GPU`` ``CPU`` |
|
|
|
@@ -97,6 +98,9 @@ class TimeDistributed(Cell): |
|
|
|
raise TypeError("Please initialize TimeDistributed with mindspore.nn.Cell or " |
|
|
|
"mindspore.ops.Primitive instance. You passed: {input}".format(input=layer)) |
|
|
|
super(TimeDistributed, self).__init__() |
|
|
|
Validator.check_is_int(time_axis) |
|
|
|
if reshape_with_axis is not None: |
|
|
|
Validator.check_is_int(reshape_with_axis) |
|
|
|
self.layer = layer |
|
|
|
self.time_axis = time_axis |
|
|
|
self.reshape_with_axis = reshape_with_axis |
|
|
|
|