You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

standard_method.py 6.1 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270
  1. # This is the Python adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
  2. #
  3. # Copyright 2020 Huawei Technologies Co., Ltd
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. # ============================================================================
  17. """standard_method"""
  18. from dataclasses import dataclass
  19. from mindspore.common import dtype as mstype
  20. from ...ops import functional as F
  21. from ...ops import operations as P
  22. from ...ops.primitive import constexpr
  23. from ...ops.composite import tail, core, MultitypeFuncGraph, env_get, hyper_add, \
  24. zeros_like, ones_like
  25. from ...ops.composite.base import _append
  26. __all__ = ['MultitypeFuncGraph', 'env_get', 'hyper_add', 'zeros_like', 'ones_like']
  27. trans = P.Transpose()
  28. def transpose(x):
  29. """Implementation of `transpose`."""
  30. shape = F.shape(x)
  31. length = F.tuple_len(shape)
  32. perm = F.make_range(0, length)
  33. revert_perm = F.tuple_reversed(perm)
  34. out = trans(x, revert_perm)
  35. return out
  36. def getitem(data, item):
  37. """Implementation of `getitem`."""
  38. return data.__getitem__(item)
  39. def setitem(data, item, value):
  40. """Implementation of `setitem`."""
  41. return data.__setitem__(item, value)
  42. def ms_iter(xs):
  43. """Implementation of `iter`."""
  44. return xs.__ms_iter__()
  45. def ms_next(it):
  46. """Implementation of `next`."""
  47. return it.__ms_next__()
  48. def hasnext(it):
  49. """Implementation of `hasnext`."""
  50. return it.__ms_hasnext__()
  51. def ms_len(data):
  52. """Implementation of `len`."""
  53. return data.__len__()
  54. def floor(x):
  55. """Implementation of `floor`."""
  56. return x.__floor__()
  57. def trunc(x):
  58. """Implementation of `trunc`."""
  59. return x.__trunc__()
  60. def uadd(x):
  61. """Implementation of `uadd`."""
  62. return x.__pos__()
  63. def usub(x):
  64. """Implementation of `usub`."""
  65. return x.__neg__()
  66. def scalar_truediv(x, y):
  67. """Implementation of `scalar_truediv`."""
  68. return x.__truediv__(y)
  69. def scalar_floordiv(x, y):
  70. """Implementation of `scalar_floordiv`."""
  71. return x.__floordiv__(y)
  72. def bool_(x):
  73. """Implementation of `bool`."""
  74. return x.__bool__()
  75. def while_cond(x):
  76. """For while condtion, if the condition is a tensor, the loop will not be unrolled"""
  77. if F.issubclass_(F.typeof(x), F.typeof(mstype.tensor)):
  78. is_cond = check_is_tensor_bool_cond(F.shape(x))
  79. if is_cond:
  80. return F.cast(x, mstype.bool_)
  81. return x
  82. @constexpr
  83. def check_is_tensor_bool_cond(shp):
  84. """check if tensor is a bool condition"""
  85. if shp in ((), (1,)):
  86. return True
  87. raise ValueError("tensor as bool condition, its shape should be () or (1,), but got ", shp)
  88. @constexpr
  89. def const_tensor_to_bool(x):
  90. """convert bool tensor to bool condition"""
  91. if x is None:
  92. raise ValueError("Only constant tensor bool can be converted to bool")
  93. x = x.asnumpy()
  94. if x.shape not in ((), (1,)):
  95. raise ValueError("Tensor to bool should input shape () or (1), but got ", x.shape)
  96. if x.shape == ():
  97. value = bool(x)
  98. else:
  99. value = bool(x[0])
  100. return value
  101. def tensor_bool(x):
  102. """tensor as conditon, if is constant, return immediate bool value"""
  103. is_cond = check_is_tensor_bool_cond(F.shape(x))
  104. if is_cond and F.isconstant(x):
  105. return const_tensor_to_bool(x)
  106. return F.cast(x, mstype.bool_)
  107. def and_(x, y):
  108. """Implementation of `and` (`&`)."""
  109. return x.__and__(y)
  110. def or_(x, y):
  111. """Implementation of `or` (`|`)."""
  112. return x.__or__(y)
  113. def matmul(x, y):
  114. """Implementation of `matmul` (`@`)."""
  115. return x.__matmul__(y)
  116. def float_bool(x):
  117. """Implementation of `float_bool`."""
  118. return x != 0.0
  119. def int_bool(x):
  120. """Implementation of `int_bool`."""
  121. return x != 0
  122. def str_bool(x):
  123. """Implementation of `str_bool`."""
  124. if x == "":
  125. return False
  126. return True
  127. def list_bool(x):
  128. """Implementation of `tuple_bool`."""
  129. return len(x) != 0
  130. def tuple_bool(x):
  131. """Implementation of `tuple_bool`."""
  132. return len(x) != 0
  133. def dict_bool(x):
  134. """Implementation of `dict_bool`."""
  135. return len(x) != 0
  136. def none_bool(x):
  137. """Implementation of `none_bool`."""
  138. return False
  139. def float_floordiv(x, y):
  140. """Implementation of `float_floordiv`."""
  141. return floor(x / y)
  142. #############
  143. # Iteration #
  144. #############
  145. @dataclass(frozen=True)
  146. class SequenceIterator:
  147. """
  148. SequenceIterator is a util dataclass for iterating sequence object.
  149. Iterator to use for sequences like List, Array.
  150. """
  151. idx: int
  152. seq: list
  153. @core(ignore_values=True)
  154. def __ms_hasnext__(self):
  155. """Whether the index is past the length of the sequence."""
  156. return self.idx < ms_len(self.seq)
  157. @core(ignore_values=True)
  158. def __ms_next__(self):
  159. """Return the next element and a new iterator."""
  160. return self.seq[self.idx], SequenceIterator(self.idx + 1, self.seq)
  161. def list_iter(xs):
  162. """Iterator for List."""
  163. return SequenceIterator(0, xs)
  164. def array_iter(xs):
  165. """Iterator for Array."""
  166. return SequenceIterator(0, xs)
  167. def tuple_next(xs):
  168. """Next tuple."""
  169. return xs[0], tail(xs)
  170. def tuple_hasnext(xs):
  171. """Whether the tuple is empty or not."""
  172. return len(xs) > 0
  173. def list_next(xs):
  174. """Next list."""
  175. return xs[0], tail(xs)
  176. def list_hasnext(xs):
  177. """Whether the list is empty or not."""
  178. return len(xs) > 0
  179. def list_append(self_, item):
  180. return _append(self_, item)
  181. #################
  182. # Array methods #
  183. #################
  184. def to_array(x):
  185. """Implementation of `to_array`."""
  186. return x.__ms_to_array__()