Browse Source

unified Tensor type of graph mode and pynative mode

tags/v1.2.0-rc1
dinglongwei 5 years ago
parent
commit
d9d53a7710
1 changed files with 3 additions and 10 deletions
  1. +3
    -10
      mindspore/nn/layer/timedistributed.py

+ 3
- 10
mindspore/nn/layer/timedistributed.py View File

@@ -16,7 +16,6 @@

from mindspore.ops.primitive import constexpr, Primitive
from mindspore.ops import Reshape, Transpose, Pack, Unpack
from mindspore.common.dtype import tensor
from mindspore.common import Tensor
from ..cell import Cell

@@ -105,9 +104,7 @@ class TimeDistributed(Cell):
self.reshape = Reshape()

def construct(self, inputs):
is_capital_tensor = isinstance(inputs, Tensor)
is_tensor = True if is_capital_tensor else isinstance(inputs, tensor)
_check_data(is_tensor)
_check_data(isinstance(inputs, Tensor))
_check_inputs_dim(inputs.shape)
time_axis = self.time_axis % len(inputs.shape)
if self.reshape_with_axis is not None:
@@ -122,9 +119,7 @@ class TimeDistributed(Cell):
inputs_shape_new = inputs.shape
inputs = self.reshape(inputs, inputs_shape_new[: reshape_pos] + (-1,) + inputs_shape_new[reshape_pos + 2:])
outputs = self.layer(inputs)
is_capital_tensor = isinstance(outputs, Tensor)
is_tensor = True if is_capital_tensor else isinstance(outputs, tensor)
_check_data(is_tensor)
_check_data(isinstance(outputs, Tensor))
_check_reshape_pos(reshape_pos, inputs.shape, outputs.shape)
outputs_shape_new = outputs.shape[:reshape_pos] + inputs_shape_new[reshape_pos: reshape_pos + 2]
if reshape_pos + 1 < len(outputs.shape):
@@ -136,9 +131,7 @@ class TimeDistributed(Cell):
y = ()
for item in inputs:
outputs = self.layer(item)
is_capital_tensor = isinstance(outputs, Tensor)
is_tensor = True if is_capital_tensor else isinstance(outputs, tensor)
_check_data(is_tensor)
_check_data(isinstance(outputs, Tensor))
_check_expand_dims_axis(time_axis, outputs.ndim)
y += (outputs,)
y = Pack(time_axis)(y)


Loading…
Cancel
Save