You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

standard_method.py 4.9 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235
  1. # This is the Python adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
  2. #
  3. # Copyright 2020 Huawei Technologies Co., Ltd
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. # ============================================================================
  17. """standard_method"""
  18. from dataclasses import dataclass
  19. from ...ops import functional as F
  20. from ...ops import operations as P
  21. from ...ops.composite import tail, core, MultitypeFuncGraph, env_get, hyper_add, \
  22. zeros_like, ones_like
  23. from ...ops.composite.base import _append
  24. __all__ = ['MultitypeFuncGraph', 'env_get', 'hyper_add', 'zeros_like', 'ones_like']
  25. trans = P.Transpose()
  26. def transpose(x):
  27. """Implementation of `transpose`."""
  28. shape = F.shape(x)
  29. length = F.tuple_len(shape)
  30. perm = F.make_range(0, length)
  31. revert_perm = F.tuple_reversed(perm)
  32. out = trans(x, revert_perm)
  33. return out
  34. def getitem(data, item):
  35. """Implementation of `getitem`."""
  36. return data.__getitem__(item)
  37. def setitem(data, item, value):
  38. """Implementation of `setitem`."""
  39. return data.__setitem__(item, value)
  40. def ms_iter(xs):
  41. """Implementation of `iter`."""
  42. return xs.__ms_iter__()
  43. def ms_next(it):
  44. """Implementation of `next`."""
  45. return it.__ms_next__()
  46. def hasnext(it):
  47. """Implementation of `hasnext`."""
  48. return it.__ms_hasnext__()
  49. def ms_len(data):
  50. """Implementation of `len`."""
  51. return data.__len__()
  52. def floor(x):
  53. """Implementation of `floor`."""
  54. return x.__floor__()
  55. def trunc(x):
  56. """Implementation of `trunc`."""
  57. return x.__trunc__()
  58. def uadd(x):
  59. """Implementation of `uadd`."""
  60. return x.__pos__()
  61. def usub(x):
  62. """Implementation of `usub`."""
  63. return x.__neg__()
  64. def scalar_truediv(x, y):
  65. """Implementation of `scalar_truediv`."""
  66. return x.__truediv__(y)
  67. def scalar_floordiv(x, y):
  68. """Implementation of `scalar_floordiv`."""
  69. return x.__floordiv__(y)
  70. def bool_(x):
  71. """Implementation of `bool`."""
  72. return x.__bool__()
  73. def tensor_bool(x):
  74. """return immedate x, x is a tensor of bool value"""
  75. return x
  76. def and_(x, y):
  77. """Implementation of `and` (`&`)."""
  78. return x.__and__(y)
  79. def or_(x, y):
  80. """Implementation of `or` (`|`)."""
  81. return x.__or__(y)
  82. def matmul(x, y):
  83. """Implementation of `matmul` (`@`)."""
  84. return x.__matmul__(y)
  85. def float_bool(x):
  86. """Implementation of `float_bool`."""
  87. return x != 0.0
  88. def int_bool(x):
  89. """Implementation of `int_bool`."""
  90. return x != 0
  91. def str_bool(x):
  92. """Implementation of `str_bool`."""
  93. if x == "":
  94. return False
  95. return True
  96. def list_bool(x):
  97. """Implementation of `tuple_bool`."""
  98. return len(x) != 0
  99. def tuple_bool(x):
  100. """Implementation of `tuple_bool`."""
  101. return len(x) != 0
  102. def dict_bool(x):
  103. """Implementation of `dict_bool`."""
  104. return len(x) != 0
  105. def none_bool(x):
  106. """Implementation of `none_bool`."""
  107. return False
  108. def float_floordiv(x, y):
  109. """Implementation of `float_floordiv`."""
  110. return floor(x / y)
  111. #############
  112. # Iteration #
  113. #############
  114. @dataclass(frozen=True)
  115. class SequenceIterator:
  116. """
  117. SequenceIterator is a util dataclass for iterating sequence object.
  118. Iterator to use for sequences like List, Array.
  119. """
  120. idx: int
  121. seq: list
  122. @core(ignore_values=True)
  123. def __ms_hasnext__(self):
  124. """Whether the index is past the length of the sequence."""
  125. return self.idx < ms_len(self.seq)
  126. @core(ignore_values=True)
  127. def __ms_next__(self):
  128. """Return the next element and a new iterator."""
  129. return self.seq[self.idx], SequenceIterator(self.idx + 1, self.seq)
  130. def list_iter(xs):
  131. """Iterator for List."""
  132. return SequenceIterator(0, xs)
  133. def array_iter(xs):
  134. """Iterator for Array."""
  135. return SequenceIterator(0, xs)
  136. def tuple_next(xs):
  137. """Next tuple."""
  138. return xs[0], tail(xs)
  139. def tuple_hasnext(xs):
  140. """Whether the tuple is empty or not."""
  141. return len(xs) > 0
  142. def list_next(xs):
  143. """Next list."""
  144. return xs[0], tail(xs)
  145. def list_hasnext(xs):
  146. """Whether the list is empty or not."""
  147. return len(xs) > 0
  148. def list_append(self_, item):
  149. return _append(self_, item)
  150. #################
  151. # Array methods #
  152. #################
  153. def to_array(x):
  154. """Implementation of `to_array`."""
  155. return x.__ms_to_array__()