You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

standard_method.py 11 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430
  1. # This is the Python adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
  2. #
  3. # Copyright 2020 Huawei Technologies Co., Ltd
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. # ============================================================================
  17. """standard_method"""
  18. from dataclasses import dataclass
  19. from mindspore.common import dtype as mstype
  20. from ...ops import functional as F
  21. from ...ops import operations as P
  22. from ...ops.primitive import constexpr
  23. from ...ops.composite import tail, core, MultitypeFuncGraph, env_get, hyper_add, \
  24. zeros_like, ones_like
  25. from ...ops.composite.base import _append
  26. __all__ = ['MultitypeFuncGraph', 'env_get', 'hyper_add', 'zeros_like', 'ones_like']
  27. trans = P.Transpose()
  28. shape_ = P.Shape()
  29. reshape_ = P.Reshape()
  30. dtype_ = P.DType()
  31. abs_ = P.Abs()
  32. def mean(x, axis=(), keep_dims=False):
  33. """
  34. Reduce a dimension of a tensor by averaging all elements in the dimension.
  35. Args:
  36. axis (Union[None, int, tuple(int)]): Dimensions of reduction,
  37. when axis is None or empty tuple, reduce all dimensions.
  38. Default: (), reduce all dimensions.
  39. keep_dims (bool): Whether to keep the reduced dimensions.
  40. Default : False, don't keep these reduced dimensions.
  41. Returns:
  42. Tensor, has the same data type as x.
  43. """
  44. if axis is None:
  45. axis = ()
  46. reduce_mean = P.ReduceMean(keep_dims)
  47. return reduce_mean(x, axis)
  48. def all_(x, axis=(), keep_dims=False):
  49. """
  50. Check all array elements along a given axis evaluate to True.
  51. Args:
  52. x (Tensor): A Tensor to be reduced.
  53. axis (Union[None, int, tuple(int)): Dimensions of reduction.
  54. keep_dims (bool): Whether to keep the reduced dimensions.
  55. Returns:
  56. Tensor, has the same data type as x.
  57. """
  58. if axis is None:
  59. axis = ()
  60. reduce_all = P.ReduceAll(keep_dims)
  61. return reduce_all(x, axis)
  62. def any_(x, axis=(), keep_dims=False):
  63. """
  64. Check any array element along a given axis evaluate to True.
  65. Args:
  66. x (Tensor): A Tensor to be reduced.
  67. axis (Union[None, int, tuple(int)): Dimensions of reduction.
  68. keep_dims (bool): Whether to keep the reduced dimensions.
  69. Returns:
  70. Tensor, has the same data type as x.
  71. """
  72. if axis is None:
  73. axis = ()
  74. reduce_any = P.ReduceAny(keep_dims)
  75. return reduce_any(x, axis)
  76. def transpose(x, *axis):
  77. """Implementation of `transpose`."""
  78. new_order = None
  79. shape = F.shape(x)
  80. length = F.tuple_len(shape)
  81. if not axis:
  82. perm = F.make_range(0, length)
  83. new_order = F.tuple_reversed(perm)
  84. elif len(axis) == 1:
  85. new_order = convert_list_to_tuple(axis[0])
  86. elif len(axis) == length:
  87. new_order = axis
  88. out = trans(x, new_order)
  89. return out
  90. def getitem(data, item):
  91. """Implementation of `getitem`."""
  92. return data.__getitem__(item)
  93. def setitem(data, item, value):
  94. """Implementation of `setitem`."""
  95. return data.__setitem__(item, value)
  96. def ms_iter(xs):
  97. """Implementation of `iter`."""
  98. return xs.__ms_iter__()
  99. def ms_next(it):
  100. """Implementation of `next`."""
  101. return it.__ms_next__()
  102. def hasnext(it):
  103. """Implementation of `hasnext`."""
  104. return it.__ms_hasnext__()
  105. def ms_len(data):
  106. """Implementation of `len`."""
  107. return data.__len__()
  108. def floor(x):
  109. """Implementation of `floor`."""
  110. return x.__floor__()
  111. def trunc(x):
  112. """Implementation of `trunc`."""
  113. return x.__trunc__()
  114. def uadd(x):
  115. """Implementation of `uadd`."""
  116. return x.__pos__()
  117. def usub(x):
  118. """Implementation of `usub`."""
  119. return x.__neg__()
  120. def scalar_truediv(x, y):
  121. """Implementation of `scalar_truediv`."""
  122. return x.__truediv__(y)
  123. def scalar_floordiv(x, y):
  124. """Implementation of `scalar_floordiv`."""
  125. return x.__floordiv__(y)
  126. def bool_(x):
  127. """Implementation of `bool`."""
  128. return x.__bool__()
  129. def enumerate_(x, start=0):
  130. """Enumerate list or tuple or tensor."""
  131. x_type = F.typeof(x)
  132. ret = ()
  133. op_name = "enumerate"
  134. if check_is_tuple_or_list_or_tensor(x_type, op_name, "first input") and check_is_const_int(start, op_name, "start"):
  135. if check_is_tensor(x_type):
  136. for i in range(x.shape[0]):
  137. ret += ((start + i, x[i]),)
  138. else:
  139. ret = zip(range(start, start + len(x)), x)
  140. return ret
  141. def expand_tensor_as(x, y):
  142. """Expand tensor"""
  143. broadcast_to = P.BroadcastTo(shape_(y))
  144. return broadcast_to(x)
  145. def view(x, *shape):
  146. """Reshape tensor, if shape is -1, reshape tensor into one dimension"""
  147. shape = check_view_shape(shape)
  148. return reshape_(x, shape)
  149. def isinstance_(x, base_type):
  150. """Determine whether x is an instance of base_type."""
  151. x_type = F.typeof(x)
  152. return check_type_same(x_type, base_type)
  153. def while_cond(x):
  154. """For while condtion, if the condition is a tensor, the loop will not be unrolled"""
  155. if F.issubclass_(F.typeof(x), F.typeof(mstype.tensor)):
  156. is_cond = check_is_tensor_bool_cond(F.shape(x))
  157. if is_cond:
  158. return F.cast(x, mstype.bool_)
  159. return x
  160. @constexpr
  161. def check_type_same(x_type, base_type):
  162. """Check x_type is same as base_type."""
  163. if mstype.issubclass_(x_type, base_type):
  164. return True
  165. return False
  166. @constexpr
  167. def check_is_tensor(x):
  168. """check whether x is tensor."""
  169. if isinstance(x, mstype.tensor_type):
  170. return True
  171. return False
  172. @constexpr
  173. def check_is_tuple_or_list_or_tensor(x, op_name, arg_name):
  174. """check whether x is list or tuple or tensor."""
  175. if isinstance(x, (mstype.list_type, mstype.tuple_type, mstype.tensor_type)):
  176. return True
  177. raise TypeError(f"For '{op_name}', the '{arg_name}' should be tuple or list or tensor, but got {x}.")
  178. @constexpr
  179. def check_is_const_int(x, op_name, arg_name):
  180. """check whether x is const int."""
  181. if x is None:
  182. raise TypeError(f"For '{op_name}', the '{arg_name}' should be a const int number, but got not const.")
  183. if not isinstance(x, int):
  184. raise TypeError(f"For '{op_name}', the '{arg_name}' should be a const int number, but got {x}.")
  185. return True
  186. @constexpr
  187. def check_is_tensor_bool_cond(shp):
  188. """check if tensor is a bool condition"""
  189. if shp in ((), (1,)):
  190. return True
  191. raise ValueError("The truth value of an array with several elements is ambiguous.")
  192. @constexpr
  193. def const_tensor_to_bool(x):
  194. """convert bool tensor to bool condition"""
  195. if x is None:
  196. raise ValueError("Only constant tensor bool can be converted to bool")
  197. x = x.asnumpy()
  198. if x.shape == ():
  199. return bool(x)
  200. if x.shape == (1,):
  201. return bool(x[0])
  202. raise ValueError("The truth value of an array with several elements is ambiguous.")
  203. @constexpr
  204. def check_view_shape(x):
  205. """Check view function input shape"""
  206. if not x:
  207. raise ValueError("The shape variable should not be empty")
  208. if isinstance(x[0], tuple):
  209. if len(x) != 1:
  210. raise ValueError(f"Only one tuple is needed, but got {x}")
  211. x = x[0]
  212. return x
  213. @constexpr
  214. def convert_list_to_tuple(shp):
  215. """Check the type of the shape, if is list, convert to tuple"""
  216. if not isinstance(shp, (list, tuple)):
  217. raise ValueError(f"The shape variable should be a list or tuple, but got {type(shp)}")
  218. if isinstance(shp, list):
  219. shp = tuple(shp)
  220. return shp
  221. def tensor_bool(x):
  222. """tensor as conditon, if is constant, return immediate bool value"""
  223. is_cond = check_is_tensor_bool_cond(F.shape(x))
  224. if is_cond and F.isconstant(x):
  225. return const_tensor_to_bool(x)
  226. return F.cast(x, mstype.bool_)
  227. def and_(x, y):
  228. """Implementation of `and` (`&`)."""
  229. return x.__and__(y)
  230. def or_(x, y):
  231. """Implementation of `or` (`|`)."""
  232. return x.__or__(y)
  233. def matmul(x, y):
  234. """Implementation of `matmul` (`@`)."""
  235. return x.__matmul__(y)
  236. def float_bool(x):
  237. """Implementation of `float_bool`."""
  238. return x != 0.0
  239. def int_bool(x):
  240. """Implementation of `int_bool`."""
  241. return x != 0
  242. def str_bool(x):
  243. """Implementation of `str_bool`."""
  244. if x == "":
  245. return False
  246. return True
  247. def list_bool(x):
  248. """Implementation of `tuple_bool`."""
  249. return len(x) != 0
  250. def tuple_bool(x):
  251. """Implementation of `tuple_bool`."""
  252. return len(x) != 0
  253. def dict_bool(x):
  254. """Implementation of `dict_bool`."""
  255. return len(x) != 0
  256. def none_bool(x):
  257. """Implementation of `none_bool`."""
  258. return False
  259. def func_bool(x):
  260. """Implementation of `func_bool`."""
  261. return True
  262. def float_floordiv(x, y):
  263. """Implementation of `float_floordiv`."""
  264. return floor(x / y)
  265. #############
  266. # Iteration #
  267. #############
  268. @dataclass(frozen=True)
  269. class SequenceIterator:
  270. """
  271. SequenceIterator is a util dataclass for iterating sequence object.
  272. Iterator to use for sequences like List, Array.
  273. """
  274. idx: int
  275. seq: list
  276. @core(ignore_values=True)
  277. def __ms_hasnext__(self):
  278. """Whether the index is past the length of the sequence."""
  279. return self.idx < ms_len(self.seq)
  280. @core(ignore_values=True)
  281. def __ms_next__(self):
  282. """Return the next element and a new iterator."""
  283. return self.seq[self.idx], SequenceIterator(self.idx + 1, self.seq)
  284. def list_iter(xs):
  285. """Iterator for List."""
  286. return SequenceIterator(0, xs)
  287. def array_iter(xs):
  288. """Iterator for Array."""
  289. return SequenceIterator(0, xs)
  290. def tuple_next(xs):
  291. """Next tuple."""
  292. return xs[0], tail(xs)
  293. def tuple_hasnext(xs):
  294. """Whether the tuple is empty or not."""
  295. return len(xs) > 0
  296. def list_next(xs):
  297. """Next list."""
  298. return xs[0], tail(xs)
  299. def list_hasnext(xs):
  300. """Whether the list is empty or not."""
  301. return len(xs) > 0
  302. def list_append(self_, item):
  303. return _append(self_, item)
  304. #################
  305. # Array methods #
  306. #################
  307. def to_array(x):
  308. """Implementation of `to_array`."""
  309. return x.__ms_to_array__()