Browse Source

!13452 numpy-native remove maximum minimum from ascend ci

From: @jachua
Reviewed-by: 
Signed-off-by:
tags/v1.2.0-rc1
mindspore-ci-bot Gitee 4 years ago
parent
commit
4e1e16c335
9 changed files with 84 additions and 304 deletions
  1. +2
    -2
      mindspore/numpy/array_creations.py
  2. +2
    -1
      mindspore/numpy/array_ops.py
  3. +0
    -33
      mindspore/numpy/logic_ops.py
  4. +23
    -135
      mindspore/numpy/math_ops.py
  5. +1
    -0
      mindspore/numpy/utils_const.py
  6. +3
    -3
      mindspore/ops/composite/math_ops.py
  7. +11
    -37
      tests/st/numpy_native/test_array_creations.py
  8. +19
    -67
      tests/st/numpy_native/test_array_ops.py
  9. +23
    -26
      tests/st/numpy_native/test_math_ops.py

+ 2
- 2
mindspore/numpy/array_creations.py View File

@@ -168,7 +168,7 @@ def asfarray_const(a, dtype=mstype.float32):
a = _deep_tensor_to_nparray(a)
a = onp.asarray(a)
if a.dtype is onp.dtype('object'):
raise TypeError(f"For Tensor conversion, the input_data is {a} that contains unsupported element.")
raise ValueError(f"For Tensor conversion, the input_data is {a} that contains unsupported element.")
a = Tensor.from_numpy(a)

return Tensor(a, dtype)
@@ -214,7 +214,7 @@ def asfarray(a, dtype=mstype.float32):
if isinstance(a, Tensor):
return a.astype(dtype)

return asfarray_const(a)
return asfarray_const(a, dtype)


def copy_(a):


+ 2
- 1
mindspore/numpy/array_ops.py View File

@@ -30,7 +30,8 @@ from .utils_const import _check_axes_range, _check_start_normalize, \
_check_same_type, _check_axis_valid, _add_unit_axes, _broadcast_tuples, \
_check_is_float, _check_axis_in_range, _check_axis_type, _canonicalize_axis, \
_list_comprehensions, _check_element_int, _is_shape_empty, _type_convert, \
_tuple_getitem, _expanded_shape, _seq_prod, _get_device, _tuple_setitem
_tuple_getitem, _expanded_shape, _seq_prod, _get_device, _tuple_setitem, \
_raise_unimplemented_error

# According to official numpy reference, the dimension of a numpy array must be less
# than 32


+ 0
- 33
mindspore/numpy/logic_ops.py View File

@@ -84,9 +84,6 @@ def less_equal(x1, x2, dtype=None):
bool, unless `dtype` is passed. This is a scalar if both `x1` and `x2` are
scalars.

Raises:
TypeError: if the input is not a tensor.

Supported Platforms:
``Ascend`` ``GPU`` ``CPU``

@@ -120,9 +117,6 @@ def less(x1, x2, dtype=None):
bool, unless `dtype` is passed. This is a scalar if both `x1` and `x2` are
scalars.

Raises:
TypeError: if the input is not a tensor.

Supported Platforms:
``Ascend`` ``GPU`` ``CPU``

@@ -155,9 +149,6 @@ def greater_equal(x1, x2, dtype=None):
bool, unless `dtype` is passed. This is a scalar if both `x1` and `x2` are
scalars.

Raises:
TypeError: if the input is not a tensor.

Supported Platforms:
``Ascend`` ``GPU`` ``CPU``

@@ -190,9 +181,6 @@ def greater(x1, x2, dtype=None):
bool, unless `dtype` is passed. This is a scalar if both `x1` and `x2` are
scalars.

Raises:
TypeError: if the input is not a tensor.

Supported Platforms:
``Ascend`` ``GPU`` ``CPU``

@@ -225,9 +213,6 @@ def equal(x1, x2, dtype=None):
bool, unless `dtype` is passed. This is a scalar if both `x1` and `x2` are
scalars.

Raises:
TypeError: if the input is not a tensor.

Supported Platforms:
``Ascend`` ``GPU`` ``CPU``

@@ -260,9 +245,6 @@ def isfinite(x, dtype=None):
Tensor or scalar, true where `x` is not positive infinity, negative infinity,
or NaN; false otherwise. This is a scalar if `x` is a scalar.

Raises:
TypeError: if the input is not a tensor.

Supported Platforms:
``Ascend`` ``GPU`` ``CPU``

@@ -296,9 +278,6 @@ def isnan(x, dtype=None):
Tensor or scalar, true where `x` is NaN, false otherwise. This is a scalar if
`x` is a scalar.

Raises:
TypeError: if the input is not a tensor.

Supported Platforms:
``GPU`` ``CPU``

@@ -346,9 +325,6 @@ def isinf(x, dtype=None):
Tensor or scalar, true where `x` is positive or negative infinity, false
otherwise. This is a scalar if `x` is a scalar.

Raises:
TypeError: if the input is not a tensor.

Supported Platforms:
``GPU`` ``CPU``

@@ -688,9 +664,6 @@ def logical_or(x1, x2, dtype=None):
bool, unless ``dtype=object`` is passed. This is a scalar if both `x1` and `x2` are
scalars.

Raises:
TypeError: if the input is not a tensor.

Supported Platforms:
``Ascend`` ``GPU`` ``CPU``

@@ -725,9 +698,6 @@ def logical_and(x1, x2, dtype=None):
Boolean result of the logical AND operation applied to the elements of `x1` and `x2`;
the shape is determined by broadcasting. This is a scalar if both `x1` and `x2` are scalars.

Raises:
TypeError: if the input is not a tensor.

Supported Platforms:
``Ascend`` ``GPU`` ``CPU``

@@ -762,9 +732,6 @@ def logical_xor(x1, x2, dtype=None):
Boolean result of the logical AND operation applied to the elements of `x1` and `x2`;
the shape is determined by broadcasting. This is a scalar if both `x1` and `x2` are scalars.

Raises:
TypeError: if the input is not a tensor.

Supported Platforms:
``Ascend`` ``GPU`` ``CPU``



+ 23
- 135
mindspore/numpy/math_ops.py View File

@@ -109,9 +109,6 @@ def count_nonzero(x, axis=None, keepdims=False):
Tensor, indicating number of non-zero values in the `x` along a given axis.
Otherwise, the total number of non-zero values in `x` is returned.

Raises:
TypeError: if the input is not a tensor.

Supported Platforms:
``Ascend`` ``GPU`` ``CPU``

@@ -217,9 +214,6 @@ def rad2deg(x, dtype=None):
Tensor, the corresponding angle in degrees. This is a tensor scalar if `x`
is a tensor scalar.

Raises:
TypeError: if the input is not a tensor.

Supported Platforms:
``Ascend`` ``GPU`` ``CPU``

@@ -255,9 +249,6 @@ def add(x1, x2, dtype=None):
Tensor or scalar, the sum of `x1` and `x2`, element-wise. This is a scalar
if both `x1` and `x2` are scalars.

Raises:
TypeError: if the input is not a tensor.

Supported Platforms:
``Ascend`` ``GPU`` ``CPU``

@@ -297,9 +288,6 @@ def subtract(x1, x2, dtype=None):
Tensor or scalar, the difference of `x1` and `x2`, element-wise. This is a
scalar if both `x1` and `x2` are scalars.

Raises:
TypeError: if the input is not a tensor.

Supported Platforms:
``Ascend`` ``GPU`` ``CPU``

@@ -334,9 +322,6 @@ def multiply(x1, x2, dtype=None):
Tensor or scalar, the product of `x1` and `x2`, element-wise. This is a scalar
if both `x1` and `x2` are scalars.

Raises:
TypeError: if the input is not a tensor.

Supported Platforms:
``Ascend`` ``GPU`` ``CPU``

@@ -380,9 +365,6 @@ def divide(x1, x2, dtype=None):
Returns:
Tensor or scalar, this is a scalar if both `x1` and `x2` are scalars.

Raises:
TypeError: if the input is not a tensor.

Supported Platforms:
``Ascend`` ``GPU`` ``CPU``

@@ -422,9 +404,6 @@ def true_divide(x1, x2, dtype=None):
Returns:
Tensor or scalar, this is a scalar if both `x1` and `x2` are scalars.

Raises:
TypeError: if the input is not a tensor.

Supported Platforms:
``Ascend`` ``GPU`` ``CPU``

@@ -462,9 +441,6 @@ def power(x1, x2, dtype=None):
Tensor or scalar, the bases in `x1` raised to the exponents in `x2`. This
is a scalar if both `x1` and `x2` are scalars.

Raises:
TypeError: if the input is not a tensor.

Supported Platforms:
``Ascend`` ``GPU`` ``CPU``

@@ -507,9 +483,6 @@ def float_power(x1, x2, dtype=None):
Tensor or scalar, the bases in `x1` raised to the exponents in `x2`. This
is a scalar if both `x1` and `x2` are scalars.

Raises:
TypeError: if the input is not a tensor.

Supported Platforms:
``Ascend`` ``GPU`` ``CPU``

@@ -538,9 +511,7 @@ def minimum(x1, x2, dtype=None):
Note:
Numpy arguments `out`, `where`, `casting`, `order`, `subok`, `signature`, and `extobj` are
not supported.
Unlike numpy, when one of the elements is a NaN, the second element is
always returned regardless of whether the second element is a NaN, instead
of returning NaN.
On Ascend, input arrays containing inf or NaN are not supported.

Args:
x1 (Tensor): first input tensor to be compared.
@@ -1166,9 +1137,6 @@ def square(x, dtype=None):
Tensor or scalar, element-wise ``x*x``, of the same shape and dtype as `x`.
This is a scalar if `x` is a scalar..

Raises:
TypeError: if the input is not a tensor.

Supported Platforms:
``Ascend`` ``GPU`` ``CPU``

@@ -1201,9 +1169,6 @@ def sqrt(x, dtype=None):
square-root of each element in `x`. For negative elements, nan is returned.
This is a scalar if `x` is a scalar.

Raises:
TypeError: if the input is not a tensor.

Supported Platforms:
``Ascend`` ``GPU`` ``CPU``

@@ -1242,9 +1207,6 @@ def reciprocal(x, dtype=None):
Returns:
Tensor or scalar, this is a scalar if `x` is a scalar.

Raises:
TypeError: if the input is not a tensor.

Supported Platforms:
``Ascend`` ``GPU`` ``CPU``

@@ -1283,9 +1245,6 @@ def log(x, dtype=None):
Tensor or scalar, the natural logarithm of `x`, element-wise. This is a
scalar if `x` is a scalar.

Raises:
TypeError: if the input is not a tensor.

Supported Platforms:
``Ascend`` ``GPU`` ``CPU``

@@ -1316,9 +1275,7 @@ def maximum(x1, x2, dtype=None):
Note:
Numpy arguments `out`, `where`, `casting`, `order`, `subok`, `signature`, and `extobj` are
not supported.
Unlike numpy, when one of the elements is a NaN, the second element is
always returned regardless of whether the second element is a NaN, instead
of returning NaN.
On Ascend, input arrays containing inf or NaN are not supported.

Args:
x1 (Tensor): Input array
@@ -1332,9 +1289,6 @@ def maximum(x1, x2, dtype=None):
Tensor or scalar, the maximum of `x1` and `x2`, element-wise. This is a scalar
if both `x1` and `x2` are scalars.

Raises:
TypeError: if the input is not a tensor.

Supported Platforms:
``Ascend`` ``GPU`` ``CPU``

@@ -1385,9 +1339,6 @@ def heaviside(x1, x2, dtype=None):
Tensor or scalar, the output array, element-wise Heaviside step function
of `x1`. This is a scalar if both `x1` and `x2` are scalars.

Raises:
TypeError: if the input is not a tensor.

Supported Platforms:
``Ascend`` ``GPU`` ``CPU``

@@ -1562,9 +1513,6 @@ def hypot(x1, x2, dtype=None):
Tensor or scalar, the hypotenuse of the triangle(s). This is a scalar if
both `x1` and `x2` are scalars.

Raises:
TypeError: if the input is not a tensor.

Supported Platforms:
``Ascend`` ``GPU`` ``CPU``

@@ -1614,9 +1562,6 @@ def floor(x, dtype=None):
Tensor or scalar, the floor of each element in `x`. This is a scalar if `x`
is a scalar.

Raises:
TypeError: if the input is not a tensor.

Supported Platforms:
``Ascend`` ``GPU`` ``CPU``

@@ -1648,9 +1593,6 @@ def floor_divide(x1, x2, dtype=None):
Returns:
Tensor or scalar.

Raises:
TypeError: if the input is not a tensor.

Supported Platforms:
``Ascend`` ``GPU`` ``CPU``

@@ -1709,9 +1651,6 @@ def remainder(x1, x2, dtype=None):
Tensor or scalar, the element-wise remainder of the quotient
``floor_divide(x1, x2)``. This is a scalar if both `x1` and `x2` are scalars.

Raises:
TypeError: if the input is not a tensor.

Supported Platforms:
``Ascend`` ``GPU`` ``CPU``

@@ -1787,9 +1726,6 @@ def fmod(x1, x2, dtype=None):
Tensor or scalar, the remainder of the division of `x1` by `x2`. This is a
scalar if both `x1` and `x2` are scalars.

Raises:
TypeError: if the input is not a tensor.

Supported Platforms:
``Ascend`` ``GPU`` ``CPU``

@@ -1822,9 +1758,6 @@ def trunc(x, dtype=None):
Tensor or scalar, the truncated value of each element in `x`. This is a scalar if `x` is
a scalar.

Raises:
TypeError: if the input is not a tensor.

Supported Platforms:
``Ascend`` ``GPU`` ``CPU``

@@ -1859,9 +1792,6 @@ def exp(x, dtype=None):
Tensor or scalar, element-wise exponential of `x`. This is a scalar if both
`x1` and `x2` are scalars.

Raises:
TypeError: if the input is not a tensor.

Supported Platforms:
``Ascend`` ``GPU`` ``CPU``

@@ -1893,9 +1823,6 @@ def expm1(x, dtype=None):
Tensor or scalar, element-wise exponential minus one, ``out = exp(x) - 1``.
This is a scalar if both `x1` and `x2` are scalars.

Raises:
TypeError: if the input is not a tensor.

Supported Platforms:
``Ascend`` ``GPU`` ``CPU``

@@ -2117,6 +2044,7 @@ def trapz(y, x=None, dx=1.0, axis=-1):
``Ascend`` ``GPU`` ``CPU``

Examples:
>>> import mindspore.numpy as np
>>> a = np.arange(6).reshape(2, 3)
>>> output = np.trapz(a, x=[-2, 1, 2], axis=1)
>>> print(output)
@@ -2197,16 +2125,14 @@ def gcd(x1, x2, dtype=None):
Tensor or scalar, the greatest common divisor of the absolute value of the inputs.
This is a scalar if both `x1` and `x2` are scalars.

Raises:
TypeError: if the input is not a tensor.

Supported Platforms:
``Ascend`` ``GPU`` ``CPU``

Examples:
>>> import mindspore.numpy as np
>>> output = np.gcd(np.arange(6), np.array(20))
>>> print(output)
[20 1 2 1 4 5]
[20 1 2 1 4 5]
"""
return _apply_tensor_op(_gcd, x1, x2, dtype=dtype)

@@ -2229,16 +2155,14 @@ def lcm(x1, x2, dtype=None):
Tensor or scalar, the lowest common multiple of the absolute value of the inputs.
This is a scalar if both `x1` and `x2` are scalars.

Raises:
TypeError: if the input is not a tensor.

Supported Platforms:
``Ascend`` ``GPU`` ``CPU``

Examples:
>>> import mindspore.numpy as np
>>> output = np.lcm(np.arange(6), np.array(20))
>>> print(output)
[ 0 20 20 60 20 20]
[ 0 20 20 60 20 20]
"""
def _lcm(x1, x2):
"""Calculates lcm without applying keyword arguments"""
@@ -2290,7 +2214,7 @@ def convolve(a, v, mode='full'):
>>> import mindspore.numpy as np
>>> output = np.convolve([1., 2., 3., 4., 5.], [2., 3.], mode="valid")
>>> print(output)
[ 3. 6. 9. 12.]
[ 3. 6. 9. 12.]
"""
if not isinstance(a, Tensor):
a = asarray_const(a)
@@ -2406,6 +2330,7 @@ def cov(m, y=None, rowvar=True, bias=False, ddof=None, fweights=None, aweights=N
``Ascend`` ``GPU`` ``CPU``

Examples:
>>> import mindspore.numpy as np
>>> output = np.cov([[2., 3., 4., 5.], [0., 2., 3., 4.], [7., 8., 9., 10.]])
>>> print(output)
[[1.6666666 2.1666667 1.6666666]
@@ -2509,6 +2434,10 @@ def _reduce(a, reduce_fn, cmp_fn=None, axis=None, keepdims=False, initial=None,
if dtype is None:
dtype = F.dtype(a)
axes = _check_axis_valid(axis, ndim)
if initial is not None:
if ((isinstance(initial, Tensor) and F.rank(initial) > 0) or
not isinstance(initial, (int, float, bool, Tensor))):
_raise_type_error('initial should be scalar')

if _is_shape_empty(shape):
if not axes:
@@ -2578,6 +2507,7 @@ def nansum(a, axis=None, dtype=None, keepdims=False):
``GPU`` ``CPU``

Examples:
>>> import mindspore.numpy as np
>>> a = np.array([[1, 1], [1, np.nan]])
>>> output = np.nansum(a)
>>> print(output)
@@ -2638,6 +2568,7 @@ def nanmean(a, axis=None, dtype=None, keepdims=False):
``GPU`` ``CPU``

Examples:
>>> import mindspore.numpy as np
>>> a = np.array([[1, np.nan], [3, 4]])
>>> output = np.nanmean(a)
>>> print(output)
@@ -2700,6 +2631,7 @@ def nanvar(a, axis=None, dtype=None, ddof=0, keepdims=False):
``GPU`` ``CPU``

Examples:
>>> import mindspore.numpy as np
>>> a = np.array([[1, np.nan], [3, 4]])
>>> output = np.nanstd(a)
>>> print(output)
@@ -2752,6 +2684,7 @@ def nanstd(a, axis=None, dtype=None, ddof=0, keepdims=False):
``GPU`` ``CPU``

Examples:
>>> import mindspore.numpy as np
>>> a = np.array([[1, np.nan], [3, 4]])
>>> output = np.nanvar(a)
>>> print(output)
@@ -2784,13 +2717,11 @@ def exp2(x, dtype=None):
Returns:
Tensor or scalar, element-wise 2 to the power `x`.

Raises:
TypeError: if the input is not a tensor.

Supported Platforms:
``Ascend`` ``GPU`` ``CPU``

Examples:
>>> import mindspore.numpy as np
>>> x = np.array([2, 3]).astype(np.float32)
>>> output = np.exp2(x)
>>> print(output)
@@ -2817,6 +2748,7 @@ def kron(a, b):
``Ascend`` ``GPU`` ``CPU``

Examples:
>>> import mindspore.numpy as np
>>> output = np.kron([1,10,100], [5,6,7])
>>> print(output)
[ 5 6 7 50 60 70 500 600 700]
@@ -2885,6 +2817,7 @@ def cross(a, b, axisa=- 1, axisb=- 1, axisc=- 1, axis=None):
``Ascend`` ``GPU`` ``CPU``

Examples:
>>> import mindspore.numpy as np
>>> x = np.array([[1,2,3], [4,5,6]])
>>> y = np.array([[4,5,6], [1,2,3]])
>>> output = np.cross(x, y)
@@ -2968,13 +2901,11 @@ def ceil(x, dtype=None):
Returns:
Tensor or scalar, the floor of each element in `x`. This is a scalar if `x` is a scalar.

Raises:
TypeError: if the input is not a tensor.

Supported Platforms:
``Ascend`` ``GPU`` ``CPU``

Examples:
>>> import mindspore.numpy as np
>>> a = np.array([-1.7, -1.5, -0.2, 0.2, 1.5, 1.7, 2.0])
>>> output = np.ceil(a)
>>> print(output)
@@ -3086,6 +3017,7 @@ def cumsum(a, axis=None, dtype=None):
``Ascend`` ``GPU`` ``CPU``

Examples:
>>> import mindspore.numpy as np
>>> output = np.cumsum(np.ones((3,3)), axis=0)
>>> print(output)
[[1. 1. 1.]
@@ -3141,6 +3073,7 @@ def nancumsum(a, axis=None, dtype=None):
``GPU`` ``CPU``

Examples:
>>> import mindspore.numpy as np
>>> a = np.array([[1, 2], [3, np.nan]])
>>> output = np.nancumsum(a)
>>> print(output)
@@ -3212,9 +3145,6 @@ def log1p(x, dtype=None):
Returns:
Tensor or scalar. This is a scalar if `x` is a scalar.

Raises:
TypeError: if the input is not a tensor.

Supported Platforms:
``Ascend`` ``GPU`` ``CPU``

@@ -3251,9 +3181,6 @@ def logaddexp(x1, x2, dtype=None):
Returns:
Tensor or scalar. This is a scalar if both `x1` and `x2` are scalars.

Raises:
TypeError: if the input is not a tensor.

Supported Platforms:
``Ascend`` ``GPU`` ``CPU``

@@ -3286,9 +3213,6 @@ def log2(x, dtype=None):
Returns:
Tensor or scalar. This is a scalar if `x` is a scalar.

Raises:
TypeError: if the input is not a tensor.

Supported Platforms:
``Ascend`` ``GPU`` ``CPU``

@@ -3329,9 +3253,6 @@ def logaddexp2(x1, x2, dtype=None):
Returns:
Tensor or scalar. This is a scalar if both `x1` and `x2` are scalars.

Raises:
TypeError: if the input is not a tensor.

Supported Platforms:
``Ascend`` ``GPU`` ``CPU``

@@ -3364,9 +3285,6 @@ def log10(x, dtype=None):
Returns:
Tensor or scalar. This is a scalar if `x` is a scalar.

Raises:
TypeError: if the input is not a tensor.

Supported Platforms:
``Ascend`` ``GPU`` ``CPU``

@@ -3407,9 +3325,6 @@ def sin(x, dtype=None):
Returns:
Tensor or scalar. This is a scalar if `x` is a scalar.

Raises:
TypeError: if the input is not a tensor.

Supported Platforms:
``Ascend`` ``GPU`` ``CPU``

@@ -3440,9 +3355,6 @@ def cos(x, dtype=None):
Returns:
Tensor or scalar. This is a scalar if `x` is a scalar.

Raises:
TypeError: if the input is not a tensor.

Supported Platforms:
``Ascend`` ``GPU`` ``CPU``

@@ -3575,9 +3487,6 @@ def arctan(x, dtype=None):
Returns:
Tensor or scalar. This is a scalar if `x` is a scalar.

Raises:
TypeError: if the input is not a tensor.

Supported Platforms:
``Ascend`` ``GPU`` ``CPU``

@@ -3607,9 +3516,6 @@ def sinh(x, dtype=None):
Returns:
Tensor or scalar. This is a scalar if `x` is a scalar.

Raises:
TypeError: if the input is not a tensor.

Supported Platforms:
``Ascend`` ``CPU``

@@ -3639,9 +3545,6 @@ def cosh(x, dtype=None):
Returns:
Tensor or scalar. This is a scalar if `x` is a scalar.

Raises:
TypeError: if the input is not a tensor.

Supported Platforms:
``Ascend`` ``CPU``

@@ -3671,9 +3574,6 @@ def tanh(x, dtype=None):
Returns:
Tensor or scalar. This is a scalar if `x` is a scalar.

Raises:
TypeError: if the input is not a tensor.

Supported Platforms:
``Ascend`` ``GPU`` ``CPU``

@@ -3703,9 +3603,6 @@ def arcsinh(x, dtype=None):
Returns:
Tensor or scalar. This is a scalar if `x` is a scalar.

Raises:
TypeError: if the input is not a tensor.

Supported Platforms:
``Ascend`` ``GPU`` ``CPU``

@@ -3735,9 +3632,6 @@ def arccosh(x, dtype=None):
Returns:
Tensor or scalar. This is a scalar if `x` is a scalar.

Raises:
TypeError: if the input is not a tensor.

Supported Platforms:
``Ascend`` ``GPU`` ``CPU``

@@ -3767,9 +3661,6 @@ def arctanh(x, dtype=None):
Returns:
Tensor or scalar. This is a scalar if `x` is a scalar.

Raises:
TypeError: if the input is not a tensor.

Supported Platforms:
``Ascend`` ``CPU``

@@ -3801,9 +3692,6 @@ def arctan2(x1, x2, dtype=None):
Tensor or scalar, the sum of `x1` and `x2`, element-wise. This is a scalar
if both `x1` and `x2` are scalars.

Raises:
TypeError: if the input is not a tensor.

Supported Platforms:
``Ascend`` ``CPU``



+ 1
- 0
mindspore/numpy/utils_const.py View File

@@ -472,6 +472,7 @@ def _make_tensor(val, dtype):
return Tensor(val, dtype)


@constexpr
def _tuple_slice(tup, start, end):
"""get sliced tuple from start and end."""
return tup[start:end]

+ 3
- 3
mindspore/ops/composite/math_ops.py View File

@@ -591,9 +591,9 @@ def matmul(x1, x2, dtype=None):
``Ascend`` ``GPU`` ``CPU``

Examples:
>>> x1 = np.arange(2*3*4).reshape(2, 3, 4).astype('float32')
>>> x2 = np.arange(4*5).reshape(4, 5).astype('float32')
>>> output = np.matmul(x1, x2)
>>> x1 = Tensor(np.arange(2*3*4).reshape(2, 3, 4), mindspore.float32)
>>> x2 = Tensor(np.arange(4*5).reshape(4, 5), mindspore.float32)
>>> output = ops.matmul(x1, x2)
>>> print(output)
[[[ 70. 76. 82. 88. 94.]
[ 190. 212. 234. 256. 278.]


+ 11
- 37
tests/st/numpy_native/test_array_creations.py View File

@@ -26,7 +26,7 @@ from .utils import rand_int, rand_bool, match_array, match_res, match_meta, \
class Cases():
def __init__(self):
self.all_shapes = [
0, 1, 2, (), (1,), (2,), (1, 2, 3), [], [1], [2], [1, 2, 3]
1, 2, (1,), (2,), (1, 2, 3), [1], [2], [1, 2, 3]
]
self.onp_dtypes = [onp.int32, 'int32', int,
onp.float32, 'float32', float,
@@ -94,18 +94,16 @@ class Cases():

self.mnp_prototypes = [
mnp.ones((2, 3, 4)),
mnp.ones((0, 3, 0, 2, 5)),
mnp.ones((2, 7, 0)),
mnp.ones(()),
mnp.ones((1, 3, 1, 2, 5)),
mnp.ones((2, 7, 1)),
[mnp.ones(3), (1, 2, 3), mnp.ones(3), [4, 5, 6]],
([(1, 2), mnp.ones(2)], (mnp.ones(2), [3, 4])),
]

self.onp_prototypes = [
onp.ones((2, 3, 4)),
onp.ones((0, 3, 0, 2, 5)),
onp.ones((2, 7, 0)),
onp.ones(()),
onp.ones((1, 3, 1, 2, 5)),
onp.ones((2, 7, 1)),
[onp.ones(3), (1, 2, 3), onp.ones(3), [4, 5, 6]],
([(1, 2), onp.ones(2)], (onp.ones(2), [3, 4])),
]
@@ -257,10 +255,6 @@ def test_full():
expected = mnp.full((2, 2), [1, 2]).asnumpy()
match_array(actual, expected)

actual = onp.full((2, 0), onp.inf)
expected = mnp.full((2, 0), mnp.inf).asnumpy()
match_array(actual, expected)

actual = onp.full((2, 3), True)
expected = mnp.full((2, 3), True).asnumpy()
match_array(actual, expected)
@@ -579,29 +573,19 @@ def onp_diagonal(arr):
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_diagonal():
arr = rand_int(0, 0)
match_res(mnp.diagonal, onp.diagonal, arr, offset=1)

arr = rand_int(3, 5)
for i in [-10, -5, -1, 0, 2, 5, 6, 10]:
for i in [-1, 0, 2]:
match_res(mnp.diagonal, onp.diagonal, arr, offset=i, axis1=0, axis2=1)
match_res(mnp.diagonal, onp.diagonal, arr, offset=i, axis1=1, axis2=0)

arr = rand_int(7, 4, 9)
for i in [-10, -5, -1, 0, 2, 5, 6, 10]:
for i in [-1, 0, 2]:
match_res(mnp.diagonal, onp.diagonal, arr, offset=i, axis1=0, axis2=-1)
match_res(mnp.diagonal, onp.diagonal, arr, offset=i, axis1=-2, axis2=2)
match_res(mnp.diagonal, onp.diagonal, arr,
offset=i, axis1=-1, axis2=-2)

arr = rand_int(2, 5, 8, 1)
match_res(mnp_diagonal, onp_diagonal, arr)
for i in [-10, -5, -1, 0, 2, 5, 6, 10]:
match_res(mnp.diagonal, onp.diagonal, arr, offset=i, axis1=-3, axis2=2)
match_res(mnp.diagonal, onp.diagonal, arr, offset=i, axis1=1, axis2=3)
match_res(mnp.diagonal, onp.diagonal, arr, offset=i, axis1=0, axis2=-2)
match_res(mnp.diagonal, onp.diagonal, arr, offset=i, axis1=2, axis2=-1)


def mnp_trace(arr):
return mnp.trace(arr, offset=4, axis1=1, axis2=2)
@@ -618,27 +602,18 @@ def onp_trace(arr):
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_trace():
arr = rand_int(0, 0)
match_res(mnp.trace, onp.trace, arr, offset=1)

arr = rand_int(3, 5)
for i in [-10, -5, -1, 0, 2, 5, 6, 10]:
for i in [-1, 0]:
match_res(mnp.trace, onp.trace, arr, offset=i, axis1=0, axis2=1)
match_res(mnp.trace, onp.trace, arr, offset=i, axis1=1, axis2=0)

arr = rand_int(7, 4, 9)
for i in [-10, -5, -1, 0, 2, 5, 6, 10]:
for i in [-1, 0, 2]:
match_res(mnp.trace, onp.trace, arr, offset=i, axis1=0, axis2=-1)
match_res(mnp.trace, onp.trace, arr, offset=i, axis1=-2, axis2=2)
match_res(mnp.trace, onp.trace, arr, offset=i, axis1=-1, axis2=-2)

arr = rand_int(2, 5, 8, 1)
match_res(mnp_trace, onp_trace, arr)
for i in [-10, -5, -1, 0, 2, 5, 6, 10]:
match_res(mnp.trace, onp.trace, arr, offset=i, axis1=-3, axis2=2)
match_res(mnp.trace, onp.trace, arr, offset=i, axis1=1, axis2=3)
match_res(mnp.trace, onp.trace, arr, offset=i, axis1=0, axis2=-2)
match_res(mnp.trace, onp.trace, arr, offset=i, axis1=2, axis2=-1)


def mnp_meshgrid(*xi):
@@ -712,7 +687,7 @@ def test_ogrid():
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_diagflat():
arrs = [rand_int(0), rand_int(2, 3), rand_int(3, 5, 0)]
arrs = [rand_int(2, 3)]
for arr in arrs:
for i in [-2, 0, 7]:
match_res(mnp.diagflat, onp.diagflat, arr, k=i)
@@ -725,8 +700,7 @@ def test_diagflat():
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_diag():
arrs = [rand_int(0), rand_int(0, 0), rand_int(7), rand_int(5, 5),
rand_int(3, 8), rand_int(9, 6)]
arrs = [rand_int(7), rand_int(5, 5), rand_int(3, 8), rand_int(9, 6)]
for arr in arrs:
for i in [-10, -5, -1, 0, 2, 5, 6, 10]:
match_res(mnp.diag, onp.diag, arr, k=i)


+ 19
- 67
tests/st/numpy_native/test_array_ops.py View File

@@ -29,7 +29,7 @@ from .utils import rand_int, run_non_kw_test, check_all_results, match_array, \
class Cases():
def __init__(self):
self.all_shapes = [
0, 1, 2, (), (1,), (2,), (1, 2, 3), [], [1], [2], [1, 2, 3]
1, 2, (1,), (2,), (1, 2, 3), [1], [2], [1, 2, 3]
]
self.onp_dtypes = [onp.int32, 'int32', int,
onp.float32, 'float32', float,
@@ -97,18 +97,12 @@ class Cases():

self.mnp_prototypes = [
mnp.ones((2, 3, 4)),
mnp.ones((0, 3, 0, 2, 5)),
onp.ones((2, 7, 0)),
onp.ones(()),
[mnp.ones(3), (1, 2, 3), onp.ones(3), [4, 5, 6]],
([(1, 2), mnp.ones(2)], (onp.ones(2), [3, 4])),
]

self.onp_prototypes = [
onp.ones((2, 3, 4)),
onp.ones((0, 3, 0, 2, 5)),
onp.ones((2, 7, 0)),
onp.ones(()),
[onp.ones(3), (1, 2, 3), onp.ones(3), [4, 5, 6]],
([(1, 2), onp.ones(2)], (onp.ones(2), [3, 4])),
]
@@ -794,11 +788,6 @@ def test_stack():
for i in range(-4, 4):
match_res(mnp.stack, onp.stack, arr, axis=i)

arr = rand_int(7, 4, 0, 3)
match_res(mnp.stack, onp.stack, arr)
for i in range(-4, 4):
match_res(mnp.stack, onp.stack, arr, axis=i)

arrs = [rand_int(3, 4, 5) for i in range(10)]
match_res(mnp.stack, onp.stack, arrs)
match_res(mnp.stack, onp.stack, tuple(arrs))
@@ -806,13 +795,6 @@ def test_stack():
for i in range(-4, 4):
match_res(mnp.stack, onp.stack, arrs, axis=i)

arrs = [rand_int(3, 0, 5, 8, 0) for i in range(5)]
match_res(mnp.stack, onp.stack, arrs)
match_res(mnp.stack, onp.stack, tuple(arrs))
match_res(mnp_stack, onp_stack, *arrs)
for i in range(-6, 6):
match_res(mnp.stack, onp.stack, arrs, axis=i)


def mnp_roll(input_tensor):
a = mnp.roll(input_tensor, -3)
@@ -868,28 +850,22 @@ def onp_moveaxis(a):
def test_moveaxis():
a = rand_int(2, 4, 5, 9, 6)
match_res(mnp_moveaxis, onp_moveaxis, a)
a = rand_int(2, 4, 5, 0, 6, 7, 1, 3, 8)
match_res(mnp_moveaxis, onp_moveaxis, a)


def mnp_tile(x):
a = mnp.tile(x, 0)
b = mnp.tile(x, 1)
c = mnp.tile(x, 3)
d = mnp.tile(x, [5, 1])
e = mnp.tile(x, (3, 1, 0))
f = mnp.tile(x, [5, 1, 2, 3, 7])
return a, b, c, d, e, f
a = mnp.tile(x, 1)
b = mnp.tile(x, 3)
c = mnp.tile(x, [5, 1])
d = mnp.tile(x, [5, 1, 2, 3, 7])
return a, b, c, d


def onp_tile(x):
a = onp.tile(x, 0)
b = onp.tile(x, 1)
c = onp.tile(x, 3)
d = onp.tile(x, [5, 1])
e = onp.tile(x, (3, 1, 0))
f = onp.tile(x, [5, 1, 2, 3, 7])
return a, b, c, d, e, f
a = onp.tile(x, 1)
b = onp.tile(x, 3)
c = onp.tile(x, [5, 1])
d = onp.tile(x, [5, 1, 2, 3, 7])
return a, b, c, d


@pytest.mark.level1
@@ -901,8 +877,6 @@ def onp_tile(x):
def test_tile():
a = rand_int(2, 3, 4)
match_res(mnp_tile, onp_tile, a)
b = rand_int(5, 0, 8)
match_res(mnp_tile, onp_tile, b)


def mnp_broadcast_to(x):
@@ -1022,21 +996,13 @@ def test_fliplr():
def mnp_split(input_tensor):
a = mnp.split(input_tensor, indices_or_sections=1)
b = mnp.split(input_tensor, indices_or_sections=3)
c = mnp.split(input_tensor, indices_or_sections=(-9, -8, 6))
d = mnp.split(input_tensor, indices_or_sections=(3, 2, 1))
e = mnp.split(input_tensor, indices_or_sections=(-10, -4, 5, 10))
f = mnp.split(input_tensor, indices_or_sections=[0, 2], axis=1)
return a, b, c, d, e, f
return a, b


def onp_split(input_array):
a = onp.split(input_array, indices_or_sections=1)
b = onp.split(input_array, indices_or_sections=3)
c = onp.split(input_array, indices_or_sections=(-9, -8, 6))
d = onp.split(input_array, indices_or_sections=(3, 2, 1))
e = onp.split(input_array, indices_or_sections=(-10, -4, 5, 10))
f = onp.split(input_array, indices_or_sections=[0, 2], axis=1)
return a, b, c, d, e, f
return a, b


@pytest.mark.level1
@@ -1090,16 +1056,12 @@ def test_array_split():

def mnp_vsplit(input_tensor):
a = mnp.vsplit(input_tensor, indices_or_sections=3)
b = mnp.vsplit(input_tensor, indices_or_sections=(-10, -4, 5, 10))
c = mnp.vsplit(input_tensor, indices_or_sections=[0, 2])
return a, b, c
return a


def onp_vsplit(input_array):
a = onp.vsplit(input_array, indices_or_sections=3)
b = onp.vsplit(input_array, indices_or_sections=(-10, -4, 5, 10))
c = onp.vsplit(input_array, indices_or_sections=[0, 2])
return a, b, c
return a


@pytest.mark.level1
@@ -1123,16 +1085,12 @@ def test_vsplit():

def mnp_hsplit(input_tensor):
a = mnp.hsplit(input_tensor, indices_or_sections=3)
b = mnp.hsplit(input_tensor, indices_or_sections=(-10, -4, 5, 10))
c = mnp.hsplit(input_tensor, indices_or_sections=[0, 2])
return a, b, c
return a


def onp_hsplit(input_array):
a = onp.hsplit(input_array, indices_or_sections=3)
b = onp.hsplit(input_array, indices_or_sections=(-10, -4, 5, 10))
c = onp.hsplit(input_array, indices_or_sections=[0, 2])
return a, b, c
return a


@pytest.mark.level1
@@ -1156,17 +1114,11 @@ def test_hsplit():

def mnp_dsplit(input_tensor):
a = mnp.dsplit(input_tensor, indices_or_sections=3)
b = mnp.dsplit(input_tensor, indices_or_sections=(-10, -4, 5, 10))
c = mnp.dsplit(input_tensor, indices_or_sections=[0, 2])
return a, b, c

return a

def onp_dsplit(input_array):
a = onp.dsplit(input_array, indices_or_sections=3)
b = onp.dsplit(input_array, indices_or_sections=(-10, -4, 5, 10))
c = onp.dsplit(input_array, indices_or_sections=[0, 2])
return a, b, c

return a

@pytest.mark.level1
@pytest.mark.platform_arm_ascend_training


+ 23
- 26
tests/st/numpy_native/test_math_ops.py View File

@@ -37,13 +37,6 @@ class Cases():
rand_int(1, 1),
]

# empty arrays
self.empty_arrs = [
rand_int(0),
rand_int(4, 0),
rand_int(2, 0, 2),
]

# arrays of the same size expanded across the 0th dimension
self.expanded_arrs = [
rand_int(2, 3),
@@ -244,8 +237,6 @@ def test_float_power():


@pytest.mark.level1
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
@@ -687,11 +678,11 @@ def test_ptp():


def mnp_add_dtype(x1, x2):
return mnp.add(x1, x2, dtype=mnp.float16)
return mnp.add(x1, x2, dtype=mnp.float32)


def onp_add_dtype(x1, x2):
return onp.add(x1, x2, dtype=onp.float16)
return onp.add(x1, x2, dtype=onp.float32)


@pytest.mark.level1
@@ -927,8 +918,6 @@ def onp_maximum(x1, x2):


@pytest.mark.level1
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
@@ -1410,24 +1399,22 @@ def mnp_diff(input_tensor):
a = mnp.diff(input_tensor, 2, append=3.0)
b = mnp.diff(input_tensor, 4, prepend=6, axis=-2)
c = mnp.diff(input_tensor, 0, append=3.0, axis=-1)
d = mnp.diff(input_tensor, 10, prepend=6)
e = mnp.diff(input_tensor, 1, prepend=input_tensor)
f = mnp.ediff1d(input_tensor, to_end=input_tensor)
g = mnp.ediff1d(input_tensor)
h = mnp.ediff1d(input_tensor, to_begin=3)
return a, b, c, d, e, f, g, h
d = mnp.diff(input_tensor, 1, prepend=input_tensor)
e = mnp.ediff1d(input_tensor, to_end=input_tensor)
f = mnp.ediff1d(input_tensor)
g = mnp.ediff1d(input_tensor, to_begin=3)
return a, b, c, d, e, f, g


def onp_diff(input_array):
a = onp.diff(input_array, 2, append=3.0)
b = onp.diff(input_array, 4, prepend=6, axis=-2)
c = onp.diff(input_array, 0, append=3.0, axis=-1)
d = onp.diff(input_array, 10, prepend=6)
e = onp.diff(input_array, 1, prepend=input_array)
f = onp.ediff1d(input_array, to_end=input_array)
g = onp.ediff1d(input_array)
h = onp.ediff1d(input_array, to_begin=3)
return a, b, c, d, e, f, g, h
d = onp.diff(input_array, 1, prepend=input_array)
e = onp.ediff1d(input_array, to_end=input_array)
f = onp.ediff1d(input_array)
g = onp.ediff1d(input_array, to_begin=3)
return a, b, c, d, e, f, g


@pytest.mark.level1
@@ -1926,7 +1913,6 @@ def test_mean():
run_multi_test(mnp_mean, onp_mean, test_case.arrs, error=3)
run_multi_test(mnp_mean, onp_mean, test_case.expanded_arrs, error=3)
run_multi_test(mnp_mean, onp_mean, test_case.scalars, error=3)
run_multi_test(mnp_mean, onp_mean, test_case.empty_arrs, error=3)


@pytest.mark.level1
@@ -1961,3 +1947,14 @@ def test_exception_add():
def test_exception_mean():
with pytest.raises(ValueError):
mnp.mean(to_tensor(test_case.arrs[0]), (-1, 0))


@pytest.mark.level1
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_exception_amax():
with pytest.raises(TypeError):
mnp.amax(mnp.array([[1, 2], [3, 4]]).astype(mnp.float32), initial=[1.0, 2.0])

Loading…
Cancel
Save