|
|
|
@@ -17,6 +17,7 @@ |
|
|
|
from mindspore.ops.primitive import constexpr, Primitive |
|
|
|
from mindspore.ops import Reshape, Transpose, Pack, Unpack |
|
|
|
from mindspore.common.dtype import tensor |
|
|
|
from mindspore.common import Tensor |
|
|
|
from ..cell import Cell |
|
|
|
|
|
|
|
__all__ = ['TimeDistributed'] |
|
|
|
@@ -104,7 +105,9 @@ class TimeDistributed(Cell): |
|
|
|
self.reshape = Reshape() |
|
|
|
|
|
|
|
def construct(self, inputs): |
|
|
|
_check_data(isinstance(inputs, tensor)) |
|
|
|
is_capital_tensor = isinstance(inputs, Tensor) |
|
|
|
is_tensor = True if is_capital_tensor else isinstance(inputs, tensor) |
|
|
|
_check_data(is_tensor) |
|
|
|
_check_inputs_dim(inputs.shape) |
|
|
|
time_axis = self.time_axis % len(inputs.shape) |
|
|
|
if self.reshape_with_axis is not None: |
|
|
|
@@ -119,7 +122,9 @@ class TimeDistributed(Cell): |
|
|
|
inputs_shape_new = inputs.shape |
|
|
|
inputs = self.reshape(inputs, inputs_shape_new[: reshape_pos] + (-1,) + inputs_shape_new[reshape_pos + 2:]) |
|
|
|
outputs = self.layer(inputs) |
|
|
|
_check_data(isinstance(outputs, tensor)) |
|
|
|
is_capital_tensor = isinstance(outputs, Tensor) |
|
|
|
is_tensor = True if is_capital_tensor else isinstance(outputs, tensor) |
|
|
|
_check_data(is_tensor) |
|
|
|
_check_reshape_pos(reshape_pos, inputs.shape, outputs.shape) |
|
|
|
outputs_shape_new = outputs.shape[:reshape_pos] + inputs_shape_new[reshape_pos: reshape_pos + 2] |
|
|
|
if reshape_pos + 1 < len(outputs.shape): |
|
|
|
@@ -131,7 +136,9 @@ class TimeDistributed(Cell): |
|
|
|
y = () |
|
|
|
for item in inputs: |
|
|
|
outputs = self.layer(item) |
|
|
|
_check_data(isinstance(outputs, tensor)) |
|
|
|
is_capital_tensor = isinstance(outputs, Tensor) |
|
|
|
is_tensor = True if is_capital_tensor else isinstance(outputs, tensor) |
|
|
|
_check_data(is_tensor) |
|
|
|
_check_expand_dims_axis(time_axis, outputs.ndim) |
|
|
|
y += (outputs,) |
|
|
|
y = Pack(time_axis)(y) |
|
|
|
|