You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

standard_method.py 8.5 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356
  1. # This is the Python adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
  2. #
  3. # Copyright 2020 Huawei Technologies Co., Ltd
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. # ============================================================================
  17. """standard_method"""
  18. from dataclasses import dataclass
  19. from mindspore.common import dtype as mstype
  20. from ...ops import functional as F
  21. from ...ops import operations as P
  22. from ...ops.primitive import constexpr
  23. from ...ops.composite import tail, core, MultitypeFuncGraph, env_get, hyper_add, \
  24. zeros_like, ones_like
  25. from ...ops.composite.base import _append
  26. __all__ = ['MultitypeFuncGraph', 'env_get', 'hyper_add', 'zeros_like', 'ones_like']
  27. trans = P.Transpose()
  28. shape_ = P.Shape()
  29. dtype_ = P.DType()
  30. def all_(x, axis=(), keep_dims=False):
  31. """
  32. Check all array elements along a given axis evaluate to True.
  33. Args:
  34. x (Tensor): A Tensor to be reduced.
  35. axis (Union[None, int, tuple(int)): Dimensions of reduction.
  36. keep_dims (bool): Whether to keep the reduced dimensions.
  37. Returns:
  38. Tensor, has the same data type as x.
  39. """
  40. if axis is None:
  41. axis = ()
  42. reduce_all = P.ReduceAll(keep_dims)
  43. return reduce_all(x, axis)
  44. def any_(x, axis=(), keep_dims=False):
  45. """
  46. Check any array element along a given axis evaluate to True.
  47. Args:
  48. x (Tensor): A Tensor to be reduced.
  49. axis (Union[None, int, tuple(int)): Dimensions of reduction.
  50. keep_dims (bool): Whether to keep the reduced dimensions.
  51. Returns:
  52. Tensor, has the same data type as x.
  53. """
  54. if axis is None:
  55. axis = ()
  56. reduce_any = P.ReduceAny(keep_dims)
  57. return reduce_any(x, axis)
  58. def transpose(x):
  59. """Implementation of `transpose`."""
  60. shape = F.shape(x)
  61. length = F.tuple_len(shape)
  62. perm = F.make_range(0, length)
  63. revert_perm = F.tuple_reversed(perm)
  64. out = trans(x, revert_perm)
  65. return out
  66. def getitem(data, item):
  67. """Implementation of `getitem`."""
  68. return data.__getitem__(item)
  69. def setitem(data, item, value):
  70. """Implementation of `setitem`."""
  71. return data.__setitem__(item, value)
  72. def ms_iter(xs):
  73. """Implementation of `iter`."""
  74. return xs.__ms_iter__()
  75. def ms_next(it):
  76. """Implementation of `next`."""
  77. return it.__ms_next__()
  78. def hasnext(it):
  79. """Implementation of `hasnext`."""
  80. return it.__ms_hasnext__()
  81. def ms_len(data):
  82. """Implementation of `len`."""
  83. return data.__len__()
  84. def floor(x):
  85. """Implementation of `floor`."""
  86. return x.__floor__()
  87. def trunc(x):
  88. """Implementation of `trunc`."""
  89. return x.__trunc__()
  90. def uadd(x):
  91. """Implementation of `uadd`."""
  92. return x.__pos__()
  93. def usub(x):
  94. """Implementation of `usub`."""
  95. return x.__neg__()
  96. def scalar_truediv(x, y):
  97. """Implementation of `scalar_truediv`."""
  98. return x.__truediv__(y)
  99. def scalar_floordiv(x, y):
  100. """Implementation of `scalar_floordiv`."""
  101. return x.__floordiv__(y)
  102. def bool_(x):
  103. """Implementation of `bool`."""
  104. return x.__bool__()
  105. def enumerate_(x, start=0):
  106. """Enumerate list or tuple."""
  107. x_type = F.typeof(x)
  108. ret = ()
  109. op_name = "enumerate"
  110. if check_is_tuple_or_list(x_type, op_name, "first input") and check_is_const_int(start, op_name, "start"):
  111. ret = zip(range(start, start + len(x)), x)
  112. return ret
  113. def isinstance_(x, base_type):
  114. """Determine whether x is an instance of base_type."""
  115. x_type = F.typeof(x)
  116. return check_type_same(x_type, base_type)
  117. def while_cond(x):
  118. """For while condtion, if the condition is a tensor, the loop will not be unrolled"""
  119. if F.issubclass_(F.typeof(x), F.typeof(mstype.tensor)):
  120. is_cond = check_is_tensor_bool_cond(F.shape(x))
  121. if is_cond:
  122. return F.cast(x, mstype.bool_)
  123. return x
  124. @constexpr
  125. def check_type_same(x_type, base_type):
  126. """Check x_type is same as base_type."""
  127. if mstype.issubclass_(x_type, base_type):
  128. return True
  129. return False
  130. @constexpr
  131. def check_is_tuple_or_list(x, op_name, arg_name):
  132. """check whether x is list or tuple."""
  133. if isinstance(x, (mstype.list_type, mstype.tuple_type)):
  134. return True
  135. raise TypeError(f"For '{op_name}', the '{arg_name}' should be tuple or list, but got {x}.")
  136. @constexpr
  137. def check_is_const_int(x, op_name, arg_name):
  138. """check whether x is const int."""
  139. if x is None:
  140. raise TypeError(f"For '{op_name}', the '{arg_name}' should be a const int number, but got not const.")
  141. if not isinstance(x, int):
  142. raise TypeError(f"For '{op_name}', the '{arg_name}' should be a const int number, but got {x}.")
  143. return True
  144. @constexpr
  145. def check_is_tensor_bool_cond(shp):
  146. """check if tensor is a bool condition"""
  147. if shp in ((), (1,)):
  148. return True
  149. raise ValueError("The truth value of an array with several elements is ambiguous.")
  150. @constexpr
  151. def const_tensor_to_bool(x):
  152. """convert bool tensor to bool condition"""
  153. if x is None:
  154. raise ValueError("Only constant tensor bool can be converted to bool")
  155. x = x.asnumpy()
  156. if x.shape == ():
  157. return bool(x)
  158. if x.shape == (1,):
  159. return bool(x[0])
  160. raise ValueError("The truth value of an array with several elements is ambiguous.")
  161. def tensor_bool(x):
  162. """tensor as conditon, if is constant, return immediate bool value"""
  163. is_cond = check_is_tensor_bool_cond(F.shape(x))
  164. if is_cond and F.isconstant(x):
  165. return const_tensor_to_bool(x)
  166. return F.cast(x, mstype.bool_)
  167. def and_(x, y):
  168. """Implementation of `and` (`&`)."""
  169. return x.__and__(y)
  170. def or_(x, y):
  171. """Implementation of `or` (`|`)."""
  172. return x.__or__(y)
  173. def matmul(x, y):
  174. """Implementation of `matmul` (`@`)."""
  175. return x.__matmul__(y)
  176. def float_bool(x):
  177. """Implementation of `float_bool`."""
  178. return x != 0.0
  179. def int_bool(x):
  180. """Implementation of `int_bool`."""
  181. return x != 0
  182. def str_bool(x):
  183. """Implementation of `str_bool`."""
  184. if x == "":
  185. return False
  186. return True
  187. def list_bool(x):
  188. """Implementation of `tuple_bool`."""
  189. return len(x) != 0
  190. def tuple_bool(x):
  191. """Implementation of `tuple_bool`."""
  192. return len(x) != 0
  193. def dict_bool(x):
  194. """Implementation of `dict_bool`."""
  195. return len(x) != 0
  196. def none_bool(x):
  197. """Implementation of `none_bool`."""
  198. return False
  199. def func_bool(x):
  200. """Implementation of `func_bool`."""
  201. return True
  202. def float_floordiv(x, y):
  203. """Implementation of `float_floordiv`."""
  204. return floor(x / y)
  205. #############
  206. # Iteration #
  207. #############
  208. @dataclass(frozen=True)
  209. class SequenceIterator:
  210. """
  211. SequenceIterator is a util dataclass for iterating sequence object.
  212. Iterator to use for sequences like List, Array.
  213. """
  214. idx: int
  215. seq: list
  216. @core(ignore_values=True)
  217. def __ms_hasnext__(self):
  218. """Whether the index is past the length of the sequence."""
  219. return self.idx < ms_len(self.seq)
  220. @core(ignore_values=True)
  221. def __ms_next__(self):
  222. """Return the next element and a new iterator."""
  223. return self.seq[self.idx], SequenceIterator(self.idx + 1, self.seq)
  224. def list_iter(xs):
  225. """Iterator for List."""
  226. return SequenceIterator(0, xs)
  227. def array_iter(xs):
  228. """Iterator for Array."""
  229. return SequenceIterator(0, xs)
  230. def tuple_next(xs):
  231. """Next tuple."""
  232. return xs[0], tail(xs)
  233. def tuple_hasnext(xs):
  234. """Whether the tuple is empty or not."""
  235. return len(xs) > 0
  236. def list_next(xs):
  237. """Next list."""
  238. return xs[0], tail(xs)
  239. def list_hasnext(xs):
  240. """Whether the list is empty or not."""
  241. return len(xs) > 0
  242. def list_append(self_, item):
  243. return _append(self_, item)
  244. #################
  245. # Array methods #
  246. #################
  247. def to_array(x):
  248. """Implementation of `to_array`."""
  249. return x.__ms_to_array__()