You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

standard_method.py 7.5 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316
  1. # This is the Python adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
  2. #
  3. # Copyright 2020 Huawei Technologies Co., Ltd
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. # ============================================================================
  17. """standard_method"""
  18. from dataclasses import dataclass
  19. from mindspore.common import dtype as mstype
  20. from mindspore.common._register_for_tensor import tensor_operator_registry
  21. from ...ops import functional as F
  22. from ...ops import operations as P
  23. from ...ops.primitive import constexpr
  24. from ...ops.composite import tail, core, MultitypeFuncGraph, env_get, hyper_add, \
  25. zeros_like, ones_like
  26. from ...ops.composite.base import _append
  27. __all__ = ['MultitypeFuncGraph', 'env_get', 'hyper_add', 'zeros_like', 'ones_like']
  28. trans = P.Transpose()
  29. def transpose(x):
  30. """Implementation of `transpose`."""
  31. shape = F.shape(x)
  32. length = F.tuple_len(shape)
  33. perm = F.make_range(0, length)
  34. revert_perm = F.tuple_reversed(perm)
  35. out = trans(x, revert_perm)
  36. return out
  37. def getitem(data, item):
  38. """Implementation of `getitem`."""
  39. return data.__getitem__(item)
  40. def setitem(data, item, value):
  41. """Implementation of `setitem`."""
  42. return data.__setitem__(item, value)
  43. def ms_iter(xs):
  44. """Implementation of `iter`."""
  45. return xs.__ms_iter__()
  46. def ms_next(it):
  47. """Implementation of `next`."""
  48. return it.__ms_next__()
  49. def hasnext(it):
  50. """Implementation of `hasnext`."""
  51. return it.__ms_hasnext__()
  52. def ms_len(data):
  53. """Implementation of `len`."""
  54. return data.__len__()
  55. def floor(x):
  56. """Implementation of `floor`."""
  57. return x.__floor__()
  58. def trunc(x):
  59. """Implementation of `trunc`."""
  60. return x.__trunc__()
  61. def uadd(x):
  62. """Implementation of `uadd`."""
  63. return x.__pos__()
  64. def usub(x):
  65. """Implementation of `usub`."""
  66. return x.__neg__()
  67. def scalar_truediv(x, y):
  68. """Implementation of `scalar_truediv`."""
  69. return x.__truediv__(y)
  70. def scalar_floordiv(x, y):
  71. """Implementation of `scalar_floordiv`."""
  72. return x.__floordiv__(y)
  73. def bool_(x):
  74. """Implementation of `bool`."""
  75. return x.__bool__()
  76. def enumerate_(x, start=0):
  77. """Enumerate list or tuple."""
  78. x_type = F.typeof(x)
  79. ret = ()
  80. op_name = "enumerate"
  81. if check_is_tuple_or_list(x_type, op_name, "first input") and check_is_const_int(start, op_name, "start"):
  82. ret = zip(range(start, start + len(x)), x)
  83. return ret
  84. def isinstance_(x, base_type):
  85. """Determine whether x is an instance of base_type."""
  86. x_type = F.typeof(x)
  87. return check_type_same(x_type, base_type)
  88. def while_cond(x):
  89. """For while condtion, if the condition is a tensor, the loop will not be unrolled"""
  90. if F.issubclass_(F.typeof(x), F.typeof(mstype.tensor)):
  91. is_cond = check_is_tensor_bool_cond(F.shape(x))
  92. if is_cond:
  93. return F.cast(x, mstype.bool_)
  94. return x
  95. @constexpr
  96. def check_type_same(x_type, base_type):
  97. """Check x_type is same as base_type."""
  98. return mstype.issubclass_(x_type, base_type)
  99. @constexpr
  100. def check_is_tuple_or_list(x, op_name, arg_name):
  101. """check whether x is list or tuple."""
  102. if isinstance(x, (mstype.list_type, mstype.tuple_type)):
  103. return True
  104. raise TypeError(f"For '{op_name}', the '{arg_name}' should be tuple or list, but got {x}.")
  105. @constexpr
  106. def check_is_const_int(x, op_name, arg_name):
  107. """check whether x is const int."""
  108. if x is None:
  109. raise TypeError(f"For '{op_name}', the '{arg_name}' should be a const int number, but got not const.")
  110. if not isinstance(x, int):
  111. raise TypeError(f"For '{op_name}', the '{arg_name}' should be a const int number, but got {x}.")
  112. return True
  113. @constexpr
  114. def check_is_tensor_bool_cond(shp):
  115. """check if tensor is a bool condition"""
  116. if shp in ((), (1,)):
  117. return True
  118. raise ValueError("The truth value of an array with several elements is ambiguous.")
  119. @constexpr
  120. def const_tensor_to_bool(x):
  121. """convert bool tensor to bool condition"""
  122. if x is None:
  123. raise ValueError("Only constant tensor bool can be converted to bool")
  124. x = x.asnumpy()
  125. if x.shape not in ((), (1,)):
  126. raise ValueError("The truth value of an array with several elements is ambiguous.")
  127. if x.shape == ():
  128. value = bool(x)
  129. else:
  130. value = bool(x[0])
  131. return value
  132. def tensor_bool(x):
  133. """tensor as conditon, if is constant, return immediate bool value"""
  134. is_cond = check_is_tensor_bool_cond(F.shape(x))
  135. if is_cond and F.isconstant(x):
  136. return const_tensor_to_bool(x)
  137. return F.cast(x, mstype.bool_)
  138. def and_(x, y):
  139. """Implementation of `and` (`&`)."""
  140. return x.__and__(y)
  141. def or_(x, y):
  142. """Implementation of `or` (`|`)."""
  143. return x.__or__(y)
  144. def matmul(x, y):
  145. """Implementation of `matmul` (`@`)."""
  146. return x.__matmul__(y)
  147. def float_bool(x):
  148. """Implementation of `float_bool`."""
  149. return x != 0.0
  150. def int_bool(x):
  151. """Implementation of `int_bool`."""
  152. return x != 0
  153. def str_bool(x):
  154. """Implementation of `str_bool`."""
  155. if x == "":
  156. return False
  157. return True
  158. def list_bool(x):
  159. """Implementation of `tuple_bool`."""
  160. return len(x) != 0
  161. def tuple_bool(x):
  162. """Implementation of `tuple_bool`."""
  163. return len(x) != 0
  164. def dict_bool(x):
  165. """Implementation of `dict_bool`."""
  166. return len(x) != 0
  167. def none_bool(x):
  168. """Implementation of `none_bool`."""
  169. return False
  170. def float_floordiv(x, y):
  171. """Implementation of `float_floordiv`."""
  172. return floor(x / y)
  173. #############
  174. # Iteration #
  175. #############
  176. @dataclass(frozen=True)
  177. class SequenceIterator:
  178. """
  179. SequenceIterator is a util dataclass for iterating sequence object.
  180. Iterator to use for sequences like List, Array.
  181. """
  182. idx: int
  183. seq: list
  184. @core(ignore_values=True)
  185. def __ms_hasnext__(self):
  186. """Whether the index is past the length of the sequence."""
  187. return self.idx < ms_len(self.seq)
  188. @core(ignore_values=True)
  189. def __ms_next__(self):
  190. """Return the next element and a new iterator."""
  191. return self.seq[self.idx], SequenceIterator(self.idx + 1, self.seq)
  192. def list_iter(xs):
  193. """Iterator for List."""
  194. return SequenceIterator(0, xs)
  195. def array_iter(xs):
  196. """Iterator for Array."""
  197. return SequenceIterator(0, xs)
  198. def tuple_next(xs):
  199. """Next tuple."""
  200. return xs[0], tail(xs)
  201. def tuple_hasnext(xs):
  202. """Whether the tuple is empty or not."""
  203. return len(xs) > 0
  204. def list_next(xs):
  205. """Next list."""
  206. return xs[0], tail(xs)
  207. def list_hasnext(xs):
  208. """Whether the list is empty or not."""
  209. return len(xs) > 0
  210. def list_append(self_, item):
  211. return _append(self_, item)
  212. #################
  213. # Array methods #
  214. #################
  215. def to_array(x):
  216. """Implementation of `to_array`."""
  217. return x.__ms_to_array__()
  218. tensor_operator_registry.register('__bool__', tensor_bool)