Browse Source

Add new np interfaces

tags/v1.2.0-rc1
yanglf1121 5 years ago
parent
commit
72b365c24b
16 changed files with 4187 additions and 1060 deletions
  1. +25
    -8
      mindspore/numpy/__init__.py
  2. +246
    -76
      mindspore/numpy/array_creations.py
  3. +234
    -59
      mindspore/numpy/array_ops.py
  4. +13
    -0
      mindspore/numpy/dtypes.py
  5. +342
    -148
      mindspore/numpy/logic_ops.py
  6. +1977
    -574
      mindspore/numpy/math_ops.py
  7. +22
    -7
      mindspore/numpy/utils.py
  8. +57
    -39
      mindspore/numpy/utils_const.py
  9. +16
    -0
      mindspore/ops/composite/multitype_ops/not_equal_impl.py
  10. +17
    -0
      mindspore/ops/functional.py
  11. +2
    -2
      mindspore/ops/operations/math_ops.py
  12. +95
    -18
      tests/st/numpy_native/test_array_creations.py
  13. +118
    -27
      tests/st/numpy_native/test_array_ops.py
  14. +149
    -2
      tests/st/numpy_native/test_logic_ops.py
  15. +850
    -96
      tests/st/numpy_native/test_math_ops.py
  16. +24
    -4
      tests/st/numpy_native/utils.py

+ 25
- 8
mindspore/numpy/__init__.py View File

@@ -30,13 +30,14 @@ from .array_ops import (transpose, expand_dims, squeeze, rollaxis, swapaxes, res
ravel, concatenate, where, atleast_1d, atleast_2d, atleast_3d, ravel, concatenate, where, atleast_1d, atleast_2d, atleast_3d,
column_stack, hstack, dstack, vstack, stack, unique, moveaxis, column_stack, hstack, dstack, vstack, stack, unique, moveaxis,
tile, broadcast_to, broadcast_arrays, roll, append, split, vsplit, tile, broadcast_to, broadcast_arrays, roll, append, split, vsplit,
flip, flipud, fliplr, hsplit, dsplit, take_along_axis, take, repeat)
flip, flipud, fliplr, hsplit, dsplit, take_along_axis, take, repeat,
rot90, select, array_split)
from .array_creations import copy_ as copy from .array_creations import copy_ as copy
from .array_creations import (array, asarray, asfarray, ones, zeros, full, arange, from .array_creations import (array, asarray, asfarray, ones, zeros, full, arange,
linspace, logspace, eye, identity, empty, empty_like, linspace, logspace, eye, identity, empty, empty_like,
ones_like, zeros_like, full_like, diagonal, tril, triu, ones_like, zeros_like, full_like, diagonal, tril, triu,
tri, trace, meshgrid, mgrid, ogrid, diagflat, tri, trace, meshgrid, mgrid, ogrid, diagflat,
diag, diag_indices, ix_)
diag, diag_indices, ix_, indices, geomspace, vander)
from .dtypes import (int_, int8, int16, int32, int64, uint, uint8, uint16, from .dtypes import (int_, int8, int16, int32, int64, uint, uint8, uint16,
uint32, uint64, float_, float16, float32, float64, bool_, inf, nan, uint32, uint64, float_, float16, float32, float64, bool_, inf, nan,
numeric_types, PINF, NINF) numeric_types, PINF, NINF)
@@ -45,35 +46,51 @@ from .math_ops import (mean, inner, add, subtract, multiply, divide, true_divide
matmul, square, sqrt, reciprocal, log, maximum, heaviside, amax, amin, matmul, square, sqrt, reciprocal, log, maximum, heaviside, amax, amin,
hypot, float_power, floor, ptp, deg2rad, rad2deg, count_nonzero, hypot, float_power, floor, ptp, deg2rad, rad2deg, count_nonzero,
positive, negative, clip, floor_divide, remainder, fix, fmod, trunc, positive, negative, clip, floor_divide, remainder, fix, fmod, trunc,
exp, expm1, cumsum)
exp, expm1, exp2, kron, promote_types, divmod_, diff, cbrt,
cross, ceil, trapz, gcd, lcm, convolve, log1p, logaddexp, log2,
logaddexp2, log10, ediff1d, nansum, nanmean, nanvar, nanstd, cumsum, nancumsum,
sin, cos, tan, arcsin, arccos, arctan, sinh, cosh, tanh, arcsinh, arccosh,
arctanh, arctan2, cov)
from .logic_ops import (not_equal, less_equal, less, greater_equal, greater, equal, isfinite, from .logic_ops import (not_equal, less_equal, less, greater_equal, greater, equal, isfinite,
isnan, isinf, isposinf, isneginf, isscalar)
isnan, isinf, isposinf, isneginf, isscalar, logical_and, logical_not,
logical_or, logical_xor, in1d, isin, isclose)


mod = remainder mod = remainder
fabs = absolute fabs = absolute
divmod = divmod_ # pylint: disable=redefined-builtin
abs = absolute # pylint: disable=redefined-builtin
max = amax # pylint: disable=redefined-builtin
min = amin # pylint: disable=redefined-builtin



array_ops_module = ['transpose', 'expand_dims', 'squeeze', 'rollaxis', 'swapaxes', 'reshape', array_ops_module = ['transpose', 'expand_dims', 'squeeze', 'rollaxis', 'swapaxes', 'reshape',
'ravel', 'concatenate', 'where', 'atleast_1d', 'atleast_2d', 'atleast_3d', 'ravel', 'concatenate', 'where', 'atleast_1d', 'atleast_2d', 'atleast_3d',
'column_stack', 'hstack', 'dstack', 'vstack', 'stack', 'unique', 'moveaxis', 'column_stack', 'hstack', 'dstack', 'vstack', 'stack', 'unique', 'moveaxis',
'tile', 'broadcast_to', 'broadcast_arrays', 'append', 'roll', 'split', 'vsplit', 'tile', 'broadcast_to', 'broadcast_arrays', 'append', 'roll', 'split', 'vsplit',
'flip', 'flipud', 'fliplr', 'hsplit', 'dsplit', 'take_along_axis', 'take', 'flip', 'flipud', 'fliplr', 'hsplit', 'dsplit', 'take_along_axis', 'take',
'repeat']
'repeat', 'rot90', 'select', 'array_split']


array_creations_module = ['array', 'asarray', 'asfarray', 'ones', 'zeros', 'full', 'arange', array_creations_module = ['array', 'asarray', 'asfarray', 'ones', 'zeros', 'full', 'arange',
'linspace', 'logspace', 'eye', 'identity', 'empty', 'empty_like', 'linspace', 'logspace', 'eye', 'identity', 'empty', 'empty_like',
'ones_like', 'zeros_like', 'full_like', 'diagonal', 'tril', 'triu', 'ones_like', 'zeros_like', 'full_like', 'diagonal', 'tril', 'triu',
'tri', 'trace', 'meshgrid', 'mgrid', 'ogrid', 'diagflat', 'diag', 'tri', 'trace', 'meshgrid', 'mgrid', 'ogrid', 'diagflat', 'diag',
'diag_indices', 'ix_', 'cumsum']
'diag_indices', 'ix_', 'indices', 'geomspace', 'vander']


math_module = ['mean', 'inner', 'add', 'subtract', 'multiply', 'divide', 'true_divide', 'power', math_module = ['mean', 'inner', 'add', 'subtract', 'multiply', 'divide', 'true_divide', 'power',
'dot', 'outer', 'tensordot', 'absolute', 'std', 'var', 'average', 'not_equal', 'dot', 'outer', 'tensordot', 'absolute', 'std', 'var', 'average', 'not_equal',
'minimum', 'matmul', 'square', 'sqrt', 'reciprocal', 'log', 'maximum', 'minimum', 'matmul', 'square', 'sqrt', 'reciprocal', 'log', 'maximum',
'heaviside', 'amax', 'amin', 'hypot', 'float_power', 'floor', 'ptp', 'deg2rad', 'heaviside', 'amax', 'amin', 'hypot', 'float_power', 'floor', 'ptp', 'deg2rad',
'rad2deg', 'count_nonzero', 'positive', 'negative', 'clip', 'floor_divide', 'rad2deg', 'count_nonzero', 'positive', 'negative', 'clip', 'floor_divide',
'remainder', 'mod', 'fix', 'fmod', 'trunc', 'exp', 'expm1', 'fabs', 'cumsum']
'remainder', 'mod', 'fix', 'fmod', 'trunc', 'exp', 'expm1', 'fabs', 'exp2', 'kron',
'promote_types', 'divmod', 'diff', 'cbrt', 'cross', 'ceil', 'trapz',
'abs', 'max', 'min', 'gcd', 'lcm', 'log1p', 'logaddexp', 'log2', 'logaddexp2', 'log10',
'convolve', 'ediff1d', 'nansum', 'nanmean', 'nanvar', 'nanstd', 'cumsum',
'nancumsum', 'sin', 'cos', 'tan', 'arcsin', 'arccos', 'arctan', 'sinh', 'cosh', 'tanh',
'arcsinh', 'arccosh', 'arctanh', 'arctan2', 'cov']


logic_module = ['not_equal', 'less_equal', 'less', 'greater_equal', 'greater', 'equal', 'isfinite', logic_module = ['not_equal', 'less_equal', 'less', 'greater_equal', 'greater', 'equal', 'isfinite',
'isnan', 'isinf', 'isposinf', 'isneginf', 'isscalar']
'isnan', 'isinf', 'isposinf', 'isneginf', 'isscalar', 'logical_and', 'logical_not',
'logical_or', 'logical_xor', 'in1d', 'isin', 'isclose']


__all__ = array_ops_module + array_creations_module + math_module + logic_module + numeric_types __all__ = array_ops_module + array_creations_module + math_module + logic_module + numeric_types




+ 246
- 76
mindspore/numpy/array_creations.py View File

@@ -13,8 +13,6 @@
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
"""array operations, the function docs are adapted from Numpy API.""" """array operations, the function docs are adapted from Numpy API."""
from copy import deepcopy

import numpy as onp import numpy as onp


from ..common import Tensor from ..common import Tensor
@@ -27,10 +25,11 @@ from .._c_expression import Tensor as Tensor_
from .._c_expression.typing import Float from .._c_expression.typing import Float


from .utils import _check_input_for_asarray, _deep_list, _deep_tensor_to_nparray, \ from .utils import _check_input_for_asarray, _deep_list, _deep_tensor_to_nparray, \
_broadcast_to_shape, _check_input_tensor, _convert_64_to_32, _get_dtype_from_scalar
_broadcast_to_shape, _check_input_tensor, _convert_64_to_32, _get_dtype_from_scalar, \
_expand
from .utils_const import _raise_value_error, _empty, _check_axis_valid, _max, _min, \ from .utils_const import _raise_value_error, _empty, _check_axis_valid, _max, _min, \
_check_same_type, _is_shape_empty, _check_shape, _check_dtype, _tile_size, _abs, \ _check_same_type, _is_shape_empty, _check_shape, _check_dtype, _tile_size, _abs, \
_raise_type_error, _expanded_shape, _check_is_float, _iota, \
_raise_type_error, _expanded_shape, _tuple_getitem, _check_is_float, _iota, \
_type_convert, _canonicalize_axis, _list_comprehensions, _ceil _type_convert, _canonicalize_axis, _list_comprehensions, _ceil
from .array_ops import transpose, ravel, concatenate, broadcast_arrays, reshape, broadcast_to from .array_ops import transpose, ravel, concatenate, broadcast_arrays, reshape, broadcast_to
from .dtypes import nan from .dtypes import nan
@@ -49,9 +48,8 @@ def array(obj, dtype=None, copy=True, ndmin=0):
This function creates tensors from an array-like object. This function creates tensors from an array-like object.


Args: Args:
obj (Union[int, float, bool, list, tuple, numpy.ndarray]): Input data, in
any form that can be converted to a `Tensor`. This includes lists, lists of
tuples, tuples, tuples of tuples, tuples of lists and numpy.ndarray.
obj (Union[int, float, bool, list, tuple]): Input data, in any form that
can be converted to a `Tensor`. This includes Tensor, list, tuple and numbers.
dtype (Union[:class:`mindspore.dtype`, str], optional): Designated tensor dtype, can dtype (Union[:class:`mindspore.dtype`, str], optional): Designated tensor dtype, can
be in format of np.int32, or \'int32\'. If dtype is :class:`None`, the data type be in format of np.int32, or \'int32\'. If dtype is :class:`None`, the data type
of the new tensor will be inferred from obj. Default is :class:`None`. of the new tensor will be inferred from obj. Default is :class:`None`.
@@ -76,17 +74,53 @@ def array(obj, dtype=None, copy=True, ndmin=0):
>>> print(np.array([1,2,3])) >>> print(np.array([1,2,3]))
[1 2 3] [1 2 3]
""" """
if ndmin > 0:
# Fall back to original numpy creation.
if isinstance(obj, Tensor):
obj = obj.asnumpy()
return asarray(onp.array(obj, dtype, copy=copy, ndmin=ndmin))
res = asarray(obj, dtype)
if ndmin > res.ndim:
res = _expand(res, ndmin)

if copy:
res = copy_(res)
elif dtype is not None and dtype != res.dtype:
res = res.astype(dtype)

return res


@constexpr
def asarray_const(a, dtype=None):
"""Converts the input to tensor. Note here `a` cannot be tensor itself."""
_check_input_for_asarray(a)

if dtype is not None:
dtype = _check_dtype(dtype)

if isinstance(a, (float, int, bool)) and dtype is None:
dtype = _get_dtype_from_scalar(a)

if isinstance(a, (list, tuple)):
# Convert all tuple/nested tuples to lists
a = _deep_list(a)
# Convert all tensor sub-elements to numpy arrays
a = _deep_tensor_to_nparray(a)
a = onp.asarray(a)
if a.dtype is onp.dtype('object'):
raise ValueError('Input array must have the same size across all dimensions.')
# If dtype is not specified, we keep consistent with numpy decision
# only exceptions are: we use int/float32
if dtype is None:
dtype = mstype.pytype_to_dtype(a.dtype)
if dtype == mstype.float64:
dtype = mstype.float32
elif dtype == mstype.int64:
dtype = mstype.int32


if not copy:
return asarray(obj, dtype=dtype)
if isinstance(a, onp.ndarray) and dtype is None:
if a.dtype is onp.dtype('object'):
raise TypeError(f"For Tensor conversion, the input_data is {a} that contains unsupported element.")
dtype = mstype.pytype_to_dtype(a.dtype)
a = Tensor.from_numpy(a)


obj = deepcopy(obj)
return asarray(obj, dtype=dtype)
return Tensor(a, dtype=dtype)




def asarray(a, dtype=None): def asarray(a, dtype=None):
@@ -96,9 +130,8 @@ def asarray(a, dtype=None):
This function converts tensors from an array-like object. This function converts tensors from an array-like object.


Args: Args:
a (Union[int, float, bool, list, tuple, numpy.ndarray]): Input data, in
any form that can be converted to a `Tensor`. This includes lists, lists of
tuples, tuples, tuples of tuples, tuples of lists and numpy.ndarray.
a (Union[int, float, bool, list, tuple, Tensor]): Input data, in any form that can
be converted to a `Tensor`. This includes Tensor, list, tuple and numbers.
dtype (Union[:class:`mindspore.dtype`, str], optional): Designated tensor dtype, can dtype (Union[:class:`mindspore.dtype`, str], optional): Designated tensor dtype, can
be in format of np.int32, or \'int32\'. If dtype is :class:`None`, the data type be in format of np.int32, or \'int32\'. If dtype is :class:`None`, the data type
of the new tensor will be inferred from obj. Default is :class:`None`. of the new tensor will be inferred from obj. Default is :class:`None`.
@@ -118,46 +151,28 @@ def asarray(a, dtype=None):
>>> print(np.asarray([1,2,3])) >>> print(np.asarray([1,2,3]))
[1 2 3] [1 2 3]
""" """
_check_input_for_asarray(a)

if dtype is not None:
dtype = _check_dtype(dtype)
if isinstance(a, Tensor):
if dtype is None or dtype == a.dtype:
return a
return a.astype(dtype)
return asarray_const(a, dtype)


if isinstance(a, (float, int, bool)) and dtype is None:
dtype = _get_dtype_from_scalar(a)


@constexpr
def asfarray_const(a, dtype=mstype.float32):
"""Converts the input to tensor. Note here `a` cannot be tensor itself."""
_check_input_for_asarray(a)
if isinstance(a, (list, tuple)): if isinstance(a, (list, tuple)):
# Convert all tuple/nested tuples to lists # Convert all tuple/nested tuples to lists
a = _deep_list(a) a = _deep_list(a)
# Convert all tensor sub-elements to numpy arrays # Convert all tensor sub-elements to numpy arrays
a = _deep_tensor_to_nparray(a) a = _deep_tensor_to_nparray(a)
a = onp.asarray(a) a = onp.asarray(a)
if a.dtype is onp.dtype('object'):
raise ValueError('Input array must have the same size across all dimensions.')
# If dtype is not specified, we keep consistent with numpy decision
# only exceptions are: we use int/float32
if dtype is None:
dtype = mstype.pytype_to_dtype(a.dtype)
if dtype == mstype.float64:
dtype = mstype.float32
elif dtype == mstype.int64:
dtype = mstype.int32

if isinstance(a, onp.ndarray) and dtype is None:
if a.dtype is onp.dtype('object'): if a.dtype is onp.dtype('object'):
raise TypeError(f"For Tensor conversion, the input_data is {a} that contains unsupported element.") raise TypeError(f"For Tensor conversion, the input_data is {a} that contains unsupported element.")
dtype = mstype.pytype_to_dtype(a.dtype)
a = Tensor.from_numpy(a) a = Tensor.from_numpy(a)


# If a is already a tensor and we don't need to cast dtype, return a
if isinstance(a, Tensor):
if dtype is None or dtype == a.dtype:
return a

return Tensor(a, dtype=dtype)


asarray_const = constexpr(asarray)
return Tensor(a, dtype)




def asfarray(a, dtype=mstype.float32): def asfarray(a, dtype=mstype.float32):
@@ -167,9 +182,8 @@ def asfarray(a, dtype=mstype.float32):
If non-float dtype is defined, this function will return a float32 tensor instead. If non-float dtype is defined, this function will return a float32 tensor instead.


Args: Args:
a (Union[int, float, bool, list, tuple, numpy.ndarray]): Input data, in
any form that can be converted to a `Tensor`. This includes lists, lists of
tuples, tuples, tuples of tuples, tuples of lists and numpy.ndarray.
a (Union[int, float, bool, list, tuple, Tensor]): Input data, in any form that can
be converted to a `Tensor`. This includes Tensor, list, tuple and numbers.
dtype (Union[:class:`mindspore.dtype`, str], optional): Designated tensor dtype, can dtype (Union[:class:`mindspore.dtype`, str], optional): Designated tensor dtype, can
be in format of np.int32, or \'int32\'. If dtype is :class:`None`, the data type be in format of np.int32, or \'int32\'. If dtype is :class:`None`, the data type
of the new tensor will be inferred from `a`. Default is :class:`mindspore.float32`. of the new tensor will be inferred from `a`. Default is :class:`mindspore.float32`.
@@ -190,27 +204,18 @@ def asfarray(a, dtype=mstype.float32):
>>> print(np.asfarray([1,2,3])) >>> print(np.asfarray([1,2,3]))
[1. 2. 3.] [1. 2. 3.]
""" """
_check_input_for_asarray(a)

if dtype is None: if dtype is None:
return asarray(a) return asarray(a)


dtype = _check_dtype(dtype) dtype = _check_dtype(dtype)
if dtype not in (mstype.float16, mstype.float32, mstype.float64):
# pylint: disable=consider-using-in
if dtype != mstype.float16 and dtype != mstype.float32 and dtype != mstype.float64:
dtype = mstype.float32 dtype = mstype.float32


if isinstance(a, (list, tuple)):
# Convert all tuple/nested tuples to lists
a = _deep_list(a)
# Convert all tensor sub-elements to numpy arrays
a = _deep_tensor_to_nparray(a)
a = onp.asarray(a)
if a.dtype is onp.dtype('object'):
raise TypeError(f"For Tensor conversion, the input_data is {a} that contains unsupported element.")
if isinstance(a, onp.ndarray):
a = Tensor.from_numpy(a)
if isinstance(a, Tensor):
return a.astype(dtype)


return Tensor(a, dtype)
return asfarray_const(a)




def copy_(a): def copy_(a):
@@ -218,9 +223,8 @@ def copy_(a):
Returns a tensor copy of the given object. Returns a tensor copy of the given object.


Args: Args:
a (Union[int, float, bool, list, tuple, numpy.ndarray]): Input data, in
any form that can be converted to a tensor. This includes lists, lists of
tuples, tuples, tuples of tuples, tuples of lists and numpy.ndarray.
a (Union[int, float, bool, list, tuple, Tensor]): Input data, in any form that can
be converted to a `Tensor`. This includes Tensor, list, tuple and numbers.


Returns: Returns:
Tensor, has the same data as `a`. Tensor, has the same data as `a`.
@@ -241,8 +245,16 @@ def copy_(a):
""" """
if not isinstance(a, Tensor): if not isinstance(a, Tensor):
a = asarray_const(a) a = asarray_const(a)
return a.copy()

# The current implementation registers a new memory location for copied tensor by
# doing some reduandent operations.
origin_dtype = a.dtype
if origin_dtype == mstype.bool_:
return F.logical_not(F.logical_not(a))
if origin_dtype != mstype.float64:
a = a.astype("float32")
a = a / ones_like(a)
a = a.astype(origin_dtype)
return a


def ones(shape, dtype=mstype.float32): def ones(shape, dtype=mstype.float32):
""" """
@@ -566,6 +578,65 @@ def logspace(start, stop, num=50, endpoint=True, base=10.0, dtype=None, axis=0):
return F.tensor_pow(base, linspace_res).astype(dtype) return F.tensor_pow(base, linspace_res).astype(dtype)




def geomspace(start, stop, num=50, endpoint=True, dtype=None, axis=0):
"""
Returns numbers spaced evenly on a log scale (a geometric progression).

This is similar to logspace, but with endpoints specified directly. Each output sample
is a constant multiple of the previous.

Args:
start (Union[int, list(int), tuple(int), tensor]): The starting value of the sequence.
stop (Union[int, list(int), tuple(int), tensor]): The final value of the sequence,
unless endpoint is False. In that case, num + 1 values are spaced over the
interval in log-space, of which all but the last (a sequence of length num) are
returned.
num (int, optional): Number of samples to generate. Default is 50.
endpoint (bool, optional): If True, `stop` is the last sample. Otherwise, it is
not included. Default is True.
dtype (Union[:class:`mindspore.dtype`, str], optional): Designated tensor dtype, can
be in format of np.float32, or `float32`.If `dtype` is None, infer the data
type from other input arguments. Default is None.
axis (int, optional): The axis in the result to store the samples. Relevant
only if start or stop is array-like. By default (0), the samples will
be along a new axis inserted at the beginning. Use -1 to get an axis at the end.
Default is 0.

Returns:
Tensor, with samples equally spaced on a log scale.

Raises:
TypeError: If input arguments have types not specified above.

Supported Platforms:
``Ascend`` ``GPU`` ``CPU``

Examples:
>>> output = np.geomspace(1, 256, num=9)
>>> print(output)
[ 1. 2. 4. 8. 16. 32. 64. 128. 256.]
>>> output = np.geomspace(1, 256, num=8, endpoint=False)
>>> print(output)
[ 1. 2. 4. 8. 16. 32. 64. 128.]
"""
start, stop, num, endpoint, dtype, axis = _type_checking_for_xspace(start, stop, num, endpoint, dtype, axis)
root = num
if endpoint:
root -= 1
bases = F.tensor_pow(F.tensor_div(stop, start), asarray_const(1/(root)))
exponents = linspace(zeros(F.shape(bases)), F.fill(F.dtype(bases), F.shape(bases), root),
num, endpoint=endpoint, dtype=dtype, axis=axis)
shape = F.shape(bases)
axis = axis + F.rank(bases) + 1 if axis < 0 else axis
expanded_shape = _tuple_getitem(shape, axis, False) + (1,) + _tuple_getitem(shape, axis)
bases = F.reshape(bases, expanded_shape)
start = F.reshape(start, expanded_shape)
res = F.tensor_mul(F.tensor_pow(bases, exponents), start)
if dtype is not None:
res = F.cast(res, dtype)
return res


def eye(N, M=None, k=0, dtype=mstype.float32): def eye(N, M=None, k=0, dtype=mstype.float32):
""" """
Returns a 2-D tensor with ones on the diagnoal and zeros elsewhere. Returns a 2-D tensor with ones on the diagnoal and zeros elsewhere.
@@ -757,7 +828,7 @@ def empty_like(prototype, dtype=None, shape=None):


Examples: Examples:
>>> import mindspore.numpy as np >>> import mindspore.numpy as np
>>> a = [[(1, 2)], np.ones((1, 2)), [[2, 3]], np.ones((1, 2))]
>>> a = np.ones((4,1,2))
>>> output = np.empty_like(a) >>> output = np.empty_like(a)
>>> print(output) >>> print(output)
# result may vary # result may vary
@@ -794,7 +865,7 @@ def ones_like(a, dtype=None, shape=None):


Examples: Examples:
>>> import mindspore.numpy as np >>> import mindspore.numpy as np
>>> a = [[(1, 2)], np.ones((1, 2)), [[2, 3]], np.ones((1, 2))]
>>> a = np.ones((4,1,2))
>>> output = np.ones_like(a) >>> output = np.ones_like(a)
>>> print(output) >>> print(output)
[[[1. 1.]] [[[1. 1.]]
@@ -832,7 +903,7 @@ def zeros_like(a, dtype=None, shape=None):


Examples: Examples:
>>> import mindspore.numpy as np >>> import mindspore.numpy as np
>>> a = [[(1, 2)], np.ones((1, 2)), [[2, 3]], np.ones((1, 2))]
>>> a = np.ones((4,1,2))
>>> output = np.zeros_like(a) >>> output = np.zeros_like(a)
>>> print(output) >>> print(output)
[[[0. 0.]] [[[0. 0.]]
@@ -871,7 +942,7 @@ def full_like(a, fill_value, dtype=None, shape=None):


Examples: Examples:
>>> import mindspore.numpy as np >>> import mindspore.numpy as np
>>> a = [[(1, 2)], np.ones((1, 2)), [[2, 3]], np.ones((1, 2))]
>>> a = np.ones((4,1,2))
>>> output = np.full_like(a, 0.5) >>> output = np.full_like(a, 0.5)
>>> print(output) >>> print(output)
[[[0.5 0.5]] [[[0.5 0.5]]
@@ -1175,9 +1246,8 @@ def _index(i, size, Cartesian=True):
if Cartesian: if Cartesian:
if i == 1: if i == 1:
return 0 return 0
if i == 0:
if size >= 2:
return 1
if i == 0 and size >= 2:
return 1
return i return i




@@ -1630,3 +1700,103 @@ def ix_(*args):
return _raise_value_error('Cross index must be 1 dimensional') return _raise_value_error('Cross index must be 1 dimensional')
res += (F.reshape(arr, _expanded_shape(ndim, arr.size, i)),) res += (F.reshape(arr, _expanded_shape(ndim, arr.size, i)),)
return res return res


def vander(x, N=None, increasing=False):
"""
Generates a Vandermonde matrix.

The columns of the output matrix are powers of the input vector. The order of
the powers is determined by the increasing boolean argument. Specifically, when
increasing is `False`, the i-th output column is the input vector raised element-wise
to the power of :math:`N - i - 1`. Such a matrix with a geometric progression in each row
is named for Alexandre-Theophile Vandermonde.

Args:
x (Union[list, tuple, Tensor]): 1-D input array.
N (int, optional): Number of columns in the output. If N is not specified, a
square array is returned (``N = len(x)``).
increasing (bool, optional): Order of the powers of the columns. If True, the
powers increase from left to right, if False (the default) they are reversed.

Returns:
Vandermonde matrix. If `increasing` is `False`, the first column is :math:`x^{(N-1)}`,
the second :math:`x^{(N-2)}` and so forth. If `increasing` is `True`, the columns are
:math:`x^0, x^1, ..., x^{(N-1)}`.

Raises:
TypeError: If inputs have types not specified above.
ValueError: If `x` is not 1-D, or `N` < 0.

Supported Platforms:
``Ascend`` ``GPU`` ``CPU``

Examples:
>>> import mindspore.numpy as np
>>> print(np.vander([1,2,3,4,5]))
[[ 1 1 1 1 1]
[ 16 8 4 2 1]
[ 81 27 9 3 1]
[256 64 16 4 1]
[625 125 25 5 1]]
"""
if isinstance(x, (list, tuple)):
x = asarray_const(x)
elif not isinstance(x, Tensor):
_raise_type_error("Input x must be list, tuple or Tensor, but got ", x)
if x.ndim != 1:
_raise_value_error("Input x must be 1-D, but got dimension=", x.ndim)
N = N or x.size
if not isinstance(N, int):
_raise_type_error("Input N must be an integer.")
if N <= 0:
_raise_value_error("Input N must > 0.")
if not isinstance(increasing, bool):
_raise_type_error("increasing must be a bool.")
exponent = _iota(x.dtype, N, increasing)
x = F.expand_dims(x, 1)
exponent = F.expand_dims(exponent, 0)
return F.tensor_pow(x, exponent)


def indices(dimensions, dtype=mstype.int32, sparse=False):
"""
Returns an array representing the indices of a grid.

Computes an array where the subarrays contain index values 0, 1, …
varying only along the corresponding axis.

Args:
dimensions (tuple or list of ints): The shape of the grid.
dtype (data type, optional): Data type of the result.
sparse (boolean, optional): Defaults to False. Return a sparse
representation of the grid instead of a dense representation.

Returns:
Tensor or tuple of Tensor, If `sparse` is False, returns one array
of grid indices, ``grid.shape = (len(dimensions),) + tuple(dimensions)``.
If sparse is True, returns a tuple of arrays, with
``grid[i].shape = (1, ..., 1, dimensions[i], 1, ..., 1)`` with
``dimensions[i]`` in the `ith` place

Raises:
TypeError: if input dimensions is not a tuple or list.

Supported Platforms:
``Ascend`` ``GPU`` ``CPU``

Examples:
>>> grid = np.indices((2, 3))
>>> print(indices)
[Tensor(shape=[2, 3], dtype=Int32, value=
[[0, 0, 0],
[1, 1, 1]]), Tensor(shape=[2, 3], dtype=Int32, value=
[[0, 1, 2],
[0, 1, 2]])]
"""
if not isinstance(dimensions, (tuple, list)):
_raise_type_error('Shape of the grid must be tuple or list')
grids = ()
for d in dimensions:
grids += (arange(d, dtype=dtype),)
return meshgrid(*grids, sparse=sparse, indexing='ij')

+ 234
- 59
mindspore/numpy/array_ops.py View File

@@ -24,62 +24,19 @@ from ..ops.primitive import constexpr
from ..nn import Cell from ..nn import Cell


from .utils import _convert_list_tensor_to_tuple_tensor, _expand, _broadcast_to_shape, \ from .utils import _convert_list_tensor_to_tuple_tensor, _expand, _broadcast_to_shape, \
_check_input_tensor, _broadcast_to
_check_input_tensor, _broadcast_to, _to_tensor
from .utils_const import _check_axes_range, _check_start_normalize, \ from .utils_const import _check_axes_range, _check_start_normalize, \
_raise_type_error, _raise_value_error, _infer_out_shape, _empty, _promote, \ _raise_type_error, _raise_value_error, _infer_out_shape, _empty, _promote, \
_check_same_type, _check_axis_valid, _add_unit_axes, _broadcast_tuples, \ _check_same_type, _check_axis_valid, _add_unit_axes, _broadcast_tuples, \
_check_is_float, _check_axis_in_range, _check_axis_type, _canonicalize_axis, \ _check_is_float, _check_axis_in_range, _check_axis_type, _canonicalize_axis, \
_list_comprehensions, _check_element_int, _is_shape_empty, _type_convert, \ _list_comprehensions, _check_element_int, _is_shape_empty, _type_convert, \
_tuple_getitem, _expanded_shape
_tuple_getitem, _expanded_shape, _seq_prod, _get_device, _tuple_setitem


# According to official numpy reference, the dimension of a numpy array must be less # According to official numpy reference, the dimension of a numpy array must be less
# than 32 # than 32
MAX_NUMPY_DIMS = 32 MAX_NUMPY_DIMS = 32




@constexpr
def _prepare_shape_for_expand_dims(shape, axes):
"""
Creates the expanded new shape based on the shape and given axes

Args:
shape (tuple): the shape of the tensor
axes Union(int, tuple(int), list(int)): the axes with dimensions expanded.

Returns:
new_shape(tuple): the shape with dimensions expanded.
"""

new_shape = []
shape_idx = 0
new_shape_length = len(shape)

# Convert to set
if isinstance(axes, int):
new_shape_length += 1
if axes >= new_shape_length or axes < -new_shape_length:
raise ValueError(f"axis {axes} is out of bounds for tensor of dimension {new_shape_length}")
axes = {axes}

elif isinstance(axes, (list, tuple)):
new_shape_length += len(axes)
for axis in axes:
if axis >= new_shape_length or axis < -new_shape_length:
raise ValueError(f"axis {axis} is out of bounds for tensor of dimension {new_shape_length}")
axes = set(axes)

else:
raise TypeError(f"only int, tuple and list are allowed for axes, but got {type(axes)}")

for new_shape_idx in range(new_shape_length):
if new_shape_idx in axes or new_shape_idx - new_shape_length in axes:
new_shape.append(1)
else:
new_shape.append(shape[shape_idx])
shape_idx += 1
return tuple(new_shape)


def expand_dims(a, axis): def expand_dims(a, axis):
""" """
Expands the shape of a tensor. Expands the shape of a tensor.
@@ -109,10 +66,15 @@ def expand_dims(a, axis):
(1, 2, 2) (1, 2, 2)
""" """
_check_input_tensor(a) _check_input_tensor(a)
shape = F.shape(a)
# yield expanded shape based on the axes
new_shape = _prepare_shape_for_expand_dims(shape, axis)
return F.reshape(a, new_shape)
if not isinstance(axis, (int, tuple, list)):
_raise_type_error("axis must be tuple, list or int, but got ", axis)
if isinstance(axis, int):
return F.expand_dims(a, axis)
ndim = a.ndim + len(axis)
axis = _canonicalize_axis(axis, ndim)
for ax in axis:
a = F.expand_dims(a, ax)
return a




def squeeze(a, axis=None): def squeeze(a, axis=None):
@@ -1091,6 +1053,9 @@ def roll(a, shift, axis=None):
Returns: Returns:
Tensor, with the same shape as a. Tensor, with the same shape as a.


Supported Platforms:
``Ascend`` ``GPU`` ``CPU``

Raises: Raises:
TypeError: If input arguments have types not specified above. TypeError: If input arguments have types not specified above.
ValueError: If axis exceeds `a.ndim`, or `shift` and `axis` cannot broadcast. ValueError: If axis exceeds `a.ndim`, or `shift` and `axis` cannot broadcast.
@@ -1212,12 +1177,6 @@ def moveaxis(a, source, destination):
return F.transpose(a, perm) return F.transpose(a, perm)




@constexpr
def _seq_prod(seq1, seq2):
"""Returns the element-wise product of seq1 and seq2."""
return tuple(map(lambda x, y: x*y, seq1, seq2))


def tile(a, reps): def tile(a, reps):
""" """
Constructs an array by repeating `a` the number of times given by `reps`. Constructs an array by repeating `a` the number of times given by `reps`.
@@ -1355,6 +1314,60 @@ def broadcast_arrays(*args):
return res return res




def array_split(x, indices_or_sections, axis=0):
"""
Splits a tensor into multiple sub-tensors.

Note:
Currently, array_split only supports :class:`mindspore.float32` on ``CPU``.

The only difference between ``np.split`` and ``np.array_split`` is that
``np.array_split`` allows indices_or_sections to be an integer that does not
equally divide the axis. For a tensor of length l that should be split into
n sections, it returns :math:`l % n` sub-arrays of size :math:`l//n + 1` and
the rest of size :math:`l//n`.

Args:
x (Tensor): A Tensor to be divided.
indices_or_sections (Union[int, tuple(int), list(int)]):
If integer, :math:`N`, the tensor will be divided into
:math:`N` tensors along axis.
If tuple(int), list(int) or of sorted integers,
the entries indicate where along axis the array is split.
For example, :math:`[2, 3]` would, for :math:`axis=0`, result in
three sub-tensors :math:`x[:2]`, :math:`x[2:3]`and :math:`x[3:]`.
If an index exceeds the dimension of the array along axis,
an empty sub-array is returned correspondingly.
axis (int): The axis along which to split. Default: 0.

Returns:
A list of sub-tensors.

Raises:
TypeError: If argument `indices_or_sections` is not integer,
tuple(int) or list(int) or argument `axis` is not integer.
ValueError: If argument `axis` is out of range of :math:`[-x.ndim, x.ndim)`.

Supported Platforms:
``Ascend`` ``GPU`` ``CPU``

Examples:
>>> import mindspore.numpy as np
>>> input_x = np.arange(9).astype("float32")
>>> output = np.array_split(input_x, 4)
>>> print(output)
(Tensor(shape=[3], dtype=Float32,
value= [ 0.00000000e+00, 1.00000000e+00, 2.00000000e+00]),
Tensor(shape=[2], dtype=Float32,
value= [ 3.00000000e+00, 4.00000000e+00]),
Tensor(shape=[2], dtype=Float32,
value= [ 5.00000000e+00, 6.00000000e+00]),
Tensor(shape=[2], dtype=Float32,
value= [ 7.00000000e+00, 8.00000000e+00]))
"""
return _split(x, indices_or_sections, opname="array_split", axis=axis)


def split(x, indices_or_sections, axis=0): def split(x, indices_or_sections, axis=0):
""" """
Splits a tensor into multiple sub-tensors along the given axis. Splits a tensor into multiple sub-tensors along the given axis.
@@ -1380,9 +1393,12 @@ def split(x, indices_or_sections, axis=0):
tuple(int) or list(int) or argument `axis` is not integer. tuple(int) or list(int) or argument `axis` is not integer.
ValueError: If argument `axis` is out of range of :math:`[-x.ndim, x.ndim)`. ValueError: If argument `axis` is out of range of :math:`[-x.ndim, x.ndim)`.


Supported Platforms:
``Ascend`` ``GPU`` ``CPU``

Examples: Examples:
>>> import mindspore.numpy as np >>> import mindspore.numpy as np
>>> input_x = np.arange(9).astype('float32')
>>> input_x = np.arange(9).astype("float32")
>>> output = np.split(input_x, 3) >>> output = np.split(input_x, 3)
>>> print(output) >>> print(output)
(Tensor(shape=[3], dtype=Float32, (Tensor(shape=[3], dtype=Float32,
@@ -1392,13 +1408,32 @@ def split(x, indices_or_sections, axis=0):
Tensor(shape=[3], dtype=Float32, Tensor(shape=[3], dtype=Float32,
value= [ 6.00000000e+00, 7.00000000e+00, 8.00000000e+00])) value= [ 6.00000000e+00, 7.00000000e+00, 8.00000000e+00]))
""" """
return _split(x, indices_or_sections, opname="split", axis=axis)


def _split(x, indices_or_sections, opname, axis=0):
"""Splits a tensor based on ``np.split`` or ``np.array_split``."""
_check_input_tensor(x) _check_input_tensor(x)
_ = _check_axis_type(axis, True, False, False) _ = _check_axis_type(axis, True, False, False)
axis = _canonicalize_axis(axis, x.ndim) axis = _canonicalize_axis(axis, x.ndim)
res = None res = None
arr_shape = x.shape
length_along_dim = arr_shape[axis]
if isinstance(indices_or_sections, int): if isinstance(indices_or_sections, int):
_split = P.Split(axis, indices_or_sections)
res = _split(x)
if opname == "split" or length_along_dim % indices_or_sections == 0:
res = P.Split(axis, indices_or_sections)(x)
else:
num_long_tensor = length_along_dim % indices_or_sections
num_short_tensor = indices_or_sections - num_long_tensor
length1 = num_long_tensor * (length_along_dim // indices_or_sections + 1)
length2 = length_along_dim - length1
start1 = _list_comprehensions(F.rank(x), 0, True)
size1 = _tuple_setitem(arr_shape, axis, length1)
start2 = _tuple_setitem(start1, axis, length1)
size2 = _tuple_setitem(arr_shape, axis, length2)
res = P.Split(axis, num_long_tensor)(F.tensor_slice(x, start1, size1)) + \
P.Split(axis, num_short_tensor)(F.tensor_slice(x, start2, size2))

elif isinstance(indices_or_sections, (list, tuple)) and _check_element_int(indices_or_sections): elif isinstance(indices_or_sections, (list, tuple)) and _check_element_int(indices_or_sections):
res = _split_sub_tensors(x, indices_or_sections, axis) res = _split_sub_tensors(x, indices_or_sections, axis)
else: else:
@@ -1921,7 +1956,6 @@ def repeat(a, repeats, axis=None):
if repeats == 0: if repeats == 0:
return _empty(F.dtype(a), (0,)) return _empty(F.dtype(a), (0,))
return C.repeat_elements(a, repeats, axis) return C.repeat_elements(a, repeats, axis)

shape = F.shape(a) shape = F.shape(a)
size = shape[axis] size = shape[axis]
if len(repeats) != size: if len(repeats) != size:
@@ -1932,3 +1966,144 @@ def repeat(a, repeats, axis=None):
if rep != 0: if rep != 0:
repeated_subs.append(C.repeat_elements(sub, rep, axis)) repeated_subs.append(C.repeat_elements(sub, rep, axis))
return concatenate(repeated_subs, axis) return concatenate(repeated_subs, axis)


def rot90(a, k=1, axes=(0, 1)):
"""
Rotates a tensor by 90 degrees in the plane specified by axes.
Rotation direction is from the first towards the second axis.

Args:
a (Tensor): Input tensor of two or more dimensions.
k (int): Number of times the tensor is rotated by 90 degrees. Default: 1.
axes (Union[tuple(int), list(int)]): The tensor is rotated in the plane
defined by the axes. Default: `(0, 1)`.
Axes must be different and with the shape of `(2,)`.

Returns:
Tensor.

Raises:
TypeError: if input `a` is not a Tensor or
the argument `k` is not integer or
the argument `axes` is not tuple of ints or list of ints.
ValueError: if any axis is out of range or
the length of `axes` is not `2`.

Supported Platforms:
``Ascend`` ``GPU`` ``CPU``

Examples:
>>> import mindspore.numpy as np
>>> a = np.arange(24).reshape((2, 3, 4))
>>> output = np.rot90(a)
>>> print(output)
[[[ 8 9 10 11]
[20 21 22 23]]
[[ 4 5 6 7]
[16 17 18 19]]
[[ 0 1 2 3]
[12 13 14 15]]]
>>> output = np.rot90(a, 3, (1, 2))
>>> print(output)
[[[ 8 4 0]
[ 9 5 1]
[10 6 2]
[11 7 3]]
[[20 16 12]
[21 17 13]
[22 18 14]
[23 19 15]]]
"""
_check_input_tensor(a)

if not isinstance(k, int):
_raise_type_error("integer argument expected, but got ", k)
k = k % 4 if k >= 0 else 4 - (-k % 4)

if not isinstance(axes, (tuple, list)):
_raise_type_error("tuple(ints) or list(ints) expected, but got ", axes)
if len(axes) != 2:
_raise_value_error("len(axes) must be 2.")
axis1, axis2 = axes[0], axes[1]
axis1 = _canonicalize_axis(axis1, a.ndim)
axis2 = _canonicalize_axis(axis2, a.ndim)
if axis1 == axis2:
_raise_value_error('Axes must be different.')

if k == 0:
return a
if k == 2:
return flip(flip(a, axis1), axis2)
perm = _list_comprehensions(a.ndim)
perm[axis1], perm[axis2] = perm[axis2], perm[axis1]
if k == 1:
return flip(transpose(a, perm), axis1)
return flip(transpose(a, perm), axis2)


def select(condlist, choicelist, default=0):
"""
Returns an array drawn from elements in `choicelist`, depending on conditions.

Args:
condlist (array_like): The list of conditions which determine from which array
in `choicelist` the output elements are taken. When multiple conditions are
satisfied, the first one encountered in `condlist` is used.
choicelist (array_like): The list of arrays from which the output elements are
taken. It has to be of the same length as `condlist`.
default (scalar, optional): The element inserted in output when all conditions
evaluate to `False`.

Returns:
Tensor, the output at position `m` is the `m-th` element of the array in
`choicelist` where the `m-th` element of the corresponding array in `condlist`
is `True`.


Supported Platforms:
``Ascend`` ``GPU`` ``CPU``

Raises:
ValueError: if ``len(condlist) != len(choicelist)``.

Examples:
>>> condlist = [[True, True, True, False, False],
[False, False, True, False, True]]
>>> choicelist = [[0, 1, 2, 3, 4], [0, 1, 4, 9, 16]]
>>> output = np.select(condlist, choicelist)
>>> print(output)
[ 0 1 2 0 16]
"""
condlist, choicelist = _to_tensor(condlist, choicelist)
shape_cond = F.shape(condlist)
shape_choice = F.shape(choicelist)
if F.rank(condlist) == 0 or F.rank(condlist) == 0:
_raise_value_error('input cannot be scalars')
case_num = shape_cond[0]
if shape_choice[0] != case_num:
_raise_value_error('list of cases must be same length as list of conditions')

# performs broadcast over the cases in condlist and choicelist
case_size = _infer_out_shape(shape_cond[1:], shape_choice[1:])
shape_broadcasted = (case_num,) + case_size
ndim = len(shape_broadcasted)
shape_cond_expanded = ((case_num,) + _list_comprehensions(ndim - F.rank(condlist), 1, True) +
shape_cond[1:])
condlist = _broadcast_to_shape(F.reshape(condlist, shape_cond_expanded), shape_broadcasted)
shape_choice_expanded = ((case_num,) + _list_comprehensions(ndim - F.rank(choicelist), 1, True) +
shape_choice[1:])
choicelist = _broadcast_to_shape(F.reshape(choicelist, shape_choice_expanded), shape_broadcasted)

slice_start = _list_comprehensions(ndim - 1, 0, True)
slice_size = (1,) + case_size
dtype = F.dtype(choicelist)
if _get_device() == 'CPU' and not _check_is_float(dtype):
# F.tensor_slice only supports float on CPU
choicelist = F.cast(choicelist, mstype.float32)
default_slice = F.fill(F.dtype(choicelist), slice_size, default)
for i in range(case_num - 1, -1, -1):
cond_slice = F.tensor_slice(condlist.astype(mstype.float32), (i,) + slice_start, slice_size)
choice_slice = F.tensor_slice(choicelist, (i,) + slice_start, slice_size)
default_slice = F.select(cond_slice.astype(mstype.bool_), choice_slice, default_slice)
return F.reshape(default_slice, (case_size)).astype(dtype)

+ 13
- 0
mindspore/numpy/dtypes.py View File

@@ -169,3 +169,16 @@ promotion_rule = {
(bool_, float32): float32, (bool_, float32): float32,
(bool_, float64): float64, (bool_, float64): float64,
} }

rule_for_trigonometric = {float16: float16,
float32: float32,
float64: float64,
int8: float16,
int16: float32,
int32: float32,
int64: float32,
uint8: float16,
uint16: float32,
uint32: float32,
uint64: float32,
bool_: float16}

+ 342
- 148
mindspore/numpy/logic_ops.py View File

@@ -15,33 +15,29 @@
"""logical operations, the function docs are adapted from Numpy API.""" """logical operations, the function docs are adapted from Numpy API."""




from .math_ops import _apply_tensor_op
from ..ops import functional as F from ..ops import functional as F
from ..ops.primitive import constexpr from ..ops.primitive import constexpr
from ..common import dtype as mstype from ..common import dtype as mstype
from ..common import Tensor from ..common import Tensor
from .._c_expression import typing from .._c_expression import typing


from .array_creations import zeros, ones
from .utils import _check_input_tensor
from .math_ops import _apply_tensor_op, absolute
from .array_creations import zeros, ones, empty
from .utils import _check_input_tensor, _to_tensor, _isnan
from .utils_const import _raise_type_error, _is_shape_empty, _infer_out_shape




def not_equal(x1, x2, out=None, where=True, dtype=None):
def not_equal(x1, x2, dtype=None):
""" """
Returns (x1 != x2) element-wise. Returns (x1 != x2) element-wise.


Note:
Numpy arguments `out`, `where`, `casting`, `order`, `subok`, `signature`,
and `extobj` are not supported.

Args: Args:
x1 (Tensor): First input tensor to be compared. x1 (Tensor): First input tensor to be compared.
x2 (Tensor): Second input tensor to be compared. x2 (Tensor): Second input tensor to be compared.
out (Tensor or None, optional), default is None.
where (Tensor or None, optional): For any non-default value of type other
than :class:`Tensor` or :class:`None`, the output retains its original value.
This condition is broadcasted over the input. At locations where the
condition is `True`, the out array will be set to the ufunc result.
Elsewhere, the out array will retain its original value. Note that
if an uninitialized out array is created via the default ``out=None``,
locations within it where the condition is `False` will remain
uninitialized.
dtype (:class:`mindspore.dtype`, optional): defaults to None. Overrides the dtype of the dtype (:class:`mindspore.dtype`, optional): defaults to None. Overrides the dtype of the
output Tensor. output Tensor.


@@ -65,33 +61,21 @@ def not_equal(x1, x2, out=None, where=True, dtype=None):
[False True]] [False True]]
""" """
_check_input_tensor(x1, x2) _check_input_tensor(x1, x2)
return _apply_tensor_op(F.not_equal, x1, x2, out=out, where=where, dtype=dtype)
return _apply_tensor_op(F.not_equal, x1, x2, dtype=dtype)




def less_equal(x1, x2, out=None, where=True, dtype=None):
def less_equal(x1, x2, dtype=None):
""" """
Returns the truth value of ``(x1 <= x2)`` element-wise. Returns the truth value of ``(x1 <= x2)`` element-wise.


Note: Note:
Numpy arguments `casting`, `order`, `dtype`, `subok`, `signature`, and `extobj` are
not supported.
When `where` is provided, `out` must have a tensor value. `out` is not supported
for storing the result, however it can be used in combination with `where` to set
the value at indices for which `where` is set to False.
Numpy arguments `out`, `where`, `casting`, `order`, `subok`, `signature`,
and `extobj` are not supported.


Args: Args:
x1 (Tensor): Input array. x1 (Tensor): Input array.
x2 (Tensor): Input array. If ``x1.shape != x2.shape``, they must be x2 (Tensor): Input array. If ``x1.shape != x2.shape``, they must be
broadcastable to a common shape (which becomes the shape of the output). broadcastable to a common shape (which becomes the shape of the output).
out (Tensor or None, optional): defaults to None.
where (Tensor or None, optional): For any non-default value of type other
than :class:`Tensor` or :class:`None`, the output retains its original value.
This condition is broadcasted over the input. At locations where the
condition is `True`, the out array will be set to the ufunc result.
Elsewhere, the out array will retain its original value. Note that
if an uninitialized out array is created via the default ``out=None``,
locations within it where the condition is `False` will remain
uninitialized.
dtype (:class:`mindspore.dtype`, optional): defaults to None. Overrides the dtype of the dtype (:class:`mindspore.dtype`, optional): defaults to None. Overrides the dtype of the
output Tensor. output Tensor.


@@ -113,33 +97,21 @@ def less_equal(x1, x2, out=None, where=True, dtype=None):
[False True True] [False True True]
""" """
_check_input_tensor(x1, x2) _check_input_tensor(x1, x2)
return _apply_tensor_op(F.tensor_le, x1, x2, out=out, where=where, dtype=dtype)
return _apply_tensor_op(F.tensor_le, x1, x2, dtype=dtype)




def less(x1, x2, out=None, where=True, dtype=None):
def less(x1, x2, dtype=None):
""" """
Returns the truth value of ``(x1 < x2)`` element-wise. Returns the truth value of ``(x1 < x2)`` element-wise.


Note: Note:
Numpy arguments `casting`, `order`, `dtype`, `subok`, `signature`, and `extobj` are
not supported.
When `where` is provided, `out` must have a tensor value. `out` is not supported
for storing the result, however it can be used in combination with `where` to set
the value at indices for which `where` is set to False.
Numpy arguments `out`, `where`, `casting`, `order`, `subok`, `signature`,
and `extobj` are not supported.


Args: Args:
x1 (Tensor): input array. x1 (Tensor): input array.
x2 (Tensor): Input array. If ``x1.shape != x2.shape``, they must be x2 (Tensor): Input array. If ``x1.shape != x2.shape``, they must be
broadcastable to a common shape (which becomes the shape of the output). broadcastable to a common shape (which becomes the shape of the output).
out (Tensor or None, optional): defaults to None.
where (Tensor or None, optional): For any non-default value of type other
than :class:`Tensor` or :class:`None`, the output retains its original value.
This condition is broadcasted over the input. At locations where the
condition is `True`, the out array will be set to the ufunc result.
Elsewhere, the out array will retain its original value. Note that
if an uninitialized out array is created via the default ``out=None``,
locations within it where the condition is `False` will remain
uninitialized.
dtype (:class:`mindspore.dtype`, optional): defaults to None. Overrides the dtype of the dtype (:class:`mindspore.dtype`, optional): defaults to None. Overrides the dtype of the
output Tensor. output Tensor.


@@ -160,33 +132,21 @@ def less(x1, x2, out=None, where=True, dtype=None):
>>> print(output) >>> print(output)
[ True False] [ True False]
""" """
return _apply_tensor_op(F.tensor_lt, x1, x2, out=out, where=where, dtype=dtype)
return _apply_tensor_op(F.tensor_lt, x1, x2, dtype=dtype)




def greater_equal(x1, x2, out=None, where=True, dtype=None):
def greater_equal(x1, x2, dtype=None):
""" """
Returns the truth value of ``(x1 >= x2)`` element-wise. Returns the truth value of ``(x1 >= x2)`` element-wise.


Note: Note:
Numpy arguments `casting`, `order`, `dtype`, `subok`, `signature`, and `extobj` are
not supported.
When `where` is provided, `out` must have a tensor value. `out` is not supported
for storing the result, however it can be used in combination with `where` to set
the value at indices for which `where` is set to False.
Numpy arguments `out`, `where`, `casting`, `order`, `subok`, `signature`,
and `extobj` are not supported.


Args: Args:
x1 (Tensor): Input array. x1 (Tensor): Input array.
x2 (Tensor): Input array. If ``x1.shape != x2.shape``, they must be x2 (Tensor): Input array. If ``x1.shape != x2.shape``, they must be
broadcastable to a common shape (which becomes the shape of the output). broadcastable to a common shape (which becomes the shape of the output).
out (Tensor or None, optional): defaults to None.
where (Tensor or None, optional): For any non-default value of type other
than :class:`Tensor` or :class:`None`, the output retains its original value.
This condition is broadcasted over the input. At locations where the
condition is `True`, the out array will be set to the ufunc result.
Elsewhere, the out array will retain its original value. Note that
if an uninitialized out array is created via the default ``out=None``,
locations within it where the condition is `False` will remain
uninitialized.
dtype (:class:`mindspore.dtype`, optional): defaults to None. Overrides the dtype of the dtype (:class:`mindspore.dtype`, optional): defaults to None. Overrides the dtype of the
output Tensor. output Tensor.


@@ -207,33 +167,21 @@ def greater_equal(x1, x2, out=None, where=True, dtype=None):
>>> print(output) >>> print(output)
[ True True False] [ True True False]
""" """
return _apply_tensor_op(F.tensor_ge, x1, x2, out=out, where=where, dtype=dtype)
return _apply_tensor_op(F.tensor_ge, x1, x2, dtype=dtype)




def greater(x1, x2, out=None, where=True, dtype=None):
def greater(x1, x2, dtype=None):
""" """
Returns the truth value of ``(x1 > x2)`` element-wise. Returns the truth value of ``(x1 > x2)`` element-wise.


Note: Note:
Numpy arguments `casting`, `order`, `dtype`, `subok`, `signature`, and `extobj` are
not supported.
When `where` is provided, `out` must have a tensor value. `out` is not supported
for storing the result, however it can be used in combination with `where` to set
the value at indices for which `where` is set to False.
Numpy arguments `out`, `where`, `casting`, `order`, `subok`, `signature`,
and `extobj` are not supported.


Args: Args:
x1 (Tensor): Input array. x1 (Tensor): Input array.
x2 (Tensor): Input array. If ``x1.shape != x2.shape``, they must be x2 (Tensor): Input array. If ``x1.shape != x2.shape``, they must be
broadcastable to a common shape (which becomes the shape of the output). broadcastable to a common shape (which becomes the shape of the output).
out (Tensor or None, optional): defaults to None.
where (Tensor or None, optional): For any non-default value of type other
than :class:`Tensor` or :class:`None`, the output retains its original value.
This condition is broadcasted over the input. At locations where the
condition is `True`, the out array will be set to the ufunc result.
Elsewhere, the out array will retain its original value. Note that
if an uninitialized out array is created via the default ``out=None``,
locations within it where the condition is `False` will remain
uninitialized.
dtype (:class:`mindspore.dtype`, optional): defaults to None. Overrides the dtype of the dtype (:class:`mindspore.dtype`, optional): defaults to None. Overrides the dtype of the
output Tensor. output Tensor.


@@ -254,33 +202,21 @@ def greater(x1, x2, out=None, where=True, dtype=None):
>>> print(output) >>> print(output)
[ True False] [ True False]
""" """
return _apply_tensor_op(F.tensor_gt, x1, x2, out=out, where=where, dtype=dtype)
return _apply_tensor_op(F.tensor_gt, x1, x2, dtype=dtype)




def equal(x1, x2, out=None, where=True, dtype=None):
def equal(x1, x2, dtype=None):
""" """
Returns the truth value of ``(x1 == x2)`` element-wise. Returns the truth value of ``(x1 == x2)`` element-wise.


Note: Note:
Numpy arguments `casting`, `order`, `dtype`, `subok`, `signature`, and `extobj` are
not supported.
When `where` is provided, `out` must have a tensor value. `out` is not supported
for storing the result, however it can be used in combination with `where` to set
the value at indices for which `where` is set to False.
Numpy arguments `out`, `where`, `casting`, `order`, `subok`, `signature`,
and `extobj` are not supported.


Args: Args:
x1 (Tensor): Input array. x1 (Tensor): Input array.
x2 (Tensor): Input array. If ``x1.shape != x2.shape``, they must be x2 (Tensor): Input array. If ``x1.shape != x2.shape``, they must be
broadcastable to a common shape (which becomes the shape of the output). broadcastable to a common shape (which becomes the shape of the output).
out (Tensor or None, optional): defaults to None.
where (Tensor or None, optional): For any non-default value of type other
than :class:`Tensor` or :class:`None`, the output retains its original value.
This condition is broadcasted over the input. At locations where the
condition is `True`, the out array will be set to the ufunc result.
Elsewhere, the out array will retain its original value. Note that
if an uninitialized out array is created via the default ``out=None``,
locations within it where the condition is `False` will remain
uninitialized.
dtype (:class:`mindspore.dtype`, optional): defaults to None. Overrides the dtype of the dtype (:class:`mindspore.dtype`, optional): defaults to None. Overrides the dtype of the
output Tensor. output Tensor.


@@ -301,34 +237,22 @@ def equal(x1, x2, out=None, where=True, dtype=None):
>>> print(output) >>> print(output)
[ True True False] [ True True False]
""" """
return _apply_tensor_op(F.equal, x1, x2, out=out, where=where, dtype=dtype)
return _apply_tensor_op(F.equal, x1, x2, dtype=dtype)




def isfinite(x, out=None, where=True, dtype=None):
def isfinite(x, dtype=None):
""" """
Tests element-wise for finiteness (not infinity or not Not a Number). Tests element-wise for finiteness (not infinity or not Not a Number).


The result is returned as a boolean array. The result is returned as a boolean array.


Note: Note:
Numpy arguments `casting`, `order`, `dtype`, `subok`, `signature`, and `extobj` are
not supported.
When `where` is provided, `out` must have a tensor value. `out` is not supported
for storing the result, however it can be used in combination with `where` to set
the value at indices for which `where` is set to False.
Numpy arguments `out`, `where`, `casting`, `order`, `subok`, `signature`,
and `extobj` are not supported.
On GPU, the supported dtypes are np.float16, and np.float32. On GPU, the supported dtypes are np.float16, and np.float32.


Args: Args:
x (Tensor): Input values. x (Tensor): Input values.
out (Tensor or None, optional): defaults to None.
where (Tensor or None, optional): For any non-default value of type other
than :class:`Tensor` or :class:`None`, the output retains its original value.
This condition is broadcasted over the input. At locations where the
condition is `True`, the out array will be set to the ufunc result.
Elsewhere, the out array will retain its original value. Note that
if an uninitialized out array is created via the default ``out=None``,
locations within it where the condition is `False` will remain
uninitialized.
dtype (:class:`mindspore.dtype`, optional): defaults to None. Overrides the dtype of the dtype (:class:`mindspore.dtype`, optional): defaults to None. Overrides the dtype of the
output Tensor. output Tensor.


@@ -351,37 +275,20 @@ def isfinite(x, out=None, where=True, dtype=None):
>>> print(output) >>> print(output)
False False
""" """
return _apply_tensor_op(F.isfinite, x, out=out, where=where, dtype=dtype)


def _isnan(x):
"""Computes isnan without applying keyword arguments."""
return F.not_equal(x, x)
return _apply_tensor_op(F.isfinite, x, dtype=dtype)




def isnan(x, out=None, where=True, dtype=None):
def isnan(x, dtype=None):
""" """
Tests element-wise for NaN and return result as a boolean array. Tests element-wise for NaN and return result as a boolean array.


Note: Note:
Numpy arguments `casting`, `order`, `dtype`, `subok`, `signature`, and `extobj` are
not supported.
When `where` is provided, `out` must have a tensor value. `out` is not supported
for storing the result, however it can be used in combination with `where` to set
the value at indices for which `where` is set to False.
Numpy arguments `out`, `where`, `casting`, `order`, `subok`, `signature`,
and `extobj` are not supported.
Only np.float32 is currently supported. Only np.float32 is currently supported.


Args: Args:
x (Tensor): Input values. x (Tensor): Input values.
out (Tensor or None, optional): defaults to None.
where (Tensor or None, optional): For any non-default value of type other
than :class:`Tensor` or :class:`None`, the output retains its original value.
This condition is broadcasted over the input. At locations where the
condition is `True`, the out array will be set to the ufunc result.
Elsewhere, the out array will retain its original value. Note that
if an uninitialized out array is created via the default ``out=None``,
locations within it where the condition is `False` will remain
uninitialized.
dtype (:class:`mindspore.dtype`, optional): defaults to None. Overrides the dtype of the dtype (:class:`mindspore.dtype`, optional): defaults to None. Overrides the dtype of the
output Tensor. output Tensor.


@@ -404,7 +311,7 @@ def isnan(x, out=None, where=True, dtype=None):
>>> print(output) >>> print(output)
False False
""" """
return _apply_tensor_op(_isnan, x, out=out, where=where, dtype=dtype)
return _apply_tensor_op(_isnan, x, dtype=dtype)




def _isinf(x): def _isinf(x):
@@ -419,31 +326,19 @@ def _isinf(x):
return F.cast(res, mstype.bool_) return F.cast(res, mstype.bool_)




def isinf(x, out=None, where=True, dtype=None):
def isinf(x, dtype=None):
""" """
Tests element-wise for positive or negative infinity. Tests element-wise for positive or negative infinity.


Returns a boolean array of the same shape as `x`, True where ``x == +/-inf``, otherwise False. Returns a boolean array of the same shape as `x`, True where ``x == +/-inf``, otherwise False.


Note: Note:
Numpy arguments `casting`, `order`, `dtype`, `subok`, `signature`, and `extobj` are
not supported.
When `where` is provided, `out` must have a tensor value. `out` is not supported
for storing the result, however it can be used in combination with `where` to set
the value at indices for which `where` is set to False.
Numpy arguments `out`, `where`, `casting`, `order`, `subok`, `signature`,
and `extobj` are not supported.
Only np.float32 is currently supported. Only np.float32 is currently supported.


Args: Args:
x (Tensor): Input values. x (Tensor): Input values.
out (Tensor or None, optional): defaults to None.
where (Tensor or None, optional): For any non-default value of type other
than :class:`Tensor` or :class:`None`, the output retains its original value.
This condition is broadcasted over the input. At locations where the
condition is `True`, the out array will be set to the ufunc result.
Elsewhere, the out array will retain its original value. Note that
if an uninitialized out array is created via the default ``out=None``,
locations within it where the condition is `False` will remain
uninitialized.
dtype (:class:`mindspore.dtype`, optional): defaults to None. Overrides the dtype of the dtype (:class:`mindspore.dtype`, optional): defaults to None. Overrides the dtype of the
output Tensor. output Tensor.


@@ -466,7 +361,7 @@ def isinf(x, out=None, where=True, dtype=None):
>>> print(output) >>> print(output)
[ True True False False] [ True True False False]
""" """
return _apply_tensor_op(_isinf, x, out=out, where=where, dtype=dtype)
return _apply_tensor_op(_isinf, x, dtype=dtype)




def _is_sign_inf(x, fn): def _is_sign_inf(x, fn):
@@ -562,7 +457,7 @@ def isscalar(element):
element (any): Input argument, can be of any type and shape. element (any): Input argument, can be of any type and shape.


Returns: Returns:
Boolean, True if `element` is a scalar type, False if it is not.
Boolean, True if `element` is a scalar type, False if it is not.


Raises: Raises:
TypeError: if the type of `element` is not supported by mindspore parser. TypeError: if the type of `element` is not supported by mindspore parser.
@@ -587,3 +482,302 @@ def isscalar(element):
""" """
obj_type = F.typeof(element) obj_type = F.typeof(element)
return not isinstance(obj_type, Tensor) and _isscalar(obj_type) return not isinstance(obj_type, Tensor) and _isscalar(obj_type)


def isclose(a, b, rtol=1e-05, atol=1e-08, equal_nan=False):
"""
Returns a boolean tensor where two tensors are element-wise equal within a tolerance.

The tolerance values are positive, typically very small numbers. The relative
difference (:math:`rtol * abs(b)`) and the absolute difference `atol` are added together
to compare against the absolute difference between `a` and `b`.

Note:
For finite values, isclose uses the following equation to test whether two
floating point values are equivalent.
:math:`absolute(a - b) <= (atol + rtol * absolute(b))`

Args:
a (Union[Tensor, list, tuple]): Input first tensor to compare.
b (Union[Tensor, list, tuple]): Input second tensor to compare.
rtol (Number): The relative tolerance parameter (see Note).
atol (Number): The absolute tolerance parameter (see Note).
equal_nan (bool): Whether to compare ``NaN`` as equal. If True, ``NaN`` in
`a` will be considered equal to ``NaN`` in `b` in the output tensor.

Returns:
A ``bool`` tensor of where `a` and `b` are equal within the given tolerance.

Raises:
TypeError: If inputs have types not specified above.

Supported Platforms:
``GPU`` ``CPU``

Examples:
>>> a = np.array([0,1,2,float('inf'),float('inf'),float('nan')])
>>> b = np.array([0,1,-2,float('-inf'),float('inf'),float('nan')])
>>> print(np.isclose(a, b))
[ True True False False True False]
>>> print(np.isclose(a, b, equal_nan=True))
[ True True False False True True]
"""
a, b = _to_tensor(a, b)
if not isinstance(rtol, (int, float, bool)) or not isinstance(atol, (int, float, bool)):
_raise_type_error("rtol and atol are expected to be numbers.")
if not isinstance(equal_nan, bool):
_raise_type_error("equal_nan is expected to be bool.")

if _is_shape_empty(a.shape) or _is_shape_empty(b.shape):
return empty(_infer_out_shape(a.shape, b.shape), dtype=mstype.bool_)
rtol = _to_tensor(rtol).astype("float32")
atol = _to_tensor(atol).astype("float32")
res = absolute(a - b) <= (atol + rtol * absolute(b))
# infs are treated as equal
a_posinf = isposinf(a)
b_posinf = isposinf(b)
a_neginf = isneginf(a)
b_neginf = isneginf(b)
same_inf = F.logical_or(F.logical_and(a_posinf, b_posinf), F.logical_and(a_neginf, b_neginf))
diff_inf = F.logical_or(F.logical_and(a_posinf, b_neginf), F.logical_and(a_neginf, b_posinf))
res = F.logical_and(F.logical_or(res, same_inf), F.logical_not(diff_inf))
both_nan = F.logical_and(_isnan(a), _isnan(b))
if equal_nan:
res = F.logical_or(both_nan, res)
else:
res = F.logical_and(F.logical_not(both_nan), res)
return res


def in1d(ar1, ar2, invert=False):
"""
Tests whether each element of a 1-D array is also present in a second array.

Returns a boolean array the same length as `ar1` that is True where an element
of `ar1` is in `ar2` and False otherwise.

Note:
Numpy argument `assume_unique` is not supported since the implementation does
not rely on the uniqueness of the input arrays.

Args:
ar1 (array_like): Input array with shape `(M,)`.
ar2 (array_like): The values against which to test each value of `ar1`.
invert (boolean, optional): If True, the values in the returned array are
inverted (that is, False where an element of `ar1` is in `ar2` and True
otherwise). Default is False.

Returns:
Tensor, with shape `(M,)`. The values ``ar1[in1d]`` are in `ar2`.

Supported Platforms:
``Ascend`` ``GPU`` ``CPU``

Examples:
>>> test = np.array([0, 1, 2, 5, 0])
>>> states = [0, 2]
>>> mask = np.in1d(test, states)
>>> print(mask)
[ True False True False True]
>>> mask = np.in1d(test, states, invert=True)
>>> print(mask)
[False True False True False]
"""
ar1, ar2 = _to_tensor(ar1, ar2)
ar1 = F.expand_dims(ar1.ravel(), -1)
ar2 = ar2.ravel()
included = F.equal(ar1, ar2)
# F.reduce_sum only supports float
res = F.reduce_sum(included.astype(mstype.float32), -1).astype(mstype.bool_)
if invert:
res = F.equal(res, _to_tensor(False))
return res


def isin(element, test_elements, invert=False):
"""
Calculates element in `test_elements`, broadcasting over `element` only. Returns a
boolean array of the same shape as `element` that is True where an element of
`element` is in `test_elements` and False otherwise.

Note:
Numpy argument `assume_unique` is not supported since the implementation does
not rely on the uniqueness of the input arrays.

Args:
element (array_like): Input array.
test_elements (array_like): The values against which to test each value of
`element`.
invert (boolean, optional): If True, the values in the returned array are
inverted, as if calculating `element` not in `test_elements`. Default is False.

Returns:
Tensor, has the same shape as `element`. The values ``element[isin]`` are in
`test_elements`.

Supported Platforms:
``Ascend`` ``GPU`` ``CPU``

Examples:
>>> element = 2*np.arange(4).reshape((2, 2))
>>> test_elements = [1, 2, 4, 8]
>>> mask = np.isin(element, test_elements)
>>> print(mask)
[[False True]
[ True False]]
>>> mask = np.isin(element, test_elements, invert=True)
>>> print(mask)
[[ True False]
[False True]]
"""
res = in1d(element, test_elements, invert=invert)
return F.reshape(res, F.shape(element))


def logical_not(a, dtype=None):
"""
Computes the truth value of NOT `a` element-wise.

Note:
Numpy arguments `out`, `where`, `casting`, `order`, `subok`, `signature`, and `extobj` are
not supported.

Args:
a (Tensor): The input tensor whose dtype is bool.
dtype (:class:`mindspore.dtype`, optional): Default: :class:`None`. Overrides the dtype of the
output Tensor.

Returns:
Tensor or scalar.
Boolean result with the same shape as `a` of the NOT operation on elements of `a`.
This is a scalar if `a` is a scalar.

Raises:
TypeError: if the input is not a tensor or its dtype is not bool.

Supported Platforms:
``Ascend`` ``GPU`` ``CPU``

Examples:
>>> import mindspore.numpy as np
>>> a = np.array([True, False])
>>> output = np.logical_not(a)
>>> print(output)
[False True]
"""
return _apply_tensor_op(F.logical_not, a, dtype=dtype)


def logical_or(x1, x2, dtype=None):
"""
Computes the truth value of `x1` OR `x2` element-wise.

Note:
Numpy arguments `out`, `where`, `casting`, `order`, `subok`, `signature`,
and `extobj` are not supported.

Args:
x1 (Tensor): Input tensor.
x2 (Tensor): Input tensor. If ``x1.shape != x2.shape``, they must be
broadcastable to a common shape (which becomes the shape of the output).
dtype (:class:`mindspore.dtype`, optional): defaults to None. Overrides the dtype of the
output Tensor.

Returns:
Tensor or scalar, element-wise comparison of `x1` and `x2`. Typically of type
bool, unless ``dtype=object`` is passed. This is a scalar if both `x1` and `x2` are
scalars.

Raises:
TypeError: if the input is not a tensor.

Supported Platforms:
``Ascend`` ``GPU`` ``CPU``

Examples:
>>> import mindspore.numpy as np
>>> x1 = np.array([True, False])
>>> x2 = np.array([False, True])
>>> output = np.logical_or(x1, x2)
>>> print(output)
[ True True]
"""
return _apply_tensor_op(F.logical_or, x1, x2, dtype=dtype)


def logical_and(x1, x2, dtype=None):
"""
Computes the truth value of `x1` AND `x2` element-wise.

Note:
Numpy arguments `out`, `where`, `casting`, `order`, `subok`, `signature`,
and `extobj` are not supported.

Args:
x1 (Tensor): Input tensor.
x2 (Tensor): Input tensor. If ``x1.shape != x2.shape``, they must be
broadcastable to a common shape (which becomes the shape of the output).
dtype (:class:`mindspore.dtype`, optional): defaults to None. Overrides the dtype of the
output Tensor.

Returns:
Tensor or scalar.
Boolean result of the logical AND operation applied to the elements of `x1` and `x2`;
the shape is determined by broadcasting. This is a scalar if both `x1` and `x2` are scalars.

Raises:
TypeError: if the input is not a tensor.

Supported Platforms:
``Ascend`` ``GPU`` ``CPU``

Examples:
>>> import mindspore.numpy as np
>>> x1 = np.array([True, False])
>>> x2 = np.array([False, False])
>>> output = np.logical_and(x1, x2)
>>> print(output)
[False False]
"""
return _apply_tensor_op(F.logical_and, x1, x2, dtype=dtype)


def logical_xor(x1, x2, dtype=None):
"""
Computes the truth value of `x1` XOR `x2`, element-wise.

Note:
Numpy arguments `out`, `where`, `casting`, `order`, `subok`, `signature`,
and `extobj` are not supported.

Args:
x1 (Tensor): Input tensor.
x2 (Tensor): Input tensor. If ``x1.shape != x2.shape``, they must be
broadcastable to a common shape (which becomes the shape of the output).
dtype (:class:`mindspore.dtype`, optional): defaults to None. Overrides the dtype of the
output Tensor.

Returns:
Tensor or scalar.
Boolean result of the logical AND operation applied to the elements of `x1` and `x2`;
the shape is determined by broadcasting. This is a scalar if both `x1` and `x2` are scalars.

Raises:
TypeError: if the input is not a tensor.

Supported Platforms:
``Ascend`` ``GPU`` ``CPU``

Examples:
>>> import mindspore.numpy as np
>>> x1 = np.array([True, False])
>>> x2 = np.array([False, False])
>>> output = np.logical_xor(x1, x2)
>>> print(output)
[True False]
"""
_check_input_tensor(x1)
_check_input_tensor(x2)
y1 = F.logical_or(x1, x2)
y2 = F.logical_or(F.logical_not(x1), F.logical_not(x2))
return _apply_tensor_op(F.logical_and, y1, y2, dtype=dtype)

+ 1977
- 574
mindspore/numpy/math_ops.py
File diff suppressed because it is too large
View File


+ 22
- 7
mindspore/numpy/utils.py View File

@@ -13,14 +13,11 @@
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
"""internal utility functions""" """internal utility functions"""

import numpy as onp

from ..common import Tensor from ..common import Tensor
from ..ops import functional as F from ..ops import functional as F
from ..common import dtype as mstype from ..common import dtype as mstype


from .utils_const import _tile_size, _add_unit_axes, _raise_type_error
from .utils_const import _tile_size, _add_unit_axes, _raise_type_error, _type_convert




def _deep_list(array_like): def _deep_list(array_like):
@@ -56,9 +53,8 @@ def _deep_tensor_to_nparray(array_like):


def _check_input_for_asarray(array_like): def _check_input_for_asarray(array_like):
"""check whether array_like argument is a valid type for np.asarray conversion""" """check whether array_like argument is a valid type for np.asarray conversion"""
if not isinstance(array_like, (Tensor, list, tuple, int, float, bool, onp.ndarray)):
_raise_type_error("input data must be `int`, `float`, `bool`, `Tensor`, `list`, `tuple`" + \
"or numpy.ndarray, but got ", array_like)
if not isinstance(array_like, (Tensor, list, tuple, int, float, bool)):
_raise_type_error("input data must be `int`, `float`, `bool`, `Tensor`, `list`, `tuple`, but got ", array_like)




def _is_scalar(shape): def _is_scalar(shape):
@@ -121,6 +117,20 @@ def _convert_64_to_32(tensor):
return tensor return tensor




def _to_tensor(*args):
"""Returns each input as Tensor"""
res = ()
for arg in args:
if isinstance(arg, (int, float, bool, list, tuple)):
arg = _convert_64_to_32(_type_convert(Tensor, arg))
elif not isinstance(arg, Tensor):
_raise_type_error("Expect input to be array like.")
res += (arg,)
if len(res) == 1:
return res[0]
return res


def _get_dtype_from_scalar(*input_numbers): def _get_dtype_from_scalar(*input_numbers):
""" """
Get the final dtype from series of input numbers, compared with F.typeof, we Get the final dtype from series of input numbers, compared with F.typeof, we
@@ -139,3 +149,8 @@ def _get_dtype_from_scalar(*input_numbers):
if int_flag: if int_flag:
return mstype.int32 return mstype.int32
return mstype.float32 return mstype.float32


def _isnan(x):
"""Computes isnan."""
return F.not_equal(x, x)

+ 57
- 39
mindspore/numpy/utils_const.py View File

@@ -14,7 +14,8 @@
# ============================================================================ # ============================================================================
"""internal graph-compatible utility functions""" """internal graph-compatible utility functions"""
import math import math
from functools import partial
from itertools import zip_longest
from collections import deque


import mindspore.context as context import mindspore.context as context
from ..ops import functional as F from ..ops import functional as F
@@ -24,7 +25,7 @@ from ..common import Tensor
from .._c_expression import Tensor as Tensor_ from .._c_expression import Tensor as Tensor_
from .._c_expression import typing from .._c_expression import typing


from .dtypes import promotion_rule, dtype_tuple, all_types, dtype_map
from .dtypes import promotion_rule, dtype_tuple, all_types, dtype_map, rule_for_trigonometric




@constexpr @constexpr
@@ -110,44 +111,19 @@ def _get_device():
return context.get_context('device_target') return context.get_context('device_target')




@constexpr
def _reverse_index(idx, arr):
"""
Returns 1 if shape[idx:] is broadcastable to shape_out[idx:],
2 situations if the function returns 1:
- 1. Tensor's shape has 1 at the designated dimension.
- 2. Tensor's dimension is less than the designated idx. (The Tensor shape
has been reversed)
For both cases, 2 tensors are broadcastable.
otherwise returns the element at position of shape
"""
if len(arr) <= idx:
return 1
return arr[-1 - idx]


@constexpr @constexpr
def _infer_out_shape(*shapes): def _infer_out_shape(*shapes):
""" """
Returns shape of output after broadcasting
Raises ValueError if shape1 and shape2 cannot be broadcast
Returns shape of output after broadcasting. Raises ValueError if shapes cannot be broadcast.
""" """
shapes_unbroadcastable = False
ndim_max = max(map(len, shapes))
shape_out = [0]*ndim_max
i = 0
for i in range(ndim_max):
shape_out[-1 - i] = max(map(partial(_reverse_index, i), shapes))
for shape in shapes:
if _reverse_index(i, shape) != shape_out[-1 - i]:
if _reverse_index(i, shape) != 1:
shapes_unbroadcastable = True
break
if shapes_unbroadcastable:
break
if not shapes_unbroadcastable:
return tuple(shape_out)
raise ValueError(f'operands could not be broadcast together with shapes {*shapes,}')
shape_out = deque()
reversed_shapes = map(reversed, shapes)
for items in zip_longest(*reversed_shapes, fillvalue=1):
max_size = 0 if 0 in items else max(items)
if any(item not in (1, max_size) for item in items):
raise ValueError(f'operands could not be broadcast together with shapes {*shapes,}')
shape_out.appendleft(max_size)
return tuple(shape_out)




@constexpr @constexpr
@@ -228,6 +204,21 @@ def _raise_value_error(info, param=None):
raise ValueError(info + f"{param}") raise ValueError(info + f"{param}")




@constexpr
def _raise_runtime_error(info, param=None):
"""
Raise RuntimeError in both graph/pynative mode

Args:
info(str): info string to display
param(python obj): any object that can be recognized by graph mode. If is
not None, then param's value information will be extracted and displayed.
Default is None.
"""
if param is None:
raise RuntimeError(info)
raise RuntimeError(info + f"{param}")

@constexpr @constexpr
def _empty(dtype, shape): def _empty(dtype, shape):
"""Returns an uninitialized array with dtype and shape.""" """Returns an uninitialized array with dtype and shape."""
@@ -242,6 +233,9 @@ def _promote(dtype1, dtype2):
return promotion_rule[dtype1, dtype2] return promotion_rule[dtype1, dtype2]
return promotion_rule[dtype2, dtype1] return promotion_rule[dtype2, dtype1]


@constexpr
def _promote_for_trigonometric(dtype):
return rule_for_trigonometric[dtype]


@constexpr @constexpr
def _max(*args): def _max(*args):
@@ -315,7 +309,7 @@ def _canonicalize_axis(axis, ndim):


axis = tuple([canonicalizer(axis) for axis in axis]) axis = tuple([canonicalizer(axis) for axis in axis])
if all(axis.count(el) <= 1 for el in axis): if all(axis.count(el) <= 1 for el in axis):
return axis if len(axis) > 1 else axis[0]
return tuple(sorted(axis)) if len(axis) > 1 else axis[0]
raise ValueError(f"duplicate axes in {axis}.") raise ValueError(f"duplicate axes in {axis}.")




@@ -426,13 +420,37 @@ def _tuple_getitem(tup, idx, startswith=True):




@constexpr @constexpr
def _iota(dtype, num):
def _tuple_setitem(tup, idx, value):
"""
Returns a tuple with specified `idx` set to `value`.
"""
tup = list(tup)
tup[idx] = value
return tuple(tup)


@constexpr
def _iota(dtype, num, increasing=True):
"""Creates a 1-D tensor with value: [0,1,...num-1] and dtype.""" """Creates a 1-D tensor with value: [0,1,...num-1] and dtype."""
# TODO: Change to P.Linspace when the kernel is implemented on CPU. # TODO: Change to P.Linspace when the kernel is implemented on CPU.
return Tensor(list(range(int(num))), dtype)
if increasing:
return Tensor(list(range(int(num))), dtype)
return Tensor(list(range(int(num)-1, -1, -1)), dtype)




@constexpr @constexpr
def _ceil(number): def _ceil(number):
"""Ceils the number in graph mode.""" """Ceils the number in graph mode."""
return math.ceil(number) return math.ceil(number)


@constexpr
def _seq_prod(seq1, seq2):
"""Returns the element-wise product of seq1 and seq2."""
return tuple(map(lambda x, y: x*y, seq1, seq2))


@constexpr
def _make_tensor(val, dtype):
""" Returns the tensor with value `val` and dtype `dtype`."""
return Tensor(val, dtype)

+ 16
- 0
mindspore/ops/composite/multitype_ops/not_equal_impl.py View File

@@ -15,6 +15,7 @@


"""Implementation for internal polymorphism `not equal` operations.""" """Implementation for internal polymorphism `not equal` operations."""


from . import _constexpr_utils as const_utils
from ...composite import base from ...composite import base
from ... import functional as F from ... import functional as F


@@ -41,6 +42,21 @@ def _not_equal_scalar(x, y):
return not F.scalar_eq(x, y) return not F.scalar_eq(x, y)




@not_equal.register("mstype", "mstype")
def _not_equal_mstype(x, y):
"""
Determine if two mindspore types are not equal.

Args:
x (mstype): first input mindspore type.
y (mstype): second input mindspore type.

Returns:
bool, if x != y return true, x == y return false.
"""
return not const_utils.mstype_eq(x, y)


@not_equal.register("String", "String") @not_equal.register("String", "String")
def _not_equal_string(x, y): def _not_equal_string(x, y):
""" """


+ 17
- 0
mindspore/ops/functional.py View File

@@ -77,6 +77,7 @@ floormod = tensor_mod
tensor_exp = P.Exp() tensor_exp = P.Exp()
exp = tensor_exp exp = tensor_exp
tensor_expm1 = P.Expm1() tensor_expm1 = P.Expm1()
tensor_slice = P.Slice()
strided_slice = P.StridedSlice() strided_slice = P.StridedSlice()
same_type_shape = P.SameTypeShape() same_type_shape = P.SameTypeShape()
check_bprop = P.CheckBprop() check_bprop = P.CheckBprop()
@@ -94,6 +95,22 @@ tensor_slice = P.Slice()
maximum = P.Maximum() maximum = P.Maximum()
minimum = P.Minimum() minimum = P.Minimum()
floor = P.Floor() floor = P.Floor()
logical_not = P.LogicalNot()
logical_or = P.LogicalOr()
logical_and = P.LogicalAnd()
sin = P.Sin()
cos = P.Cos()
tan = P.Tan()
asin = P.Asin()
acos = P.ACos()
atan = P.Atan()
sinh = P.Sinh()
cosh = P.Cosh()
tanh = P.Tanh()
asinh = P.Asinh()
acosh = P.Acosh()
atanh = P.Atanh()
atan2 = P.Atan2()


scalar_to_array = P.ScalarToArray() scalar_to_array = P.ScalarToArray()
scalar_to_tensor = P.ScalarToTensor() scalar_to_tensor = P.ScalarToTensor()


+ 2
- 2
mindspore/ops/operations/math_ops.py View File

@@ -2560,7 +2560,7 @@ class Acosh(PrimitiveWithInfer):
TypeError: If `input_x` is not a Tensor. TypeError: If `input_x` is not a Tensor.


Supported Platforms: Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
``Ascend`` ``GPU``


Examples: Examples:
>>> acosh = ops.Acosh() >>> acosh = ops.Acosh()
@@ -2637,7 +2637,7 @@ class Asinh(PrimitiveWithInfer):
TypeError: If `input_x` is not a Tensor. TypeError: If `input_x` is not a Tensor.


Supported Platforms: Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
``Ascend`` ``GPU``


Examples: Examples:
>>> asinh = ops.Asinh() >>> asinh = ops.Asinh()


+ 95
- 18
tests/st/numpy_native/test_array_creations.py View File

@@ -20,7 +20,7 @@ import numpy as onp
import mindspore.numpy as mnp import mindspore.numpy as mnp


from .utils import rand_int, rand_bool, match_array, match_res, match_meta, \ from .utils import rand_int, rand_bool, match_array, match_res, match_meta, \
match_all_arrays
match_all_arrays, run_multi_test, to_tensor




class Cases(): class Cases():
@@ -40,8 +40,8 @@ class Cases():


self.array_sets = [1, 1.1, True, [1, 0, True], [1, 1.0, 2], (1,), self.array_sets = [1, 1.1, True, [1, 0, True], [1, 1.0, 2], (1,),
[(1, 2, 3), (4, 5, 6)], onp.random.random( # pylint: disable=no-member [(1, 2, 3), (4, 5, 6)], onp.random.random( # pylint: disable=no-member
(100, 100)).astype(onp.float32),
onp.random.random((100, 100)).astype(onp.bool)]
(100, 100)).astype(onp.float32).tolist(),
onp.random.random((100, 100)).astype(onp.bool).tolist()]


self.arrs = [ self.arrs = [
rand_int(2), rand_int(2),
@@ -138,8 +138,8 @@ def test_asarray():
expected = mnp.asarray(array, test_case.mnp_dtypes[i]).asnumpy() expected = mnp.asarray(array, test_case.mnp_dtypes[i]).asnumpy()
match_array(actual, expected, error=7) match_array(actual, expected, error=7)


# Additional tests for nested tensor/numpy_array mixture
mnp_input = [(onp.ones(3,), mnp.ones(3)), [[1, 1, 1], (1, 1, 1)]]
# Additional tests for nested tensor mixture
mnp_input = [(mnp.ones(3,), mnp.ones(3)), [[1, 1, 1], (1, 1, 1)]]
onp_input = [(onp.ones(3,), onp.ones(3)), [[1, 1, 1], (1, 1, 1)]] onp_input = [(onp.ones(3,), onp.ones(3)), [[1, 1, 1], (1, 1, 1)]]


actual = onp.asarray(onp_input) actual = onp.asarray(onp_input)
@@ -168,11 +168,11 @@ def test_array():
assert arr4 is arr5 assert arr4 is arr5


# Additional tests for nested tensor/numpy_array mixture # Additional tests for nested tensor/numpy_array mixture
mnp_input = [(onp.ones(3,), mnp.ones(3)), [[1, 1, 1], (1, 1, 1)]]
mnp_input = [(mnp.ones(3,), mnp.ones(3)), [[1, 1, 1], (1, 1, 1)]]
onp_input = [(onp.ones(3,), onp.ones(3)), [[1, 1, 1], (1, 1, 1)]] onp_input = [(onp.ones(3,), onp.ones(3)), [[1, 1, 1], (1, 1, 1)]]


actual = onp.asarray(onp_input)
expected = mnp.asarray(mnp_input).asnumpy()
actual = onp.array(onp_input)
expected = mnp.array(mnp_input).asnumpy()
match_array(actual, expected, error=7) match_array(actual, expected, error=7)




@@ -202,11 +202,11 @@ def test_asfarray():
match_array(actual, expected, error=7) match_array(actual, expected, error=7)


# Additional tests for nested tensor/numpy_array mixture # Additional tests for nested tensor/numpy_array mixture
mnp_input = [(onp.ones(3,), mnp.ones(3)), [[1, 1, 1], (1, 1, 1)]]
mnp_input = [(mnp.ones(3,), mnp.ones(3)), [[1, 1, 1], (1, 1, 1)]]
onp_input = [(onp.ones(3,), onp.ones(3)), [[1, 1, 1], (1, 1, 1)]] onp_input = [(onp.ones(3,), onp.ones(3)), [[1, 1, 1], (1, 1, 1)]]


actual = onp.asarray(onp_input)
expected = mnp.asarray(mnp_input).asnumpy()
actual = onp.asfarray(onp_input)
expected = mnp.asfarray(mnp_input).asnumpy()
match_array(actual, expected, error=7) match_array(actual, expected, error=7)




@@ -373,14 +373,14 @@ def test_linspace():
stop = onp.random.random([1, 5, 1]).astype("float32") stop = onp.random.random([1, 5, 1]).astype("float32")
actual = onp.linspace(start, stop, num=20, retstep=True, actual = onp.linspace(start, stop, num=20, retstep=True,
endpoint=False, dtype=onp.float32) endpoint=False, dtype=onp.float32)
expected = mnp.linspace(mnp.asarray(start), mnp.asarray(stop), num=20,
expected = mnp.linspace(to_tensor(start), to_tensor(stop), num=20,
retstep=True, endpoint=False) retstep=True, endpoint=False)
match_array(actual[0], expected[0].asnumpy(), error=6) match_array(actual[0], expected[0].asnumpy(), error=6)
match_array(actual[1], expected[1].asnumpy(), error=6) match_array(actual[1], expected[1].asnumpy(), error=6)


actual = onp.linspace(start, stop, num=20, retstep=True, actual = onp.linspace(start, stop, num=20, retstep=True,
endpoint=False, dtype=onp.int16) endpoint=False, dtype=onp.int16)
expected = mnp.linspace(mnp.asarray(start), mnp.asarray(stop), num=20,
expected = mnp.linspace(to_tensor(start), to_tensor(stop), num=20,
retstep=True, endpoint=False, dtype=mnp.int16) retstep=True, endpoint=False, dtype=mnp.int16)
match_array(actual[0], expected[0].asnumpy(), error=6) match_array(actual[0], expected[0].asnumpy(), error=6)
match_array(actual[1], expected[1].asnumpy(), error=6) match_array(actual[1], expected[1].asnumpy(), error=6)
@@ -388,7 +388,7 @@ def test_linspace():
for axis in range(2): for axis in range(2):
actual = onp.linspace(start, stop, num=20, retstep=False, actual = onp.linspace(start, stop, num=20, retstep=False,
endpoint=False, dtype=onp.float32, axis=axis) endpoint=False, dtype=onp.float32, axis=axis)
expected = mnp.linspace(mnp.asarray(start), mnp.asarray(stop), num=20,
expected = mnp.linspace(to_tensor(start), to_tensor(stop), num=20,
retstep=False, endpoint=False, dtype=mnp.float32, axis=axis) retstep=False, endpoint=False, dtype=mnp.float32, axis=axis)
match_array(actual, expected.asnumpy(), error=6) match_array(actual, expected.asnumpy(), error=6)


@@ -510,18 +510,18 @@ def test_full_like():
for mnp_proto, onp_proto in zip(test_case.mnp_prototypes, test_case.onp_prototypes): for mnp_proto, onp_proto in zip(test_case.mnp_prototypes, test_case.onp_prototypes):
shape = onp.zeros_like(onp_proto).shape shape = onp.zeros_like(onp_proto).shape
fill_value = rand_int() fill_value = rand_int()
actual = mnp.full_like(mnp_proto, mnp.array(fill_value)).asnumpy()
actual = mnp.full_like(mnp_proto, to_tensor(fill_value)).asnumpy()
expected = onp.full_like(onp_proto, fill_value) expected = onp.full_like(onp_proto, fill_value)
match_array(actual, expected) match_array(actual, expected)


for i in range(len(shape) - 1, 0, -1): for i in range(len(shape) - 1, 0, -1):
fill_value = rand_int(*shape[i:]) fill_value = rand_int(*shape[i:])
actual = mnp.full_like(mnp_proto, mnp.array(fill_value)).asnumpy()
actual = mnp.full_like(mnp_proto, to_tensor(fill_value)).asnumpy()
expected = onp.full_like(onp_proto, fill_value) expected = onp.full_like(onp_proto, fill_value)
match_array(actual, expected) match_array(actual, expected)


fill_value = rand_int(1, *shape[i + 1:]) fill_value = rand_int(1, *shape[i + 1:])
actual = mnp.full_like(mnp_proto, mnp.array(fill_value)).asnumpy()
actual = mnp.full_like(mnp_proto, to_tensor(fill_value)).asnumpy()
expected = onp.full_like(onp_proto, fill_value) expected = onp.full_like(onp_proto, fill_value)
match_array(actual, expected) match_array(actual, expected)


@@ -549,6 +549,21 @@ def test_tri_triu_tril():
match_array(mnp.tri(64, 64, -10).asnumpy(), onp.tri(64, 64, -10)) match_array(mnp.tri(64, 64, -10).asnumpy(), onp.tri(64, 64, -10))




@pytest.mark.level1
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_nancumsum():
x = rand_int(2, 3, 4, 5)
x[0][2][1][3] = onp.nan
x[1][0][2][4] = onp.nan
x[1][1][1][1] = onp.nan
match_res(mnp.nancumsum, onp.nancumsum, x)
match_res(mnp.nancumsum, onp.nancumsum, x, axis=-2)
match_res(mnp.nancumsum, onp.nancumsum, x, axis=0)
match_res(mnp.nancumsum, onp.nancumsum, x, axis=3)


def mnp_diagonal(arr): def mnp_diagonal(arr):
return mnp.diagonal(arr, offset=2, axis1=-1, axis2=0) return mnp.diagonal(arr, offset=2, axis1=-1, axis2=0)


@@ -653,7 +668,7 @@ def test_meshgrid():
(2, 3), 9), onp.full((4, 5, 6), 7)) (2, 3), 9), onp.full((4, 5, 6), 7))
for i in range(len(xi)): for i in range(len(xi)):
arrs = xi[i:] arrs = xi[i:]
mnp_arrs = map(mnp.asarray, arrs)
mnp_arrs = map(to_tensor, arrs)
for mnp_res, onp_res in zip(mnp_meshgrid(*mnp_arrs), onp_meshgrid(*arrs)): for mnp_res, onp_res in zip(mnp_meshgrid(*mnp_arrs), onp_meshgrid(*arrs)):
match_all_arrays(mnp_res, onp_res) match_all_arrays(mnp_res, onp_res)


@@ -750,6 +765,68 @@ def test_ix_():
match_res(mnp_ix_, onp_ix_, *test_arrs) match_res(mnp_ix_, onp_ix_, *test_arrs)




def mnp_indices():
a = mnp.indices((2, 3))
b = mnp.indices((2, 3, 4), sparse=True)
return a, b


def onp_indices():
a = onp.indices((2, 3))
b = onp.indices((2, 3, 4), sparse=True)
return a, b


def test_indices():
run_multi_test(mnp_indices, onp_indices, ())


@pytest.mark.level1
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_geomspace():
start = onp.arange(1, 7).reshape(2, 3)
end = [1000, 2000, 3000]
match_array(mnp.geomspace(1, 256, num=9).asnumpy(),
onp.geomspace(1, 256, num=9), error=1)
match_array(mnp.geomspace(1, 256, num=8, endpoint=False).asnumpy(),
onp.geomspace(1, 256, num=8, endpoint=False), error=1)
match_array(mnp.geomspace(to_tensor(start), end, num=4).asnumpy(),
onp.geomspace(start, end, num=4), error=1)
match_array(mnp.geomspace(to_tensor(start), end, num=4, endpoint=False).asnumpy(),
onp.geomspace(start, end, num=4, endpoint=False), error=1)
match_array(mnp.geomspace(to_tensor(start), end, num=4, axis=-1).asnumpy(),
onp.geomspace(start, end, num=4, axis=-1), error=1)
match_array(mnp.geomspace(to_tensor(start), end, num=4, endpoint=False, axis=-1).asnumpy(),
onp.geomspace(start, end, num=4, endpoint=False, axis=-1), error=1)

start = onp.arange(1, 1 + 2*3*4*5).reshape(2, 3, 4, 5)
end = [1000, 2000, 3000, 4000, 5000]
for i in range(-5, 5):
match_array(mnp.geomspace(to_tensor(start), end, num=4, axis=i).asnumpy(),
onp.geomspace(start, end, num=4, axis=i), error=1)


@pytest.mark.level1
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_vander():
arrs = [rand_int(i + 3) for i in range(3)]
for i in range(3):
mnp_vander = mnp.vander(to_tensor(arrs[i]))
onp_vander = onp.vander(arrs[i])
match_all_arrays(mnp_vander, onp_vander)
mnp_vander = mnp.vander(to_tensor(arrs[i]), N=2, increasing=True)
onp_vander = onp.vander(arrs[i], N=2, increasing=True)
match_all_arrays(mnp_vander, onp_vander)


@pytest.mark.level1 @pytest.mark.level1
@pytest.mark.platform_arm_ascend_training @pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training @pytest.mark.platform_x86_ascend_training


+ 118
- 27
tests/st/numpy_native/test_array_ops.py View File

@@ -23,7 +23,7 @@ import mindspore.numpy as mnp
from mindspore.nn import Cell from mindspore.nn import Cell


from .utils import rand_int, run_non_kw_test, check_all_results, match_array, \ from .utils import rand_int, run_non_kw_test, check_all_results, match_array, \
rand_bool, match_res, run_multi_test
rand_bool, match_res, run_multi_test, to_tensor




class Cases(): class Cases():
@@ -139,7 +139,7 @@ def onp_transpose(input_array):
@pytest.mark.env_onecard @pytest.mark.env_onecard
def test_transpose(): def test_transpose():
onp_array = onp.random.random((3, 4, 5)).astype('float32') onp_array = onp.random.random((3, 4, 5)).astype('float32')
mnp_array = mnp.asarray(onp_array)
mnp_array = to_tensor(onp_array)
o_transposed = onp_transpose(onp_array) o_transposed = onp_transpose(onp_array)
m_transposed = mnp_transpose(mnp_array) m_transposed = mnp_transpose(mnp_array)
check_all_results(o_transposed, m_transposed) check_all_results(o_transposed, m_transposed)
@@ -170,7 +170,7 @@ def onp_expand_dims(input_array):
@pytest.mark.env_onecard @pytest.mark.env_onecard
def test_expand_dims(): def test_expand_dims():
onp_array = onp.random.random((3, 4, 5)).astype('float32') onp_array = onp.random.random((3, 4, 5)).astype('float32')
mnp_array = mnp.asarray(onp_array)
mnp_array = to_tensor(onp_array)
o_expanded = onp_expand_dims(onp_array) o_expanded = onp_expand_dims(onp_array)
m_expanded = mnp_expand_dims(mnp_array) m_expanded = mnp_expand_dims(mnp_array)
check_all_results(o_expanded, m_expanded) check_all_results(o_expanded, m_expanded)
@@ -205,13 +205,13 @@ def onp_squeeze(input_array):
@pytest.mark.env_onecard @pytest.mark.env_onecard
def test_squeeze(): def test_squeeze():
onp_array = onp.random.random((1, 3, 1, 4, 2)).astype('float32') onp_array = onp.random.random((1, 3, 1, 4, 2)).astype('float32')
mnp_array = mnp.asarray(onp_array)
mnp_array = to_tensor(onp_array)
o_squeezed = onp_squeeze(onp_array) o_squeezed = onp_squeeze(onp_array)
m_squeezed = mnp_squeeze(mnp_array) m_squeezed = mnp_squeeze(mnp_array)
check_all_results(o_squeezed, m_squeezed) check_all_results(o_squeezed, m_squeezed)


onp_array = onp.random.random((1, 1, 1, 1, 1)).astype('float32') onp_array = onp.random.random((1, 1, 1, 1, 1)).astype('float32')
mnp_array = mnp.asarray(onp_array)
mnp_array = to_tensor(onp_array)
o_squeezed = onp_squeeze(onp_array) o_squeezed = onp_squeeze(onp_array)
m_squeezed = mnp_squeeze(mnp_array) m_squeezed = mnp_squeeze(mnp_array)
check_all_results(o_squeezed, m_squeezed) check_all_results(o_squeezed, m_squeezed)
@@ -246,7 +246,7 @@ def onp_rollaxis(input_array):
@pytest.mark.env_onecard @pytest.mark.env_onecard
def test_rollaxis(): def test_rollaxis():
onp_array = onp.random.random((3, 4, 5)).astype('float32') onp_array = onp.random.random((3, 4, 5)).astype('float32')
mnp_array = mnp.asarray(onp_array)
mnp_array = to_tensor(onp_array)
o_rolled = onp_rollaxis(onp_array) o_rolled = onp_rollaxis(onp_array)
m_rolled = mnp_rollaxis(mnp_array) m_rolled = mnp_rollaxis(mnp_array)
check_all_results(o_rolled, m_rolled) check_all_results(o_rolled, m_rolled)
@@ -281,7 +281,7 @@ def onp_swapaxes(input_array):
@pytest.mark.env_onecard @pytest.mark.env_onecard
def test_swapaxes(): def test_swapaxes():
onp_array = onp.random.random((3, 4, 5)).astype('float32') onp_array = onp.random.random((3, 4, 5)).astype('float32')
mnp_array = mnp.asarray(onp_array)
mnp_array = to_tensor(onp_array)
o_swaped = onp_swapaxes(onp_array) o_swaped = onp_swapaxes(onp_array)
m_swaped = mnp_swapaxes(mnp_array) m_swaped = mnp_swapaxes(mnp_array)
check_all_results(o_swaped, m_swaped) check_all_results(o_swaped, m_swaped)
@@ -324,7 +324,7 @@ def onp_reshape(input_array):
@pytest.mark.env_onecard @pytest.mark.env_onecard
def test_reshape(): def test_reshape():
onp_array = onp.random.random((2, 3, 4)).astype('float32') onp_array = onp.random.random((2, 3, 4)).astype('float32')
mnp_array = mnp.asarray(onp_array)
mnp_array = to_tensor(onp_array)
o_reshaped = onp_reshape(onp_array) o_reshaped = onp_reshape(onp_array)
m_reshaped = mnp_reshape(mnp_array) m_reshaped = mnp_reshape(mnp_array)
check_all_results(o_reshaped, m_reshaped) check_all_results(o_reshaped, m_reshaped)
@@ -349,7 +349,7 @@ def onp_ravel(input_array):
@pytest.mark.env_onecard @pytest.mark.env_onecard
def test_ravel(): def test_ravel():
onp_array = onp.random.random((2, 3, 4)).astype('float32') onp_array = onp.random.random((2, 3, 4)).astype('float32')
mnp_array = mnp.asarray(onp_array)
mnp_array = to_tensor(onp_array)
o_ravel = onp_ravel(onp_array) o_ravel = onp_ravel(onp_array)
m_ravel = mnp_ravel(mnp_array).asnumpy() m_ravel = mnp_ravel(mnp_array).asnumpy()
match_array(o_ravel, m_ravel) match_array(o_ravel, m_ravel)
@@ -380,7 +380,7 @@ def onp_concatenate(input_array):
@pytest.mark.env_onecard @pytest.mark.env_onecard
def test_concatenate(): def test_concatenate():
onp_array = onp.random.random((5, 4, 3, 2)).astype('float32') onp_array = onp.random.random((5, 4, 3, 2)).astype('float32')
mnp_array = mnp.asarray(onp_array)
mnp_array = to_tensor(onp_array)
o_concatenate = onp_concatenate(onp_array) o_concatenate = onp_concatenate(onp_array)
m_concatenate = mnp_concatenate(mnp_array) m_concatenate = mnp_concatenate(mnp_array)
check_all_results(o_concatenate, m_concatenate) check_all_results(o_concatenate, m_concatenate)
@@ -407,8 +407,8 @@ def onp_append(arr1, arr2):
def test_append(): def test_append():
onp_array = onp.random.random((4, 3, 2)).astype('float32') onp_array = onp.random.random((4, 3, 2)).astype('float32')
onp_value = onp.random.random((4, 3, 2)).astype('float32') onp_value = onp.random.random((4, 3, 2)).astype('float32')
mnp_array = mnp.asarray(onp_array)
mnp_value = mnp.asarray(onp_value)
mnp_array = to_tensor(onp_array)
mnp_value = to_tensor(onp_value)
onp_res = onp_append(onp_array, onp_value) onp_res = onp_append(onp_array, onp_value)
mnp_res = mnp_append(mnp_array, mnp_value) mnp_res = mnp_append(mnp_array, mnp_value)
check_all_results(onp_res, mnp_res) check_all_results(onp_res, mnp_res)
@@ -424,13 +424,13 @@ def construct_arrays(n=1, ndim=1, axis=None, low=1, high=5):
onp_array1 = onp.random.randint( onp_array1 = onp.random.randint(
low=low, high=high, size=shape).astype(onp.float32) low=low, high=high, size=shape).astype(onp.float32)
onp_array_lst.append(onp_array1) onp_array_lst.append(onp_array1)
mnp_array_lst.append(mnp.asarray(onp_array1))
mnp_array_lst.append(to_tensor(onp_array1))
if axis is not None and axis < ndim: if axis is not None and axis < ndim:
new_shape[axis] += onp.random.randint(2) new_shape[axis] += onp.random.randint(2)
onp_array2 = onp.random.randint( onp_array2 = onp.random.randint(
low=low, high=high, size=new_shape).astype(onp.float32) low=low, high=high, size=new_shape).astype(onp.float32)
onp_array_lst.append(onp_array2) onp_array_lst.append(onp_array2)
mnp_array_lst.append(mnp.asarray(onp_array2))
mnp_array_lst.append(to_tensor(onp_array2))
return onp_array_lst, mnp_array_lst return onp_array_lst, mnp_array_lst


# Test np.xstack # Test np.xstack
@@ -656,7 +656,7 @@ def onp_ndarray_flatten(input_array):
@pytest.mark.env_onecard @pytest.mark.env_onecard
def test_ndarray_flatten(): def test_ndarray_flatten():
onp_array = onp.random.random((3, 4, 5)).astype('float32') onp_array = onp.random.random((3, 4, 5)).astype('float32')
mnp_array = mnp.asarray(onp_array)
mnp_array = to_tensor(onp_array)
o_flatten = onp_ndarray_flatten(onp_array) o_flatten = onp_ndarray_flatten(onp_array)
m_flatten = mnp_ndarray_flatten(mnp_array) m_flatten = mnp_ndarray_flatten(mnp_array)
check_all_results(o_flatten, m_flatten) check_all_results(o_flatten, m_flatten)
@@ -687,7 +687,7 @@ def onp_ndarray_transpose(input_array):
@pytest.mark.env_onecard @pytest.mark.env_onecard
def test_ndarray_transpose(): def test_ndarray_transpose():
onp_array = onp.random.random((3, 4, 5)).astype('float32') onp_array = onp.random.random((3, 4, 5)).astype('float32')
mnp_array = mnp.asarray(onp_array)
mnp_array = to_tensor(onp_array)
o_transposed = onp_ndarray_transpose(onp_array) o_transposed = onp_ndarray_transpose(onp_array)
m_transposed = mnp_ndarray_transpose(mnp_array) m_transposed = mnp_ndarray_transpose(mnp_array)
check_all_results(o_transposed, m_transposed) check_all_results(o_transposed, m_transposed)
@@ -716,7 +716,7 @@ def onp_ndarray_astype(input_array):
@pytest.mark.env_onecard @pytest.mark.env_onecard
def test_ndarray_astype(): def test_ndarray_astype():
onp_array = onp.random.random((3, 4, 5)).astype('float32') onp_array = onp.random.random((3, 4, 5)).astype('float32')
mnp_array = mnp.asarray(onp_array)
mnp_array = to_tensor(onp_array)
o_astype = onp_ndarray_astype(onp_array) o_astype = onp_ndarray_astype(onp_array)
m_astype = mnp_ndarray_astype(mnp_array) m_astype = mnp_ndarray_astype(mnp_array)
for arr1, arr2 in zip(o_astype, m_astype): for arr1, arr2 in zip(o_astype, m_astype):
@@ -747,7 +747,7 @@ def mnp_concatenate_type_promotion(mnp_array1, mnp_array2, mnp_array3, mnp_array
@pytest.mark.env_onecard @pytest.mark.env_onecard
def test_concatenate_type_promotion(): def test_concatenate_type_promotion():
onp_array = onp.random.random((5, 1)).astype('float32') onp_array = onp.random.random((5, 1)).astype('float32')
mnp_array = mnp.asarray(onp_array)
mnp_array = to_tensor(onp_array)
onp_array1 = onp_array.astype(onp.float16) onp_array1 = onp_array.astype(onp.float16)
onp_array2 = onp_array.astype(onp.bool_) onp_array2 = onp_array.astype(onp.bool_)
onp_array3 = onp_array.astype(onp.float32) onp_array3 = onp_array.astype(onp.float32)
@@ -1049,7 +1049,7 @@ def test_split():
onp_arrs = [ onp_arrs = [
onp.random.randint(1, 5, size=(9, 4, 5)).astype('float32') onp.random.randint(1, 5, size=(9, 4, 5)).astype('float32')
] ]
mnp_arrs = [mnp.asarray(arr) for arr in onp_arrs]
mnp_arrs = [to_tensor(arr) for arr in onp_arrs]
for onp_arr, mnp_arr in zip(onp_arrs, mnp_arrs): for onp_arr, mnp_arr in zip(onp_arrs, mnp_arrs):
o_split = onp_split(onp_arr) o_split = onp_split(onp_arr)
m_split = mnp_split(mnp_arr) m_split = mnp_split(mnp_arr)
@@ -1058,6 +1058,36 @@ def test_split():
match_array(expect, actual.asnumpy()) match_array(expect, actual.asnumpy())




def mnp_array_split(input_tensor):
a = mnp.array_split(input_tensor, indices_or_sections=4, axis=2)
b = mnp.array_split(input_tensor, indices_or_sections=3, axis=1)
c = mnp.array_split(input_tensor, indices_or_sections=6)
return a, b, c


def onp_array_split(input_array):
a = onp.array_split(input_array, indices_or_sections=4, axis=2)
b = onp.array_split(input_array, indices_or_sections=3, axis=1)
c = onp.array_split(input_array, indices_or_sections=6)
return a, b, c


@pytest.mark.level1
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_array_split():
onp_arr = onp.random.randint(1, 5, size=(9, 7, 13)).astype('float32')
mnp_arr = to_tensor(onp_arr)
o_split = onp_split(onp_arr)
m_split = mnp_split(mnp_arr)
for expect_lst, actual_lst in zip(o_split, m_split):
for expect, actual in zip(expect_lst, actual_lst):
match_array(expect, actual.asnumpy())


def mnp_vsplit(input_tensor): def mnp_vsplit(input_tensor):
a = mnp.vsplit(input_tensor, indices_or_sections=3) a = mnp.vsplit(input_tensor, indices_or_sections=3)
b = mnp.vsplit(input_tensor, indices_or_sections=(-10, -4, 5, 10)) b = mnp.vsplit(input_tensor, indices_or_sections=(-10, -4, 5, 10))
@@ -1082,7 +1112,7 @@ def test_vsplit():
onp_arrs = [ onp_arrs = [
onp.random.randint(1, 5, size=(9, 4, 5)).astype('float32') onp.random.randint(1, 5, size=(9, 4, 5)).astype('float32')
] ]
mnp_arrs = [mnp.asarray(arr) for arr in onp_arrs]
mnp_arrs = [to_tensor(arr) for arr in onp_arrs]
for onp_arr, mnp_arr in zip(onp_arrs, mnp_arrs): for onp_arr, mnp_arr in zip(onp_arrs, mnp_arrs):
o_vsplit = onp_vsplit(onp_arr) o_vsplit = onp_vsplit(onp_arr)
m_vsplit = mnp_vsplit(mnp_arr) m_vsplit = mnp_vsplit(mnp_arr)
@@ -1115,7 +1145,7 @@ def test_hsplit():
onp_arrs = [ onp_arrs = [
onp.random.randint(1, 5, size=(4, 9, 5)).astype('float32') onp.random.randint(1, 5, size=(4, 9, 5)).astype('float32')
] ]
mnp_arrs = [mnp.asarray(arr) for arr in onp_arrs]
mnp_arrs = [to_tensor(arr) for arr in onp_arrs]
for onp_arr, mnp_arr in zip(onp_arrs, mnp_arrs): for onp_arr, mnp_arr in zip(onp_arrs, mnp_arrs):
o_hsplit = onp_hsplit(onp_arr) o_hsplit = onp_hsplit(onp_arr)
m_hsplit = mnp_hsplit(mnp_arr) m_hsplit = mnp_hsplit(mnp_arr)
@@ -1148,7 +1178,7 @@ def test_dsplit():
onp_arrs = [ onp_arrs = [
onp.random.randint(1, 5, size=(5, 4, 9)).astype('float32') onp.random.randint(1, 5, size=(5, 4, 9)).astype('float32')
] ]
mnp_arrs = [mnp.asarray(arr) for arr in onp_arrs]
mnp_arrs = [to_tensor(arr) for arr in onp_arrs]
for onp_arr, mnp_arr in zip(onp_arrs, mnp_arrs): for onp_arr, mnp_arr in zip(onp_arrs, mnp_arrs):
o_dsplit = onp_dsplit(onp_arr) o_dsplit = onp_dsplit(onp_arr)
m_dsplit = mnp_dsplit(mnp_arr) m_dsplit = mnp_dsplit(mnp_arr)
@@ -1248,6 +1278,29 @@ def test_repeat():
run_multi_test(mnp_repeat, onp_repeat, (x,)) run_multi_test(mnp_repeat, onp_repeat, (x,))




@pytest.mark.level1
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_select():
choicelist = rand_int(2, 3, 4, 5)
condlist = choicelist > 2
match_res(mnp.select, onp.select, condlist, choicelist)
match_res(mnp.select, onp.select, condlist, choicelist, default=10)

condlist = rand_bool(5, 4, 1, 3)
choicelist = rand_int(5, 3)
match_res(mnp.select, onp.select, condlist, choicelist)
match_res(mnp.select, onp.select, condlist, choicelist, default=10)

condlist = rand_bool(3, 1, 7)
choicelist = rand_int(3, 5, 2, 1)
match_res(mnp.select, onp.select, condlist, choicelist)
match_res(mnp.select, onp.select, condlist, choicelist, default=10)


class ReshapeExpandSqueeze(Cell): class ReshapeExpandSqueeze(Cell):
def __init__(self): def __init__(self):
super(ReshapeExpandSqueeze, self).__init__() super(ReshapeExpandSqueeze, self).__init__()
@@ -1333,7 +1386,7 @@ def test_swapaxes_exception():
@pytest.mark.env_onecard @pytest.mark.env_onecard
def test_tensor_flatten(): def test_tensor_flatten():
lst = [[1.0, 2.0], [3.0, 4.0]] lst = [[1.0, 2.0], [3.0, 4.0]]
tensor_list = mnp.asarray(lst)
tensor_list = to_tensor(lst)
assert tensor_list.flatten().asnumpy().tolist() == [1.0, 2.0, 3.0, 4.0] assert tensor_list.flatten().asnumpy().tolist() == [1.0, 2.0, 3.0, 4.0]
assert tensor_list.flatten(order='F').asnumpy().tolist() == [ assert tensor_list.flatten(order='F').asnumpy().tolist() == [
1.0, 3.0, 2.0, 4.0] 1.0, 3.0, 2.0, 4.0]
@@ -1347,7 +1400,7 @@ def test_tensor_flatten():
@pytest.mark.env_onecard @pytest.mark.env_onecard
def test_tensor_reshape(): def test_tensor_reshape():
lst = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0] lst = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]
tensor_list = mnp.asarray(lst)
tensor_list = to_tensor(lst)
with pytest.raises(TypeError): with pytest.raises(TypeError):
tensor_list = tensor_list.reshape({0, 1, 2}) tensor_list = tensor_list.reshape({0, 1, 2})
with pytest.raises(ValueError): with pytest.raises(ValueError):
@@ -1364,7 +1417,7 @@ def test_tensor_reshape():
@pytest.mark.env_onecard @pytest.mark.env_onecard
def test_tensor_squeeze(): def test_tensor_squeeze():
lst = [[[1.0], [2.0], [3.0]]] lst = [[[1.0], [2.0], [3.0]]]
tensor_list = mnp.asarray(lst)
tensor_list = to_tensor(lst)
with pytest.raises(TypeError): with pytest.raises(TypeError):
tensor_list = tensor_list.squeeze(1.2) tensor_list = tensor_list.squeeze(1.2)
with pytest.raises(ValueError): with pytest.raises(ValueError):
@@ -1381,7 +1434,7 @@ def test_tensor_squeeze():
@pytest.mark.env_onecard @pytest.mark.env_onecard
def test_tensor_ravel(): def test_tensor_ravel():
lst = [[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0]] lst = [[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0]]
tensor_list = mnp.asarray(lst)
tensor_list = to_tensor(lst)
assert tensor_list.ravel().shape == (8,) assert tensor_list.ravel().shape == (8,)
assert tensor_list.ravel().asnumpy().tolist() == [ assert tensor_list.ravel().asnumpy().tolist() == [
1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0] 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]
@@ -1395,9 +1448,47 @@ def test_tensor_ravel():
@pytest.mark.env_onecard @pytest.mark.env_onecard
def test_tensor_swapaxes(): def test_tensor_swapaxes():
lst = [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]] lst = [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]
tensor_list = mnp.asarray(lst)
tensor_list = to_tensor(lst)
with pytest.raises(TypeError): with pytest.raises(TypeError):
tensor_list = tensor_list.swapaxes(0, (1,)) tensor_list = tensor_list.swapaxes(0, (1,))
with pytest.raises(ValueError): with pytest.raises(ValueError):
tensor_list = tensor_list.swapaxes(0, 3) tensor_list = tensor_list.swapaxes(0, 3)
assert tensor_list.swapaxes(0, 1).shape == (3, 2) assert tensor_list.swapaxes(0, 1).shape == (3, 2)


def mnp_rot90(input_tensor):
a = mnp.rot90(input_tensor)
b = mnp.rot90(input_tensor, 2)
c = mnp.rot90(input_tensor, 3)
d = mnp.rot90(input_tensor, 4)
e = mnp.rot90(input_tensor, 5, (0, -1))
f = mnp.rot90(input_tensor, 1, (2, 0))
g = mnp.rot90(input_tensor, -3, (-1, -2))
h = mnp.rot90(input_tensor, 3, (2, 1))
return a, b, c, d, e, f, g, h


def onp_rot90(input_array):
a = onp.rot90(input_array)
b = onp.rot90(input_array, 2)
c = onp.rot90(input_array, 3)
d = onp.rot90(input_array, 4)
e = onp.rot90(input_array, 5, (0, -1))
f = onp.rot90(input_array, 1, (2, 0))
g = onp.rot90(input_array, -3, (-1, -2))
h = onp.rot90(input_array, 3, (2, 1))
return a, b, c, d, e, f, g, h


@pytest.mark.level1
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_rot90():
onp_array = rand_int(3, 4, 5).astype('float32')
mnp_array = to_tensor(onp_array)
o_rot = onp_rot90(onp_array)
m_rot = mnp_rot90(mnp_array)
check_all_results(o_rot, m_rot)

+ 149
- 2
tests/st/numpy_native/test_logic_ops.py View File

@@ -19,7 +19,8 @@ import numpy as onp


import mindspore.numpy as mnp import mindspore.numpy as mnp


from .utils import rand_int, run_binop_test, match_res
from .utils import rand_int, rand_bool, run_binop_test, run_logical_test, match_res, \
match_all_arrays, to_tensor




class Cases(): class Cases():
@@ -55,6 +56,15 @@ class Cases():
rand_int(8, 1, 6, 1) rand_int(8, 1, 6, 1)
] ]


# Boolean arrays
self.boolean_arrs = [
rand_bool(),
rand_bool(5),
rand_bool(6, 1),
rand_bool(7, 1, 5),
rand_bool(8, 1, 6, 1)
]

# array which contains infs and nans # array which contains infs and nans
self.infs = onp.array([[1.0, onp.nan], [onp.inf, onp.NINF], [2.3, -4.5], [onp.nan, 0.0]]) self.infs = onp.array([[1.0, onp.nan], [onp.inf, onp.NINF], [2.3, -4.5], [onp.nan, 0.0]])


@@ -246,10 +256,147 @@ def test_isneginf():
match_res(mnp_isneginf, onp_isneginf, test_case.infs) match_res(mnp_isneginf, onp_isneginf, test_case.infs)




@pytest.mark.level1
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_isscalar(): def test_isscalar():
assert mnp.isscalar(1) == onp.isscalar(1) assert mnp.isscalar(1) == onp.isscalar(1)
assert mnp.isscalar(2.3) == onp.isscalar(2.3) assert mnp.isscalar(2.3) == onp.isscalar(2.3)
assert mnp.isscalar([4.5]) == onp.isscalar([4.5]) assert mnp.isscalar([4.5]) == onp.isscalar([4.5])
assert mnp.isscalar(False) == onp.isscalar(False) assert mnp.isscalar(False) == onp.isscalar(False)
assert mnp.isscalar(mnp.array(True)) == onp.isscalar(onp.array(True))
assert mnp.isscalar(to_tensor(True)) == onp.isscalar(onp.array(True))
assert mnp.isscalar('numpy') == onp.isscalar('numpy') assert mnp.isscalar('numpy') == onp.isscalar('numpy')


@pytest.mark.level1
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_isclose():
a = [0, 1, 2, float('inf'), float('inf'), float('nan')]
b = [0, 1, -2, float('-inf'), float('inf'), float('nan')]
match_all_arrays(mnp.isclose(a, b), onp.isclose(a, b))
match_all_arrays(mnp.isclose(a, b, equal_nan=True), onp.isclose(a, b, equal_nan=True))

a = rand_int(2, 3, 4, 5)
diff = (onp.random.random((2, 3, 4, 5)).astype("float32") - 0.5) / 1000
b = a + diff
match_all_arrays(mnp.isclose(to_tensor(a), to_tensor(b), atol=1e-3), onp.isclose(a, b, atol=1e-3))
match_all_arrays(mnp.isclose(to_tensor(a), to_tensor(b), atol=1e-3, rtol=1e-4),
onp.isclose(a, b, atol=1e-3, rtol=1e-4))
match_all_arrays(mnp.isclose(to_tensor(a), to_tensor(b), atol=1e-2, rtol=1e-6),
onp.isclose(a, b, atol=1e-2, rtol=1e-6))

a = rand_int(2, 3, 4, 5)
b = rand_int(4, 5)
match_all_arrays(mnp.isclose(to_tensor(a), to_tensor(b)), onp.isclose(a, b))


@pytest.mark.level1
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_in1d():
xi = [rand_int(), rand_int(1), rand_int(10)]
yi = [rand_int(), rand_int(1), rand_int(10)]
for x in xi:
for y in yi:
match_res(mnp.in1d, onp.in1d, x, y)
match_res(mnp.in1d, onp.in1d, x, y, invert=True)


@pytest.mark.level1
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_isin():
xi = [rand_int(), rand_int(1), rand_int(10), rand_int(2, 3)]
yi = [rand_int(), rand_int(1), rand_int(10), rand_int(2, 3)]
for x in xi:
for y in yi:
match_res(mnp.in1d, onp.in1d, x, y)
match_res(mnp.in1d, onp.in1d, x, y, invert=True)


def mnp_logical_or(x1, x2):
return mnp.logical_or(x1, x2)


def onp_logical_or(x1, x2):
return onp.logical_or(x1, x2)


@pytest.mark.level1
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_logical_or():
run_logical_test(mnp_logical_or, onp_logical_or, test_case)


def mnp_logical_xor(x1, x2):
return mnp.logical_xor(x1, x2)


def onp_logical_xor(x1, x2):
return onp.logical_xor(x1, x2)


@pytest.mark.level1
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_logical_xor():
run_logical_test(mnp_logical_xor, onp_logical_xor, test_case)


def mnp_logical_and(x1, x2):
return mnp.logical_and(x1, x2)


def onp_logical_and(x1, x2):
return onp.logical_and(x1, x2)


@pytest.mark.level1
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_logical_and():
run_logical_test(mnp_logical_and, onp_logical_and, test_case)


def mnp_logical_not(x):
return mnp.logical_not(x)


def onp_logical_not(x):
return onp.logical_not(x)


@pytest.mark.level1
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_logical_not():
for arr in test_case.boolean_arrs:
expected = onp_logical_not(arr)
actual = mnp_logical_not(to_tensor(arr))
onp.testing.assert_equal(actual.asnumpy().tolist(), expected.tolist())

+ 850
- 96
tests/st/numpy_native/test_math_ops.py
File diff suppressed because it is too large
View File


+ 24
- 4
tests/st/numpy_native/utils.py View File

@@ -15,6 +15,7 @@
"""utility functions for mindspore.numpy st tests""" """utility functions for mindspore.numpy st tests"""
import functools import functools
import numpy as onp import numpy as onp
from mindspore import Tensor
import mindspore.numpy as mnp import mindspore.numpy as mnp




@@ -90,7 +91,9 @@ def rand_bool(*shape):


def match_res(mnp_fn, onp_fn, *arrs, **kwargs): def match_res(mnp_fn, onp_fn, *arrs, **kwargs):
"""Checks results from applying mnp_fn and onp_fn on arrs respectively""" """Checks results from applying mnp_fn and onp_fn on arrs respectively"""
mnp_arrs = map(functools.partial(mnp.asarray, dtype='float32'), arrs)
dtype = kwargs.get('dtype', mnp.float32)
kwargs.pop('dtype', None)
mnp_arrs = map(functools.partial(Tensor, dtype=dtype), arrs)
error = kwargs.get('error', 0) error = kwargs.get('error', 0)
kwargs.pop('error', None) kwargs.pop('error', None)
mnp_res = mnp_fn(*mnp_arrs, **kwargs) mnp_res = mnp_fn(*mnp_arrs, **kwargs)
@@ -151,15 +154,32 @@ def run_unary_test(mnp_fn, onp_fn, test_case, error=0):




def run_multi_test(mnp_fn, onp_fn, arrs, error=0): def run_multi_test(mnp_fn, onp_fn, arrs, error=0):
mnp_arrs = map(mnp.asarray, arrs)
mnp_arrs = map(Tensor, arrs)
for actual, expected in zip(mnp_fn(*mnp_arrs), onp_fn(*arrs)): for actual, expected in zip(mnp_fn(*mnp_arrs), onp_fn(*arrs)):
match_array(actual.asnumpy(), expected, error)
match_all_arrays(actual, expected, error)




def run_single_test(mnp_fn, onp_fn, arr, error=0): def run_single_test(mnp_fn, onp_fn, arr, error=0):
mnp_arr = mnp.asarray(arr)
mnp_arr = Tensor(arr)
for actual, expected in zip(mnp_fn(mnp_arr), onp_fn(arr)): for actual, expected in zip(mnp_fn(mnp_arr), onp_fn(arr)):
if isinstance(expected, tuple): if isinstance(expected, tuple):
for actual_arr, expected_arr in zip(actual, expected): for actual_arr, expected_arr in zip(actual, expected):
match_array(actual_arr.asnumpy(), expected_arr, error) match_array(actual_arr.asnumpy(), expected_arr, error)
match_array(actual.asnumpy(), expected, error) match_array(actual.asnumpy(), expected, error)


def run_logical_test(mnp_fn, onp_fn, test_case):
for x1 in test_case.boolean_arrs:
for x2 in test_case.boolean_arrs:
match_res(mnp_fn, onp_fn, x1, x2, dtype=mnp.bool_)

def to_tensor(obj, dtype=None):
if dtype is None:
res = Tensor(obj)
if res.dtype == mnp.float64:
res = res.astype(mnp.float32)
if res.dtype == mnp.int64:
res = res.astype(mnp.int32)
else:
res = Tensor(obj, dtype)
return res

Loading…
Cancel
Save