diff --git a/mindspore/nn/layer/timedistributed.py b/mindspore/nn/layer/timedistributed.py index e7ee93b05b..0b1541ef85 100644 --- a/mindspore/nn/layer/timedistributed.py +++ b/mindspore/nn/layer/timedistributed.py @@ -17,6 +17,7 @@ from mindspore.ops.primitive import constexpr, Primitive from mindspore.ops import Reshape, Transpose, Pack, Unpack from mindspore.common.dtype import tensor +from mindspore.common import Tensor from ..cell import Cell __all__ = ['TimeDistributed'] @@ -104,7 +105,9 @@ class TimeDistributed(Cell): self.reshape = Reshape() def construct(self, inputs): - _check_data(isinstance(inputs, tensor)) + is_capital_tensor = isinstance(inputs, Tensor) + is_tensor = True if is_capital_tensor else isinstance(inputs, tensor) + _check_data(is_tensor) _check_inputs_dim(inputs.shape) time_axis = self.time_axis % len(inputs.shape) if self.reshape_with_axis is not None: @@ -119,7 +122,9 @@ class TimeDistributed(Cell): inputs_shape_new = inputs.shape inputs = self.reshape(inputs, inputs_shape_new[: reshape_pos] + (-1,) + inputs_shape_new[reshape_pos + 2:]) outputs = self.layer(inputs) - _check_data(isinstance(outputs, tensor)) + is_capital_tensor = isinstance(outputs, Tensor) + is_tensor = True if is_capital_tensor else isinstance(outputs, tensor) + _check_data(is_tensor) _check_reshape_pos(reshape_pos, inputs.shape, outputs.shape) outputs_shape_new = outputs.shape[:reshape_pos] + inputs_shape_new[reshape_pos: reshape_pos + 2] if reshape_pos + 1 < len(outputs.shape): @@ -131,7 +136,9 @@ class TimeDistributed(Cell): y = () for item in inputs: outputs = self.layer(item) - _check_data(isinstance(outputs, tensor)) + is_capital_tensor = isinstance(outputs, Tensor) + is_tensor = True if is_capital_tensor else isinstance(outputs, tensor) + _check_data(is_tensor) _check_expand_dims_axis(time_axis, outputs.ndim) y += (outputs,) y = Pack(time_axis)(y) diff --git a/tests/st/ops/cpu/test_time_distributed_op.py b/tests/st/ops/cpu/test_time_distributed_op.py index f6d0a3641b..000ef572a8 100644 --- a/tests/st/ops/cpu/test_time_distributed_op.py +++ b/tests/st/ops/cpu/test_time_distributed_op.py @@ -78,6 +78,22 @@ def test_time_distributed_dense(): print("Dense layer wrapped successful") +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_time_distributed_dense_pynative(): + context.set_context(mode=context.PYNATIVE_MODE, device_target='CPU') + inputs = np.random.randint(0, 10, [32, 10]) + dense = nn.Dense(10, 6) + output_expect = dense(Tensor(inputs, mindspore.float32)).asnumpy() + inputs = inputs.reshape([32, 1, 10]).repeat(6, axis=1) + time_distributed = TestTimeDistributed(dense, time_axis=1, reshape_with_axis=0) + output = time_distributed(Tensor(inputs, mindspore.float32)).asnumpy() + for i in range(output.shape[1]): + assert np.all(output[:, i, :] == output_expect) + print("Dense layer with pynative mode wrapped successful") + + @pytest.mark.level0 @pytest.mark.platform_x86_cpu @pytest.mark.env_onecard