# This is the Python adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). # # Copyright 2020-2021 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ """standard_method""" from dataclasses import dataclass from mindspore import Tensor, Parameter from mindspore import dtype as mstype from ..._checkparam import Validator as validator from ...ops import functional as F from ...ops import operations as P from ...ops.composite import tail, core, MultitypeFuncGraph, env_get, hyper_add, \ zeros_like, ones_like from ...ops.composite.base import _append from ...ops.primitive import constexpr __all__ = ['MultitypeFuncGraph', 'env_get', 'hyper_add', 'zeros_like', 'ones_like'] shape_ = P.Shape() dtype_ = P.DType() abs_ = P.Abs() ndim_ = P.Rank() size_ = P.Size() itemsize_map = {mstype.bool_: 1, mstype.int8: 1, mstype.uint8: 1, mstype.float16: 2, mstype.int16: 2, mstype.uint16: 2, mstype.float32: 4, mstype.int32: 4, mstype.uint32: 4, mstype.float64: 8, mstype.int64: 8, mstype.uint64: 8} def mean(x, axis=(), keep_dims=False): """ Reduces a dimension of a tensor by averaging all elements in the dimension. Args: axis (Union[None, int, tuple(int)]): Dimensions of reduction, when axis is None or empty tuple, reduce all dimensions. Default: (), reduce all dimensions. keep_dims (bool): Whether to keep the reduced dimensions. Default : False, don't keep these reduced dimensions. Returns: Tensor, has the same data type as x. """ if axis is None: axis = () reduce_mean = P.ReduceMean(keep_dims) return reduce_mean(x, axis) def all_(x, axis=(), keep_dims=False): """ Check all array elements along a given axis evaluate to True. Args: x (Tensor): A Tensor to be reduced. axis (Union[None, int, tuple(int)): Dimensions of reduction. keep_dims (bool): Whether to keep the reduced dimensions. Returns: Tensor, has the same data type as x. """ if axis is None: axis = () reduce_all = P.ReduceAll(keep_dims) return reduce_all(x, axis) def any_(x, axis=(), keep_dims=False): """ Check any array element along a given axis evaluate to True. Args: x (Tensor): A Tensor to be reduced. axis (Union[None, int, tuple(int)): Dimensions of reduction. keep_dims (bool): Whether to keep the reduced dimensions. Returns: Tensor, has the same data type as x. """ if axis is None: axis = () reduce_any = P.ReduceAny(keep_dims) return reduce_any(x, axis) def itemsize_(x): """ Return length of one tensor element in bytes. Args: x (Tensor): Input tensor. Returns: itemsize(int). """ return get_itemsize(x.dtype) def nbytes_(x): """ Return total number of bytes taken by the tensor. Args: x (Tensor): Input tensor. Returns: nbytes(int). """ return itemsize_(x) * F.shape_mul(shape_(x)) def strides_(x): """ Return the tuple of bytes to step in each dimension when traversing a tensor. Args: x (Tensor): Input tensor. Returns: strides (tuple[int]). """ strides = () ndim = P.Rank()(x) tensor_shape = shape_(x) for i in F.make_range(0, ndim): stride = itemsize_(x) for j in F.make_range(i + 1, ndim): stride *= tensor_shape[j] strides += (stride,) return strides def astype(x, dtype, copy=True): """Implementation of `astype`.""" dtype = check_astype_dtype_const(dtype) if not copy and dtype == x.dtype: return x return F.cast(x, dtype) def transpose(x, *axis): """Implementation of `transpose`.""" ndim = F.rank(x) perm = check_transpose_axis_const(axis, ndim) return F.transpose(x, perm) # `tensor.T` is used as a property in graph mode T_ = transpose def reshape(x, *shape): """Implementation of `reshape`.""" new_shape = check_reshape_shp_const(shape) return F.reshape(x, new_shape) def ravel(x): """Implementation of `ravel`.""" return reshape(x, (-1,)) def flatten(x, order='C'): """ Returns a copy of the array collapsed into one dimension. Args: order (str, optional): Can choose between `C` and `F`. `C` means to flatten in row-major (C-style) order. ‘F’ means to flatten in column-major (Fortran- style) order. Only `C` and `F` are supported. Returns: Tensor, has the same data type as x. """ order = check_flatten_order_const(order) if order == 'C': return F.reshape(x, (-1,)) perm = F.make_range(0, F.rank(x)) new_order = F.tuple_reversed(perm) return F.reshape(F.transpose(x, new_order), (-1,)) def swapaxes(x, axis1, axis2): """ Interchanges two axes of a tensor. Args: axis1 (int): First axis. axis2 (int): Second axis. Returns: Transposed tensor, has the same data type as the original tensor x. """ axis1, axis2 = check_swapaxes_axis_const((axis1, axis2), x.ndim) if axis1 == axis2: return x if axis1 > axis2: axis1, axis2 = axis2, axis1 perm = F.make_range(0, x.ndim) new_perm = None if axis2 + 1 < x.ndim: new_perm = perm[0:axis1] + perm[axis2:axis2 + 1] + \ perm[axis1 + 1:axis2] + perm[axis1:axis1 + 1] + perm[axis2 + 1:] else: new_perm = perm[0:axis1] + perm[axis2:axis2 + 1] + \ perm[axis1 + 1:axis2] + perm[axis1:axis1 + 1] return F.transpose(x, new_perm) def squeeze(x, axis=None): """ Removes single-dimensional entries from the shape of an tensor. Args: axis: Union[None, int, list(int), tuple(list)]. Default is None. Returns: Tensor, with all or a subset of the dimensions of length 1 removed. """ shape = F.shape(x) if axis is None: return F.squeeze(x) # yield squeezed shape based on the axes new_shape = prepare_shape_for_squeeze_const(shape, axis) return F.reshape(x, new_shape) def getitem(data, item): """Implementation of `getitem`.""" return data.__getitem__(item) def setitem(data, item, value): """Implementation of `setitem`.""" return data.__setitem__(item, value) def ms_iter(xs): """Implementation of `iter`.""" return xs.__ms_iter__() def ms_next(it): """Implementation of `next`.""" return it.__ms_next__() def hasnext(it): """Implementation of `hasnext`.""" return it.__ms_hasnext__() def ms_len(data): """Implementation of `len`.""" return data.__len__() def floor(x): """Implementation of `floor`.""" return x.__floor__() def trunc(x): """Implementation of `trunc`.""" return x.__trunc__() def uadd(x): """Implementation of `uadd`.""" return x.__pos__() def usub(x): """Implementation of `usub`.""" return x.__neg__() def scalar_truediv(x, y): """Implementation of `scalar_truediv`.""" return x.__truediv__(y) def scalar_floordiv(x, y): """Implementation of `scalar_floordiv`.""" return x.__floordiv__(y) def bool_(x): """Implementation of `bool`.""" return x.__bool__() def enumerate_(x, start=0): """Enumerate list or tuple or tensor.""" x_type = F.typeof(x) ret = () op_name = "enumerate" if check_is_tuple_or_list_or_tensor(x_type, op_name, "first input") and check_is_const_int(start, op_name, "start"): if check_is_tensor(x_type): for i in range(x.shape[0]): ret += ((start + i, x[i]),) else: ret = zip(range(start, start + len(x)), x) return ret def expand_tensor_as(x, y): """Expand tensor""" broadcast_to = P.BroadcastTo(shape_(y)) return broadcast_to(x) def view(x, *shape): """Reshape tensor, if shape is -1, reshape tensor into one dimension""" shape = check_view_shape(shape) return F.reshape(x, shape) def isinstance_(x, base_type): """Determine whether x is an instance of base_type.""" x_type = F.typeof(x) return check_type_same(x_type, base_type) def while_cond(x): """For while condition, if the condition is a tensor, the loop will not be unrolled""" if F.issubclass_(F.typeof(x), F.typeof(mstype.tensor)): is_cond = check_is_tensor_bool_cond(F.shape(x)) if is_cond: return F.cast(x, mstype.bool_) return x @constexpr def check_type_same(x_type, base_type): """Check x_type is same as base_type.""" pytype_to_mstype = { bool: mstype.Bool, int: mstype.Int, float: mstype.Float, str: mstype.String, list: mstype.List, tuple: mstype.Tuple, dict: mstype.Dict, Tensor: mstype.tensor_type, Parameter: mstype.ref_type } has_int = False has_tensor = False def to_target_type(origin_type): try: if isinstance(origin_type, type): ret_type = pytype_to_mstype[origin_type] if ret_type == mstype.Int: nonlocal has_int has_int = True if ret_type == mstype.tensor_type: nonlocal has_tensor has_tensor = True return (ret_type,) if isinstance(origin_type, tuple): return tuple(to_target_type(i) for i in origin_type) raise TypeError(f"The second arg of 'isinstance' must be a type or a tuple of types, " f"but got a {type(origin_type).__name__}") except KeyError: raise TypeError(f"The second arg of 'isinstance' should be bool, int, float, str, list, tuple, " f"Tensor, Parameter, or a tuple containing only these types, but got {origin_type}") target_type = to_target_type(base_type) if (isinstance(x_type, mstype.Bool) and has_int) or (isinstance(x_type, mstype.ref_type) and has_tensor): return True return isinstance(x_type, target_type) @constexpr def get_itemsize(x_type): """get itemsize from tensor's dtype.""" return itemsize_map[x_type] @constexpr def check_is_tensor(x): """check whether x is tensor.""" if isinstance(x, mstype.tensor_type): return True return False @constexpr def check_is_tuple_or_list_or_tensor(x, op_name, arg_name): """check whether x is list or tuple or tensor.""" if isinstance(x, (mstype.List, mstype.Tuple, mstype.tensor_type)): return True raise TypeError(f"For '{op_name}', the '{arg_name}' should be tuple or list or tensor, but got {x}.") @constexpr def check_is_const_int(x, op_name, arg_name): """check whether x is const int.""" if x is None: raise TypeError(f"For '{op_name}', the '{arg_name}' should be a const int number, but got not const.") if not isinstance(x, int): raise TypeError(f"For '{op_name}', the '{arg_name}' should be a const int number, but got {x}.") return True @constexpr def check_is_tensor_bool_cond(shp): """check if tensor is a bool condition""" if shp in ((), (1,)): return True raise ValueError("The truth value of an array with several elements is ambiguous.") @constexpr def const_tensor_to_bool(x): """convert bool tensor to bool condition""" if x is None: raise ValueError("Only constant tensor bool can be converted to bool") x = x.asnumpy() if x.shape == (): return bool(x) if x.shape == (1,): return bool(x[0]) raise ValueError("The truth value of an array with several elements is ambiguous.") @constexpr def check_view_shape(x): """Check view function input shape""" if not x: raise ValueError("The shape variable should not be empty") if isinstance(x[0], tuple): if len(x) != 1: raise ValueError(f"Only one tuple is needed, but got {x}") x = x[0] return x # convert normal param_check functions to constexpr functions check_astype_dtype_const = constexpr(validator.check_astype_dtype) check_transpose_axis_const = constexpr(validator.check_transpose_axis) check_reshape_shp_const = constexpr(validator.check_reshape_shp) check_flatten_order_const = constexpr(validator.check_flatten_order) check_swapaxes_axis_const = constexpr(validator.check_swapaxes_axis) prepare_shape_for_squeeze_const = constexpr(validator.prepare_shape_for_squeeze) def tensor_bool(x): """tensor as condition, if is constant, return immediate bool value""" is_cond = check_is_tensor_bool_cond(F.shape(x)) if is_cond and F.isconstant(x): return const_tensor_to_bool(x) return F.cast(x, mstype.bool_) def and_(x, y): """Implementation of `and` (`&`).""" return x.__and__(y) def or_(x, y): """Implementation of `or` (`|`).""" return x.__or__(y) def matmul(x, y): """Implementation of `matmul` (`@`).""" return x.__matmul__(y) def float_bool(x): """Implementation of `float_bool`.""" return x != 0.0 def int_bool(x): """Implementation of `int_bool`.""" return x != 0 def str_bool(x): """Implementation of `str_bool`.""" if x == "": return False return True def list_bool(x): """Implementation of `tuple_bool`.""" return len(x) != 0 def tuple_bool(x): """Implementation of `tuple_bool`.""" return len(x) != 0 def dict_bool(x): """Implementation of `dict_bool`.""" return len(x) != 0 def none_bool(x): """Implementation of `none_bool`.""" return False def func_bool(x): """Implementation of `func_bool`.""" return True def float_floordiv(x, y): """Implementation of `float_floordiv`.""" return floor(x / y) ############# # Iteration # ############# @dataclass(frozen=True) class SequenceIterator: """ SequenceIterator is a util dataclass for iterating sequence object. Iterator to use for sequences like List, Array. """ idx: int seq: list @core(ignore_values=True) def __ms_hasnext__(self): """Whether the index is past the length of the sequence.""" return self.idx < ms_len(self.seq) @core(ignore_values=True) def __ms_next__(self): """Return the next element and a new iterator.""" return self.seq[self.idx], SequenceIterator(self.idx + 1, self.seq) def list_iter(xs): """Iterator for List.""" return SequenceIterator(0, xs) def array_iter(xs): """Iterator for Array.""" return SequenceIterator(0, xs) def tuple_next(xs): """Next tuple.""" return xs[0], tail(xs) def tuple_hasnext(xs): """Whether the tuple is empty or not.""" return len(xs) > 0 def list_next(xs): """Next list.""" return xs[0], tail(xs) def list_hasnext(xs): """Whether the list is empty or not.""" return len(xs) > 0 def list_append(self_, item): return _append(self_, item) ################# # Array methods # ################# def to_array(x): """Implementation of `to_array`.""" return x.__ms_to_array__()