You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

utils.py 3.4 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """utils for operator"""
  16. from ..._checkparam import Validator as validator
  17. from ..._checkparam import Rel
  18. from ...common import dtype as mstype
  19. def _get_broadcast_shape(x_shape, y_shape, prim_name):
  20. """
  21. Doing broadcast between tensor x and tensor y.
  22. Args:
  23. x_shape (list): The shape of tensor x.
  24. y_shape (list): The shape of tensor y.
  25. prim_name (str): Primitive name.
  26. Returns:
  27. List, the shape that broadcast between tensor x and tensor y.
  28. Raises:
  29. ValueError: If tensor x and tensor y are not equal and could't broadcast.
  30. Examples:
  31. >>> x_shape = [1, 2, 3]
  32. >>> y_shape = [1, 2]
  33. >>> broadcast_shape = _get_broadcast_shape(x_shape, y_shape)
  34. """
  35. if x_shape == y_shape:
  36. return x_shape
  37. x_len = len(x_shape)
  38. y_len = len(y_shape)
  39. length = x_len if x_len < y_len else y_len
  40. broadcast_shape_back = []
  41. for i in range(-length, 0):
  42. if x_shape[i] == 1:
  43. broadcast_shape_back.append(y_shape[i])
  44. elif y_shape[i] == 1:
  45. broadcast_shape_back.append(x_shape[i])
  46. elif x_shape[i] == y_shape[i]:
  47. broadcast_shape_back.append(x_shape[i])
  48. else:
  49. raise ValueError("For '{}' the x_shape {} and y_shape {} can not broadcast.".format(
  50. prim_name, x_shape, y_shape))
  51. broadcast_shape_front = y_shape[0: y_len - length] if length == x_len else x_shape[0: x_len - length]
  52. broadcast_shape = broadcast_shape_front + broadcast_shape_back
  53. return broadcast_shape
  54. def _get_concat_offset(x_shp, x_type, axis, prim_name):
  55. """for concat and concatoffset check args and compute offset"""
  56. validator.check_value_type("shape", x_shp, [tuple], prim_name)
  57. validator.check_integer("input_x rank", len(x_shp), 0, Rel.GT, prim_name)
  58. validator.check_subclass("shape0", x_type[0], mstype.tensor, prim_name)
  59. validator.check_integer("len of x_shp[0]", len(x_shp[0]), 0, Rel.GT, prim_name)
  60. rank_base = len(x_shp[0])
  61. validator.check_int_range('axis', axis, -rank_base - 1, rank_base, Rel.INC_BOTH, prim_name)
  62. if axis < 0:
  63. axis = axis + rank_base
  64. all_shp = x_shp[0][axis]
  65. offset = [0,]
  66. for i in range(1, len(x_shp)):
  67. v = x_shp[i]
  68. validator.check('len of x_shp[%d]' % i, len(v), 'len of x_shp[0]', len(x_shp[0]), Rel.EQ, prim_name)
  69. validator.check('x_type[%d]' % i, x_type[i], 'x_type[0]', x_type[0], Rel.EQ, prim_name)
  70. for j in range(rank_base):
  71. if j != axis and v[j] != x_shp[0][j]:
  72. raise ValueError(f"For \'{prim_name}\' element {i} shape in input can not concat with first element")
  73. offset.append(all_shp)
  74. all_shp += v[axis]
  75. return offset, all_shp, axis