You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

utils.py 6.5 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208
  1. # Copyright 2020-2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """internal utility functions"""
  16. import types
  17. from ..common import Tensor
  18. from ..ops import functional as F
  19. from ..common import dtype as mstype
  20. from .utils_const import _tile_size, _add_unit_axes, _raise_type_error, _type_convert, \
  21. _tuple_setitem, _callable_const
  22. def _deep_list(array_like):
  23. """convert nested tuple/list mixtures to pure nested list"""
  24. if isinstance(array_like, (list, tuple)):
  25. return list(map(_deep_list, array_like))
  26. return array_like
  27. def _deep_tensor_to_nparray(array_like):
  28. """
  29. convert a nested list of tensor to nested list of np_array.
  30. Args:
  31. array_like(list(tensor)): In any format of nested lists that may contain
  32. tensors.
  33. Returns:
  34. array_like(list(np_array)): Formatted array that can be directly processed
  35. by numpy.array(), with all tensor elements converted to numpy_array.
  36. """
  37. # Recursively check whether each element is a tensor or not, if is tensor,
  38. # convert it to a numpy array in place
  39. if isinstance(array_like, Tensor):
  40. return array_like.asnumpy()
  41. if isinstance(array_like, list):
  42. for idx, value in enumerate(array_like):
  43. array_like[idx] = _deep_tensor_to_nparray(value)
  44. return array_like
  45. def _check_input_for_asarray(array_like):
  46. """check whether array_like argument is a valid type for np.asarray conversion"""
  47. if not isinstance(array_like, (Tensor, list, tuple, int, float, bool)):
  48. _raise_type_error("input data must be `int`, `float`, `bool`, `Tensor`, `list`, `tuple`, but got ", array_like)
  49. def _is_scalar(shape):
  50. """check whether input shape is a scalar"""
  51. return F.shape_mul(shape) == 1
  52. def _convert_list_tensor_to_tuple_tensor(list_of_tensor):
  53. """Convert a list of tensor to a tuple of tensor"""
  54. if isinstance(list_of_tensor, list):
  55. tuple_of_tensor = ()
  56. for tensor in list_of_tensor:
  57. tuple_of_tensor += (tensor,)
  58. return tuple_of_tensor
  59. return list_of_tensor
  60. def _expand(x, ndim, axis=0):
  61. """Expand x to ndim from axis, which can be 0 or -1."""
  62. shape = _add_unit_axes(F.shape(x), ndim, axis == -1)
  63. return F.reshape(x, shape)
  64. def _broadcast_to(x, shape_cur, shape_to, ndim_to):
  65. """Broadcasts x from shape_cur to shape_to."""
  66. size = _tile_size(shape_cur, shape_to, ndim_to)
  67. return F.tile(x, size)
  68. def _broadcast_to_shape(x, shape):
  69. """Broadcasts x from current shape to shape"""
  70. ndim_to = len(shape)
  71. x = _expand(x, ndim_to)
  72. return _broadcast_to(x, F.shape(x), shape, ndim_to)
  73. def _get_size(x, axis=None):
  74. """Get the number of elements along the given axis of tensor x."""
  75. if axis is None or F.tuple_len(axis) == 0:
  76. axis = F.make_range(x.ndim)
  77. nums = 1
  78. for ax in axis:
  79. nums *= x.shape[ax]
  80. return nums
  81. def _check_input_tensor(*tensors):
  82. for tensor in tensors:
  83. if not isinstance(tensor, Tensor):
  84. _raise_type_error('expect Tensor, but got ', F.typeof(tensor))
  85. return True
  86. def _convert_64_to_32(tensor):
  87. """Convert tensor with float64/int64 types to float32/int32."""
  88. if tensor.dtype == mstype.float64:
  89. return tensor.astype("float32")
  90. if tensor.dtype == mstype.int64:
  91. return tensor.astype("int32")
  92. return tensor
  93. def _to_tensor(*args):
  94. """Returns each input as Tensor"""
  95. res = ()
  96. for arg in args:
  97. if isinstance(arg, (int, float, bool, list, tuple)):
  98. arg = _convert_64_to_32(_type_convert(Tensor, arg))
  99. elif not isinstance(arg, Tensor):
  100. _raise_type_error("Expect input to be array like.")
  101. res += (arg,)
  102. if len(res) == 1:
  103. return res[0]
  104. return res
  105. def _get_dtype_from_scalar(*input_numbers):
  106. """
  107. Get the final dtype from series of input numbers, compared with F.typeof, we
  108. return int32/float32 for python int/float instead.
  109. """
  110. bool_flag = True
  111. int_flag = True
  112. for number in input_numbers:
  113. if number is not None:
  114. if not isinstance(number, bool):
  115. bool_flag = False
  116. if not isinstance(number, int):
  117. int_flag = False
  118. if bool_flag:
  119. return mstype.bool_
  120. if int_flag:
  121. return mstype.int32
  122. return mstype.float32
  123. def _isnan(x):
  124. """Computes isnan."""
  125. return F.not_equal(x, x)
  126. def _convert_bool_to_int(tensor):
  127. """Convert tensor with bool type to int32."""
  128. if tensor.dtype == mstype.bool_:
  129. return tensor.astype("int32")
  130. return tensor
  131. def _slice_along_axis(f, axis, slice_start, slice_end):
  132. """
  133. Slice a tensor along a given axis
  134. Args:
  135. f (Tensor): Input Tensor.
  136. axis (int): Specified axis.
  137. slice_start (int): The start of the slice.
  138. slice_end (int): The end of the slice.
  139. Returns:
  140. Sliced tensor.
  141. """
  142. index_start = (0,) * f.ndim
  143. index_end = f.shape
  144. slice_size = slice_end - slice_start
  145. index_start = _tuple_setitem(index_start, axis, slice_start)
  146. index_end = _tuple_setitem(index_end, axis, slice_size)
  147. return F.tensor_slice(f, index_start, index_end)
  148. def _to_tensor_origin_dtype(*args):
  149. """Returns each input as Tensor and remains original dtype."""
  150. res = []
  151. for arg in args:
  152. if isinstance(arg, (int, float, bool, list, tuple)):
  153. arg = _type_convert(Tensor, arg)
  154. elif not isinstance(arg, Tensor):
  155. _raise_type_error("Expect input to be array like.")
  156. res.append(arg)
  157. if len(res) == 1:
  158. return res[0]
  159. return res
  160. def _callable(tensor, obj):
  161. """Returns True if `obj` is a function."""
  162. if F.isconstant(tensor):
  163. return isinstance(obj, types.FunctionType)
  164. return _callable_const(F.typeof(obj))