| @@ -30,13 +30,14 @@ from .array_ops import (transpose, expand_dims, squeeze, rollaxis, swapaxes, res | |||
| ravel, concatenate, where, atleast_1d, atleast_2d, atleast_3d, | |||
| column_stack, hstack, dstack, vstack, stack, unique, moveaxis, | |||
| tile, broadcast_to, broadcast_arrays, roll, append, split, vsplit, | |||
| flip, flipud, fliplr, hsplit, dsplit, take_along_axis, take, repeat) | |||
| flip, flipud, fliplr, hsplit, dsplit, take_along_axis, take, repeat, | |||
| rot90, select, array_split) | |||
| from .array_creations import copy_ as copy | |||
| from .array_creations import (array, asarray, asfarray, ones, zeros, full, arange, | |||
| linspace, logspace, eye, identity, empty, empty_like, | |||
| ones_like, zeros_like, full_like, diagonal, tril, triu, | |||
| tri, trace, meshgrid, mgrid, ogrid, diagflat, | |||
| diag, diag_indices, ix_) | |||
| diag, diag_indices, ix_, indices, geomspace, vander) | |||
| from .dtypes import (int_, int8, int16, int32, int64, uint, uint8, uint16, | |||
| uint32, uint64, float_, float16, float32, float64, bool_, inf, nan, | |||
| numeric_types, PINF, NINF) | |||
| @@ -45,35 +46,51 @@ from .math_ops import (mean, inner, add, subtract, multiply, divide, true_divide | |||
| matmul, square, sqrt, reciprocal, log, maximum, heaviside, amax, amin, | |||
| hypot, float_power, floor, ptp, deg2rad, rad2deg, count_nonzero, | |||
| positive, negative, clip, floor_divide, remainder, fix, fmod, trunc, | |||
| exp, expm1, cumsum) | |||
| exp, expm1, exp2, kron, promote_types, divmod_, diff, cbrt, | |||
| cross, ceil, trapz, gcd, lcm, convolve, log1p, logaddexp, log2, | |||
| logaddexp2, log10, ediff1d, nansum, nanmean, nanvar, nanstd, cumsum, nancumsum, | |||
| sin, cos, tan, arcsin, arccos, arctan, sinh, cosh, tanh, arcsinh, arccosh, | |||
| arctanh, arctan2, cov) | |||
| from .logic_ops import (not_equal, less_equal, less, greater_equal, greater, equal, isfinite, | |||
| isnan, isinf, isposinf, isneginf, isscalar) | |||
| isnan, isinf, isposinf, isneginf, isscalar, logical_and, logical_not, | |||
| logical_or, logical_xor, in1d, isin, isclose) | |||
| mod = remainder | |||
| fabs = absolute | |||
| divmod = divmod_ # pylint: disable=redefined-builtin | |||
| abs = absolute # pylint: disable=redefined-builtin | |||
| max = amax # pylint: disable=redefined-builtin | |||
| min = amin # pylint: disable=redefined-builtin | |||
| array_ops_module = ['transpose', 'expand_dims', 'squeeze', 'rollaxis', 'swapaxes', 'reshape', | |||
| 'ravel', 'concatenate', 'where', 'atleast_1d', 'atleast_2d', 'atleast_3d', | |||
| 'column_stack', 'hstack', 'dstack', 'vstack', 'stack', 'unique', 'moveaxis', | |||
| 'tile', 'broadcast_to', 'broadcast_arrays', 'append', 'roll', 'split', 'vsplit', | |||
| 'flip', 'flipud', 'fliplr', 'hsplit', 'dsplit', 'take_along_axis', 'take', | |||
| 'repeat'] | |||
| 'repeat', 'rot90', 'select', 'array_split'] | |||
| array_creations_module = ['array', 'asarray', 'asfarray', 'ones', 'zeros', 'full', 'arange', | |||
| 'linspace', 'logspace', 'eye', 'identity', 'empty', 'empty_like', | |||
| 'ones_like', 'zeros_like', 'full_like', 'diagonal', 'tril', 'triu', | |||
| 'tri', 'trace', 'meshgrid', 'mgrid', 'ogrid', 'diagflat', 'diag', | |||
| 'diag_indices', 'ix_', 'cumsum'] | |||
| 'diag_indices', 'ix_', 'indices', 'geomspace', 'vander'] | |||
| math_module = ['mean', 'inner', 'add', 'subtract', 'multiply', 'divide', 'true_divide', 'power', | |||
| 'dot', 'outer', 'tensordot', 'absolute', 'std', 'var', 'average', 'not_equal', | |||
| 'minimum', 'matmul', 'square', 'sqrt', 'reciprocal', 'log', 'maximum', | |||
| 'heaviside', 'amax', 'amin', 'hypot', 'float_power', 'floor', 'ptp', 'deg2rad', | |||
| 'rad2deg', 'count_nonzero', 'positive', 'negative', 'clip', 'floor_divide', | |||
| 'remainder', 'mod', 'fix', 'fmod', 'trunc', 'exp', 'expm1', 'fabs', 'cumsum'] | |||
| 'remainder', 'mod', 'fix', 'fmod', 'trunc', 'exp', 'expm1', 'fabs', 'exp2', 'kron', | |||
| 'promote_types', 'divmod', 'diff', 'cbrt', 'cross', 'ceil', 'trapz', | |||
| 'abs', 'max', 'min', 'gcd', 'lcm', 'log1p', 'logaddexp', 'log2', 'logaddexp2', 'log10', | |||
| 'convolve', 'ediff1d', 'nansum', 'nanmean', 'nanvar', 'nanstd', 'cumsum', | |||
| 'nancumsum', 'sin', 'cos', 'tan', 'arcsin', 'arccos', 'arctan', 'sinh', 'cosh', 'tanh', | |||
| 'arcsinh', 'arccosh', 'arctanh', 'arctan2', 'cov'] | |||
| logic_module = ['not_equal', 'less_equal', 'less', 'greater_equal', 'greater', 'equal', 'isfinite', | |||
| 'isnan', 'isinf', 'isposinf', 'isneginf', 'isscalar'] | |||
| 'isnan', 'isinf', 'isposinf', 'isneginf', 'isscalar', 'logical_and', 'logical_not', | |||
| 'logical_or', 'logical_xor', 'in1d', 'isin', 'isclose'] | |||
| __all__ = array_ops_module + array_creations_module + math_module + logic_module + numeric_types | |||
| @@ -13,8 +13,6 @@ | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """array operations, the function docs are adapted from Numpy API.""" | |||
| from copy import deepcopy | |||
| import numpy as onp | |||
| from ..common import Tensor | |||
| @@ -27,10 +25,11 @@ from .._c_expression import Tensor as Tensor_ | |||
| from .._c_expression.typing import Float | |||
| from .utils import _check_input_for_asarray, _deep_list, _deep_tensor_to_nparray, \ | |||
| _broadcast_to_shape, _check_input_tensor, _convert_64_to_32, _get_dtype_from_scalar | |||
| _broadcast_to_shape, _check_input_tensor, _convert_64_to_32, _get_dtype_from_scalar, \ | |||
| _expand | |||
| from .utils_const import _raise_value_error, _empty, _check_axis_valid, _max, _min, \ | |||
| _check_same_type, _is_shape_empty, _check_shape, _check_dtype, _tile_size, _abs, \ | |||
| _raise_type_error, _expanded_shape, _check_is_float, _iota, \ | |||
| _raise_type_error, _expanded_shape, _tuple_getitem, _check_is_float, _iota, \ | |||
| _type_convert, _canonicalize_axis, _list_comprehensions, _ceil | |||
| from .array_ops import transpose, ravel, concatenate, broadcast_arrays, reshape, broadcast_to | |||
| from .dtypes import nan | |||
| @@ -49,9 +48,8 @@ def array(obj, dtype=None, copy=True, ndmin=0): | |||
| This function creates tensors from an array-like object. | |||
| Args: | |||
| obj (Union[int, float, bool, list, tuple, numpy.ndarray]): Input data, in | |||
| any form that can be converted to a `Tensor`. This includes lists, lists of | |||
| tuples, tuples, tuples of tuples, tuples of lists and numpy.ndarray. | |||
| obj (Union[int, float, bool, list, tuple]): Input data, in any form that | |||
| can be converted to a `Tensor`. This includes Tensor, list, tuple and numbers. | |||
| dtype (Union[:class:`mindspore.dtype`, str], optional): Designated tensor dtype, can | |||
| be in format of np.int32, or \'int32\'. If dtype is :class:`None`, the data type | |||
| of the new tensor will be inferred from obj. Default is :class:`None`. | |||
| @@ -76,17 +74,53 @@ def array(obj, dtype=None, copy=True, ndmin=0): | |||
| >>> print(np.array([1,2,3])) | |||
| [1 2 3] | |||
| """ | |||
| if ndmin > 0: | |||
| # Fall back to original numpy creation. | |||
| if isinstance(obj, Tensor): | |||
| obj = obj.asnumpy() | |||
| return asarray(onp.array(obj, dtype, copy=copy, ndmin=ndmin)) | |||
| res = asarray(obj, dtype) | |||
| if ndmin > res.ndim: | |||
| res = _expand(res, ndmin) | |||
| if copy: | |||
| res = copy_(res) | |||
| elif dtype is not None and dtype != res.dtype: | |||
| res = res.astype(dtype) | |||
| return res | |||
| @constexpr | |||
| def asarray_const(a, dtype=None): | |||
| """Converts the input to tensor. Note here `a` cannot be tensor itself.""" | |||
| _check_input_for_asarray(a) | |||
| if dtype is not None: | |||
| dtype = _check_dtype(dtype) | |||
| if isinstance(a, (float, int, bool)) and dtype is None: | |||
| dtype = _get_dtype_from_scalar(a) | |||
| if isinstance(a, (list, tuple)): | |||
| # Convert all tuple/nested tuples to lists | |||
| a = _deep_list(a) | |||
| # Convert all tensor sub-elements to numpy arrays | |||
| a = _deep_tensor_to_nparray(a) | |||
| a = onp.asarray(a) | |||
| if a.dtype is onp.dtype('object'): | |||
| raise ValueError('Input array must have the same size across all dimensions.') | |||
| # If dtype is not specified, we keep consistent with numpy decision | |||
| # only exceptions are: we use int/float32 | |||
| if dtype is None: | |||
| dtype = mstype.pytype_to_dtype(a.dtype) | |||
| if dtype == mstype.float64: | |||
| dtype = mstype.float32 | |||
| elif dtype == mstype.int64: | |||
| dtype = mstype.int32 | |||
| if not copy: | |||
| return asarray(obj, dtype=dtype) | |||
| if isinstance(a, onp.ndarray) and dtype is None: | |||
| if a.dtype is onp.dtype('object'): | |||
| raise TypeError(f"For Tensor conversion, the input_data is {a} that contains unsupported element.") | |||
| dtype = mstype.pytype_to_dtype(a.dtype) | |||
| a = Tensor.from_numpy(a) | |||
| obj = deepcopy(obj) | |||
| return asarray(obj, dtype=dtype) | |||
| return Tensor(a, dtype=dtype) | |||
| def asarray(a, dtype=None): | |||
| @@ -96,9 +130,8 @@ def asarray(a, dtype=None): | |||
| This function converts tensors from an array-like object. | |||
| Args: | |||
| a (Union[int, float, bool, list, tuple, numpy.ndarray]): Input data, in | |||
| any form that can be converted to a `Tensor`. This includes lists, lists of | |||
| tuples, tuples, tuples of tuples, tuples of lists and numpy.ndarray. | |||
| a (Union[int, float, bool, list, tuple, Tensor]): Input data, in any form that can | |||
| be converted to a `Tensor`. This includes Tensor, list, tuple and numbers. | |||
| dtype (Union[:class:`mindspore.dtype`, str], optional): Designated tensor dtype, can | |||
| be in format of np.int32, or \'int32\'. If dtype is :class:`None`, the data type | |||
| of the new tensor will be inferred from obj. Default is :class:`None`. | |||
| @@ -118,46 +151,28 @@ def asarray(a, dtype=None): | |||
| >>> print(np.asarray([1,2,3])) | |||
| [1 2 3] | |||
| """ | |||
| _check_input_for_asarray(a) | |||
| if dtype is not None: | |||
| dtype = _check_dtype(dtype) | |||
| if isinstance(a, Tensor): | |||
| if dtype is None or dtype == a.dtype: | |||
| return a | |||
| return a.astype(dtype) | |||
| return asarray_const(a, dtype) | |||
| if isinstance(a, (float, int, bool)) and dtype is None: | |||
| dtype = _get_dtype_from_scalar(a) | |||
| @constexpr | |||
| def asfarray_const(a, dtype=mstype.float32): | |||
| """Converts the input to tensor. Note here `a` cannot be tensor itself.""" | |||
| _check_input_for_asarray(a) | |||
| if isinstance(a, (list, tuple)): | |||
| # Convert all tuple/nested tuples to lists | |||
| a = _deep_list(a) | |||
| # Convert all tensor sub-elements to numpy arrays | |||
| a = _deep_tensor_to_nparray(a) | |||
| a = onp.asarray(a) | |||
| if a.dtype is onp.dtype('object'): | |||
| raise ValueError('Input array must have the same size across all dimensions.') | |||
| # If dtype is not specified, we keep consistent with numpy decision | |||
| # only exceptions are: we use int/float32 | |||
| if dtype is None: | |||
| dtype = mstype.pytype_to_dtype(a.dtype) | |||
| if dtype == mstype.float64: | |||
| dtype = mstype.float32 | |||
| elif dtype == mstype.int64: | |||
| dtype = mstype.int32 | |||
| if isinstance(a, onp.ndarray) and dtype is None: | |||
| if a.dtype is onp.dtype('object'): | |||
| raise TypeError(f"For Tensor conversion, the input_data is {a} that contains unsupported element.") | |||
| dtype = mstype.pytype_to_dtype(a.dtype) | |||
| a = Tensor.from_numpy(a) | |||
| # If a is already a tensor and we don't need to cast dtype, return a | |||
| if isinstance(a, Tensor): | |||
| if dtype is None or dtype == a.dtype: | |||
| return a | |||
| return Tensor(a, dtype=dtype) | |||
| asarray_const = constexpr(asarray) | |||
| return Tensor(a, dtype) | |||
| def asfarray(a, dtype=mstype.float32): | |||
| @@ -167,9 +182,8 @@ def asfarray(a, dtype=mstype.float32): | |||
| If non-float dtype is defined, this function will return a float32 tensor instead. | |||
| Args: | |||
| a (Union[int, float, bool, list, tuple, numpy.ndarray]): Input data, in | |||
| any form that can be converted to a `Tensor`. This includes lists, lists of | |||
| tuples, tuples, tuples of tuples, tuples of lists and numpy.ndarray. | |||
| a (Union[int, float, bool, list, tuple, Tensor]): Input data, in any form that can | |||
| be converted to a `Tensor`. This includes Tensor, list, tuple and numbers. | |||
| dtype (Union[:class:`mindspore.dtype`, str], optional): Designated tensor dtype, can | |||
| be in format of np.int32, or \'int32\'. If dtype is :class:`None`, the data type | |||
| of the new tensor will be inferred from `a`. Default is :class:`mindspore.float32`. | |||
| @@ -190,27 +204,18 @@ def asfarray(a, dtype=mstype.float32): | |||
| >>> print(np.asfarray([1,2,3])) | |||
| [1. 2. 3.] | |||
| """ | |||
| _check_input_for_asarray(a) | |||
| if dtype is None: | |||
| return asarray(a) | |||
| dtype = _check_dtype(dtype) | |||
| if dtype not in (mstype.float16, mstype.float32, mstype.float64): | |||
| # pylint: disable=consider-using-in | |||
| if dtype != mstype.float16 and dtype != mstype.float32 and dtype != mstype.float64: | |||
| dtype = mstype.float32 | |||
| if isinstance(a, (list, tuple)): | |||
| # Convert all tuple/nested tuples to lists | |||
| a = _deep_list(a) | |||
| # Convert all tensor sub-elements to numpy arrays | |||
| a = _deep_tensor_to_nparray(a) | |||
| a = onp.asarray(a) | |||
| if a.dtype is onp.dtype('object'): | |||
| raise TypeError(f"For Tensor conversion, the input_data is {a} that contains unsupported element.") | |||
| if isinstance(a, onp.ndarray): | |||
| a = Tensor.from_numpy(a) | |||
| if isinstance(a, Tensor): | |||
| return a.astype(dtype) | |||
| return Tensor(a, dtype) | |||
| return asfarray_const(a) | |||
| def copy_(a): | |||
| @@ -218,9 +223,8 @@ def copy_(a): | |||
| Returns a tensor copy of the given object. | |||
| Args: | |||
| a (Union[int, float, bool, list, tuple, numpy.ndarray]): Input data, in | |||
| any form that can be converted to a tensor. This includes lists, lists of | |||
| tuples, tuples, tuples of tuples, tuples of lists and numpy.ndarray. | |||
| a (Union[int, float, bool, list, tuple, Tensor]): Input data, in any form that can | |||
| be converted to a `Tensor`. This includes Tensor, list, tuple and numbers. | |||
| Returns: | |||
| Tensor, has the same data as `a`. | |||
| @@ -241,8 +245,16 @@ def copy_(a): | |||
| """ | |||
| if not isinstance(a, Tensor): | |||
| a = asarray_const(a) | |||
| return a.copy() | |||
| # The current implementation registers a new memory location for copied tensor by | |||
| # doing some reduandent operations. | |||
| origin_dtype = a.dtype | |||
| if origin_dtype == mstype.bool_: | |||
| return F.logical_not(F.logical_not(a)) | |||
| if origin_dtype != mstype.float64: | |||
| a = a.astype("float32") | |||
| a = a / ones_like(a) | |||
| a = a.astype(origin_dtype) | |||
| return a | |||
| def ones(shape, dtype=mstype.float32): | |||
| """ | |||
| @@ -566,6 +578,65 @@ def logspace(start, stop, num=50, endpoint=True, base=10.0, dtype=None, axis=0): | |||
| return F.tensor_pow(base, linspace_res).astype(dtype) | |||
| def geomspace(start, stop, num=50, endpoint=True, dtype=None, axis=0): | |||
| """ | |||
| Returns numbers spaced evenly on a log scale (a geometric progression). | |||
| This is similar to logspace, but with endpoints specified directly. Each output sample | |||
| is a constant multiple of the previous. | |||
| Args: | |||
| start (Union[int, list(int), tuple(int), tensor]): The starting value of the sequence. | |||
| stop (Union[int, list(int), tuple(int), tensor]): The final value of the sequence, | |||
| unless endpoint is False. In that case, num + 1 values are spaced over the | |||
| interval in log-space, of which all but the last (a sequence of length num) are | |||
| returned. | |||
| num (int, optional): Number of samples to generate. Default is 50. | |||
| endpoint (bool, optional): If True, `stop` is the last sample. Otherwise, it is | |||
| not included. Default is True. | |||
| dtype (Union[:class:`mindspore.dtype`, str], optional): Designated tensor dtype, can | |||
| be in format of np.float32, or `float32`.If `dtype` is None, infer the data | |||
| type from other input arguments. Default is None. | |||
| axis (int, optional): The axis in the result to store the samples. Relevant | |||
| only if start or stop is array-like. By default (0), the samples will | |||
| be along a new axis inserted at the beginning. Use -1 to get an axis at the end. | |||
| Default is 0. | |||
| Returns: | |||
| Tensor, with samples equally spaced on a log scale. | |||
| Raises: | |||
| TypeError: If input arguments have types not specified above. | |||
| Supported Platforms: | |||
| ``Ascend`` ``GPU`` ``CPU`` | |||
| Examples: | |||
| >>> output = np.geomspace(1, 256, num=9) | |||
| >>> print(output) | |||
| [ 1. 2. 4. 8. 16. 32. 64. 128. 256.] | |||
| >>> output = np.geomspace(1, 256, num=8, endpoint=False) | |||
| >>> print(output) | |||
| [ 1. 2. 4. 8. 16. 32. 64. 128.] | |||
| """ | |||
| start, stop, num, endpoint, dtype, axis = _type_checking_for_xspace(start, stop, num, endpoint, dtype, axis) | |||
| root = num | |||
| if endpoint: | |||
| root -= 1 | |||
| bases = F.tensor_pow(F.tensor_div(stop, start), asarray_const(1/(root))) | |||
| exponents = linspace(zeros(F.shape(bases)), F.fill(F.dtype(bases), F.shape(bases), root), | |||
| num, endpoint=endpoint, dtype=dtype, axis=axis) | |||
| shape = F.shape(bases) | |||
| axis = axis + F.rank(bases) + 1 if axis < 0 else axis | |||
| expanded_shape = _tuple_getitem(shape, axis, False) + (1,) + _tuple_getitem(shape, axis) | |||
| bases = F.reshape(bases, expanded_shape) | |||
| start = F.reshape(start, expanded_shape) | |||
| res = F.tensor_mul(F.tensor_pow(bases, exponents), start) | |||
| if dtype is not None: | |||
| res = F.cast(res, dtype) | |||
| return res | |||
| def eye(N, M=None, k=0, dtype=mstype.float32): | |||
| """ | |||
| Returns a 2-D tensor with ones on the diagnoal and zeros elsewhere. | |||
| @@ -757,7 +828,7 @@ def empty_like(prototype, dtype=None, shape=None): | |||
| Examples: | |||
| >>> import mindspore.numpy as np | |||
| >>> a = [[(1, 2)], np.ones((1, 2)), [[2, 3]], np.ones((1, 2))] | |||
| >>> a = np.ones((4,1,2)) | |||
| >>> output = np.empty_like(a) | |||
| >>> print(output) | |||
| # result may vary | |||
| @@ -794,7 +865,7 @@ def ones_like(a, dtype=None, shape=None): | |||
| Examples: | |||
| >>> import mindspore.numpy as np | |||
| >>> a = [[(1, 2)], np.ones((1, 2)), [[2, 3]], np.ones((1, 2))] | |||
| >>> a = np.ones((4,1,2)) | |||
| >>> output = np.ones_like(a) | |||
| >>> print(output) | |||
| [[[1. 1.]] | |||
| @@ -832,7 +903,7 @@ def zeros_like(a, dtype=None, shape=None): | |||
| Examples: | |||
| >>> import mindspore.numpy as np | |||
| >>> a = [[(1, 2)], np.ones((1, 2)), [[2, 3]], np.ones((1, 2))] | |||
| >>> a = np.ones((4,1,2)) | |||
| >>> output = np.zeros_like(a) | |||
| >>> print(output) | |||
| [[[0. 0.]] | |||
| @@ -871,7 +942,7 @@ def full_like(a, fill_value, dtype=None, shape=None): | |||
| Examples: | |||
| >>> import mindspore.numpy as np | |||
| >>> a = [[(1, 2)], np.ones((1, 2)), [[2, 3]], np.ones((1, 2))] | |||
| >>> a = np.ones((4,1,2)) | |||
| >>> output = np.full_like(a, 0.5) | |||
| >>> print(output) | |||
| [[[0.5 0.5]] | |||
| @@ -1175,9 +1246,8 @@ def _index(i, size, Cartesian=True): | |||
| if Cartesian: | |||
| if i == 1: | |||
| return 0 | |||
| if i == 0: | |||
| if size >= 2: | |||
| return 1 | |||
| if i == 0 and size >= 2: | |||
| return 1 | |||
| return i | |||
| @@ -1630,3 +1700,103 @@ def ix_(*args): | |||
| return _raise_value_error('Cross index must be 1 dimensional') | |||
| res += (F.reshape(arr, _expanded_shape(ndim, arr.size, i)),) | |||
| return res | |||
| def vander(x, N=None, increasing=False): | |||
| """ | |||
| Generates a Vandermonde matrix. | |||
| The columns of the output matrix are powers of the input vector. The order of | |||
| the powers is determined by the increasing boolean argument. Specifically, when | |||
| increasing is `False`, the i-th output column is the input vector raised element-wise | |||
| to the power of :math:`N - i - 1`. Such a matrix with a geometric progression in each row | |||
| is named for Alexandre-Theophile Vandermonde. | |||
| Args: | |||
| x (Union[list, tuple, Tensor]): 1-D input array. | |||
| N (int, optional): Number of columns in the output. If N is not specified, a | |||
| square array is returned (``N = len(x)``). | |||
| increasing (bool, optional): Order of the powers of the columns. If True, the | |||
| powers increase from left to right, if False (the default) they are reversed. | |||
| Returns: | |||
| Vandermonde matrix. If `increasing` is `False`, the first column is :math:`x^{(N-1)}`, | |||
| the second :math:`x^{(N-2)}` and so forth. If `increasing` is `True`, the columns are | |||
| :math:`x^0, x^1, ..., x^{(N-1)}`. | |||
| Raises: | |||
| TypeError: If inputs have types not specified above. | |||
| ValueError: If `x` is not 1-D, or `N` < 0. | |||
| Supported Platforms: | |||
| ``Ascend`` ``GPU`` ``CPU`` | |||
| Examples: | |||
| >>> import mindspore.numpy as np | |||
| >>> print(np.vander([1,2,3,4,5])) | |||
| [[ 1 1 1 1 1] | |||
| [ 16 8 4 2 1] | |||
| [ 81 27 9 3 1] | |||
| [256 64 16 4 1] | |||
| [625 125 25 5 1]] | |||
| """ | |||
| if isinstance(x, (list, tuple)): | |||
| x = asarray_const(x) | |||
| elif not isinstance(x, Tensor): | |||
| _raise_type_error("Input x must be list, tuple or Tensor, but got ", x) | |||
| if x.ndim != 1: | |||
| _raise_value_error("Input x must be 1-D, but got dimension=", x.ndim) | |||
| N = N or x.size | |||
| if not isinstance(N, int): | |||
| _raise_type_error("Input N must be an integer.") | |||
| if N <= 0: | |||
| _raise_value_error("Input N must > 0.") | |||
| if not isinstance(increasing, bool): | |||
| _raise_type_error("increasing must be a bool.") | |||
| exponent = _iota(x.dtype, N, increasing) | |||
| x = F.expand_dims(x, 1) | |||
| exponent = F.expand_dims(exponent, 0) | |||
| return F.tensor_pow(x, exponent) | |||
| def indices(dimensions, dtype=mstype.int32, sparse=False): | |||
| """ | |||
| Returns an array representing the indices of a grid. | |||
| Computes an array where the subarrays contain index values 0, 1, … | |||
| varying only along the corresponding axis. | |||
| Args: | |||
| dimensions (tuple or list of ints): The shape of the grid. | |||
| dtype (data type, optional): Data type of the result. | |||
| sparse (boolean, optional): Defaults to False. Return a sparse | |||
| representation of the grid instead of a dense representation. | |||
| Returns: | |||
| Tensor or tuple of Tensor, If `sparse` is False, returns one array | |||
| of grid indices, ``grid.shape = (len(dimensions),) + tuple(dimensions)``. | |||
| If sparse is True, returns a tuple of arrays, with | |||
| ``grid[i].shape = (1, ..., 1, dimensions[i], 1, ..., 1)`` with | |||
| ``dimensions[i]`` in the `ith` place | |||
| Raises: | |||
| TypeError: if input dimensions is not a tuple or list. | |||
| Supported Platforms: | |||
| ``Ascend`` ``GPU`` ``CPU`` | |||
| Examples: | |||
| >>> grid = np.indices((2, 3)) | |||
| >>> print(indices) | |||
| [Tensor(shape=[2, 3], dtype=Int32, value= | |||
| [[0, 0, 0], | |||
| [1, 1, 1]]), Tensor(shape=[2, 3], dtype=Int32, value= | |||
| [[0, 1, 2], | |||
| [0, 1, 2]])] | |||
| """ | |||
| if not isinstance(dimensions, (tuple, list)): | |||
| _raise_type_error('Shape of the grid must be tuple or list') | |||
| grids = () | |||
| for d in dimensions: | |||
| grids += (arange(d, dtype=dtype),) | |||
| return meshgrid(*grids, sparse=sparse, indexing='ij') | |||
| @@ -24,62 +24,19 @@ from ..ops.primitive import constexpr | |||
| from ..nn import Cell | |||
| from .utils import _convert_list_tensor_to_tuple_tensor, _expand, _broadcast_to_shape, \ | |||
| _check_input_tensor, _broadcast_to | |||
| _check_input_tensor, _broadcast_to, _to_tensor | |||
| from .utils_const import _check_axes_range, _check_start_normalize, \ | |||
| _raise_type_error, _raise_value_error, _infer_out_shape, _empty, _promote, \ | |||
| _check_same_type, _check_axis_valid, _add_unit_axes, _broadcast_tuples, \ | |||
| _check_is_float, _check_axis_in_range, _check_axis_type, _canonicalize_axis, \ | |||
| _list_comprehensions, _check_element_int, _is_shape_empty, _type_convert, \ | |||
| _tuple_getitem, _expanded_shape | |||
| _tuple_getitem, _expanded_shape, _seq_prod, _get_device, _tuple_setitem | |||
| # According to official numpy reference, the dimension of a numpy array must be less | |||
| # than 32 | |||
| MAX_NUMPY_DIMS = 32 | |||
| @constexpr | |||
| def _prepare_shape_for_expand_dims(shape, axes): | |||
| """ | |||
| Creates the expanded new shape based on the shape and given axes | |||
| Args: | |||
| shape (tuple): the shape of the tensor | |||
| axes Union(int, tuple(int), list(int)): the axes with dimensions expanded. | |||
| Returns: | |||
| new_shape(tuple): the shape with dimensions expanded. | |||
| """ | |||
| new_shape = [] | |||
| shape_idx = 0 | |||
| new_shape_length = len(shape) | |||
| # Convert to set | |||
| if isinstance(axes, int): | |||
| new_shape_length += 1 | |||
| if axes >= new_shape_length or axes < -new_shape_length: | |||
| raise ValueError(f"axis {axes} is out of bounds for tensor of dimension {new_shape_length}") | |||
| axes = {axes} | |||
| elif isinstance(axes, (list, tuple)): | |||
| new_shape_length += len(axes) | |||
| for axis in axes: | |||
| if axis >= new_shape_length or axis < -new_shape_length: | |||
| raise ValueError(f"axis {axis} is out of bounds for tensor of dimension {new_shape_length}") | |||
| axes = set(axes) | |||
| else: | |||
| raise TypeError(f"only int, tuple and list are allowed for axes, but got {type(axes)}") | |||
| for new_shape_idx in range(new_shape_length): | |||
| if new_shape_idx in axes or new_shape_idx - new_shape_length in axes: | |||
| new_shape.append(1) | |||
| else: | |||
| new_shape.append(shape[shape_idx]) | |||
| shape_idx += 1 | |||
| return tuple(new_shape) | |||
| def expand_dims(a, axis): | |||
| """ | |||
| Expands the shape of a tensor. | |||
| @@ -109,10 +66,15 @@ def expand_dims(a, axis): | |||
| (1, 2, 2) | |||
| """ | |||
| _check_input_tensor(a) | |||
| shape = F.shape(a) | |||
| # yield expanded shape based on the axes | |||
| new_shape = _prepare_shape_for_expand_dims(shape, axis) | |||
| return F.reshape(a, new_shape) | |||
| if not isinstance(axis, (int, tuple, list)): | |||
| _raise_type_error("axis must be tuple, list or int, but got ", axis) | |||
| if isinstance(axis, int): | |||
| return F.expand_dims(a, axis) | |||
| ndim = a.ndim + len(axis) | |||
| axis = _canonicalize_axis(axis, ndim) | |||
| for ax in axis: | |||
| a = F.expand_dims(a, ax) | |||
| return a | |||
| def squeeze(a, axis=None): | |||
| @@ -1091,6 +1053,9 @@ def roll(a, shift, axis=None): | |||
| Returns: | |||
| Tensor, with the same shape as a. | |||
| Supported Platforms: | |||
| ``Ascend`` ``GPU`` ``CPU`` | |||
| Raises: | |||
| TypeError: If input arguments have types not specified above. | |||
| ValueError: If axis exceeds `a.ndim`, or `shift` and `axis` cannot broadcast. | |||
| @@ -1212,12 +1177,6 @@ def moveaxis(a, source, destination): | |||
| return F.transpose(a, perm) | |||
| @constexpr | |||
| def _seq_prod(seq1, seq2): | |||
| """Returns the element-wise product of seq1 and seq2.""" | |||
| return tuple(map(lambda x, y: x*y, seq1, seq2)) | |||
| def tile(a, reps): | |||
| """ | |||
| Constructs an array by repeating `a` the number of times given by `reps`. | |||
| @@ -1355,6 +1314,60 @@ def broadcast_arrays(*args): | |||
| return res | |||
| def array_split(x, indices_or_sections, axis=0): | |||
| """ | |||
| Splits a tensor into multiple sub-tensors. | |||
| Note: | |||
| Currently, array_split only supports :class:`mindspore.float32` on ``CPU``. | |||
| The only difference between ``np.split`` and ``np.array_split`` is that | |||
| ``np.array_split`` allows indices_or_sections to be an integer that does not | |||
| equally divide the axis. For a tensor of length l that should be split into | |||
| n sections, it returns :math:`l % n` sub-arrays of size :math:`l//n + 1` and | |||
| the rest of size :math:`l//n`. | |||
| Args: | |||
| x (Tensor): A Tensor to be divided. | |||
| indices_or_sections (Union[int, tuple(int), list(int)]): | |||
| If integer, :math:`N`, the tensor will be divided into | |||
| :math:`N` tensors along axis. | |||
| If tuple(int), list(int) or of sorted integers, | |||
| the entries indicate where along axis the array is split. | |||
| For example, :math:`[2, 3]` would, for :math:`axis=0`, result in | |||
| three sub-tensors :math:`x[:2]`, :math:`x[2:3]`and :math:`x[3:]`. | |||
| If an index exceeds the dimension of the array along axis, | |||
| an empty sub-array is returned correspondingly. | |||
| axis (int): The axis along which to split. Default: 0. | |||
| Returns: | |||
| A list of sub-tensors. | |||
| Raises: | |||
| TypeError: If argument `indices_or_sections` is not integer, | |||
| tuple(int) or list(int) or argument `axis` is not integer. | |||
| ValueError: If argument `axis` is out of range of :math:`[-x.ndim, x.ndim)`. | |||
| Supported Platforms: | |||
| ``Ascend`` ``GPU`` ``CPU`` | |||
| Examples: | |||
| >>> import mindspore.numpy as np | |||
| >>> input_x = np.arange(9).astype("float32") | |||
| >>> output = np.array_split(input_x, 4) | |||
| >>> print(output) | |||
| (Tensor(shape=[3], dtype=Float32, | |||
| value= [ 0.00000000e+00, 1.00000000e+00, 2.00000000e+00]), | |||
| Tensor(shape=[2], dtype=Float32, | |||
| value= [ 3.00000000e+00, 4.00000000e+00]), | |||
| Tensor(shape=[2], dtype=Float32, | |||
| value= [ 5.00000000e+00, 6.00000000e+00]), | |||
| Tensor(shape=[2], dtype=Float32, | |||
| value= [ 7.00000000e+00, 8.00000000e+00])) | |||
| """ | |||
| return _split(x, indices_or_sections, opname="array_split", axis=axis) | |||
| def split(x, indices_or_sections, axis=0): | |||
| """ | |||
| Splits a tensor into multiple sub-tensors along the given axis. | |||
| @@ -1380,9 +1393,12 @@ def split(x, indices_or_sections, axis=0): | |||
| tuple(int) or list(int) or argument `axis` is not integer. | |||
| ValueError: If argument `axis` is out of range of :math:`[-x.ndim, x.ndim)`. | |||
| Supported Platforms: | |||
| ``Ascend`` ``GPU`` ``CPU`` | |||
| Examples: | |||
| >>> import mindspore.numpy as np | |||
| >>> input_x = np.arange(9).astype('float32') | |||
| >>> input_x = np.arange(9).astype("float32") | |||
| >>> output = np.split(input_x, 3) | |||
| >>> print(output) | |||
| (Tensor(shape=[3], dtype=Float32, | |||
| @@ -1392,13 +1408,32 @@ def split(x, indices_or_sections, axis=0): | |||
| Tensor(shape=[3], dtype=Float32, | |||
| value= [ 6.00000000e+00, 7.00000000e+00, 8.00000000e+00])) | |||
| """ | |||
| return _split(x, indices_or_sections, opname="split", axis=axis) | |||
| def _split(x, indices_or_sections, opname, axis=0): | |||
| """Splits a tensor based on ``np.split`` or ``np.array_split``.""" | |||
| _check_input_tensor(x) | |||
| _ = _check_axis_type(axis, True, False, False) | |||
| axis = _canonicalize_axis(axis, x.ndim) | |||
| res = None | |||
| arr_shape = x.shape | |||
| length_along_dim = arr_shape[axis] | |||
| if isinstance(indices_or_sections, int): | |||
| _split = P.Split(axis, indices_or_sections) | |||
| res = _split(x) | |||
| if opname == "split" or length_along_dim % indices_or_sections == 0: | |||
| res = P.Split(axis, indices_or_sections)(x) | |||
| else: | |||
| num_long_tensor = length_along_dim % indices_or_sections | |||
| num_short_tensor = indices_or_sections - num_long_tensor | |||
| length1 = num_long_tensor * (length_along_dim // indices_or_sections + 1) | |||
| length2 = length_along_dim - length1 | |||
| start1 = _list_comprehensions(F.rank(x), 0, True) | |||
| size1 = _tuple_setitem(arr_shape, axis, length1) | |||
| start2 = _tuple_setitem(start1, axis, length1) | |||
| size2 = _tuple_setitem(arr_shape, axis, length2) | |||
| res = P.Split(axis, num_long_tensor)(F.tensor_slice(x, start1, size1)) + \ | |||
| P.Split(axis, num_short_tensor)(F.tensor_slice(x, start2, size2)) | |||
| elif isinstance(indices_or_sections, (list, tuple)) and _check_element_int(indices_or_sections): | |||
| res = _split_sub_tensors(x, indices_or_sections, axis) | |||
| else: | |||
| @@ -1921,7 +1956,6 @@ def repeat(a, repeats, axis=None): | |||
| if repeats == 0: | |||
| return _empty(F.dtype(a), (0,)) | |||
| return C.repeat_elements(a, repeats, axis) | |||
| shape = F.shape(a) | |||
| size = shape[axis] | |||
| if len(repeats) != size: | |||
| @@ -1932,3 +1966,144 @@ def repeat(a, repeats, axis=None): | |||
| if rep != 0: | |||
| repeated_subs.append(C.repeat_elements(sub, rep, axis)) | |||
| return concatenate(repeated_subs, axis) | |||
| def rot90(a, k=1, axes=(0, 1)): | |||
| """ | |||
| Rotates a tensor by 90 degrees in the plane specified by axes. | |||
| Rotation direction is from the first towards the second axis. | |||
| Args: | |||
| a (Tensor): Input tensor of two or more dimensions. | |||
| k (int): Number of times the tensor is rotated by 90 degrees. Default: 1. | |||
| axes (Union[tuple(int), list(int)]): The tensor is rotated in the plane | |||
| defined by the axes. Default: `(0, 1)`. | |||
| Axes must be different and with the shape of `(2,)`. | |||
| Returns: | |||
| Tensor. | |||
| Raises: | |||
| TypeError: if input `a` is not a Tensor or | |||
| the argument `k` is not integer or | |||
| the argument `axes` is not tuple of ints or list of ints. | |||
| ValueError: if any axis is out of range or | |||
| the length of `axes` is not `2`. | |||
| Supported Platforms: | |||
| ``Ascend`` ``GPU`` ``CPU`` | |||
| Examples: | |||
| >>> import mindspore.numpy as np | |||
| >>> a = np.arange(24).reshape((2, 3, 4)) | |||
| >>> output = np.rot90(a) | |||
| >>> print(output) | |||
| [[[ 8 9 10 11] | |||
| [20 21 22 23]] | |||
| [[ 4 5 6 7] | |||
| [16 17 18 19]] | |||
| [[ 0 1 2 3] | |||
| [12 13 14 15]]] | |||
| >>> output = np.rot90(a, 3, (1, 2)) | |||
| >>> print(output) | |||
| [[[ 8 4 0] | |||
| [ 9 5 1] | |||
| [10 6 2] | |||
| [11 7 3]] | |||
| [[20 16 12] | |||
| [21 17 13] | |||
| [22 18 14] | |||
| [23 19 15]]] | |||
| """ | |||
| _check_input_tensor(a) | |||
| if not isinstance(k, int): | |||
| _raise_type_error("integer argument expected, but got ", k) | |||
| k = k % 4 if k >= 0 else 4 - (-k % 4) | |||
| if not isinstance(axes, (tuple, list)): | |||
| _raise_type_error("tuple(ints) or list(ints) expected, but got ", axes) | |||
| if len(axes) != 2: | |||
| _raise_value_error("len(axes) must be 2.") | |||
| axis1, axis2 = axes[0], axes[1] | |||
| axis1 = _canonicalize_axis(axis1, a.ndim) | |||
| axis2 = _canonicalize_axis(axis2, a.ndim) | |||
| if axis1 == axis2: | |||
| _raise_value_error('Axes must be different.') | |||
| if k == 0: | |||
| return a | |||
| if k == 2: | |||
| return flip(flip(a, axis1), axis2) | |||
| perm = _list_comprehensions(a.ndim) | |||
| perm[axis1], perm[axis2] = perm[axis2], perm[axis1] | |||
| if k == 1: | |||
| return flip(transpose(a, perm), axis1) | |||
| return flip(transpose(a, perm), axis2) | |||
| def select(condlist, choicelist, default=0): | |||
| """ | |||
| Returns an array drawn from elements in `choicelist`, depending on conditions. | |||
| Args: | |||
| condlist (array_like): The list of conditions which determine from which array | |||
| in `choicelist` the output elements are taken. When multiple conditions are | |||
| satisfied, the first one encountered in `condlist` is used. | |||
| choicelist (array_like): The list of arrays from which the output elements are | |||
| taken. It has to be of the same length as `condlist`. | |||
| default (scalar, optional): The element inserted in output when all conditions | |||
| evaluate to `False`. | |||
| Returns: | |||
| Tensor, the output at position `m` is the `m-th` element of the array in | |||
| `choicelist` where the `m-th` element of the corresponding array in `condlist` | |||
| is `True`. | |||
| Supported Platforms: | |||
| ``Ascend`` ``GPU`` ``CPU`` | |||
| Raises: | |||
| ValueError: if ``len(condlist) != len(choicelist)``. | |||
| Examples: | |||
| >>> condlist = [[True, True, True, False, False], | |||
| [False, False, True, False, True]] | |||
| >>> choicelist = [[0, 1, 2, 3, 4], [0, 1, 4, 9, 16]] | |||
| >>> output = np.select(condlist, choicelist) | |||
| >>> print(output) | |||
| [ 0 1 2 0 16] | |||
| """ | |||
| condlist, choicelist = _to_tensor(condlist, choicelist) | |||
| shape_cond = F.shape(condlist) | |||
| shape_choice = F.shape(choicelist) | |||
| if F.rank(condlist) == 0 or F.rank(condlist) == 0: | |||
| _raise_value_error('input cannot be scalars') | |||
| case_num = shape_cond[0] | |||
| if shape_choice[0] != case_num: | |||
| _raise_value_error('list of cases must be same length as list of conditions') | |||
| # performs broadcast over the cases in condlist and choicelist | |||
| case_size = _infer_out_shape(shape_cond[1:], shape_choice[1:]) | |||
| shape_broadcasted = (case_num,) + case_size | |||
| ndim = len(shape_broadcasted) | |||
| shape_cond_expanded = ((case_num,) + _list_comprehensions(ndim - F.rank(condlist), 1, True) + | |||
| shape_cond[1:]) | |||
| condlist = _broadcast_to_shape(F.reshape(condlist, shape_cond_expanded), shape_broadcasted) | |||
| shape_choice_expanded = ((case_num,) + _list_comprehensions(ndim - F.rank(choicelist), 1, True) + | |||
| shape_choice[1:]) | |||
| choicelist = _broadcast_to_shape(F.reshape(choicelist, shape_choice_expanded), shape_broadcasted) | |||
| slice_start = _list_comprehensions(ndim - 1, 0, True) | |||
| slice_size = (1,) + case_size | |||
| dtype = F.dtype(choicelist) | |||
| if _get_device() == 'CPU' and not _check_is_float(dtype): | |||
| # F.tensor_slice only supports float on CPU | |||
| choicelist = F.cast(choicelist, mstype.float32) | |||
| default_slice = F.fill(F.dtype(choicelist), slice_size, default) | |||
| for i in range(case_num - 1, -1, -1): | |||
| cond_slice = F.tensor_slice(condlist.astype(mstype.float32), (i,) + slice_start, slice_size) | |||
| choice_slice = F.tensor_slice(choicelist, (i,) + slice_start, slice_size) | |||
| default_slice = F.select(cond_slice.astype(mstype.bool_), choice_slice, default_slice) | |||
| return F.reshape(default_slice, (case_size)).astype(dtype) | |||
| @@ -169,3 +169,16 @@ promotion_rule = { | |||
| (bool_, float32): float32, | |||
| (bool_, float64): float64, | |||
| } | |||
| rule_for_trigonometric = {float16: float16, | |||
| float32: float32, | |||
| float64: float64, | |||
| int8: float16, | |||
| int16: float32, | |||
| int32: float32, | |||
| int64: float32, | |||
| uint8: float16, | |||
| uint16: float32, | |||
| uint32: float32, | |||
| uint64: float32, | |||
| bool_: float16} | |||
| @@ -15,33 +15,29 @@ | |||
| """logical operations, the function docs are adapted from Numpy API.""" | |||
| from .math_ops import _apply_tensor_op | |||
| from ..ops import functional as F | |||
| from ..ops.primitive import constexpr | |||
| from ..common import dtype as mstype | |||
| from ..common import Tensor | |||
| from .._c_expression import typing | |||
| from .array_creations import zeros, ones | |||
| from .utils import _check_input_tensor | |||
| from .math_ops import _apply_tensor_op, absolute | |||
| from .array_creations import zeros, ones, empty | |||
| from .utils import _check_input_tensor, _to_tensor, _isnan | |||
| from .utils_const import _raise_type_error, _is_shape_empty, _infer_out_shape | |||
| def not_equal(x1, x2, out=None, where=True, dtype=None): | |||
| def not_equal(x1, x2, dtype=None): | |||
| """ | |||
| Returns (x1 != x2) element-wise. | |||
| Note: | |||
| Numpy arguments `out`, `where`, `casting`, `order`, `subok`, `signature`, | |||
| and `extobj` are not supported. | |||
| Args: | |||
| x1 (Tensor): First input tensor to be compared. | |||
| x2 (Tensor): Second input tensor to be compared. | |||
| out (Tensor or None, optional), default is None. | |||
| where (Tensor or None, optional): For any non-default value of type other | |||
| than :class:`Tensor` or :class:`None`, the output retains its original value. | |||
| This condition is broadcasted over the input. At locations where the | |||
| condition is `True`, the out array will be set to the ufunc result. | |||
| Elsewhere, the out array will retain its original value. Note that | |||
| if an uninitialized out array is created via the default ``out=None``, | |||
| locations within it where the condition is `False` will remain | |||
| uninitialized. | |||
| dtype (:class:`mindspore.dtype`, optional): defaults to None. Overrides the dtype of the | |||
| output Tensor. | |||
| @@ -65,33 +61,21 @@ def not_equal(x1, x2, out=None, where=True, dtype=None): | |||
| [False True]] | |||
| """ | |||
| _check_input_tensor(x1, x2) | |||
| return _apply_tensor_op(F.not_equal, x1, x2, out=out, where=where, dtype=dtype) | |||
| return _apply_tensor_op(F.not_equal, x1, x2, dtype=dtype) | |||
| def less_equal(x1, x2, out=None, where=True, dtype=None): | |||
| def less_equal(x1, x2, dtype=None): | |||
| """ | |||
| Returns the truth value of ``(x1 <= x2)`` element-wise. | |||
| Note: | |||
| Numpy arguments `casting`, `order`, `dtype`, `subok`, `signature`, and `extobj` are | |||
| not supported. | |||
| When `where` is provided, `out` must have a tensor value. `out` is not supported | |||
| for storing the result, however it can be used in combination with `where` to set | |||
| the value at indices for which `where` is set to False. | |||
| Numpy arguments `out`, `where`, `casting`, `order`, `subok`, `signature`, | |||
| and `extobj` are not supported. | |||
| Args: | |||
| x1 (Tensor): Input array. | |||
| x2 (Tensor): Input array. If ``x1.shape != x2.shape``, they must be | |||
| broadcastable to a common shape (which becomes the shape of the output). | |||
| out (Tensor or None, optional): defaults to None. | |||
| where (Tensor or None, optional): For any non-default value of type other | |||
| than :class:`Tensor` or :class:`None`, the output retains its original value. | |||
| This condition is broadcasted over the input. At locations where the | |||
| condition is `True`, the out array will be set to the ufunc result. | |||
| Elsewhere, the out array will retain its original value. Note that | |||
| if an uninitialized out array is created via the default ``out=None``, | |||
| locations within it where the condition is `False` will remain | |||
| uninitialized. | |||
| dtype (:class:`mindspore.dtype`, optional): defaults to None. Overrides the dtype of the | |||
| output Tensor. | |||
| @@ -113,33 +97,21 @@ def less_equal(x1, x2, out=None, where=True, dtype=None): | |||
| [False True True] | |||
| """ | |||
| _check_input_tensor(x1, x2) | |||
| return _apply_tensor_op(F.tensor_le, x1, x2, out=out, where=where, dtype=dtype) | |||
| return _apply_tensor_op(F.tensor_le, x1, x2, dtype=dtype) | |||
| def less(x1, x2, out=None, where=True, dtype=None): | |||
| def less(x1, x2, dtype=None): | |||
| """ | |||
| Returns the truth value of ``(x1 < x2)`` element-wise. | |||
| Note: | |||
| Numpy arguments `casting`, `order`, `dtype`, `subok`, `signature`, and `extobj` are | |||
| not supported. | |||
| When `where` is provided, `out` must have a tensor value. `out` is not supported | |||
| for storing the result, however it can be used in combination with `where` to set | |||
| the value at indices for which `where` is set to False. | |||
| Numpy arguments `out`, `where`, `casting`, `order`, `subok`, `signature`, | |||
| and `extobj` are not supported. | |||
| Args: | |||
| x1 (Tensor): input array. | |||
| x2 (Tensor): Input array. If ``x1.shape != x2.shape``, they must be | |||
| broadcastable to a common shape (which becomes the shape of the output). | |||
| out (Tensor or None, optional): defaults to None. | |||
| where (Tensor or None, optional): For any non-default value of type other | |||
| than :class:`Tensor` or :class:`None`, the output retains its original value. | |||
| This condition is broadcasted over the input. At locations where the | |||
| condition is `True`, the out array will be set to the ufunc result. | |||
| Elsewhere, the out array will retain its original value. Note that | |||
| if an uninitialized out array is created via the default ``out=None``, | |||
| locations within it where the condition is `False` will remain | |||
| uninitialized. | |||
| dtype (:class:`mindspore.dtype`, optional): defaults to None. Overrides the dtype of the | |||
| output Tensor. | |||
| @@ -160,33 +132,21 @@ def less(x1, x2, out=None, where=True, dtype=None): | |||
| >>> print(output) | |||
| [ True False] | |||
| """ | |||
| return _apply_tensor_op(F.tensor_lt, x1, x2, out=out, where=where, dtype=dtype) | |||
| return _apply_tensor_op(F.tensor_lt, x1, x2, dtype=dtype) | |||
| def greater_equal(x1, x2, out=None, where=True, dtype=None): | |||
| def greater_equal(x1, x2, dtype=None): | |||
| """ | |||
| Returns the truth value of ``(x1 >= x2)`` element-wise. | |||
| Note: | |||
| Numpy arguments `casting`, `order`, `dtype`, `subok`, `signature`, and `extobj` are | |||
| not supported. | |||
| When `where` is provided, `out` must have a tensor value. `out` is not supported | |||
| for storing the result, however it can be used in combination with `where` to set | |||
| the value at indices for which `where` is set to False. | |||
| Numpy arguments `out`, `where`, `casting`, `order`, `subok`, `signature`, | |||
| and `extobj` are not supported. | |||
| Args: | |||
| x1 (Tensor): Input array. | |||
| x2 (Tensor): Input array. If ``x1.shape != x2.shape``, they must be | |||
| broadcastable to a common shape (which becomes the shape of the output). | |||
| out (Tensor or None, optional): defaults to None. | |||
| where (Tensor or None, optional): For any non-default value of type other | |||
| than :class:`Tensor` or :class:`None`, the output retains its original value. | |||
| This condition is broadcasted over the input. At locations where the | |||
| condition is `True`, the out array will be set to the ufunc result. | |||
| Elsewhere, the out array will retain its original value. Note that | |||
| if an uninitialized out array is created via the default ``out=None``, | |||
| locations within it where the condition is `False` will remain | |||
| uninitialized. | |||
| dtype (:class:`mindspore.dtype`, optional): defaults to None. Overrides the dtype of the | |||
| output Tensor. | |||
| @@ -207,33 +167,21 @@ def greater_equal(x1, x2, out=None, where=True, dtype=None): | |||
| >>> print(output) | |||
| [ True True False] | |||
| """ | |||
| return _apply_tensor_op(F.tensor_ge, x1, x2, out=out, where=where, dtype=dtype) | |||
| return _apply_tensor_op(F.tensor_ge, x1, x2, dtype=dtype) | |||
| def greater(x1, x2, out=None, where=True, dtype=None): | |||
| def greater(x1, x2, dtype=None): | |||
| """ | |||
| Returns the truth value of ``(x1 > x2)`` element-wise. | |||
| Note: | |||
| Numpy arguments `casting`, `order`, `dtype`, `subok`, `signature`, and `extobj` are | |||
| not supported. | |||
| When `where` is provided, `out` must have a tensor value. `out` is not supported | |||
| for storing the result, however it can be used in combination with `where` to set | |||
| the value at indices for which `where` is set to False. | |||
| Numpy arguments `out`, `where`, `casting`, `order`, `subok`, `signature`, | |||
| and `extobj` are not supported. | |||
| Args: | |||
| x1 (Tensor): Input array. | |||
| x2 (Tensor): Input array. If ``x1.shape != x2.shape``, they must be | |||
| broadcastable to a common shape (which becomes the shape of the output). | |||
| out (Tensor or None, optional): defaults to None. | |||
| where (Tensor or None, optional): For any non-default value of type other | |||
| than :class:`Tensor` or :class:`None`, the output retains its original value. | |||
| This condition is broadcasted over the input. At locations where the | |||
| condition is `True`, the out array will be set to the ufunc result. | |||
| Elsewhere, the out array will retain its original value. Note that | |||
| if an uninitialized out array is created via the default ``out=None``, | |||
| locations within it where the condition is `False` will remain | |||
| uninitialized. | |||
| dtype (:class:`mindspore.dtype`, optional): defaults to None. Overrides the dtype of the | |||
| output Tensor. | |||
| @@ -254,33 +202,21 @@ def greater(x1, x2, out=None, where=True, dtype=None): | |||
| >>> print(output) | |||
| [ True False] | |||
| """ | |||
| return _apply_tensor_op(F.tensor_gt, x1, x2, out=out, where=where, dtype=dtype) | |||
| return _apply_tensor_op(F.tensor_gt, x1, x2, dtype=dtype) | |||
| def equal(x1, x2, out=None, where=True, dtype=None): | |||
| def equal(x1, x2, dtype=None): | |||
| """ | |||
| Returns the truth value of ``(x1 == x2)`` element-wise. | |||
| Note: | |||
| Numpy arguments `casting`, `order`, `dtype`, `subok`, `signature`, and `extobj` are | |||
| not supported. | |||
| When `where` is provided, `out` must have a tensor value. `out` is not supported | |||
| for storing the result, however it can be used in combination with `where` to set | |||
| the value at indices for which `where` is set to False. | |||
| Numpy arguments `out`, `where`, `casting`, `order`, `subok`, `signature`, | |||
| and `extobj` are not supported. | |||
| Args: | |||
| x1 (Tensor): Input array. | |||
| x2 (Tensor): Input array. If ``x1.shape != x2.shape``, they must be | |||
| broadcastable to a common shape (which becomes the shape of the output). | |||
| out (Tensor or None, optional): defaults to None. | |||
| where (Tensor or None, optional): For any non-default value of type other | |||
| than :class:`Tensor` or :class:`None`, the output retains its original value. | |||
| This condition is broadcasted over the input. At locations where the | |||
| condition is `True`, the out array will be set to the ufunc result. | |||
| Elsewhere, the out array will retain its original value. Note that | |||
| if an uninitialized out array is created via the default ``out=None``, | |||
| locations within it where the condition is `False` will remain | |||
| uninitialized. | |||
| dtype (:class:`mindspore.dtype`, optional): defaults to None. Overrides the dtype of the | |||
| output Tensor. | |||
| @@ -301,34 +237,22 @@ def equal(x1, x2, out=None, where=True, dtype=None): | |||
| >>> print(output) | |||
| [ True True False] | |||
| """ | |||
| return _apply_tensor_op(F.equal, x1, x2, out=out, where=where, dtype=dtype) | |||
| return _apply_tensor_op(F.equal, x1, x2, dtype=dtype) | |||
| def isfinite(x, out=None, where=True, dtype=None): | |||
| def isfinite(x, dtype=None): | |||
| """ | |||
| Tests element-wise for finiteness (not infinity or not Not a Number). | |||
| The result is returned as a boolean array. | |||
| Note: | |||
| Numpy arguments `casting`, `order`, `dtype`, `subok`, `signature`, and `extobj` are | |||
| not supported. | |||
| When `where` is provided, `out` must have a tensor value. `out` is not supported | |||
| for storing the result, however it can be used in combination with `where` to set | |||
| the value at indices for which `where` is set to False. | |||
| Numpy arguments `out`, `where`, `casting`, `order`, `subok`, `signature`, | |||
| and `extobj` are not supported. | |||
| On GPU, the supported dtypes are np.float16, and np.float32. | |||
| Args: | |||
| x (Tensor): Input values. | |||
| out (Tensor or None, optional): defaults to None. | |||
| where (Tensor or None, optional): For any non-default value of type other | |||
| than :class:`Tensor` or :class:`None`, the output retains its original value. | |||
| This condition is broadcasted over the input. At locations where the | |||
| condition is `True`, the out array will be set to the ufunc result. | |||
| Elsewhere, the out array will retain its original value. Note that | |||
| if an uninitialized out array is created via the default ``out=None``, | |||
| locations within it where the condition is `False` will remain | |||
| uninitialized. | |||
| dtype (:class:`mindspore.dtype`, optional): defaults to None. Overrides the dtype of the | |||
| output Tensor. | |||
| @@ -351,37 +275,20 @@ def isfinite(x, out=None, where=True, dtype=None): | |||
| >>> print(output) | |||
| False | |||
| """ | |||
| return _apply_tensor_op(F.isfinite, x, out=out, where=where, dtype=dtype) | |||
| def _isnan(x): | |||
| """Computes isnan without applying keyword arguments.""" | |||
| return F.not_equal(x, x) | |||
| return _apply_tensor_op(F.isfinite, x, dtype=dtype) | |||
| def isnan(x, out=None, where=True, dtype=None): | |||
| def isnan(x, dtype=None): | |||
| """ | |||
| Tests element-wise for NaN and return result as a boolean array. | |||
| Note: | |||
| Numpy arguments `casting`, `order`, `dtype`, `subok`, `signature`, and `extobj` are | |||
| not supported. | |||
| When `where` is provided, `out` must have a tensor value. `out` is not supported | |||
| for storing the result, however it can be used in combination with `where` to set | |||
| the value at indices for which `where` is set to False. | |||
| Numpy arguments `out`, `where`, `casting`, `order`, `subok`, `signature`, | |||
| and `extobj` are not supported. | |||
| Only np.float32 is currently supported. | |||
| Args: | |||
| x (Tensor): Input values. | |||
| out (Tensor or None, optional): defaults to None. | |||
| where (Tensor or None, optional): For any non-default value of type other | |||
| than :class:`Tensor` or :class:`None`, the output retains its original value. | |||
| This condition is broadcasted over the input. At locations where the | |||
| condition is `True`, the out array will be set to the ufunc result. | |||
| Elsewhere, the out array will retain its original value. Note that | |||
| if an uninitialized out array is created via the default ``out=None``, | |||
| locations within it where the condition is `False` will remain | |||
| uninitialized. | |||
| dtype (:class:`mindspore.dtype`, optional): defaults to None. Overrides the dtype of the | |||
| output Tensor. | |||
| @@ -404,7 +311,7 @@ def isnan(x, out=None, where=True, dtype=None): | |||
| >>> print(output) | |||
| False | |||
| """ | |||
| return _apply_tensor_op(_isnan, x, out=out, where=where, dtype=dtype) | |||
| return _apply_tensor_op(_isnan, x, dtype=dtype) | |||
| def _isinf(x): | |||
| @@ -419,31 +326,19 @@ def _isinf(x): | |||
| return F.cast(res, mstype.bool_) | |||
| def isinf(x, out=None, where=True, dtype=None): | |||
| def isinf(x, dtype=None): | |||
| """ | |||
| Tests element-wise for positive or negative infinity. | |||
| Returns a boolean array of the same shape as `x`, True where ``x == +/-inf``, otherwise False. | |||
| Note: | |||
| Numpy arguments `casting`, `order`, `dtype`, `subok`, `signature`, and `extobj` are | |||
| not supported. | |||
| When `where` is provided, `out` must have a tensor value. `out` is not supported | |||
| for storing the result, however it can be used in combination with `where` to set | |||
| the value at indices for which `where` is set to False. | |||
| Numpy arguments `out`, `where`, `casting`, `order`, `subok`, `signature`, | |||
| and `extobj` are not supported. | |||
| Only np.float32 is currently supported. | |||
| Args: | |||
| x (Tensor): Input values. | |||
| out (Tensor or None, optional): defaults to None. | |||
| where (Tensor or None, optional): For any non-default value of type other | |||
| than :class:`Tensor` or :class:`None`, the output retains its original value. | |||
| This condition is broadcasted over the input. At locations where the | |||
| condition is `True`, the out array will be set to the ufunc result. | |||
| Elsewhere, the out array will retain its original value. Note that | |||
| if an uninitialized out array is created via the default ``out=None``, | |||
| locations within it where the condition is `False` will remain | |||
| uninitialized. | |||
| dtype (:class:`mindspore.dtype`, optional): defaults to None. Overrides the dtype of the | |||
| output Tensor. | |||
| @@ -466,7 +361,7 @@ def isinf(x, out=None, where=True, dtype=None): | |||
| >>> print(output) | |||
| [ True True False False] | |||
| """ | |||
| return _apply_tensor_op(_isinf, x, out=out, where=where, dtype=dtype) | |||
| return _apply_tensor_op(_isinf, x, dtype=dtype) | |||
| def _is_sign_inf(x, fn): | |||
| @@ -562,7 +457,7 @@ def isscalar(element): | |||
| element (any): Input argument, can be of any type and shape. | |||
| Returns: | |||
| Boolean, True if `element` is a scalar type, False if it is not. | |||
| Boolean, True if `element` is a scalar type, False if it is not. | |||
| Raises: | |||
| TypeError: if the type of `element` is not supported by mindspore parser. | |||
| @@ -587,3 +482,302 @@ def isscalar(element): | |||
| """ | |||
| obj_type = F.typeof(element) | |||
| return not isinstance(obj_type, Tensor) and _isscalar(obj_type) | |||
| def isclose(a, b, rtol=1e-05, atol=1e-08, equal_nan=False): | |||
| """ | |||
| Returns a boolean tensor where two tensors are element-wise equal within a tolerance. | |||
| The tolerance values are positive, typically very small numbers. The relative | |||
| difference (:math:`rtol * abs(b)`) and the absolute difference `atol` are added together | |||
| to compare against the absolute difference between `a` and `b`. | |||
| Note: | |||
| For finite values, isclose uses the following equation to test whether two | |||
| floating point values are equivalent. | |||
| :math:`absolute(a - b) <= (atol + rtol * absolute(b))` | |||
| Args: | |||
| a (Union[Tensor, list, tuple]): Input first tensor to compare. | |||
| b (Union[Tensor, list, tuple]): Input second tensor to compare. | |||
| rtol (Number): The relative tolerance parameter (see Note). | |||
| atol (Number): The absolute tolerance parameter (see Note). | |||
| equal_nan (bool): Whether to compare ``NaN`` as equal. If True, ``NaN`` in | |||
| `a` will be considered equal to ``NaN`` in `b` in the output tensor. | |||
| Returns: | |||
| A ``bool`` tensor of where `a` and `b` are equal within the given tolerance. | |||
| Raises: | |||
| TypeError: If inputs have types not specified above. | |||
| Supported Platforms: | |||
| ``GPU`` ``CPU`` | |||
| Examples: | |||
| >>> a = np.array([0,1,2,float('inf'),float('inf'),float('nan')]) | |||
| >>> b = np.array([0,1,-2,float('-inf'),float('inf'),float('nan')]) | |||
| >>> print(np.isclose(a, b)) | |||
| [ True True False False True False] | |||
| >>> print(np.isclose(a, b, equal_nan=True)) | |||
| [ True True False False True True] | |||
| """ | |||
| a, b = _to_tensor(a, b) | |||
| if not isinstance(rtol, (int, float, bool)) or not isinstance(atol, (int, float, bool)): | |||
| _raise_type_error("rtol and atol are expected to be numbers.") | |||
| if not isinstance(equal_nan, bool): | |||
| _raise_type_error("equal_nan is expected to be bool.") | |||
| if _is_shape_empty(a.shape) or _is_shape_empty(b.shape): | |||
| return empty(_infer_out_shape(a.shape, b.shape), dtype=mstype.bool_) | |||
| rtol = _to_tensor(rtol).astype("float32") | |||
| atol = _to_tensor(atol).astype("float32") | |||
| res = absolute(a - b) <= (atol + rtol * absolute(b)) | |||
| # infs are treated as equal | |||
| a_posinf = isposinf(a) | |||
| b_posinf = isposinf(b) | |||
| a_neginf = isneginf(a) | |||
| b_neginf = isneginf(b) | |||
| same_inf = F.logical_or(F.logical_and(a_posinf, b_posinf), F.logical_and(a_neginf, b_neginf)) | |||
| diff_inf = F.logical_or(F.logical_and(a_posinf, b_neginf), F.logical_and(a_neginf, b_posinf)) | |||
| res = F.logical_and(F.logical_or(res, same_inf), F.logical_not(diff_inf)) | |||
| both_nan = F.logical_and(_isnan(a), _isnan(b)) | |||
| if equal_nan: | |||
| res = F.logical_or(both_nan, res) | |||
| else: | |||
| res = F.logical_and(F.logical_not(both_nan), res) | |||
| return res | |||
| def in1d(ar1, ar2, invert=False): | |||
| """ | |||
| Tests whether each element of a 1-D array is also present in a second array. | |||
| Returns a boolean array the same length as `ar1` that is True where an element | |||
| of `ar1` is in `ar2` and False otherwise. | |||
| Note: | |||
| Numpy argument `assume_unique` is not supported since the implementation does | |||
| not rely on the uniqueness of the input arrays. | |||
| Args: | |||
| ar1 (array_like): Input array with shape `(M,)`. | |||
| ar2 (array_like): The values against which to test each value of `ar1`. | |||
| invert (boolean, optional): If True, the values in the returned array are | |||
| inverted (that is, False where an element of `ar1` is in `ar2` and True | |||
| otherwise). Default is False. | |||
| Returns: | |||
| Tensor, with shape `(M,)`. The values ``ar1[in1d]`` are in `ar2`. | |||
| Supported Platforms: | |||
| ``Ascend`` ``GPU`` ``CPU`` | |||
| Examples: | |||
| >>> test = np.array([0, 1, 2, 5, 0]) | |||
| >>> states = [0, 2] | |||
| >>> mask = np.in1d(test, states) | |||
| >>> print(mask) | |||
| [ True False True False True] | |||
| >>> mask = np.in1d(test, states, invert=True) | |||
| >>> print(mask) | |||
| [False True False True False] | |||
| """ | |||
| ar1, ar2 = _to_tensor(ar1, ar2) | |||
| ar1 = F.expand_dims(ar1.ravel(), -1) | |||
| ar2 = ar2.ravel() | |||
| included = F.equal(ar1, ar2) | |||
| # F.reduce_sum only supports float | |||
| res = F.reduce_sum(included.astype(mstype.float32), -1).astype(mstype.bool_) | |||
| if invert: | |||
| res = F.equal(res, _to_tensor(False)) | |||
| return res | |||
| def isin(element, test_elements, invert=False): | |||
| """ | |||
| Calculates element in `test_elements`, broadcasting over `element` only. Returns a | |||
| boolean array of the same shape as `element` that is True where an element of | |||
| `element` is in `test_elements` and False otherwise. | |||
| Note: | |||
| Numpy argument `assume_unique` is not supported since the implementation does | |||
| not rely on the uniqueness of the input arrays. | |||
| Args: | |||
| element (array_like): Input array. | |||
| test_elements (array_like): The values against which to test each value of | |||
| `element`. | |||
| invert (boolean, optional): If True, the values in the returned array are | |||
| inverted, as if calculating `element` not in `test_elements`. Default is False. | |||
| Returns: | |||
| Tensor, has the same shape as `element`. The values ``element[isin]`` are in | |||
| `test_elements`. | |||
| Supported Platforms: | |||
| ``Ascend`` ``GPU`` ``CPU`` | |||
| Examples: | |||
| >>> element = 2*np.arange(4).reshape((2, 2)) | |||
| >>> test_elements = [1, 2, 4, 8] | |||
| >>> mask = np.isin(element, test_elements) | |||
| >>> print(mask) | |||
| [[False True] | |||
| [ True False]] | |||
| >>> mask = np.isin(element, test_elements, invert=True) | |||
| >>> print(mask) | |||
| [[ True False] | |||
| [False True]] | |||
| """ | |||
| res = in1d(element, test_elements, invert=invert) | |||
| return F.reshape(res, F.shape(element)) | |||
| def logical_not(a, dtype=None): | |||
| """ | |||
| Computes the truth value of NOT `a` element-wise. | |||
| Note: | |||
| Numpy arguments `out`, `where`, `casting`, `order`, `subok`, `signature`, and `extobj` are | |||
| not supported. | |||
| Args: | |||
| a (Tensor): The input tensor whose dtype is bool. | |||
| dtype (:class:`mindspore.dtype`, optional): Default: :class:`None`. Overrides the dtype of the | |||
| output Tensor. | |||
| Returns: | |||
| Tensor or scalar. | |||
| Boolean result with the same shape as `a` of the NOT operation on elements of `a`. | |||
| This is a scalar if `a` is a scalar. | |||
| Raises: | |||
| TypeError: if the input is not a tensor or its dtype is not bool. | |||
| Supported Platforms: | |||
| ``Ascend`` ``GPU`` ``CPU`` | |||
| Examples: | |||
| >>> import mindspore.numpy as np | |||
| >>> a = np.array([True, False]) | |||
| >>> output = np.logical_not(a) | |||
| >>> print(output) | |||
| [False True] | |||
| """ | |||
| return _apply_tensor_op(F.logical_not, a, dtype=dtype) | |||
| def logical_or(x1, x2, dtype=None): | |||
| """ | |||
| Computes the truth value of `x1` OR `x2` element-wise. | |||
| Note: | |||
| Numpy arguments `out`, `where`, `casting`, `order`, `subok`, `signature`, | |||
| and `extobj` are not supported. | |||
| Args: | |||
| x1 (Tensor): Input tensor. | |||
| x2 (Tensor): Input tensor. If ``x1.shape != x2.shape``, they must be | |||
| broadcastable to a common shape (which becomes the shape of the output). | |||
| dtype (:class:`mindspore.dtype`, optional): defaults to None. Overrides the dtype of the | |||
| output Tensor. | |||
| Returns: | |||
| Tensor or scalar, element-wise comparison of `x1` and `x2`. Typically of type | |||
| bool, unless ``dtype=object`` is passed. This is a scalar if both `x1` and `x2` are | |||
| scalars. | |||
| Raises: | |||
| TypeError: if the input is not a tensor. | |||
| Supported Platforms: | |||
| ``Ascend`` ``GPU`` ``CPU`` | |||
| Examples: | |||
| >>> import mindspore.numpy as np | |||
| >>> x1 = np.array([True, False]) | |||
| >>> x2 = np.array([False, True]) | |||
| >>> output = np.logical_or(x1, x2) | |||
| >>> print(output) | |||
| [ True True] | |||
| """ | |||
| return _apply_tensor_op(F.logical_or, x1, x2, dtype=dtype) | |||
| def logical_and(x1, x2, dtype=None): | |||
| """ | |||
| Computes the truth value of `x1` AND `x2` element-wise. | |||
| Note: | |||
| Numpy arguments `out`, `where`, `casting`, `order`, `subok`, `signature`, | |||
| and `extobj` are not supported. | |||
| Args: | |||
| x1 (Tensor): Input tensor. | |||
| x2 (Tensor): Input tensor. If ``x1.shape != x2.shape``, they must be | |||
| broadcastable to a common shape (which becomes the shape of the output). | |||
| dtype (:class:`mindspore.dtype`, optional): defaults to None. Overrides the dtype of the | |||
| output Tensor. | |||
| Returns: | |||
| Tensor or scalar. | |||
| Boolean result of the logical AND operation applied to the elements of `x1` and `x2`; | |||
| the shape is determined by broadcasting. This is a scalar if both `x1` and `x2` are scalars. | |||
| Raises: | |||
| TypeError: if the input is not a tensor. | |||
| Supported Platforms: | |||
| ``Ascend`` ``GPU`` ``CPU`` | |||
| Examples: | |||
| >>> import mindspore.numpy as np | |||
| >>> x1 = np.array([True, False]) | |||
| >>> x2 = np.array([False, False]) | |||
| >>> output = np.logical_and(x1, x2) | |||
| >>> print(output) | |||
| [False False] | |||
| """ | |||
| return _apply_tensor_op(F.logical_and, x1, x2, dtype=dtype) | |||
| def logical_xor(x1, x2, dtype=None): | |||
| """ | |||
| Computes the truth value of `x1` XOR `x2`, element-wise. | |||
| Note: | |||
| Numpy arguments `out`, `where`, `casting`, `order`, `subok`, `signature`, | |||
| and `extobj` are not supported. | |||
| Args: | |||
| x1 (Tensor): Input tensor. | |||
| x2 (Tensor): Input tensor. If ``x1.shape != x2.shape``, they must be | |||
| broadcastable to a common shape (which becomes the shape of the output). | |||
| dtype (:class:`mindspore.dtype`, optional): defaults to None. Overrides the dtype of the | |||
| output Tensor. | |||
| Returns: | |||
| Tensor or scalar. | |||
| Boolean result of the logical AND operation applied to the elements of `x1` and `x2`; | |||
| the shape is determined by broadcasting. This is a scalar if both `x1` and `x2` are scalars. | |||
| Raises: | |||
| TypeError: if the input is not a tensor. | |||
| Supported Platforms: | |||
| ``Ascend`` ``GPU`` ``CPU`` | |||
| Examples: | |||
| >>> import mindspore.numpy as np | |||
| >>> x1 = np.array([True, False]) | |||
| >>> x2 = np.array([False, False]) | |||
| >>> output = np.logical_xor(x1, x2) | |||
| >>> print(output) | |||
| [True False] | |||
| """ | |||
| _check_input_tensor(x1) | |||
| _check_input_tensor(x2) | |||
| y1 = F.logical_or(x1, x2) | |||
| y2 = F.logical_or(F.logical_not(x1), F.logical_not(x2)) | |||
| return _apply_tensor_op(F.logical_and, y1, y2, dtype=dtype) | |||
| @@ -13,14 +13,11 @@ | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """internal utility functions""" | |||
| import numpy as onp | |||
| from ..common import Tensor | |||
| from ..ops import functional as F | |||
| from ..common import dtype as mstype | |||
| from .utils_const import _tile_size, _add_unit_axes, _raise_type_error | |||
| from .utils_const import _tile_size, _add_unit_axes, _raise_type_error, _type_convert | |||
| def _deep_list(array_like): | |||
| @@ -56,9 +53,8 @@ def _deep_tensor_to_nparray(array_like): | |||
| def _check_input_for_asarray(array_like): | |||
| """check whether array_like argument is a valid type for np.asarray conversion""" | |||
| if not isinstance(array_like, (Tensor, list, tuple, int, float, bool, onp.ndarray)): | |||
| _raise_type_error("input data must be `int`, `float`, `bool`, `Tensor`, `list`, `tuple`" + \ | |||
| "or numpy.ndarray, but got ", array_like) | |||
| if not isinstance(array_like, (Tensor, list, tuple, int, float, bool)): | |||
| _raise_type_error("input data must be `int`, `float`, `bool`, `Tensor`, `list`, `tuple`, but got ", array_like) | |||
| def _is_scalar(shape): | |||
| @@ -121,6 +117,20 @@ def _convert_64_to_32(tensor): | |||
| return tensor | |||
| def _to_tensor(*args): | |||
| """Returns each input as Tensor""" | |||
| res = () | |||
| for arg in args: | |||
| if isinstance(arg, (int, float, bool, list, tuple)): | |||
| arg = _convert_64_to_32(_type_convert(Tensor, arg)) | |||
| elif not isinstance(arg, Tensor): | |||
| _raise_type_error("Expect input to be array like.") | |||
| res += (arg,) | |||
| if len(res) == 1: | |||
| return res[0] | |||
| return res | |||
| def _get_dtype_from_scalar(*input_numbers): | |||
| """ | |||
| Get the final dtype from series of input numbers, compared with F.typeof, we | |||
| @@ -139,3 +149,8 @@ def _get_dtype_from_scalar(*input_numbers): | |||
| if int_flag: | |||
| return mstype.int32 | |||
| return mstype.float32 | |||
| def _isnan(x): | |||
| """Computes isnan.""" | |||
| return F.not_equal(x, x) | |||
| @@ -14,7 +14,8 @@ | |||
| # ============================================================================ | |||
| """internal graph-compatible utility functions""" | |||
| import math | |||
| from functools import partial | |||
| from itertools import zip_longest | |||
| from collections import deque | |||
| import mindspore.context as context | |||
| from ..ops import functional as F | |||
| @@ -24,7 +25,7 @@ from ..common import Tensor | |||
| from .._c_expression import Tensor as Tensor_ | |||
| from .._c_expression import typing | |||
| from .dtypes import promotion_rule, dtype_tuple, all_types, dtype_map | |||
| from .dtypes import promotion_rule, dtype_tuple, all_types, dtype_map, rule_for_trigonometric | |||
| @constexpr | |||
| @@ -110,44 +111,19 @@ def _get_device(): | |||
| return context.get_context('device_target') | |||
| @constexpr | |||
| def _reverse_index(idx, arr): | |||
| """ | |||
| Returns 1 if shape[idx:] is broadcastable to shape_out[idx:], | |||
| 2 situations if the function returns 1: | |||
| - 1. Tensor's shape has 1 at the designated dimension. | |||
| - 2. Tensor's dimension is less than the designated idx. (The Tensor shape | |||
| has been reversed) | |||
| For both cases, 2 tensors are broadcastable. | |||
| otherwise returns the element at position of shape | |||
| """ | |||
| if len(arr) <= idx: | |||
| return 1 | |||
| return arr[-1 - idx] | |||
| @constexpr | |||
| def _infer_out_shape(*shapes): | |||
| """ | |||
| Returns shape of output after broadcasting | |||
| Raises ValueError if shape1 and shape2 cannot be broadcast | |||
| Returns shape of output after broadcasting. Raises ValueError if shapes cannot be broadcast. | |||
| """ | |||
| shapes_unbroadcastable = False | |||
| ndim_max = max(map(len, shapes)) | |||
| shape_out = [0]*ndim_max | |||
| i = 0 | |||
| for i in range(ndim_max): | |||
| shape_out[-1 - i] = max(map(partial(_reverse_index, i), shapes)) | |||
| for shape in shapes: | |||
| if _reverse_index(i, shape) != shape_out[-1 - i]: | |||
| if _reverse_index(i, shape) != 1: | |||
| shapes_unbroadcastable = True | |||
| break | |||
| if shapes_unbroadcastable: | |||
| break | |||
| if not shapes_unbroadcastable: | |||
| return tuple(shape_out) | |||
| raise ValueError(f'operands could not be broadcast together with shapes {*shapes,}') | |||
| shape_out = deque() | |||
| reversed_shapes = map(reversed, shapes) | |||
| for items in zip_longest(*reversed_shapes, fillvalue=1): | |||
| max_size = 0 if 0 in items else max(items) | |||
| if any(item not in (1, max_size) for item in items): | |||
| raise ValueError(f'operands could not be broadcast together with shapes {*shapes,}') | |||
| shape_out.appendleft(max_size) | |||
| return tuple(shape_out) | |||
| @constexpr | |||
| @@ -228,6 +204,21 @@ def _raise_value_error(info, param=None): | |||
| raise ValueError(info + f"{param}") | |||
| @constexpr | |||
| def _raise_runtime_error(info, param=None): | |||
| """ | |||
| Raise RuntimeError in both graph/pynative mode | |||
| Args: | |||
| info(str): info string to display | |||
| param(python obj): any object that can be recognized by graph mode. If is | |||
| not None, then param's value information will be extracted and displayed. | |||
| Default is None. | |||
| """ | |||
| if param is None: | |||
| raise RuntimeError(info) | |||
| raise RuntimeError(info + f"{param}") | |||
| @constexpr | |||
| def _empty(dtype, shape): | |||
| """Returns an uninitialized array with dtype and shape.""" | |||
| @@ -242,6 +233,9 @@ def _promote(dtype1, dtype2): | |||
| return promotion_rule[dtype1, dtype2] | |||
| return promotion_rule[dtype2, dtype1] | |||
| @constexpr | |||
| def _promote_for_trigonometric(dtype): | |||
| return rule_for_trigonometric[dtype] | |||
| @constexpr | |||
| def _max(*args): | |||
| @@ -315,7 +309,7 @@ def _canonicalize_axis(axis, ndim): | |||
| axis = tuple([canonicalizer(axis) for axis in axis]) | |||
| if all(axis.count(el) <= 1 for el in axis): | |||
| return axis if len(axis) > 1 else axis[0] | |||
| return tuple(sorted(axis)) if len(axis) > 1 else axis[0] | |||
| raise ValueError(f"duplicate axes in {axis}.") | |||
| @@ -426,13 +420,37 @@ def _tuple_getitem(tup, idx, startswith=True): | |||
| @constexpr | |||
| def _iota(dtype, num): | |||
| def _tuple_setitem(tup, idx, value): | |||
| """ | |||
| Returns a tuple with specified `idx` set to `value`. | |||
| """ | |||
| tup = list(tup) | |||
| tup[idx] = value | |||
| return tuple(tup) | |||
| @constexpr | |||
| def _iota(dtype, num, increasing=True): | |||
| """Creates a 1-D tensor with value: [0,1,...num-1] and dtype.""" | |||
| # TODO: Change to P.Linspace when the kernel is implemented on CPU. | |||
| return Tensor(list(range(int(num))), dtype) | |||
| if increasing: | |||
| return Tensor(list(range(int(num))), dtype) | |||
| return Tensor(list(range(int(num)-1, -1, -1)), dtype) | |||
| @constexpr | |||
| def _ceil(number): | |||
| """Ceils the number in graph mode.""" | |||
| return math.ceil(number) | |||
| @constexpr | |||
| def _seq_prod(seq1, seq2): | |||
| """Returns the element-wise product of seq1 and seq2.""" | |||
| return tuple(map(lambda x, y: x*y, seq1, seq2)) | |||
| @constexpr | |||
| def _make_tensor(val, dtype): | |||
| """ Returns the tensor with value `val` and dtype `dtype`.""" | |||
| return Tensor(val, dtype) | |||
| @@ -15,6 +15,7 @@ | |||
| """Implementation for internal polymorphism `not equal` operations.""" | |||
| from . import _constexpr_utils as const_utils | |||
| from ...composite import base | |||
| from ... import functional as F | |||
| @@ -41,6 +42,21 @@ def _not_equal_scalar(x, y): | |||
| return not F.scalar_eq(x, y) | |||
| @not_equal.register("mstype", "mstype") | |||
| def _not_equal_mstype(x, y): | |||
| """ | |||
| Determine if two mindspore types are not equal. | |||
| Args: | |||
| x (mstype): first input mindspore type. | |||
| y (mstype): second input mindspore type. | |||
| Returns: | |||
| bool, if x != y return true, x == y return false. | |||
| """ | |||
| return not const_utils.mstype_eq(x, y) | |||
| @not_equal.register("String", "String") | |||
| def _not_equal_string(x, y): | |||
| """ | |||
| @@ -77,6 +77,7 @@ floormod = tensor_mod | |||
| tensor_exp = P.Exp() | |||
| exp = tensor_exp | |||
| tensor_expm1 = P.Expm1() | |||
| tensor_slice = P.Slice() | |||
| strided_slice = P.StridedSlice() | |||
| same_type_shape = P.SameTypeShape() | |||
| check_bprop = P.CheckBprop() | |||
| @@ -94,6 +95,22 @@ tensor_slice = P.Slice() | |||
| maximum = P.Maximum() | |||
| minimum = P.Minimum() | |||
| floor = P.Floor() | |||
| logical_not = P.LogicalNot() | |||
| logical_or = P.LogicalOr() | |||
| logical_and = P.LogicalAnd() | |||
| sin = P.Sin() | |||
| cos = P.Cos() | |||
| tan = P.Tan() | |||
| asin = P.Asin() | |||
| acos = P.ACos() | |||
| atan = P.Atan() | |||
| sinh = P.Sinh() | |||
| cosh = P.Cosh() | |||
| tanh = P.Tanh() | |||
| asinh = P.Asinh() | |||
| acosh = P.Acosh() | |||
| atanh = P.Atanh() | |||
| atan2 = P.Atan2() | |||
| scalar_to_array = P.ScalarToArray() | |||
| scalar_to_tensor = P.ScalarToTensor() | |||
| @@ -2560,7 +2560,7 @@ class Acosh(PrimitiveWithInfer): | |||
| TypeError: If `input_x` is not a Tensor. | |||
| Supported Platforms: | |||
| ``Ascend`` ``GPU`` ``CPU`` | |||
| ``Ascend`` ``GPU`` | |||
| Examples: | |||
| >>> acosh = ops.Acosh() | |||
| @@ -2637,7 +2637,7 @@ class Asinh(PrimitiveWithInfer): | |||
| TypeError: If `input_x` is not a Tensor. | |||
| Supported Platforms: | |||
| ``Ascend`` ``GPU`` ``CPU`` | |||
| ``Ascend`` ``GPU`` | |||
| Examples: | |||
| >>> asinh = ops.Asinh() | |||
| @@ -20,7 +20,7 @@ import numpy as onp | |||
| import mindspore.numpy as mnp | |||
| from .utils import rand_int, rand_bool, match_array, match_res, match_meta, \ | |||
| match_all_arrays | |||
| match_all_arrays, run_multi_test, to_tensor | |||
| class Cases(): | |||
| @@ -40,8 +40,8 @@ class Cases(): | |||
| self.array_sets = [1, 1.1, True, [1, 0, True], [1, 1.0, 2], (1,), | |||
| [(1, 2, 3), (4, 5, 6)], onp.random.random( # pylint: disable=no-member | |||
| (100, 100)).astype(onp.float32), | |||
| onp.random.random((100, 100)).astype(onp.bool)] | |||
| (100, 100)).astype(onp.float32).tolist(), | |||
| onp.random.random((100, 100)).astype(onp.bool).tolist()] | |||
| self.arrs = [ | |||
| rand_int(2), | |||
| @@ -138,8 +138,8 @@ def test_asarray(): | |||
| expected = mnp.asarray(array, test_case.mnp_dtypes[i]).asnumpy() | |||
| match_array(actual, expected, error=7) | |||
| # Additional tests for nested tensor/numpy_array mixture | |||
| mnp_input = [(onp.ones(3,), mnp.ones(3)), [[1, 1, 1], (1, 1, 1)]] | |||
| # Additional tests for nested tensor mixture | |||
| mnp_input = [(mnp.ones(3,), mnp.ones(3)), [[1, 1, 1], (1, 1, 1)]] | |||
| onp_input = [(onp.ones(3,), onp.ones(3)), [[1, 1, 1], (1, 1, 1)]] | |||
| actual = onp.asarray(onp_input) | |||
| @@ -168,11 +168,11 @@ def test_array(): | |||
| assert arr4 is arr5 | |||
| # Additional tests for nested tensor/numpy_array mixture | |||
| mnp_input = [(onp.ones(3,), mnp.ones(3)), [[1, 1, 1], (1, 1, 1)]] | |||
| mnp_input = [(mnp.ones(3,), mnp.ones(3)), [[1, 1, 1], (1, 1, 1)]] | |||
| onp_input = [(onp.ones(3,), onp.ones(3)), [[1, 1, 1], (1, 1, 1)]] | |||
| actual = onp.asarray(onp_input) | |||
| expected = mnp.asarray(mnp_input).asnumpy() | |||
| actual = onp.array(onp_input) | |||
| expected = mnp.array(mnp_input).asnumpy() | |||
| match_array(actual, expected, error=7) | |||
| @@ -202,11 +202,11 @@ def test_asfarray(): | |||
| match_array(actual, expected, error=7) | |||
| # Additional tests for nested tensor/numpy_array mixture | |||
| mnp_input = [(onp.ones(3,), mnp.ones(3)), [[1, 1, 1], (1, 1, 1)]] | |||
| mnp_input = [(mnp.ones(3,), mnp.ones(3)), [[1, 1, 1], (1, 1, 1)]] | |||
| onp_input = [(onp.ones(3,), onp.ones(3)), [[1, 1, 1], (1, 1, 1)]] | |||
| actual = onp.asarray(onp_input) | |||
| expected = mnp.asarray(mnp_input).asnumpy() | |||
| actual = onp.asfarray(onp_input) | |||
| expected = mnp.asfarray(mnp_input).asnumpy() | |||
| match_array(actual, expected, error=7) | |||
| @@ -373,14 +373,14 @@ def test_linspace(): | |||
| stop = onp.random.random([1, 5, 1]).astype("float32") | |||
| actual = onp.linspace(start, stop, num=20, retstep=True, | |||
| endpoint=False, dtype=onp.float32) | |||
| expected = mnp.linspace(mnp.asarray(start), mnp.asarray(stop), num=20, | |||
| expected = mnp.linspace(to_tensor(start), to_tensor(stop), num=20, | |||
| retstep=True, endpoint=False) | |||
| match_array(actual[0], expected[0].asnumpy(), error=6) | |||
| match_array(actual[1], expected[1].asnumpy(), error=6) | |||
| actual = onp.linspace(start, stop, num=20, retstep=True, | |||
| endpoint=False, dtype=onp.int16) | |||
| expected = mnp.linspace(mnp.asarray(start), mnp.asarray(stop), num=20, | |||
| expected = mnp.linspace(to_tensor(start), to_tensor(stop), num=20, | |||
| retstep=True, endpoint=False, dtype=mnp.int16) | |||
| match_array(actual[0], expected[0].asnumpy(), error=6) | |||
| match_array(actual[1], expected[1].asnumpy(), error=6) | |||
| @@ -388,7 +388,7 @@ def test_linspace(): | |||
| for axis in range(2): | |||
| actual = onp.linspace(start, stop, num=20, retstep=False, | |||
| endpoint=False, dtype=onp.float32, axis=axis) | |||
| expected = mnp.linspace(mnp.asarray(start), mnp.asarray(stop), num=20, | |||
| expected = mnp.linspace(to_tensor(start), to_tensor(stop), num=20, | |||
| retstep=False, endpoint=False, dtype=mnp.float32, axis=axis) | |||
| match_array(actual, expected.asnumpy(), error=6) | |||
| @@ -510,18 +510,18 @@ def test_full_like(): | |||
| for mnp_proto, onp_proto in zip(test_case.mnp_prototypes, test_case.onp_prototypes): | |||
| shape = onp.zeros_like(onp_proto).shape | |||
| fill_value = rand_int() | |||
| actual = mnp.full_like(mnp_proto, mnp.array(fill_value)).asnumpy() | |||
| actual = mnp.full_like(mnp_proto, to_tensor(fill_value)).asnumpy() | |||
| expected = onp.full_like(onp_proto, fill_value) | |||
| match_array(actual, expected) | |||
| for i in range(len(shape) - 1, 0, -1): | |||
| fill_value = rand_int(*shape[i:]) | |||
| actual = mnp.full_like(mnp_proto, mnp.array(fill_value)).asnumpy() | |||
| actual = mnp.full_like(mnp_proto, to_tensor(fill_value)).asnumpy() | |||
| expected = onp.full_like(onp_proto, fill_value) | |||
| match_array(actual, expected) | |||
| fill_value = rand_int(1, *shape[i + 1:]) | |||
| actual = mnp.full_like(mnp_proto, mnp.array(fill_value)).asnumpy() | |||
| actual = mnp.full_like(mnp_proto, to_tensor(fill_value)).asnumpy() | |||
| expected = onp.full_like(onp_proto, fill_value) | |||
| match_array(actual, expected) | |||
| @@ -549,6 +549,21 @@ def test_tri_triu_tril(): | |||
| match_array(mnp.tri(64, 64, -10).asnumpy(), onp.tri(64, 64, -10)) | |||
| @pytest.mark.level1 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.platform_x86_cpu | |||
| @pytest.mark.env_onecard | |||
| def test_nancumsum(): | |||
| x = rand_int(2, 3, 4, 5) | |||
| x[0][2][1][3] = onp.nan | |||
| x[1][0][2][4] = onp.nan | |||
| x[1][1][1][1] = onp.nan | |||
| match_res(mnp.nancumsum, onp.nancumsum, x) | |||
| match_res(mnp.nancumsum, onp.nancumsum, x, axis=-2) | |||
| match_res(mnp.nancumsum, onp.nancumsum, x, axis=0) | |||
| match_res(mnp.nancumsum, onp.nancumsum, x, axis=3) | |||
| def mnp_diagonal(arr): | |||
| return mnp.diagonal(arr, offset=2, axis1=-1, axis2=0) | |||
| @@ -653,7 +668,7 @@ def test_meshgrid(): | |||
| (2, 3), 9), onp.full((4, 5, 6), 7)) | |||
| for i in range(len(xi)): | |||
| arrs = xi[i:] | |||
| mnp_arrs = map(mnp.asarray, arrs) | |||
| mnp_arrs = map(to_tensor, arrs) | |||
| for mnp_res, onp_res in zip(mnp_meshgrid(*mnp_arrs), onp_meshgrid(*arrs)): | |||
| match_all_arrays(mnp_res, onp_res) | |||
| @@ -750,6 +765,68 @@ def test_ix_(): | |||
| match_res(mnp_ix_, onp_ix_, *test_arrs) | |||
| def mnp_indices(): | |||
| a = mnp.indices((2, 3)) | |||
| b = mnp.indices((2, 3, 4), sparse=True) | |||
| return a, b | |||
| def onp_indices(): | |||
| a = onp.indices((2, 3)) | |||
| b = onp.indices((2, 3, 4), sparse=True) | |||
| return a, b | |||
| def test_indices(): | |||
| run_multi_test(mnp_indices, onp_indices, ()) | |||
| @pytest.mark.level1 | |||
| @pytest.mark.platform_arm_ascend_training | |||
| @pytest.mark.platform_x86_ascend_training | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.platform_x86_cpu | |||
| @pytest.mark.env_onecard | |||
| def test_geomspace(): | |||
| start = onp.arange(1, 7).reshape(2, 3) | |||
| end = [1000, 2000, 3000] | |||
| match_array(mnp.geomspace(1, 256, num=9).asnumpy(), | |||
| onp.geomspace(1, 256, num=9), error=1) | |||
| match_array(mnp.geomspace(1, 256, num=8, endpoint=False).asnumpy(), | |||
| onp.geomspace(1, 256, num=8, endpoint=False), error=1) | |||
| match_array(mnp.geomspace(to_tensor(start), end, num=4).asnumpy(), | |||
| onp.geomspace(start, end, num=4), error=1) | |||
| match_array(mnp.geomspace(to_tensor(start), end, num=4, endpoint=False).asnumpy(), | |||
| onp.geomspace(start, end, num=4, endpoint=False), error=1) | |||
| match_array(mnp.geomspace(to_tensor(start), end, num=4, axis=-1).asnumpy(), | |||
| onp.geomspace(start, end, num=4, axis=-1), error=1) | |||
| match_array(mnp.geomspace(to_tensor(start), end, num=4, endpoint=False, axis=-1).asnumpy(), | |||
| onp.geomspace(start, end, num=4, endpoint=False, axis=-1), error=1) | |||
| start = onp.arange(1, 1 + 2*3*4*5).reshape(2, 3, 4, 5) | |||
| end = [1000, 2000, 3000, 4000, 5000] | |||
| for i in range(-5, 5): | |||
| match_array(mnp.geomspace(to_tensor(start), end, num=4, axis=i).asnumpy(), | |||
| onp.geomspace(start, end, num=4, axis=i), error=1) | |||
| @pytest.mark.level1 | |||
| @pytest.mark.platform_arm_ascend_training | |||
| @pytest.mark.platform_x86_ascend_training | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.platform_x86_cpu | |||
| @pytest.mark.env_onecard | |||
| def test_vander(): | |||
| arrs = [rand_int(i + 3) for i in range(3)] | |||
| for i in range(3): | |||
| mnp_vander = mnp.vander(to_tensor(arrs[i])) | |||
| onp_vander = onp.vander(arrs[i]) | |||
| match_all_arrays(mnp_vander, onp_vander) | |||
| mnp_vander = mnp.vander(to_tensor(arrs[i]), N=2, increasing=True) | |||
| onp_vander = onp.vander(arrs[i], N=2, increasing=True) | |||
| match_all_arrays(mnp_vander, onp_vander) | |||
| @pytest.mark.level1 | |||
| @pytest.mark.platform_arm_ascend_training | |||
| @pytest.mark.platform_x86_ascend_training | |||
| @@ -23,7 +23,7 @@ import mindspore.numpy as mnp | |||
| from mindspore.nn import Cell | |||
| from .utils import rand_int, run_non_kw_test, check_all_results, match_array, \ | |||
| rand_bool, match_res, run_multi_test | |||
| rand_bool, match_res, run_multi_test, to_tensor | |||
| class Cases(): | |||
| @@ -139,7 +139,7 @@ def onp_transpose(input_array): | |||
| @pytest.mark.env_onecard | |||
| def test_transpose(): | |||
| onp_array = onp.random.random((3, 4, 5)).astype('float32') | |||
| mnp_array = mnp.asarray(onp_array) | |||
| mnp_array = to_tensor(onp_array) | |||
| o_transposed = onp_transpose(onp_array) | |||
| m_transposed = mnp_transpose(mnp_array) | |||
| check_all_results(o_transposed, m_transposed) | |||
| @@ -170,7 +170,7 @@ def onp_expand_dims(input_array): | |||
| @pytest.mark.env_onecard | |||
| def test_expand_dims(): | |||
| onp_array = onp.random.random((3, 4, 5)).astype('float32') | |||
| mnp_array = mnp.asarray(onp_array) | |||
| mnp_array = to_tensor(onp_array) | |||
| o_expanded = onp_expand_dims(onp_array) | |||
| m_expanded = mnp_expand_dims(mnp_array) | |||
| check_all_results(o_expanded, m_expanded) | |||
| @@ -205,13 +205,13 @@ def onp_squeeze(input_array): | |||
| @pytest.mark.env_onecard | |||
| def test_squeeze(): | |||
| onp_array = onp.random.random((1, 3, 1, 4, 2)).astype('float32') | |||
| mnp_array = mnp.asarray(onp_array) | |||
| mnp_array = to_tensor(onp_array) | |||
| o_squeezed = onp_squeeze(onp_array) | |||
| m_squeezed = mnp_squeeze(mnp_array) | |||
| check_all_results(o_squeezed, m_squeezed) | |||
| onp_array = onp.random.random((1, 1, 1, 1, 1)).astype('float32') | |||
| mnp_array = mnp.asarray(onp_array) | |||
| mnp_array = to_tensor(onp_array) | |||
| o_squeezed = onp_squeeze(onp_array) | |||
| m_squeezed = mnp_squeeze(mnp_array) | |||
| check_all_results(o_squeezed, m_squeezed) | |||
| @@ -246,7 +246,7 @@ def onp_rollaxis(input_array): | |||
| @pytest.mark.env_onecard | |||
| def test_rollaxis(): | |||
| onp_array = onp.random.random((3, 4, 5)).astype('float32') | |||
| mnp_array = mnp.asarray(onp_array) | |||
| mnp_array = to_tensor(onp_array) | |||
| o_rolled = onp_rollaxis(onp_array) | |||
| m_rolled = mnp_rollaxis(mnp_array) | |||
| check_all_results(o_rolled, m_rolled) | |||
| @@ -281,7 +281,7 @@ def onp_swapaxes(input_array): | |||
| @pytest.mark.env_onecard | |||
| def test_swapaxes(): | |||
| onp_array = onp.random.random((3, 4, 5)).astype('float32') | |||
| mnp_array = mnp.asarray(onp_array) | |||
| mnp_array = to_tensor(onp_array) | |||
| o_swaped = onp_swapaxes(onp_array) | |||
| m_swaped = mnp_swapaxes(mnp_array) | |||
| check_all_results(o_swaped, m_swaped) | |||
| @@ -324,7 +324,7 @@ def onp_reshape(input_array): | |||
| @pytest.mark.env_onecard | |||
| def test_reshape(): | |||
| onp_array = onp.random.random((2, 3, 4)).astype('float32') | |||
| mnp_array = mnp.asarray(onp_array) | |||
| mnp_array = to_tensor(onp_array) | |||
| o_reshaped = onp_reshape(onp_array) | |||
| m_reshaped = mnp_reshape(mnp_array) | |||
| check_all_results(o_reshaped, m_reshaped) | |||
| @@ -349,7 +349,7 @@ def onp_ravel(input_array): | |||
| @pytest.mark.env_onecard | |||
| def test_ravel(): | |||
| onp_array = onp.random.random((2, 3, 4)).astype('float32') | |||
| mnp_array = mnp.asarray(onp_array) | |||
| mnp_array = to_tensor(onp_array) | |||
| o_ravel = onp_ravel(onp_array) | |||
| m_ravel = mnp_ravel(mnp_array).asnumpy() | |||
| match_array(o_ravel, m_ravel) | |||
| @@ -380,7 +380,7 @@ def onp_concatenate(input_array): | |||
| @pytest.mark.env_onecard | |||
| def test_concatenate(): | |||
| onp_array = onp.random.random((5, 4, 3, 2)).astype('float32') | |||
| mnp_array = mnp.asarray(onp_array) | |||
| mnp_array = to_tensor(onp_array) | |||
| o_concatenate = onp_concatenate(onp_array) | |||
| m_concatenate = mnp_concatenate(mnp_array) | |||
| check_all_results(o_concatenate, m_concatenate) | |||
| @@ -407,8 +407,8 @@ def onp_append(arr1, arr2): | |||
| def test_append(): | |||
| onp_array = onp.random.random((4, 3, 2)).astype('float32') | |||
| onp_value = onp.random.random((4, 3, 2)).astype('float32') | |||
| mnp_array = mnp.asarray(onp_array) | |||
| mnp_value = mnp.asarray(onp_value) | |||
| mnp_array = to_tensor(onp_array) | |||
| mnp_value = to_tensor(onp_value) | |||
| onp_res = onp_append(onp_array, onp_value) | |||
| mnp_res = mnp_append(mnp_array, mnp_value) | |||
| check_all_results(onp_res, mnp_res) | |||
| @@ -424,13 +424,13 @@ def construct_arrays(n=1, ndim=1, axis=None, low=1, high=5): | |||
| onp_array1 = onp.random.randint( | |||
| low=low, high=high, size=shape).astype(onp.float32) | |||
| onp_array_lst.append(onp_array1) | |||
| mnp_array_lst.append(mnp.asarray(onp_array1)) | |||
| mnp_array_lst.append(to_tensor(onp_array1)) | |||
| if axis is not None and axis < ndim: | |||
| new_shape[axis] += onp.random.randint(2) | |||
| onp_array2 = onp.random.randint( | |||
| low=low, high=high, size=new_shape).astype(onp.float32) | |||
| onp_array_lst.append(onp_array2) | |||
| mnp_array_lst.append(mnp.asarray(onp_array2)) | |||
| mnp_array_lst.append(to_tensor(onp_array2)) | |||
| return onp_array_lst, mnp_array_lst | |||
| # Test np.xstack | |||
| @@ -656,7 +656,7 @@ def onp_ndarray_flatten(input_array): | |||
| @pytest.mark.env_onecard | |||
| def test_ndarray_flatten(): | |||
| onp_array = onp.random.random((3, 4, 5)).astype('float32') | |||
| mnp_array = mnp.asarray(onp_array) | |||
| mnp_array = to_tensor(onp_array) | |||
| o_flatten = onp_ndarray_flatten(onp_array) | |||
| m_flatten = mnp_ndarray_flatten(mnp_array) | |||
| check_all_results(o_flatten, m_flatten) | |||
| @@ -687,7 +687,7 @@ def onp_ndarray_transpose(input_array): | |||
| @pytest.mark.env_onecard | |||
| def test_ndarray_transpose(): | |||
| onp_array = onp.random.random((3, 4, 5)).astype('float32') | |||
| mnp_array = mnp.asarray(onp_array) | |||
| mnp_array = to_tensor(onp_array) | |||
| o_transposed = onp_ndarray_transpose(onp_array) | |||
| m_transposed = mnp_ndarray_transpose(mnp_array) | |||
| check_all_results(o_transposed, m_transposed) | |||
| @@ -716,7 +716,7 @@ def onp_ndarray_astype(input_array): | |||
| @pytest.mark.env_onecard | |||
| def test_ndarray_astype(): | |||
| onp_array = onp.random.random((3, 4, 5)).astype('float32') | |||
| mnp_array = mnp.asarray(onp_array) | |||
| mnp_array = to_tensor(onp_array) | |||
| o_astype = onp_ndarray_astype(onp_array) | |||
| m_astype = mnp_ndarray_astype(mnp_array) | |||
| for arr1, arr2 in zip(o_astype, m_astype): | |||
| @@ -747,7 +747,7 @@ def mnp_concatenate_type_promotion(mnp_array1, mnp_array2, mnp_array3, mnp_array | |||
| @pytest.mark.env_onecard | |||
| def test_concatenate_type_promotion(): | |||
| onp_array = onp.random.random((5, 1)).astype('float32') | |||
| mnp_array = mnp.asarray(onp_array) | |||
| mnp_array = to_tensor(onp_array) | |||
| onp_array1 = onp_array.astype(onp.float16) | |||
| onp_array2 = onp_array.astype(onp.bool_) | |||
| onp_array3 = onp_array.astype(onp.float32) | |||
| @@ -1049,7 +1049,7 @@ def test_split(): | |||
| onp_arrs = [ | |||
| onp.random.randint(1, 5, size=(9, 4, 5)).astype('float32') | |||
| ] | |||
| mnp_arrs = [mnp.asarray(arr) for arr in onp_arrs] | |||
| mnp_arrs = [to_tensor(arr) for arr in onp_arrs] | |||
| for onp_arr, mnp_arr in zip(onp_arrs, mnp_arrs): | |||
| o_split = onp_split(onp_arr) | |||
| m_split = mnp_split(mnp_arr) | |||
| @@ -1058,6 +1058,36 @@ def test_split(): | |||
| match_array(expect, actual.asnumpy()) | |||
| def mnp_array_split(input_tensor): | |||
| a = mnp.array_split(input_tensor, indices_or_sections=4, axis=2) | |||
| b = mnp.array_split(input_tensor, indices_or_sections=3, axis=1) | |||
| c = mnp.array_split(input_tensor, indices_or_sections=6) | |||
| return a, b, c | |||
| def onp_array_split(input_array): | |||
| a = onp.array_split(input_array, indices_or_sections=4, axis=2) | |||
| b = onp.array_split(input_array, indices_or_sections=3, axis=1) | |||
| c = onp.array_split(input_array, indices_or_sections=6) | |||
| return a, b, c | |||
| @pytest.mark.level1 | |||
| @pytest.mark.platform_arm_ascend_training | |||
| @pytest.mark.platform_x86_ascend_training | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.platform_x86_cpu | |||
| @pytest.mark.env_onecard | |||
| def test_array_split(): | |||
| onp_arr = onp.random.randint(1, 5, size=(9, 7, 13)).astype('float32') | |||
| mnp_arr = to_tensor(onp_arr) | |||
| o_split = onp_split(onp_arr) | |||
| m_split = mnp_split(mnp_arr) | |||
| for expect_lst, actual_lst in zip(o_split, m_split): | |||
| for expect, actual in zip(expect_lst, actual_lst): | |||
| match_array(expect, actual.asnumpy()) | |||
| def mnp_vsplit(input_tensor): | |||
| a = mnp.vsplit(input_tensor, indices_or_sections=3) | |||
| b = mnp.vsplit(input_tensor, indices_or_sections=(-10, -4, 5, 10)) | |||
| @@ -1082,7 +1112,7 @@ def test_vsplit(): | |||
| onp_arrs = [ | |||
| onp.random.randint(1, 5, size=(9, 4, 5)).astype('float32') | |||
| ] | |||
| mnp_arrs = [mnp.asarray(arr) for arr in onp_arrs] | |||
| mnp_arrs = [to_tensor(arr) for arr in onp_arrs] | |||
| for onp_arr, mnp_arr in zip(onp_arrs, mnp_arrs): | |||
| o_vsplit = onp_vsplit(onp_arr) | |||
| m_vsplit = mnp_vsplit(mnp_arr) | |||
| @@ -1115,7 +1145,7 @@ def test_hsplit(): | |||
| onp_arrs = [ | |||
| onp.random.randint(1, 5, size=(4, 9, 5)).astype('float32') | |||
| ] | |||
| mnp_arrs = [mnp.asarray(arr) for arr in onp_arrs] | |||
| mnp_arrs = [to_tensor(arr) for arr in onp_arrs] | |||
| for onp_arr, mnp_arr in zip(onp_arrs, mnp_arrs): | |||
| o_hsplit = onp_hsplit(onp_arr) | |||
| m_hsplit = mnp_hsplit(mnp_arr) | |||
| @@ -1148,7 +1178,7 @@ def test_dsplit(): | |||
| onp_arrs = [ | |||
| onp.random.randint(1, 5, size=(5, 4, 9)).astype('float32') | |||
| ] | |||
| mnp_arrs = [mnp.asarray(arr) for arr in onp_arrs] | |||
| mnp_arrs = [to_tensor(arr) for arr in onp_arrs] | |||
| for onp_arr, mnp_arr in zip(onp_arrs, mnp_arrs): | |||
| o_dsplit = onp_dsplit(onp_arr) | |||
| m_dsplit = mnp_dsplit(mnp_arr) | |||
| @@ -1248,6 +1278,29 @@ def test_repeat(): | |||
| run_multi_test(mnp_repeat, onp_repeat, (x,)) | |||
| @pytest.mark.level1 | |||
| @pytest.mark.platform_arm_ascend_training | |||
| @pytest.mark.platform_x86_ascend_training | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.platform_x86_cpu | |||
| @pytest.mark.env_onecard | |||
| def test_select(): | |||
| choicelist = rand_int(2, 3, 4, 5) | |||
| condlist = choicelist > 2 | |||
| match_res(mnp.select, onp.select, condlist, choicelist) | |||
| match_res(mnp.select, onp.select, condlist, choicelist, default=10) | |||
| condlist = rand_bool(5, 4, 1, 3) | |||
| choicelist = rand_int(5, 3) | |||
| match_res(mnp.select, onp.select, condlist, choicelist) | |||
| match_res(mnp.select, onp.select, condlist, choicelist, default=10) | |||
| condlist = rand_bool(3, 1, 7) | |||
| choicelist = rand_int(3, 5, 2, 1) | |||
| match_res(mnp.select, onp.select, condlist, choicelist) | |||
| match_res(mnp.select, onp.select, condlist, choicelist, default=10) | |||
| class ReshapeExpandSqueeze(Cell): | |||
| def __init__(self): | |||
| super(ReshapeExpandSqueeze, self).__init__() | |||
| @@ -1333,7 +1386,7 @@ def test_swapaxes_exception(): | |||
| @pytest.mark.env_onecard | |||
| def test_tensor_flatten(): | |||
| lst = [[1.0, 2.0], [3.0, 4.0]] | |||
| tensor_list = mnp.asarray(lst) | |||
| tensor_list = to_tensor(lst) | |||
| assert tensor_list.flatten().asnumpy().tolist() == [1.0, 2.0, 3.0, 4.0] | |||
| assert tensor_list.flatten(order='F').asnumpy().tolist() == [ | |||
| 1.0, 3.0, 2.0, 4.0] | |||
| @@ -1347,7 +1400,7 @@ def test_tensor_flatten(): | |||
| @pytest.mark.env_onecard | |||
| def test_tensor_reshape(): | |||
| lst = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0] | |||
| tensor_list = mnp.asarray(lst) | |||
| tensor_list = to_tensor(lst) | |||
| with pytest.raises(TypeError): | |||
| tensor_list = tensor_list.reshape({0, 1, 2}) | |||
| with pytest.raises(ValueError): | |||
| @@ -1364,7 +1417,7 @@ def test_tensor_reshape(): | |||
| @pytest.mark.env_onecard | |||
| def test_tensor_squeeze(): | |||
| lst = [[[1.0], [2.0], [3.0]]] | |||
| tensor_list = mnp.asarray(lst) | |||
| tensor_list = to_tensor(lst) | |||
| with pytest.raises(TypeError): | |||
| tensor_list = tensor_list.squeeze(1.2) | |||
| with pytest.raises(ValueError): | |||
| @@ -1381,7 +1434,7 @@ def test_tensor_squeeze(): | |||
| @pytest.mark.env_onecard | |||
| def test_tensor_ravel(): | |||
| lst = [[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0]] | |||
| tensor_list = mnp.asarray(lst) | |||
| tensor_list = to_tensor(lst) | |||
| assert tensor_list.ravel().shape == (8,) | |||
| assert tensor_list.ravel().asnumpy().tolist() == [ | |||
| 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0] | |||
| @@ -1395,9 +1448,47 @@ def test_tensor_ravel(): | |||
| @pytest.mark.env_onecard | |||
| def test_tensor_swapaxes(): | |||
| lst = [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]] | |||
| tensor_list = mnp.asarray(lst) | |||
| tensor_list = to_tensor(lst) | |||
| with pytest.raises(TypeError): | |||
| tensor_list = tensor_list.swapaxes(0, (1,)) | |||
| with pytest.raises(ValueError): | |||
| tensor_list = tensor_list.swapaxes(0, 3) | |||
| assert tensor_list.swapaxes(0, 1).shape == (3, 2) | |||
| def mnp_rot90(input_tensor): | |||
| a = mnp.rot90(input_tensor) | |||
| b = mnp.rot90(input_tensor, 2) | |||
| c = mnp.rot90(input_tensor, 3) | |||
| d = mnp.rot90(input_tensor, 4) | |||
| e = mnp.rot90(input_tensor, 5, (0, -1)) | |||
| f = mnp.rot90(input_tensor, 1, (2, 0)) | |||
| g = mnp.rot90(input_tensor, -3, (-1, -2)) | |||
| h = mnp.rot90(input_tensor, 3, (2, 1)) | |||
| return a, b, c, d, e, f, g, h | |||
| def onp_rot90(input_array): | |||
| a = onp.rot90(input_array) | |||
| b = onp.rot90(input_array, 2) | |||
| c = onp.rot90(input_array, 3) | |||
| d = onp.rot90(input_array, 4) | |||
| e = onp.rot90(input_array, 5, (0, -1)) | |||
| f = onp.rot90(input_array, 1, (2, 0)) | |||
| g = onp.rot90(input_array, -3, (-1, -2)) | |||
| h = onp.rot90(input_array, 3, (2, 1)) | |||
| return a, b, c, d, e, f, g, h | |||
| @pytest.mark.level1 | |||
| @pytest.mark.platform_arm_ascend_training | |||
| @pytest.mark.platform_x86_ascend_training | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.platform_x86_cpu | |||
| @pytest.mark.env_onecard | |||
| def test_rot90(): | |||
| onp_array = rand_int(3, 4, 5).astype('float32') | |||
| mnp_array = to_tensor(onp_array) | |||
| o_rot = onp_rot90(onp_array) | |||
| m_rot = mnp_rot90(mnp_array) | |||
| check_all_results(o_rot, m_rot) | |||
| @@ -19,7 +19,8 @@ import numpy as onp | |||
| import mindspore.numpy as mnp | |||
| from .utils import rand_int, run_binop_test, match_res | |||
| from .utils import rand_int, rand_bool, run_binop_test, run_logical_test, match_res, \ | |||
| match_all_arrays, to_tensor | |||
| class Cases(): | |||
| @@ -55,6 +56,15 @@ class Cases(): | |||
| rand_int(8, 1, 6, 1) | |||
| ] | |||
| # Boolean arrays | |||
| self.boolean_arrs = [ | |||
| rand_bool(), | |||
| rand_bool(5), | |||
| rand_bool(6, 1), | |||
| rand_bool(7, 1, 5), | |||
| rand_bool(8, 1, 6, 1) | |||
| ] | |||
| # array which contains infs and nans | |||
| self.infs = onp.array([[1.0, onp.nan], [onp.inf, onp.NINF], [2.3, -4.5], [onp.nan, 0.0]]) | |||
| @@ -246,10 +256,147 @@ def test_isneginf(): | |||
| match_res(mnp_isneginf, onp_isneginf, test_case.infs) | |||
| @pytest.mark.level1 | |||
| @pytest.mark.platform_arm_ascend_training | |||
| @pytest.mark.platform_x86_ascend_training | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.platform_x86_cpu | |||
| @pytest.mark.env_onecard | |||
| def test_isscalar(): | |||
| assert mnp.isscalar(1) == onp.isscalar(1) | |||
| assert mnp.isscalar(2.3) == onp.isscalar(2.3) | |||
| assert mnp.isscalar([4.5]) == onp.isscalar([4.5]) | |||
| assert mnp.isscalar(False) == onp.isscalar(False) | |||
| assert mnp.isscalar(mnp.array(True)) == onp.isscalar(onp.array(True)) | |||
| assert mnp.isscalar(to_tensor(True)) == onp.isscalar(onp.array(True)) | |||
| assert mnp.isscalar('numpy') == onp.isscalar('numpy') | |||
| @pytest.mark.level1 | |||
| @pytest.mark.platform_arm_ascend_training | |||
| @pytest.mark.platform_x86_ascend_training | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.platform_x86_cpu | |||
| @pytest.mark.env_onecard | |||
| def test_isclose(): | |||
| a = [0, 1, 2, float('inf'), float('inf'), float('nan')] | |||
| b = [0, 1, -2, float('-inf'), float('inf'), float('nan')] | |||
| match_all_arrays(mnp.isclose(a, b), onp.isclose(a, b)) | |||
| match_all_arrays(mnp.isclose(a, b, equal_nan=True), onp.isclose(a, b, equal_nan=True)) | |||
| a = rand_int(2, 3, 4, 5) | |||
| diff = (onp.random.random((2, 3, 4, 5)).astype("float32") - 0.5) / 1000 | |||
| b = a + diff | |||
| match_all_arrays(mnp.isclose(to_tensor(a), to_tensor(b), atol=1e-3), onp.isclose(a, b, atol=1e-3)) | |||
| match_all_arrays(mnp.isclose(to_tensor(a), to_tensor(b), atol=1e-3, rtol=1e-4), | |||
| onp.isclose(a, b, atol=1e-3, rtol=1e-4)) | |||
| match_all_arrays(mnp.isclose(to_tensor(a), to_tensor(b), atol=1e-2, rtol=1e-6), | |||
| onp.isclose(a, b, atol=1e-2, rtol=1e-6)) | |||
| a = rand_int(2, 3, 4, 5) | |||
| b = rand_int(4, 5) | |||
| match_all_arrays(mnp.isclose(to_tensor(a), to_tensor(b)), onp.isclose(a, b)) | |||
| @pytest.mark.level1 | |||
| @pytest.mark.platform_arm_ascend_training | |||
| @pytest.mark.platform_x86_ascend_training | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.platform_x86_cpu | |||
| @pytest.mark.env_onecard | |||
| def test_in1d(): | |||
| xi = [rand_int(), rand_int(1), rand_int(10)] | |||
| yi = [rand_int(), rand_int(1), rand_int(10)] | |||
| for x in xi: | |||
| for y in yi: | |||
| match_res(mnp.in1d, onp.in1d, x, y) | |||
| match_res(mnp.in1d, onp.in1d, x, y, invert=True) | |||
| @pytest.mark.level1 | |||
| @pytest.mark.platform_arm_ascend_training | |||
| @pytest.mark.platform_x86_ascend_training | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.platform_x86_cpu | |||
| @pytest.mark.env_onecard | |||
| def test_isin(): | |||
| xi = [rand_int(), rand_int(1), rand_int(10), rand_int(2, 3)] | |||
| yi = [rand_int(), rand_int(1), rand_int(10), rand_int(2, 3)] | |||
| for x in xi: | |||
| for y in yi: | |||
| match_res(mnp.in1d, onp.in1d, x, y) | |||
| match_res(mnp.in1d, onp.in1d, x, y, invert=True) | |||
| def mnp_logical_or(x1, x2): | |||
| return mnp.logical_or(x1, x2) | |||
| def onp_logical_or(x1, x2): | |||
| return onp.logical_or(x1, x2) | |||
| @pytest.mark.level1 | |||
| @pytest.mark.platform_arm_ascend_training | |||
| @pytest.mark.platform_x86_ascend_training | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.platform_x86_cpu | |||
| @pytest.mark.env_onecard | |||
| def test_logical_or(): | |||
| run_logical_test(mnp_logical_or, onp_logical_or, test_case) | |||
| def mnp_logical_xor(x1, x2): | |||
| return mnp.logical_xor(x1, x2) | |||
| def onp_logical_xor(x1, x2): | |||
| return onp.logical_xor(x1, x2) | |||
| @pytest.mark.level1 | |||
| @pytest.mark.platform_arm_ascend_training | |||
| @pytest.mark.platform_x86_ascend_training | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.platform_x86_cpu | |||
| @pytest.mark.env_onecard | |||
| def test_logical_xor(): | |||
| run_logical_test(mnp_logical_xor, onp_logical_xor, test_case) | |||
| def mnp_logical_and(x1, x2): | |||
| return mnp.logical_and(x1, x2) | |||
| def onp_logical_and(x1, x2): | |||
| return onp.logical_and(x1, x2) | |||
| @pytest.mark.level1 | |||
| @pytest.mark.platform_arm_ascend_training | |||
| @pytest.mark.platform_x86_ascend_training | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.platform_x86_cpu | |||
| @pytest.mark.env_onecard | |||
| def test_logical_and(): | |||
| run_logical_test(mnp_logical_and, onp_logical_and, test_case) | |||
| def mnp_logical_not(x): | |||
| return mnp.logical_not(x) | |||
| def onp_logical_not(x): | |||
| return onp.logical_not(x) | |||
| @pytest.mark.level1 | |||
| @pytest.mark.platform_arm_ascend_training | |||
| @pytest.mark.platform_x86_ascend_training | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.platform_x86_cpu | |||
| @pytest.mark.env_onecard | |||
| def test_logical_not(): | |||
| for arr in test_case.boolean_arrs: | |||
| expected = onp_logical_not(arr) | |||
| actual = mnp_logical_not(to_tensor(arr)) | |||
| onp.testing.assert_equal(actual.asnumpy().tolist(), expected.tolist()) | |||
| @@ -15,6 +15,7 @@ | |||
| """utility functions for mindspore.numpy st tests""" | |||
| import functools | |||
| import numpy as onp | |||
| from mindspore import Tensor | |||
| import mindspore.numpy as mnp | |||
| @@ -90,7 +91,9 @@ def rand_bool(*shape): | |||
| def match_res(mnp_fn, onp_fn, *arrs, **kwargs): | |||
| """Checks results from applying mnp_fn and onp_fn on arrs respectively""" | |||
| mnp_arrs = map(functools.partial(mnp.asarray, dtype='float32'), arrs) | |||
| dtype = kwargs.get('dtype', mnp.float32) | |||
| kwargs.pop('dtype', None) | |||
| mnp_arrs = map(functools.partial(Tensor, dtype=dtype), arrs) | |||
| error = kwargs.get('error', 0) | |||
| kwargs.pop('error', None) | |||
| mnp_res = mnp_fn(*mnp_arrs, **kwargs) | |||
| @@ -151,15 +154,32 @@ def run_unary_test(mnp_fn, onp_fn, test_case, error=0): | |||
| def run_multi_test(mnp_fn, onp_fn, arrs, error=0): | |||
| mnp_arrs = map(mnp.asarray, arrs) | |||
| mnp_arrs = map(Tensor, arrs) | |||
| for actual, expected in zip(mnp_fn(*mnp_arrs), onp_fn(*arrs)): | |||
| match_array(actual.asnumpy(), expected, error) | |||
| match_all_arrays(actual, expected, error) | |||
| def run_single_test(mnp_fn, onp_fn, arr, error=0): | |||
| mnp_arr = mnp.asarray(arr) | |||
| mnp_arr = Tensor(arr) | |||
| for actual, expected in zip(mnp_fn(mnp_arr), onp_fn(arr)): | |||
| if isinstance(expected, tuple): | |||
| for actual_arr, expected_arr in zip(actual, expected): | |||
| match_array(actual_arr.asnumpy(), expected_arr, error) | |||
| match_array(actual.asnumpy(), expected, error) | |||
| def run_logical_test(mnp_fn, onp_fn, test_case): | |||
| for x1 in test_case.boolean_arrs: | |||
| for x2 in test_case.boolean_arrs: | |||
| match_res(mnp_fn, onp_fn, x1, x2, dtype=mnp.bool_) | |||
| def to_tensor(obj, dtype=None): | |||
| if dtype is None: | |||
| res = Tensor(obj) | |||
| if res.dtype == mnp.float64: | |||
| res = res.astype(mnp.float32) | |||
| if res.dtype == mnp.int64: | |||
| res = res.astype(mnp.int32) | |||
| else: | |||
| res = Tensor(obj, dtype) | |||
| return res | |||