You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

timedistributed.py 5.9 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Time Distributed."""
  16. from mindspore.ops.primitive import constexpr, Primitive
  17. from mindspore.ops import Reshape, Transpose, Stack, Unstack
  18. from mindspore.common import Tensor
  19. from mindspore._checkparam import Validator
  20. from ..cell import Cell
  21. __all__ = ['TimeDistributed']
  22. @constexpr
  23. def _check_reshape_pos(reshape_pos, inputs_shape, outputs_shape):
  24. if reshape_pos >= len(outputs_shape) or inputs_shape[reshape_pos] != outputs_shape[reshape_pos]:
  25. raise ValueError("The parameter reshape_with_axis is invalid in the input and output of TimeDistributed. "
  26. "You may try pass parameters without reshape_with_axis.")
  27. @constexpr
  28. def _check_expand_dims_axis(time_axis, ndim):
  29. if time_axis > ndim:
  30. raise ValueError("The parameter time_axis is invalid in the input. "
  31. "The value of time_axis should be in range of [{}, {}].".format(-ndim - 1, ndim))
  32. @constexpr
  33. def _generate_perm(axis_a, axis_b, length):
  34. perm = tuple(range(length))
  35. axis_a, axis_b = (axis_a, axis_b) if axis_a < axis_b else (axis_b, axis_a)
  36. return perm[:axis_a] + (perm[axis_b],) + perm[axis_a: axis_b] + perm[axis_b + 1:]
  37. @constexpr
  38. def _check_data(flag):
  39. if not flag:
  40. raise TypeError("The inputs and outputs shuould be a Tensor.")
  41. @constexpr
  42. def _check_inputs_dim(shape):
  43. if len(shape) < 3:
  44. raise ValueError("The inputs should be at least 3D.")
  45. class TimeDistributed(Cell):
  46. r"""
  47. The time distributed layer.
  48. Time distributed is a wrapper which allows to apply a layer to every temporal slice of an input.
  49. And the input should be at least 3D.
  50. There are two cases in the implementation.
  51. When reshape_with_axis provided, the reshape method will be chosen, which is more efficient;
  52. otherwise, the method of dividing the inputs along time axis will be used, which is more general.
  53. For example, reshape_with_axis could not be provided when deal with Batch Normalization.
  54. Args:
  55. layer(Union[Cell, Primitive]): The Cell or Primitive which will be wrapped.
  56. time_axis(int): The axis of time_step.
  57. reshape_with_axis(int): The axis which will be reshaped with time_axis. Default: None.
  58. Inputs:
  59. - **input** (Tensor) - Tensor of shape :math:`(N, T, *)`.
  60. Outputs:
  61. Tensor of shape :math:`(N, T, *)`
  62. Supported Platforms:
  63. ``Ascend`` ``GPU`` ``CPU``
  64. Raises:
  65. TypeError: If layer is not a Cell or Primitive.
  66. Examples:
  67. >>> input = Tensor(np.random.random([32, 10, 3]), mindspore.float32)
  68. >>> dense = nn.Dense(3, 6)
  69. >>> net = nn.TimeDistributed(dense, time_axis=1, reshape_with_axis=0)
  70. >>> output = net(input)
  71. >>> print(output.shape)
  72. (32, 10, 6)
  73. """
  74. def __init__(self, layer, time_axis, reshape_with_axis=None):
  75. if not isinstance(layer, (Cell, Primitive)):
  76. raise TypeError("Please initialize TimeDistributed with mindspore.nn.Cell or "
  77. "mindspore.ops.Primitive instance. You passed: {input}".format(input=layer))
  78. super(TimeDistributed, self).__init__()
  79. Validator.check_is_int(time_axis)
  80. if reshape_with_axis is not None:
  81. Validator.check_is_int(reshape_with_axis)
  82. self.layer = layer
  83. self.time_axis = time_axis
  84. self.reshape_with_axis = reshape_with_axis
  85. self.transpose = Transpose()
  86. self.reshape = Reshape()
  87. def construct(self, inputs):
  88. _check_data(isinstance(inputs, Tensor))
  89. _check_inputs_dim(inputs.shape)
  90. time_axis = self.time_axis % len(inputs.shape)
  91. if self.reshape_with_axis is not None:
  92. reshape_with_axis = self.reshape_with_axis % len(inputs.shape)
  93. inputs_shape = inputs.shape
  94. time_axis_new = len(inputs_shape) - 2 if reshape_with_axis == len(inputs_shape) - 1 \
  95. else (reshape_with_axis + 1 if time_axis > reshape_with_axis else
  96. reshape_with_axis - 1)
  97. reshape_pos = time_axis_new if time_axis_new < reshape_with_axis else reshape_with_axis
  98. perm = _generate_perm(time_axis_new, time_axis, len(inputs_shape))
  99. inputs = self.transpose(inputs, perm)
  100. inputs_shape_new = inputs.shape
  101. inputs = self.reshape(inputs, inputs_shape_new[: reshape_pos] + (-1,) + inputs_shape_new[reshape_pos + 2:])
  102. outputs = self.layer(inputs)
  103. _check_data(isinstance(outputs, Tensor))
  104. _check_reshape_pos(reshape_pos, inputs.shape, outputs.shape)
  105. outputs_shape_new = outputs.shape[:reshape_pos] + inputs_shape_new[reshape_pos: reshape_pos + 2]
  106. if reshape_pos + 1 < len(outputs.shape):
  107. outputs_shape_new += outputs.shape[reshape_pos + 1:]
  108. return self.reshape(outputs, outputs_shape_new)
  109. unstack = Unstack(time_axis)
  110. inputs = unstack(inputs)
  111. y = ()
  112. for item in inputs:
  113. outputs = self.layer(item)
  114. _check_data(isinstance(outputs, Tensor))
  115. _check_expand_dims_axis(time_axis, outputs.ndim)
  116. y += (outputs,)
  117. y = Stack(time_axis)(y)
  118. return y