You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

utils.py 5.7 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153
  1. # Copyright 2020-2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """utils for operator"""
  16. from mindspore.common.tensor import Tensor
  17. from ..._checkparam import Validator as validator
  18. from ..._checkparam import Rel
  19. from ...common import dtype as mstype
  20. from ..primitive import constexpr
  21. def get_broadcast_shape(x_shape, y_shape, prim_name, shape_type="", arg_name1="x", arg_name2="y"):
  22. """
  23. Doing broadcast between tensor x and tensor y.
  24. Args:
  25. x_shape (list): The shape of tensor x.
  26. y_shape (list): The shape of tensor y.
  27. prim_name (str): Primitive name.
  28. shape_type (str): The type of shape, optional values are "", "min_shape" and "max_shape".
  29. arg_name1 (str): The arg name of x_shape.
  30. arg_name2 (str): The arg name of y_shape.
  31. Returns:
  32. List, the shape that broadcast between tensor x and tensor y.
  33. Raises:
  34. ValueError: If tensor x and tensor y are not equal and couldn't broadcast.
  35. Examples:
  36. >>> x_shape = [1, 2, 3]
  37. >>> y_shape = [1, 2]
  38. >>> broadcast_shape = get_broadcast_shape(x_shape, y_shape)
  39. """
  40. if x_shape == y_shape:
  41. return x_shape
  42. x_len = len(x_shape)
  43. y_len = len(y_shape)
  44. length = x_len if x_len < y_len else y_len
  45. broadcast_shape_back = []
  46. for i in range(-length, 0):
  47. if x_shape[i] == 1:
  48. broadcast_shape_back.append(y_shape[i])
  49. elif y_shape[i] == 1:
  50. broadcast_shape_back.append(x_shape[i])
  51. elif x_shape[i] == y_shape[i]:
  52. broadcast_shape_back.append(x_shape[i])
  53. elif x_shape[i] == -1 or y_shape[i] == -1:
  54. broadcast_shape_back.append(-1)
  55. else:
  56. if shape_type == "min_shape":
  57. broadcast_shape_back.append(max(x_shape[i], y_shape[i]))
  58. elif shape_type == "max_shape":
  59. broadcast_shape_back.append(min(x_shape[i], y_shape[i]))
  60. else:
  61. raise ValueError(f"For '{prim_name}', {arg_name1}.shape and {arg_name2}.shape are supposed "
  62. f"to broadcast, where broadcast means that {arg_name1}.shape[i] = 1 or -1 "
  63. f"or {arg_name2}.shape[i] = 1 or -1 "
  64. f"or {arg_name1}.shape[i] = {arg_name2}.shape[i], "
  65. f"but now {arg_name1}.shape and {arg_name2}.shape can not broadcast, "
  66. f"got i: {i}, {arg_name1}.shape: {x_shape}, {arg_name2}.shape: {y_shape}.")
  67. broadcast_shape_front = y_shape[0: y_len - length] if length == x_len else x_shape[0: x_len - length]
  68. broadcast_shape = list(broadcast_shape_front) + broadcast_shape_back
  69. return broadcast_shape
  70. def get_concat_offset(x_shp, x_type, axis, prim_name):
  71. """for concat and concatoffset check args and compute offset"""
  72. validator.check_value_type("shape", x_shp, [tuple, list], prim_name)
  73. validator.check_positive_int(len(x_shp), "input_x rank", prim_name)
  74. validator.check_subclass("shape0", x_type[0], mstype.tensor, prim_name)
  75. validator.check_positive_int(len(x_shp[0]), "len of x_shp[0]", prim_name)
  76. rank_base = len(x_shp[0])
  77. for i in range(1, len(x_shp)):
  78. validator.check('len of x_shp[%d]' % i, len(x_shp[i]), 'len of x_shp[0]', len(x_shp[0]), Rel.EQ, prim_name)
  79. validator.check('x_type[%d]' % i, x_type[i], 'x_type[0]', x_type[0], Rel.EQ, prim_name)
  80. validator.check_int_range(axis, -rank_base, rank_base - 1, Rel.INC_BOTH, 'axis', prim_name)
  81. if axis < 0:
  82. axis = axis + rank_base
  83. all_shp = x_shp[0][axis]
  84. offset = [0]
  85. for i in range(1, len(x_shp)):
  86. v = x_shp[i]
  87. for j in range(rank_base):
  88. if j != axis and v[j] != x_shp[0][j]:
  89. raise ValueError(f"The shape of the two input elements of the Concat operator do not match:"
  90. f"shape[0] = {x_shp[0]} and shape[1] = {x_shp[1]}.")
  91. offset.append(all_shp)
  92. if all_shp == -1 or v[axis] == -1:
  93. all_shp = -1
  94. else:
  95. all_shp += v[axis]
  96. return offset, all_shp, axis
  97. @constexpr
  98. def range_op(start, limit, delta, dtype):
  99. """helper function to get tensor in specified range."""
  100. output_tensor = Tensor(list(range(start, limit, delta)), dtype)
  101. return output_tensor
  102. @constexpr
  103. def get_1d_shape(in_shape):
  104. """helper function to get 1d shape."""
  105. out_shape = 1
  106. for i in in_shape:
  107. out_shape *= i
  108. return (out_shape,)
  109. @constexpr
  110. def generate_shape_index(out_shape, indices_shape, axis):
  111. out_rank = len(out_shape)
  112. ind_rank = len(indices_shape)
  113. if axis < 0:
  114. axis += out_rank - ind_rank + 1
  115. perm_part1 = tuple(range(axis, axis + ind_rank))
  116. index = tuple(range(out_rank))
  117. perm = perm_part1 + index[:axis] + index[axis + ind_rank:]
  118. return perm
  119. @constexpr
  120. def is_shape_unknown(shape):
  121. for i in shape:
  122. if i < 0:
  123. return True
  124. return False
  125. @constexpr
  126. def is_dim_unknown(shape):
  127. for i in shape:
  128. if i == -2:
  129. return True
  130. return False