You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

array_ops.py 7.9 kB

5 years ago
4 years ago
4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212
  1. # Copyright 2020-2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """array Operations."""
  16. from mindspore.ops.composite.multitype_ops import _constexpr_utils as const_utils
  17. from mindspore.common import dtype as mstype
  18. from mindspore.common._register_for_tensor import tensor_operator_registry
  19. from mindspore._checkparam import Validator as validator
  20. from mindspore._checkparam import Rel
  21. from mindspore.ops.primitive import constexpr
  22. from mindspore.ops import functional as F
  23. from .. import operations as P
  24. @constexpr
  25. def _check_is_int(arg_value, arg_name, op_name):
  26. arg_value = validator.check_is_int(arg_value, arg_name, op_name)
  27. return arg_value
  28. @constexpr
  29. def _check_positive_int(arg_value, arg_name, op_name):
  30. arg_value = validator.check_positive_int(arg_value, arg_name, op_name)
  31. return arg_value
  32. @constexpr
  33. def _check_axis_range(arg_value, limit, arg_name, op_name):
  34. arg_value = validator.check_int_range(arg_value, -limit, limit, Rel.INC_LEFT, arg_name, op_name)
  35. return arg_value
  36. @constexpr
  37. def _cal_repeat_dims(x_rank, rep, expand_axis):
  38. rep_dims = [1] * (x_rank + 1)
  39. rep_dims[expand_axis] = rep
  40. return tuple(rep_dims)
  41. @constexpr
  42. def _cal_reshape(x_shape, rep, axis):
  43. x_reshape = list(x_shape)
  44. x_reshape[axis] *= rep
  45. return tuple(x_reshape)
  46. def repeat_elements(x, rep, axis=0):
  47. """
  48. Repeat elements of a tensor along an axis, like np.repeat.
  49. Args:
  50. x (Tensor): The tensor to repeat values for. Must be of type: float16,
  51. float32, int8, uint8, int16, int32, or int64.
  52. rep (int): The number of times to repeat, must be positive, required.
  53. axis (int): The axis along which to repeat, default 0.
  54. Outputs:
  55. One tensor with values repeated along the specified axis. If x has shape
  56. (s1, s2, ..., sn) and axis is i, the output will have shape (s1, s2, ...,
  57. si * rep, ..., sn). The output type will be the same as the type of `x`.
  58. Supported Platforms:
  59. ``Ascend`` ``GPU`` ``CPU``
  60. Examples:
  61. >>> # case 1 : repeat on axis 0
  62. >>> x = Tensor(np.array([[0, 1, 2], [3, 4, 5]]), mindspore.int32)
  63. >>> output = ops.repeat_elements(x, rep = 2, axis = 0)
  64. >>> print(output)
  65. [[0 1 2]
  66. [0 1 2]
  67. [3 4 5]
  68. [3 4 5]]
  69. >>> # case 2 : repeat on axis 1
  70. >>> x = Tensor(np.array([[0, 1, 2], [3, 4, 5]]), mindspore.int32)
  71. >>> output = ops.repeat_elements(x, rep = 2, axis = 1)
  72. >>> print(output)
  73. [[0 0 1 1 2 2]
  74. [3 3 4 4 5 5]]
  75. """
  76. const_utils.check_type_valid(F.dtype(x), mstype.number_type, 'input x')
  77. rep = _check_positive_int(rep, "rep", "repeat_elements")
  78. axis = _check_is_int(axis, "axis", "repeat_elements")
  79. shape_op = P.Shape()
  80. rank_op = P.Rank()
  81. tile_op = P.Tile()
  82. expand_dims_op = P.ExpandDims()
  83. reshape_op = P.Reshape()
  84. x_rank = rank_op(x)
  85. axis = _check_axis_range(axis, x_rank, "axis", "repeat_elements")
  86. expand_axis = axis + 1
  87. x_expand = expand_dims_op(x, expand_axis)
  88. rep_dims = _cal_repeat_dims(x_rank, rep, expand_axis)
  89. x_expand = tile_op(x_expand, rep_dims)
  90. x_shape = shape_op(x)
  91. x_reshape = _cal_reshape(x_shape, rep, axis)
  92. x_rep = reshape_op(x_expand, x_reshape)
  93. return x_rep
  94. tensor_operator_registry.register('repeat_elements', repeat_elements)
  95. @constexpr
  96. def _check_sequence_mask_input_len(input_shape, prim_name=None):
  97. msg_prefix = f"For '{prim_name}', the" if prim_name else "The"
  98. if not input_shape:
  99. raise ValueError(f"{msg_prefix} input_shape should be greater than 0, but got {input_shape}.")
  100. # broadcast only supports 7d shape
  101. shape_size = len(input_shape)
  102. if shape_size >= 7:
  103. raise ValueError(f"{msg_prefix} dimension of input_shape should be less than 7, but got {shape_size}d.")
  104. def sequence_mask(lengths, maxlen=None, prim_name='sequence_mask'):
  105. """
  106. Returns a mask tensor representing the first N positions of each cell.
  107. If lengths has shape [d_1, d_2, ..., d_n], then the resulting tensor mask has type and shape
  108. [d_1, d_2, ..., d_n, maxlen], with mask[i_1, i_2, ..., i_n, j] = (j < lengths[i_1, i_2, ..., i_n])
  109. Args:
  110. lengths (Tensor): Tensor to calculate the mask for. All values in this tensor should be
  111. less than or equal to `maxlen`. Values greater than `maxlen` will be treated as `maxlen`.
  112. maxlen (int): size of the last dimension of returned tensor. Must be positive and same
  113. type as elements in `lengths`. Default is None.
  114. prim_name (str): The name of primitive. Default: 'sequence_mask'.
  115. Inputs:
  116. - **lengths** (Tensor) - Tensor to calculate the mask for. All values in this tensor should be
  117. less than or equal to `maxlen`. Values greater than `maxlen` will be treated as `maxlen`.
  118. Must be type int32 or int64.
  119. - **maxlen** (int) - size of the last dimension of returned tensor. Must be positive and same
  120. type as elements in `lengths`. Default is None.
  121. Outputs:
  122. One mask tensor of shape lengths.shape + (maxlen,).
  123. Raises:
  124. TypeError: If `lengths` is not a Tensor.
  125. TypeError: If `maxlen` is not an int.
  126. TypeError: If dtype of `lengths` is neither int32 nor int64.
  127. Supported Platforms:
  128. ``GPU``
  129. Examples:
  130. >>> # case 1: When maxlen is assigned
  131. >>> x = Tensor(np.array([1, 2, 3, 4]))
  132. >>> output = ops.sequence_mask(x, 5)
  133. >>> print(output)
  134. [[ True False False False False]
  135. [ True True False False False]
  136. [ True True True False False]
  137. [ True True True True False]]
  138. >>> # case 2: When there is 0 in x
  139. >>> x = Tensor(np.array([[1, 3], [2, 0]]))
  140. >>> output = ops.sequence_mask(x, 5)
  141. >>> print(output)
  142. [[[ True False False False False]
  143. [ True True True False False]]
  144. [[ True True False False False]
  145. [False False False False False]]]
  146. >>> # case 3: when the maxlen is not assigned
  147. >>> x = Tensor(np.array([[1, 3], [2, 4]]))
  148. >>> output = ops.sequence_mask(x)
  149. >>> print(output)
  150. [[[ True False False False ]
  151. [ True True True False ]]
  152. [[ True True False False ]
  153. [ True True True True ]]]
  154. """
  155. argmax_op = P.ArgMaxWithValue()
  156. reshape_op = P.Reshape()
  157. range_op = P.Range()
  158. expand_op = P.ExpandDims()
  159. cast_op = P.Cast()
  160. shape_op = P.Shape()
  161. to_tensor_op = P.ScalarToArray()
  162. const_utils.check_type_valid(F.dtype(lengths), [mstype.int64, mstype.int32], 'lengths')
  163. _check_sequence_mask_input_len(shape_op(lengths), prim_name)
  164. if maxlen is None:
  165. flatten_data = reshape_op(lengths, (-1,))
  166. flatten_data = cast_op(flatten_data, mstype.float32)
  167. _, value = argmax_op(flatten_data)
  168. maxlen = cast_op(value, mstype.int32)
  169. else:
  170. maxlen = _check_positive_int(maxlen, "maxlen", "sequence_mask")
  171. maxlen = to_tensor_op(maxlen)
  172. range_vector = range_op(to_tensor_op(0), maxlen
  173. , to_tensor_op(1))
  174. mask = expand_op(lengths, -1)
  175. result = range_vector < mask
  176. return result