Browse Source

!10895 add timedistributed pynative mode

From: @dinglongwei
Reviewed-by: @wuxuejian,@c_34
Signed-off-by: @wuxuejian,@c_34
tags/v1.2.0-rc1
mindspore-ci-bot Gitee 4 years ago
parent
commit
efd22c96ad
2 changed files with 26 additions and 3 deletions
  1. +10
    -3
      mindspore/nn/layer/timedistributed.py
  2. +16
    -0
      tests/st/ops/cpu/test_time_distributed_op.py

+ 10
- 3
mindspore/nn/layer/timedistributed.py View File

@@ -17,6 +17,7 @@
from mindspore.ops.primitive import constexpr, Primitive from mindspore.ops.primitive import constexpr, Primitive
from mindspore.ops import Reshape, Transpose, Pack, Unpack from mindspore.ops import Reshape, Transpose, Pack, Unpack
from mindspore.common.dtype import tensor from mindspore.common.dtype import tensor
from mindspore.common import Tensor
from ..cell import Cell from ..cell import Cell


__all__ = ['TimeDistributed'] __all__ = ['TimeDistributed']
@@ -104,7 +105,9 @@ class TimeDistributed(Cell):
self.reshape = Reshape() self.reshape = Reshape()


def construct(self, inputs): def construct(self, inputs):
_check_data(isinstance(inputs, tensor))
is_capital_tensor = isinstance(inputs, Tensor)
is_tensor = True if is_capital_tensor else isinstance(inputs, tensor)
_check_data(is_tensor)
_check_inputs_dim(inputs.shape) _check_inputs_dim(inputs.shape)
time_axis = self.time_axis % len(inputs.shape) time_axis = self.time_axis % len(inputs.shape)
if self.reshape_with_axis is not None: if self.reshape_with_axis is not None:
@@ -119,7 +122,9 @@ class TimeDistributed(Cell):
inputs_shape_new = inputs.shape inputs_shape_new = inputs.shape
inputs = self.reshape(inputs, inputs_shape_new[: reshape_pos] + (-1,) + inputs_shape_new[reshape_pos + 2:]) inputs = self.reshape(inputs, inputs_shape_new[: reshape_pos] + (-1,) + inputs_shape_new[reshape_pos + 2:])
outputs = self.layer(inputs) outputs = self.layer(inputs)
_check_data(isinstance(outputs, tensor))
is_capital_tensor = isinstance(outputs, Tensor)
is_tensor = True if is_capital_tensor else isinstance(outputs, tensor)
_check_data(is_tensor)
_check_reshape_pos(reshape_pos, inputs.shape, outputs.shape) _check_reshape_pos(reshape_pos, inputs.shape, outputs.shape)
outputs_shape_new = outputs.shape[:reshape_pos] + inputs_shape_new[reshape_pos: reshape_pos + 2] outputs_shape_new = outputs.shape[:reshape_pos] + inputs_shape_new[reshape_pos: reshape_pos + 2]
if reshape_pos + 1 < len(outputs.shape): if reshape_pos + 1 < len(outputs.shape):
@@ -131,7 +136,9 @@ class TimeDistributed(Cell):
y = () y = ()
for item in inputs: for item in inputs:
outputs = self.layer(item) outputs = self.layer(item)
_check_data(isinstance(outputs, tensor))
is_capital_tensor = isinstance(outputs, Tensor)
is_tensor = True if is_capital_tensor else isinstance(outputs, tensor)
_check_data(is_tensor)
_check_expand_dims_axis(time_axis, outputs.ndim) _check_expand_dims_axis(time_axis, outputs.ndim)
y += (outputs,) y += (outputs,)
y = Pack(time_axis)(y) y = Pack(time_axis)(y)


+ 16
- 0
tests/st/ops/cpu/test_time_distributed_op.py View File

@@ -78,6 +78,22 @@ def test_time_distributed_dense():
print("Dense layer wrapped successful") print("Dense layer wrapped successful")




@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_time_distributed_dense_pynative():
context.set_context(mode=context.PYNATIVE_MODE, device_target='CPU')
inputs = np.random.randint(0, 10, [32, 10])
dense = nn.Dense(10, 6)
output_expect = dense(Tensor(inputs, mindspore.float32)).asnumpy()
inputs = inputs.reshape([32, 1, 10]).repeat(6, axis=1)
time_distributed = TestTimeDistributed(dense, time_axis=1, reshape_with_axis=0)
output = time_distributed(Tensor(inputs, mindspore.float32)).asnumpy()
for i in range(output.shape[1]):
assert np.all(output[:, i, :] == output_expect)
print("Dense layer with pynative mode wrapped successful")


@pytest.mark.level0 @pytest.mark.level0
@pytest.mark.platform_x86_cpu @pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard @pytest.mark.env_onecard


Loading…
Cancel
Save