From: @jachua Reviewed-by: Signed-off-by:tags/v1.2.0-rc1
| @@ -168,7 +168,7 @@ def asfarray_const(a, dtype=mstype.float32): | |||
| a = _deep_tensor_to_nparray(a) | |||
| a = onp.asarray(a) | |||
| if a.dtype is onp.dtype('object'): | |||
| raise TypeError(f"For Tensor conversion, the input_data is {a} that contains unsupported element.") | |||
| raise ValueError(f"For Tensor conversion, the input_data is {a} that contains unsupported element.") | |||
| a = Tensor.from_numpy(a) | |||
| return Tensor(a, dtype) | |||
| @@ -214,7 +214,7 @@ def asfarray(a, dtype=mstype.float32): | |||
| if isinstance(a, Tensor): | |||
| return a.astype(dtype) | |||
| return asfarray_const(a) | |||
| return asfarray_const(a, dtype) | |||
| def copy_(a): | |||
| @@ -30,7 +30,8 @@ from .utils_const import _check_axes_range, _check_start_normalize, \ | |||
| _check_same_type, _check_axis_valid, _add_unit_axes, _broadcast_tuples, \ | |||
| _check_is_float, _check_axis_in_range, _check_axis_type, _canonicalize_axis, \ | |||
| _list_comprehensions, _check_element_int, _is_shape_empty, _type_convert, \ | |||
| _tuple_getitem, _expanded_shape, _seq_prod, _get_device, _tuple_setitem | |||
| _tuple_getitem, _expanded_shape, _seq_prod, _get_device, _tuple_setitem, \ | |||
| _raise_unimplemented_error | |||
| # According to official numpy reference, the dimension of a numpy array must be less | |||
| # than 32 | |||
| @@ -84,9 +84,6 @@ def less_equal(x1, x2, dtype=None): | |||
| bool, unless `dtype` is passed. This is a scalar if both `x1` and `x2` are | |||
| scalars. | |||
| Raises: | |||
| TypeError: if the input is not a tensor. | |||
| Supported Platforms: | |||
| ``Ascend`` ``GPU`` ``CPU`` | |||
| @@ -120,9 +117,6 @@ def less(x1, x2, dtype=None): | |||
| bool, unless `dtype` is passed. This is a scalar if both `x1` and `x2` are | |||
| scalars. | |||
| Raises: | |||
| TypeError: if the input is not a tensor. | |||
| Supported Platforms: | |||
| ``Ascend`` ``GPU`` ``CPU`` | |||
| @@ -155,9 +149,6 @@ def greater_equal(x1, x2, dtype=None): | |||
| bool, unless `dtype` is passed. This is a scalar if both `x1` and `x2` are | |||
| scalars. | |||
| Raises: | |||
| TypeError: if the input is not a tensor. | |||
| Supported Platforms: | |||
| ``Ascend`` ``GPU`` ``CPU`` | |||
| @@ -190,9 +181,6 @@ def greater(x1, x2, dtype=None): | |||
| bool, unless `dtype` is passed. This is a scalar if both `x1` and `x2` are | |||
| scalars. | |||
| Raises: | |||
| TypeError: if the input is not a tensor. | |||
| Supported Platforms: | |||
| ``Ascend`` ``GPU`` ``CPU`` | |||
| @@ -225,9 +213,6 @@ def equal(x1, x2, dtype=None): | |||
| bool, unless `dtype` is passed. This is a scalar if both `x1` and `x2` are | |||
| scalars. | |||
| Raises: | |||
| TypeError: if the input is not a tensor. | |||
| Supported Platforms: | |||
| ``Ascend`` ``GPU`` ``CPU`` | |||
| @@ -260,9 +245,6 @@ def isfinite(x, dtype=None): | |||
| Tensor or scalar, true where `x` is not positive infinity, negative infinity, | |||
| or NaN; false otherwise. This is a scalar if `x` is a scalar. | |||
| Raises: | |||
| TypeError: if the input is not a tensor. | |||
| Supported Platforms: | |||
| ``Ascend`` ``GPU`` ``CPU`` | |||
| @@ -296,9 +278,6 @@ def isnan(x, dtype=None): | |||
| Tensor or scalar, true where `x` is NaN, false otherwise. This is a scalar if | |||
| `x` is a scalar. | |||
| Raises: | |||
| TypeError: if the input is not a tensor. | |||
| Supported Platforms: | |||
| ``GPU`` ``CPU`` | |||
| @@ -346,9 +325,6 @@ def isinf(x, dtype=None): | |||
| Tensor or scalar, true where `x` is positive or negative infinity, false | |||
| otherwise. This is a scalar if `x` is a scalar. | |||
| Raises: | |||
| TypeError: if the input is not a tensor. | |||
| Supported Platforms: | |||
| ``GPU`` ``CPU`` | |||
| @@ -688,9 +664,6 @@ def logical_or(x1, x2, dtype=None): | |||
| bool, unless ``dtype=object`` is passed. This is a scalar if both `x1` and `x2` are | |||
| scalars. | |||
| Raises: | |||
| TypeError: if the input is not a tensor. | |||
| Supported Platforms: | |||
| ``Ascend`` ``GPU`` ``CPU`` | |||
| @@ -725,9 +698,6 @@ def logical_and(x1, x2, dtype=None): | |||
| Boolean result of the logical AND operation applied to the elements of `x1` and `x2`; | |||
| the shape is determined by broadcasting. This is a scalar if both `x1` and `x2` are scalars. | |||
| Raises: | |||
| TypeError: if the input is not a tensor. | |||
| Supported Platforms: | |||
| ``Ascend`` ``GPU`` ``CPU`` | |||
| @@ -762,9 +732,6 @@ def logical_xor(x1, x2, dtype=None): | |||
| Boolean result of the logical AND operation applied to the elements of `x1` and `x2`; | |||
| the shape is determined by broadcasting. This is a scalar if both `x1` and `x2` are scalars. | |||
| Raises: | |||
| TypeError: if the input is not a tensor. | |||
| Supported Platforms: | |||
| ``Ascend`` ``GPU`` ``CPU`` | |||
| @@ -109,9 +109,6 @@ def count_nonzero(x, axis=None, keepdims=False): | |||
| Tensor, indicating number of non-zero values in the `x` along a given axis. | |||
| Otherwise, the total number of non-zero values in `x` is returned. | |||
| Raises: | |||
| TypeError: if the input is not a tensor. | |||
| Supported Platforms: | |||
| ``Ascend`` ``GPU`` ``CPU`` | |||
| @@ -217,9 +214,6 @@ def rad2deg(x, dtype=None): | |||
| Tensor, the corresponding angle in degrees. This is a tensor scalar if `x` | |||
| is a tensor scalar. | |||
| Raises: | |||
| TypeError: if the input is not a tensor. | |||
| Supported Platforms: | |||
| ``Ascend`` ``GPU`` ``CPU`` | |||
| @@ -255,9 +249,6 @@ def add(x1, x2, dtype=None): | |||
| Tensor or scalar, the sum of `x1` and `x2`, element-wise. This is a scalar | |||
| if both `x1` and `x2` are scalars. | |||
| Raises: | |||
| TypeError: if the input is not a tensor. | |||
| Supported Platforms: | |||
| ``Ascend`` ``GPU`` ``CPU`` | |||
| @@ -297,9 +288,6 @@ def subtract(x1, x2, dtype=None): | |||
| Tensor or scalar, the difference of `x1` and `x2`, element-wise. This is a | |||
| scalar if both `x1` and `x2` are scalars. | |||
| Raises: | |||
| TypeError: if the input is not a tensor. | |||
| Supported Platforms: | |||
| ``Ascend`` ``GPU`` ``CPU`` | |||
| @@ -334,9 +322,6 @@ def multiply(x1, x2, dtype=None): | |||
| Tensor or scalar, the product of `x1` and `x2`, element-wise. This is a scalar | |||
| if both `x1` and `x2` are scalars. | |||
| Raises: | |||
| TypeError: if the input is not a tensor. | |||
| Supported Platforms: | |||
| ``Ascend`` ``GPU`` ``CPU`` | |||
| @@ -380,9 +365,6 @@ def divide(x1, x2, dtype=None): | |||
| Returns: | |||
| Tensor or scalar, this is a scalar if both `x1` and `x2` are scalars. | |||
| Raises: | |||
| TypeError: if the input is not a tensor. | |||
| Supported Platforms: | |||
| ``Ascend`` ``GPU`` ``CPU`` | |||
| @@ -422,9 +404,6 @@ def true_divide(x1, x2, dtype=None): | |||
| Returns: | |||
| Tensor or scalar, this is a scalar if both `x1` and `x2` are scalars. | |||
| Raises: | |||
| TypeError: if the input is not a tensor. | |||
| Supported Platforms: | |||
| ``Ascend`` ``GPU`` ``CPU`` | |||
| @@ -462,9 +441,6 @@ def power(x1, x2, dtype=None): | |||
| Tensor or scalar, the bases in `x1` raised to the exponents in `x2`. This | |||
| is a scalar if both `x1` and `x2` are scalars. | |||
| Raises: | |||
| TypeError: if the input is not a tensor. | |||
| Supported Platforms: | |||
| ``Ascend`` ``GPU`` ``CPU`` | |||
| @@ -507,9 +483,6 @@ def float_power(x1, x2, dtype=None): | |||
| Tensor or scalar, the bases in `x1` raised to the exponents in `x2`. This | |||
| is a scalar if both `x1` and `x2` are scalars. | |||
| Raises: | |||
| TypeError: if the input is not a tensor. | |||
| Supported Platforms: | |||
| ``Ascend`` ``GPU`` ``CPU`` | |||
| @@ -538,9 +511,7 @@ def minimum(x1, x2, dtype=None): | |||
| Note: | |||
| Numpy arguments `out`, `where`, `casting`, `order`, `subok`, `signature`, and `extobj` are | |||
| not supported. | |||
| Unlike numpy, when one of the elements is a NaN, the second element is | |||
| always returned regardless of whether the second element is a NaN, instead | |||
| of returning NaN. | |||
| On Ascend, input arrays containing inf or NaN are not supported. | |||
| Args: | |||
| x1 (Tensor): first input tensor to be compared. | |||
| @@ -1166,9 +1137,6 @@ def square(x, dtype=None): | |||
| Tensor or scalar, element-wise ``x*x``, of the same shape and dtype as `x`. | |||
| This is a scalar if `x` is a scalar.. | |||
| Raises: | |||
| TypeError: if the input is not a tensor. | |||
| Supported Platforms: | |||
| ``Ascend`` ``GPU`` ``CPU`` | |||
| @@ -1201,9 +1169,6 @@ def sqrt(x, dtype=None): | |||
| square-root of each element in `x`. For negative elements, nan is returned. | |||
| This is a scalar if `x` is a scalar. | |||
| Raises: | |||
| TypeError: if the input is not a tensor. | |||
| Supported Platforms: | |||
| ``Ascend`` ``GPU`` ``CPU`` | |||
| @@ -1242,9 +1207,6 @@ def reciprocal(x, dtype=None): | |||
| Returns: | |||
| Tensor or scalar, this is a scalar if `x` is a scalar. | |||
| Raises: | |||
| TypeError: if the input is not a tensor. | |||
| Supported Platforms: | |||
| ``Ascend`` ``GPU`` ``CPU`` | |||
| @@ -1283,9 +1245,6 @@ def log(x, dtype=None): | |||
| Tensor or scalar, the natural logarithm of `x`, element-wise. This is a | |||
| scalar if `x` is a scalar. | |||
| Raises: | |||
| TypeError: if the input is not a tensor. | |||
| Supported Platforms: | |||
| ``Ascend`` ``GPU`` ``CPU`` | |||
| @@ -1316,9 +1275,7 @@ def maximum(x1, x2, dtype=None): | |||
| Note: | |||
| Numpy arguments `out`, `where`, `casting`, `order`, `subok`, `signature`, and `extobj` are | |||
| not supported. | |||
| Unlike numpy, when one of the elements is a NaN, the second element is | |||
| always returned regardless of whether the second element is a NaN, instead | |||
| of returning NaN. | |||
| On Ascend, input arrays containing inf or NaN are not supported. | |||
| Args: | |||
| x1 (Tensor): Input array | |||
| @@ -1332,9 +1289,6 @@ def maximum(x1, x2, dtype=None): | |||
| Tensor or scalar, the maximum of `x1` and `x2`, element-wise. This is a scalar | |||
| if both `x1` and `x2` are scalars. | |||
| Raises: | |||
| TypeError: if the input is not a tensor. | |||
| Supported Platforms: | |||
| ``Ascend`` ``GPU`` ``CPU`` | |||
| @@ -1385,9 +1339,6 @@ def heaviside(x1, x2, dtype=None): | |||
| Tensor or scalar, the output array, element-wise Heaviside step function | |||
| of `x1`. This is a scalar if both `x1` and `x2` are scalars. | |||
| Raises: | |||
| TypeError: if the input is not a tensor. | |||
| Supported Platforms: | |||
| ``Ascend`` ``GPU`` ``CPU`` | |||
| @@ -1562,9 +1513,6 @@ def hypot(x1, x2, dtype=None): | |||
| Tensor or scalar, the hypotenuse of the triangle(s). This is a scalar if | |||
| both `x1` and `x2` are scalars. | |||
| Raises: | |||
| TypeError: if the input is not a tensor. | |||
| Supported Platforms: | |||
| ``Ascend`` ``GPU`` ``CPU`` | |||
| @@ -1614,9 +1562,6 @@ def floor(x, dtype=None): | |||
| Tensor or scalar, the floor of each element in `x`. This is a scalar if `x` | |||
| is a scalar. | |||
| Raises: | |||
| TypeError: if the input is not a tensor. | |||
| Supported Platforms: | |||
| ``Ascend`` ``GPU`` ``CPU`` | |||
| @@ -1648,9 +1593,6 @@ def floor_divide(x1, x2, dtype=None): | |||
| Returns: | |||
| Tensor or scalar. | |||
| Raises: | |||
| TypeError: if the input is not a tensor. | |||
| Supported Platforms: | |||
| ``Ascend`` ``GPU`` ``CPU`` | |||
| @@ -1709,9 +1651,6 @@ def remainder(x1, x2, dtype=None): | |||
| Tensor or scalar, the element-wise remainder of the quotient | |||
| ``floor_divide(x1, x2)``. This is a scalar if both `x1` and `x2` are scalars. | |||
| Raises: | |||
| TypeError: if the input is not a tensor. | |||
| Supported Platforms: | |||
| ``Ascend`` ``GPU`` ``CPU`` | |||
| @@ -1787,9 +1726,6 @@ def fmod(x1, x2, dtype=None): | |||
| Tensor or scalar, the remainder of the division of `x1` by `x2`. This is a | |||
| scalar if both `x1` and `x2` are scalars. | |||
| Raises: | |||
| TypeError: if the input is not a tensor. | |||
| Supported Platforms: | |||
| ``Ascend`` ``GPU`` ``CPU`` | |||
| @@ -1822,9 +1758,6 @@ def trunc(x, dtype=None): | |||
| Tensor or scalar, the truncated value of each element in `x`. This is a scalar if `x` is | |||
| a scalar. | |||
| Raises: | |||
| TypeError: if the input is not a tensor. | |||
| Supported Platforms: | |||
| ``Ascend`` ``GPU`` ``CPU`` | |||
| @@ -1859,9 +1792,6 @@ def exp(x, dtype=None): | |||
| Tensor or scalar, element-wise exponential of `x`. This is a scalar if both | |||
| `x1` and `x2` are scalars. | |||
| Raises: | |||
| TypeError: if the input is not a tensor. | |||
| Supported Platforms: | |||
| ``Ascend`` ``GPU`` ``CPU`` | |||
| @@ -1893,9 +1823,6 @@ def expm1(x, dtype=None): | |||
| Tensor or scalar, element-wise exponential minus one, ``out = exp(x) - 1``. | |||
| This is a scalar if both `x1` and `x2` are scalars. | |||
| Raises: | |||
| TypeError: if the input is not a tensor. | |||
| Supported Platforms: | |||
| ``Ascend`` ``GPU`` ``CPU`` | |||
| @@ -2117,6 +2044,7 @@ def trapz(y, x=None, dx=1.0, axis=-1): | |||
| ``Ascend`` ``GPU`` ``CPU`` | |||
| Examples: | |||
| >>> import mindspore.numpy as np | |||
| >>> a = np.arange(6).reshape(2, 3) | |||
| >>> output = np.trapz(a, x=[-2, 1, 2], axis=1) | |||
| >>> print(output) | |||
| @@ -2197,16 +2125,14 @@ def gcd(x1, x2, dtype=None): | |||
| Tensor or scalar, the greatest common divisor of the absolute value of the inputs. | |||
| This is a scalar if both `x1` and `x2` are scalars. | |||
| Raises: | |||
| TypeError: if the input is not a tensor. | |||
| Supported Platforms: | |||
| ``Ascend`` ``GPU`` ``CPU`` | |||
| Examples: | |||
| >>> import mindspore.numpy as np | |||
| >>> output = np.gcd(np.arange(6), np.array(20)) | |||
| >>> print(output) | |||
| [20 1 2 1 4 5] | |||
| [20 1 2 1 4 5] | |||
| """ | |||
| return _apply_tensor_op(_gcd, x1, x2, dtype=dtype) | |||
| @@ -2229,16 +2155,14 @@ def lcm(x1, x2, dtype=None): | |||
| Tensor or scalar, the lowest common multiple of the absolute value of the inputs. | |||
| This is a scalar if both `x1` and `x2` are scalars. | |||
| Raises: | |||
| TypeError: if the input is not a tensor. | |||
| Supported Platforms: | |||
| ``Ascend`` ``GPU`` ``CPU`` | |||
| Examples: | |||
| >>> import mindspore.numpy as np | |||
| >>> output = np.lcm(np.arange(6), np.array(20)) | |||
| >>> print(output) | |||
| [ 0 20 20 60 20 20] | |||
| [ 0 20 20 60 20 20] | |||
| """ | |||
| def _lcm(x1, x2): | |||
| """Calculates lcm without applying keyword arguments""" | |||
| @@ -2290,7 +2214,7 @@ def convolve(a, v, mode='full'): | |||
| >>> import mindspore.numpy as np | |||
| >>> output = np.convolve([1., 2., 3., 4., 5.], [2., 3.], mode="valid") | |||
| >>> print(output) | |||
| [ 3. 6. 9. 12.] | |||
| [ 3. 6. 9. 12.] | |||
| """ | |||
| if not isinstance(a, Tensor): | |||
| a = asarray_const(a) | |||
| @@ -2406,6 +2330,7 @@ def cov(m, y=None, rowvar=True, bias=False, ddof=None, fweights=None, aweights=N | |||
| ``Ascend`` ``GPU`` ``CPU`` | |||
| Examples: | |||
| >>> import mindspore.numpy as np | |||
| >>> output = np.cov([[2., 3., 4., 5.], [0., 2., 3., 4.], [7., 8., 9., 10.]]) | |||
| >>> print(output) | |||
| [[1.6666666 2.1666667 1.6666666] | |||
| @@ -2509,6 +2434,10 @@ def _reduce(a, reduce_fn, cmp_fn=None, axis=None, keepdims=False, initial=None, | |||
| if dtype is None: | |||
| dtype = F.dtype(a) | |||
| axes = _check_axis_valid(axis, ndim) | |||
| if initial is not None: | |||
| if ((isinstance(initial, Tensor) and F.rank(initial) > 0) or | |||
| not isinstance(initial, (int, float, bool, Tensor))): | |||
| _raise_type_error('initial should be scalar') | |||
| if _is_shape_empty(shape): | |||
| if not axes: | |||
| @@ -2578,6 +2507,7 @@ def nansum(a, axis=None, dtype=None, keepdims=False): | |||
| ``GPU`` ``CPU`` | |||
| Examples: | |||
| >>> import mindspore.numpy as np | |||
| >>> a = np.array([[1, 1], [1, np.nan]]) | |||
| >>> output = np.nansum(a) | |||
| >>> print(output) | |||
| @@ -2638,6 +2568,7 @@ def nanmean(a, axis=None, dtype=None, keepdims=False): | |||
| ``GPU`` ``CPU`` | |||
| Examples: | |||
| >>> import mindspore.numpy as np | |||
| >>> a = np.array([[1, np.nan], [3, 4]]) | |||
| >>> output = np.nanmean(a) | |||
| >>> print(output) | |||
| @@ -2700,6 +2631,7 @@ def nanvar(a, axis=None, dtype=None, ddof=0, keepdims=False): | |||
| ``GPU`` ``CPU`` | |||
| Examples: | |||
| >>> import mindspore.numpy as np | |||
| >>> a = np.array([[1, np.nan], [3, 4]]) | |||
| >>> output = np.nanstd(a) | |||
| >>> print(output) | |||
| @@ -2752,6 +2684,7 @@ def nanstd(a, axis=None, dtype=None, ddof=0, keepdims=False): | |||
| ``GPU`` ``CPU`` | |||
| Examples: | |||
| >>> import mindspore.numpy as np | |||
| >>> a = np.array([[1, np.nan], [3, 4]]) | |||
| >>> output = np.nanvar(a) | |||
| >>> print(output) | |||
| @@ -2784,13 +2717,11 @@ def exp2(x, dtype=None): | |||
| Returns: | |||
| Tensor or scalar, element-wise 2 to the power `x`. | |||
| Raises: | |||
| TypeError: if the input is not a tensor. | |||
| Supported Platforms: | |||
| ``Ascend`` ``GPU`` ``CPU`` | |||
| Examples: | |||
| >>> import mindspore.numpy as np | |||
| >>> x = np.array([2, 3]).astype(np.float32) | |||
| >>> output = np.exp2(x) | |||
| >>> print(output) | |||
| @@ -2817,6 +2748,7 @@ def kron(a, b): | |||
| ``Ascend`` ``GPU`` ``CPU`` | |||
| Examples: | |||
| >>> import mindspore.numpy as np | |||
| >>> output = np.kron([1,10,100], [5,6,7]) | |||
| >>> print(output) | |||
| [ 5 6 7 50 60 70 500 600 700] | |||
| @@ -2885,6 +2817,7 @@ def cross(a, b, axisa=- 1, axisb=- 1, axisc=- 1, axis=None): | |||
| ``Ascend`` ``GPU`` ``CPU`` | |||
| Examples: | |||
| >>> import mindspore.numpy as np | |||
| >>> x = np.array([[1,2,3], [4,5,6]]) | |||
| >>> y = np.array([[4,5,6], [1,2,3]]) | |||
| >>> output = np.cross(x, y) | |||
| @@ -2968,13 +2901,11 @@ def ceil(x, dtype=None): | |||
| Returns: | |||
| Tensor or scalar, the floor of each element in `x`. This is a scalar if `x` is a scalar. | |||
| Raises: | |||
| TypeError: if the input is not a tensor. | |||
| Supported Platforms: | |||
| ``Ascend`` ``GPU`` ``CPU`` | |||
| Examples: | |||
| >>> import mindspore.numpy as np | |||
| >>> a = np.array([-1.7, -1.5, -0.2, 0.2, 1.5, 1.7, 2.0]) | |||
| >>> output = np.ceil(a) | |||
| >>> print(output) | |||
| @@ -3086,6 +3017,7 @@ def cumsum(a, axis=None, dtype=None): | |||
| ``Ascend`` ``GPU`` ``CPU`` | |||
| Examples: | |||
| >>> import mindspore.numpy as np | |||
| >>> output = np.cumsum(np.ones((3,3)), axis=0) | |||
| >>> print(output) | |||
| [[1. 1. 1.] | |||
| @@ -3141,6 +3073,7 @@ def nancumsum(a, axis=None, dtype=None): | |||
| ``GPU`` ``CPU`` | |||
| Examples: | |||
| >>> import mindspore.numpy as np | |||
| >>> a = np.array([[1, 2], [3, np.nan]]) | |||
| >>> output = np.nancumsum(a) | |||
| >>> print(output) | |||
| @@ -3212,9 +3145,6 @@ def log1p(x, dtype=None): | |||
| Returns: | |||
| Tensor or scalar. This is a scalar if `x` is a scalar. | |||
| Raises: | |||
| TypeError: if the input is not a tensor. | |||
| Supported Platforms: | |||
| ``Ascend`` ``GPU`` ``CPU`` | |||
| @@ -3251,9 +3181,6 @@ def logaddexp(x1, x2, dtype=None): | |||
| Returns: | |||
| Tensor or scalar. This is a scalar if both `x1` and `x2` are scalars. | |||
| Raises: | |||
| TypeError: if the input is not a tensor. | |||
| Supported Platforms: | |||
| ``Ascend`` ``GPU`` ``CPU`` | |||
| @@ -3286,9 +3213,6 @@ def log2(x, dtype=None): | |||
| Returns: | |||
| Tensor or scalar. This is a scalar if `x` is a scalar. | |||
| Raises: | |||
| TypeError: if the input is not a tensor. | |||
| Supported Platforms: | |||
| ``Ascend`` ``GPU`` ``CPU`` | |||
| @@ -3329,9 +3253,6 @@ def logaddexp2(x1, x2, dtype=None): | |||
| Returns: | |||
| Tensor or scalar. This is a scalar if both `x1` and `x2` are scalars. | |||
| Raises: | |||
| TypeError: if the input is not a tensor. | |||
| Supported Platforms: | |||
| ``Ascend`` ``GPU`` ``CPU`` | |||
| @@ -3364,9 +3285,6 @@ def log10(x, dtype=None): | |||
| Returns: | |||
| Tensor or scalar. This is a scalar if `x` is a scalar. | |||
| Raises: | |||
| TypeError: if the input is not a tensor. | |||
| Supported Platforms: | |||
| ``Ascend`` ``GPU`` ``CPU`` | |||
| @@ -3407,9 +3325,6 @@ def sin(x, dtype=None): | |||
| Returns: | |||
| Tensor or scalar. This is a scalar if `x` is a scalar. | |||
| Raises: | |||
| TypeError: if the input is not a tensor. | |||
| Supported Platforms: | |||
| ``Ascend`` ``GPU`` ``CPU`` | |||
| @@ -3440,9 +3355,6 @@ def cos(x, dtype=None): | |||
| Returns: | |||
| Tensor or scalar. This is a scalar if `x` is a scalar. | |||
| Raises: | |||
| TypeError: if the input is not a tensor. | |||
| Supported Platforms: | |||
| ``Ascend`` ``GPU`` ``CPU`` | |||
| @@ -3575,9 +3487,6 @@ def arctan(x, dtype=None): | |||
| Returns: | |||
| Tensor or scalar. This is a scalar if `x` is a scalar. | |||
| Raises: | |||
| TypeError: if the input is not a tensor. | |||
| Supported Platforms: | |||
| ``Ascend`` ``GPU`` ``CPU`` | |||
| @@ -3607,9 +3516,6 @@ def sinh(x, dtype=None): | |||
| Returns: | |||
| Tensor or scalar. This is a scalar if `x` is a scalar. | |||
| Raises: | |||
| TypeError: if the input is not a tensor. | |||
| Supported Platforms: | |||
| ``Ascend`` ``CPU`` | |||
| @@ -3639,9 +3545,6 @@ def cosh(x, dtype=None): | |||
| Returns: | |||
| Tensor or scalar. This is a scalar if `x` is a scalar. | |||
| Raises: | |||
| TypeError: if the input is not a tensor. | |||
| Supported Platforms: | |||
| ``Ascend`` ``CPU`` | |||
| @@ -3671,9 +3574,6 @@ def tanh(x, dtype=None): | |||
| Returns: | |||
| Tensor or scalar. This is a scalar if `x` is a scalar. | |||
| Raises: | |||
| TypeError: if the input is not a tensor. | |||
| Supported Platforms: | |||
| ``Ascend`` ``GPU`` ``CPU`` | |||
| @@ -3703,9 +3603,6 @@ def arcsinh(x, dtype=None): | |||
| Returns: | |||
| Tensor or scalar. This is a scalar if `x` is a scalar. | |||
| Raises: | |||
| TypeError: if the input is not a tensor. | |||
| Supported Platforms: | |||
| ``Ascend`` ``GPU`` ``CPU`` | |||
| @@ -3735,9 +3632,6 @@ def arccosh(x, dtype=None): | |||
| Returns: | |||
| Tensor or scalar. This is a scalar if `x` is a scalar. | |||
| Raises: | |||
| TypeError: if the input is not a tensor. | |||
| Supported Platforms: | |||
| ``Ascend`` ``GPU`` ``CPU`` | |||
| @@ -3767,9 +3661,6 @@ def arctanh(x, dtype=None): | |||
| Returns: | |||
| Tensor or scalar. This is a scalar if `x` is a scalar. | |||
| Raises: | |||
| TypeError: if the input is not a tensor. | |||
| Supported Platforms: | |||
| ``Ascend`` ``CPU`` | |||
| @@ -3801,9 +3692,6 @@ def arctan2(x1, x2, dtype=None): | |||
| Tensor or scalar, the sum of `x1` and `x2`, element-wise. This is a scalar | |||
| if both `x1` and `x2` are scalars. | |||
| Raises: | |||
| TypeError: if the input is not a tensor. | |||
| Supported Platforms: | |||
| ``Ascend`` ``CPU`` | |||
| @@ -472,6 +472,7 @@ def _make_tensor(val, dtype): | |||
| return Tensor(val, dtype) | |||
| @constexpr | |||
| def _tuple_slice(tup, start, end): | |||
| """get sliced tuple from start and end.""" | |||
| return tup[start:end] | |||
| @@ -591,9 +591,9 @@ def matmul(x1, x2, dtype=None): | |||
| ``Ascend`` ``GPU`` ``CPU`` | |||
| Examples: | |||
| >>> x1 = np.arange(2*3*4).reshape(2, 3, 4).astype('float32') | |||
| >>> x2 = np.arange(4*5).reshape(4, 5).astype('float32') | |||
| >>> output = np.matmul(x1, x2) | |||
| >>> x1 = Tensor(np.arange(2*3*4).reshape(2, 3, 4), mindspore.float32) | |||
| >>> x2 = Tensor(np.arange(4*5).reshape(4, 5), mindspore.float32) | |||
| >>> output = ops.matmul(x1, x2) | |||
| >>> print(output) | |||
| [[[ 70. 76. 82. 88. 94.] | |||
| [ 190. 212. 234. 256. 278.] | |||
| @@ -26,7 +26,7 @@ from .utils import rand_int, rand_bool, match_array, match_res, match_meta, \ | |||
| class Cases(): | |||
| def __init__(self): | |||
| self.all_shapes = [ | |||
| 0, 1, 2, (), (1,), (2,), (1, 2, 3), [], [1], [2], [1, 2, 3] | |||
| 1, 2, (1,), (2,), (1, 2, 3), [1], [2], [1, 2, 3] | |||
| ] | |||
| self.onp_dtypes = [onp.int32, 'int32', int, | |||
| onp.float32, 'float32', float, | |||
| @@ -94,18 +94,16 @@ class Cases(): | |||
| self.mnp_prototypes = [ | |||
| mnp.ones((2, 3, 4)), | |||
| mnp.ones((0, 3, 0, 2, 5)), | |||
| mnp.ones((2, 7, 0)), | |||
| mnp.ones(()), | |||
| mnp.ones((1, 3, 1, 2, 5)), | |||
| mnp.ones((2, 7, 1)), | |||
| [mnp.ones(3), (1, 2, 3), mnp.ones(3), [4, 5, 6]], | |||
| ([(1, 2), mnp.ones(2)], (mnp.ones(2), [3, 4])), | |||
| ] | |||
| self.onp_prototypes = [ | |||
| onp.ones((2, 3, 4)), | |||
| onp.ones((0, 3, 0, 2, 5)), | |||
| onp.ones((2, 7, 0)), | |||
| onp.ones(()), | |||
| onp.ones((1, 3, 1, 2, 5)), | |||
| onp.ones((2, 7, 1)), | |||
| [onp.ones(3), (1, 2, 3), onp.ones(3), [4, 5, 6]], | |||
| ([(1, 2), onp.ones(2)], (onp.ones(2), [3, 4])), | |||
| ] | |||
| @@ -257,10 +255,6 @@ def test_full(): | |||
| expected = mnp.full((2, 2), [1, 2]).asnumpy() | |||
| match_array(actual, expected) | |||
| actual = onp.full((2, 0), onp.inf) | |||
| expected = mnp.full((2, 0), mnp.inf).asnumpy() | |||
| match_array(actual, expected) | |||
| actual = onp.full((2, 3), True) | |||
| expected = mnp.full((2, 3), True).asnumpy() | |||
| match_array(actual, expected) | |||
| @@ -579,29 +573,19 @@ def onp_diagonal(arr): | |||
| @pytest.mark.platform_x86_cpu | |||
| @pytest.mark.env_onecard | |||
| def test_diagonal(): | |||
| arr = rand_int(0, 0) | |||
| match_res(mnp.diagonal, onp.diagonal, arr, offset=1) | |||
| arr = rand_int(3, 5) | |||
| for i in [-10, -5, -1, 0, 2, 5, 6, 10]: | |||
| for i in [-1, 0, 2]: | |||
| match_res(mnp.diagonal, onp.diagonal, arr, offset=i, axis1=0, axis2=1) | |||
| match_res(mnp.diagonal, onp.diagonal, arr, offset=i, axis1=1, axis2=0) | |||
| arr = rand_int(7, 4, 9) | |||
| for i in [-10, -5, -1, 0, 2, 5, 6, 10]: | |||
| for i in [-1, 0, 2]: | |||
| match_res(mnp.diagonal, onp.diagonal, arr, offset=i, axis1=0, axis2=-1) | |||
| match_res(mnp.diagonal, onp.diagonal, arr, offset=i, axis1=-2, axis2=2) | |||
| match_res(mnp.diagonal, onp.diagonal, arr, | |||
| offset=i, axis1=-1, axis2=-2) | |||
| arr = rand_int(2, 5, 8, 1) | |||
| match_res(mnp_diagonal, onp_diagonal, arr) | |||
| for i in [-10, -5, -1, 0, 2, 5, 6, 10]: | |||
| match_res(mnp.diagonal, onp.diagonal, arr, offset=i, axis1=-3, axis2=2) | |||
| match_res(mnp.diagonal, onp.diagonal, arr, offset=i, axis1=1, axis2=3) | |||
| match_res(mnp.diagonal, onp.diagonal, arr, offset=i, axis1=0, axis2=-2) | |||
| match_res(mnp.diagonal, onp.diagonal, arr, offset=i, axis1=2, axis2=-1) | |||
| def mnp_trace(arr): | |||
| return mnp.trace(arr, offset=4, axis1=1, axis2=2) | |||
| @@ -618,27 +602,18 @@ def onp_trace(arr): | |||
| @pytest.mark.platform_x86_cpu | |||
| @pytest.mark.env_onecard | |||
| def test_trace(): | |||
| arr = rand_int(0, 0) | |||
| match_res(mnp.trace, onp.trace, arr, offset=1) | |||
| arr = rand_int(3, 5) | |||
| for i in [-10, -5, -1, 0, 2, 5, 6, 10]: | |||
| for i in [-1, 0]: | |||
| match_res(mnp.trace, onp.trace, arr, offset=i, axis1=0, axis2=1) | |||
| match_res(mnp.trace, onp.trace, arr, offset=i, axis1=1, axis2=0) | |||
| arr = rand_int(7, 4, 9) | |||
| for i in [-10, -5, -1, 0, 2, 5, 6, 10]: | |||
| for i in [-1, 0, 2]: | |||
| match_res(mnp.trace, onp.trace, arr, offset=i, axis1=0, axis2=-1) | |||
| match_res(mnp.trace, onp.trace, arr, offset=i, axis1=-2, axis2=2) | |||
| match_res(mnp.trace, onp.trace, arr, offset=i, axis1=-1, axis2=-2) | |||
| arr = rand_int(2, 5, 8, 1) | |||
| match_res(mnp_trace, onp_trace, arr) | |||
| for i in [-10, -5, -1, 0, 2, 5, 6, 10]: | |||
| match_res(mnp.trace, onp.trace, arr, offset=i, axis1=-3, axis2=2) | |||
| match_res(mnp.trace, onp.trace, arr, offset=i, axis1=1, axis2=3) | |||
| match_res(mnp.trace, onp.trace, arr, offset=i, axis1=0, axis2=-2) | |||
| match_res(mnp.trace, onp.trace, arr, offset=i, axis1=2, axis2=-1) | |||
| def mnp_meshgrid(*xi): | |||
| @@ -712,7 +687,7 @@ def test_ogrid(): | |||
| @pytest.mark.platform_x86_cpu | |||
| @pytest.mark.env_onecard | |||
| def test_diagflat(): | |||
| arrs = [rand_int(0), rand_int(2, 3), rand_int(3, 5, 0)] | |||
| arrs = [rand_int(2, 3)] | |||
| for arr in arrs: | |||
| for i in [-2, 0, 7]: | |||
| match_res(mnp.diagflat, onp.diagflat, arr, k=i) | |||
| @@ -725,8 +700,7 @@ def test_diagflat(): | |||
| @pytest.mark.platform_x86_cpu | |||
| @pytest.mark.env_onecard | |||
| def test_diag(): | |||
| arrs = [rand_int(0), rand_int(0, 0), rand_int(7), rand_int(5, 5), | |||
| rand_int(3, 8), rand_int(9, 6)] | |||
| arrs = [rand_int(7), rand_int(5, 5), rand_int(3, 8), rand_int(9, 6)] | |||
| for arr in arrs: | |||
| for i in [-10, -5, -1, 0, 2, 5, 6, 10]: | |||
| match_res(mnp.diag, onp.diag, arr, k=i) | |||
| @@ -29,7 +29,7 @@ from .utils import rand_int, run_non_kw_test, check_all_results, match_array, \ | |||
| class Cases(): | |||
| def __init__(self): | |||
| self.all_shapes = [ | |||
| 0, 1, 2, (), (1,), (2,), (1, 2, 3), [], [1], [2], [1, 2, 3] | |||
| 1, 2, (1,), (2,), (1, 2, 3), [1], [2], [1, 2, 3] | |||
| ] | |||
| self.onp_dtypes = [onp.int32, 'int32', int, | |||
| onp.float32, 'float32', float, | |||
| @@ -97,18 +97,12 @@ class Cases(): | |||
| self.mnp_prototypes = [ | |||
| mnp.ones((2, 3, 4)), | |||
| mnp.ones((0, 3, 0, 2, 5)), | |||
| onp.ones((2, 7, 0)), | |||
| onp.ones(()), | |||
| [mnp.ones(3), (1, 2, 3), onp.ones(3), [4, 5, 6]], | |||
| ([(1, 2), mnp.ones(2)], (onp.ones(2), [3, 4])), | |||
| ] | |||
| self.onp_prototypes = [ | |||
| onp.ones((2, 3, 4)), | |||
| onp.ones((0, 3, 0, 2, 5)), | |||
| onp.ones((2, 7, 0)), | |||
| onp.ones(()), | |||
| [onp.ones(3), (1, 2, 3), onp.ones(3), [4, 5, 6]], | |||
| ([(1, 2), onp.ones(2)], (onp.ones(2), [3, 4])), | |||
| ] | |||
| @@ -794,11 +788,6 @@ def test_stack(): | |||
| for i in range(-4, 4): | |||
| match_res(mnp.stack, onp.stack, arr, axis=i) | |||
| arr = rand_int(7, 4, 0, 3) | |||
| match_res(mnp.stack, onp.stack, arr) | |||
| for i in range(-4, 4): | |||
| match_res(mnp.stack, onp.stack, arr, axis=i) | |||
| arrs = [rand_int(3, 4, 5) for i in range(10)] | |||
| match_res(mnp.stack, onp.stack, arrs) | |||
| match_res(mnp.stack, onp.stack, tuple(arrs)) | |||
| @@ -806,13 +795,6 @@ def test_stack(): | |||
| for i in range(-4, 4): | |||
| match_res(mnp.stack, onp.stack, arrs, axis=i) | |||
| arrs = [rand_int(3, 0, 5, 8, 0) for i in range(5)] | |||
| match_res(mnp.stack, onp.stack, arrs) | |||
| match_res(mnp.stack, onp.stack, tuple(arrs)) | |||
| match_res(mnp_stack, onp_stack, *arrs) | |||
| for i in range(-6, 6): | |||
| match_res(mnp.stack, onp.stack, arrs, axis=i) | |||
| def mnp_roll(input_tensor): | |||
| a = mnp.roll(input_tensor, -3) | |||
| @@ -868,28 +850,22 @@ def onp_moveaxis(a): | |||
| def test_moveaxis(): | |||
| a = rand_int(2, 4, 5, 9, 6) | |||
| match_res(mnp_moveaxis, onp_moveaxis, a) | |||
| a = rand_int(2, 4, 5, 0, 6, 7, 1, 3, 8) | |||
| match_res(mnp_moveaxis, onp_moveaxis, a) | |||
| def mnp_tile(x): | |||
| a = mnp.tile(x, 0) | |||
| b = mnp.tile(x, 1) | |||
| c = mnp.tile(x, 3) | |||
| d = mnp.tile(x, [5, 1]) | |||
| e = mnp.tile(x, (3, 1, 0)) | |||
| f = mnp.tile(x, [5, 1, 2, 3, 7]) | |||
| return a, b, c, d, e, f | |||
| a = mnp.tile(x, 1) | |||
| b = mnp.tile(x, 3) | |||
| c = mnp.tile(x, [5, 1]) | |||
| d = mnp.tile(x, [5, 1, 2, 3, 7]) | |||
| return a, b, c, d | |||
| def onp_tile(x): | |||
| a = onp.tile(x, 0) | |||
| b = onp.tile(x, 1) | |||
| c = onp.tile(x, 3) | |||
| d = onp.tile(x, [5, 1]) | |||
| e = onp.tile(x, (3, 1, 0)) | |||
| f = onp.tile(x, [5, 1, 2, 3, 7]) | |||
| return a, b, c, d, e, f | |||
| a = onp.tile(x, 1) | |||
| b = onp.tile(x, 3) | |||
| c = onp.tile(x, [5, 1]) | |||
| d = onp.tile(x, [5, 1, 2, 3, 7]) | |||
| return a, b, c, d | |||
| @pytest.mark.level1 | |||
| @@ -901,8 +877,6 @@ def onp_tile(x): | |||
| def test_tile(): | |||
| a = rand_int(2, 3, 4) | |||
| match_res(mnp_tile, onp_tile, a) | |||
| b = rand_int(5, 0, 8) | |||
| match_res(mnp_tile, onp_tile, b) | |||
| def mnp_broadcast_to(x): | |||
| @@ -1022,21 +996,13 @@ def test_fliplr(): | |||
| def mnp_split(input_tensor): | |||
| a = mnp.split(input_tensor, indices_or_sections=1) | |||
| b = mnp.split(input_tensor, indices_or_sections=3) | |||
| c = mnp.split(input_tensor, indices_or_sections=(-9, -8, 6)) | |||
| d = mnp.split(input_tensor, indices_or_sections=(3, 2, 1)) | |||
| e = mnp.split(input_tensor, indices_or_sections=(-10, -4, 5, 10)) | |||
| f = mnp.split(input_tensor, indices_or_sections=[0, 2], axis=1) | |||
| return a, b, c, d, e, f | |||
| return a, b | |||
| def onp_split(input_array): | |||
| a = onp.split(input_array, indices_or_sections=1) | |||
| b = onp.split(input_array, indices_or_sections=3) | |||
| c = onp.split(input_array, indices_or_sections=(-9, -8, 6)) | |||
| d = onp.split(input_array, indices_or_sections=(3, 2, 1)) | |||
| e = onp.split(input_array, indices_or_sections=(-10, -4, 5, 10)) | |||
| f = onp.split(input_array, indices_or_sections=[0, 2], axis=1) | |||
| return a, b, c, d, e, f | |||
| return a, b | |||
| @pytest.mark.level1 | |||
| @@ -1090,16 +1056,12 @@ def test_array_split(): | |||
| def mnp_vsplit(input_tensor): | |||
| a = mnp.vsplit(input_tensor, indices_or_sections=3) | |||
| b = mnp.vsplit(input_tensor, indices_or_sections=(-10, -4, 5, 10)) | |||
| c = mnp.vsplit(input_tensor, indices_or_sections=[0, 2]) | |||
| return a, b, c | |||
| return a | |||
| def onp_vsplit(input_array): | |||
| a = onp.vsplit(input_array, indices_or_sections=3) | |||
| b = onp.vsplit(input_array, indices_or_sections=(-10, -4, 5, 10)) | |||
| c = onp.vsplit(input_array, indices_or_sections=[0, 2]) | |||
| return a, b, c | |||
| return a | |||
| @pytest.mark.level1 | |||
| @@ -1123,16 +1085,12 @@ def test_vsplit(): | |||
| def mnp_hsplit(input_tensor): | |||
| a = mnp.hsplit(input_tensor, indices_or_sections=3) | |||
| b = mnp.hsplit(input_tensor, indices_or_sections=(-10, -4, 5, 10)) | |||
| c = mnp.hsplit(input_tensor, indices_or_sections=[0, 2]) | |||
| return a, b, c | |||
| return a | |||
| def onp_hsplit(input_array): | |||
| a = onp.hsplit(input_array, indices_or_sections=3) | |||
| b = onp.hsplit(input_array, indices_or_sections=(-10, -4, 5, 10)) | |||
| c = onp.hsplit(input_array, indices_or_sections=[0, 2]) | |||
| return a, b, c | |||
| return a | |||
| @pytest.mark.level1 | |||
| @@ -1156,17 +1114,11 @@ def test_hsplit(): | |||
| def mnp_dsplit(input_tensor): | |||
| a = mnp.dsplit(input_tensor, indices_or_sections=3) | |||
| b = mnp.dsplit(input_tensor, indices_or_sections=(-10, -4, 5, 10)) | |||
| c = mnp.dsplit(input_tensor, indices_or_sections=[0, 2]) | |||
| return a, b, c | |||
| return a | |||
| def onp_dsplit(input_array): | |||
| a = onp.dsplit(input_array, indices_or_sections=3) | |||
| b = onp.dsplit(input_array, indices_or_sections=(-10, -4, 5, 10)) | |||
| c = onp.dsplit(input_array, indices_or_sections=[0, 2]) | |||
| return a, b, c | |||
| return a | |||
| @pytest.mark.level1 | |||
| @pytest.mark.platform_arm_ascend_training | |||
| @@ -37,13 +37,6 @@ class Cases(): | |||
| rand_int(1, 1), | |||
| ] | |||
| # empty arrays | |||
| self.empty_arrs = [ | |||
| rand_int(0), | |||
| rand_int(4, 0), | |||
| rand_int(2, 0, 2), | |||
| ] | |||
| # arrays of the same size expanded across the 0th dimension | |||
| self.expanded_arrs = [ | |||
| rand_int(2, 3), | |||
| @@ -244,8 +237,6 @@ def test_float_power(): | |||
| @pytest.mark.level1 | |||
| @pytest.mark.platform_arm_ascend_training | |||
| @pytest.mark.platform_x86_ascend_training | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.platform_x86_cpu | |||
| @pytest.mark.env_onecard | |||
| @@ -687,11 +678,11 @@ def test_ptp(): | |||
| def mnp_add_dtype(x1, x2): | |||
| return mnp.add(x1, x2, dtype=mnp.float16) | |||
| return mnp.add(x1, x2, dtype=mnp.float32) | |||
| def onp_add_dtype(x1, x2): | |||
| return onp.add(x1, x2, dtype=onp.float16) | |||
| return onp.add(x1, x2, dtype=onp.float32) | |||
| @pytest.mark.level1 | |||
| @@ -927,8 +918,6 @@ def onp_maximum(x1, x2): | |||
| @pytest.mark.level1 | |||
| @pytest.mark.platform_arm_ascend_training | |||
| @pytest.mark.platform_x86_ascend_training | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.platform_x86_cpu | |||
| @pytest.mark.env_onecard | |||
| @@ -1410,24 +1399,22 @@ def mnp_diff(input_tensor): | |||
| a = mnp.diff(input_tensor, 2, append=3.0) | |||
| b = mnp.diff(input_tensor, 4, prepend=6, axis=-2) | |||
| c = mnp.diff(input_tensor, 0, append=3.0, axis=-1) | |||
| d = mnp.diff(input_tensor, 10, prepend=6) | |||
| e = mnp.diff(input_tensor, 1, prepend=input_tensor) | |||
| f = mnp.ediff1d(input_tensor, to_end=input_tensor) | |||
| g = mnp.ediff1d(input_tensor) | |||
| h = mnp.ediff1d(input_tensor, to_begin=3) | |||
| return a, b, c, d, e, f, g, h | |||
| d = mnp.diff(input_tensor, 1, prepend=input_tensor) | |||
| e = mnp.ediff1d(input_tensor, to_end=input_tensor) | |||
| f = mnp.ediff1d(input_tensor) | |||
| g = mnp.ediff1d(input_tensor, to_begin=3) | |||
| return a, b, c, d, e, f, g | |||
| def onp_diff(input_array): | |||
| a = onp.diff(input_array, 2, append=3.0) | |||
| b = onp.diff(input_array, 4, prepend=6, axis=-2) | |||
| c = onp.diff(input_array, 0, append=3.0, axis=-1) | |||
| d = onp.diff(input_array, 10, prepend=6) | |||
| e = onp.diff(input_array, 1, prepend=input_array) | |||
| f = onp.ediff1d(input_array, to_end=input_array) | |||
| g = onp.ediff1d(input_array) | |||
| h = onp.ediff1d(input_array, to_begin=3) | |||
| return a, b, c, d, e, f, g, h | |||
| d = onp.diff(input_array, 1, prepend=input_array) | |||
| e = onp.ediff1d(input_array, to_end=input_array) | |||
| f = onp.ediff1d(input_array) | |||
| g = onp.ediff1d(input_array, to_begin=3) | |||
| return a, b, c, d, e, f, g | |||
| @pytest.mark.level1 | |||
| @@ -1926,7 +1913,6 @@ def test_mean(): | |||
| run_multi_test(mnp_mean, onp_mean, test_case.arrs, error=3) | |||
| run_multi_test(mnp_mean, onp_mean, test_case.expanded_arrs, error=3) | |||
| run_multi_test(mnp_mean, onp_mean, test_case.scalars, error=3) | |||
| run_multi_test(mnp_mean, onp_mean, test_case.empty_arrs, error=3) | |||
| @pytest.mark.level1 | |||
| @@ -1961,3 +1947,14 @@ def test_exception_add(): | |||
| def test_exception_mean(): | |||
| with pytest.raises(ValueError): | |||
| mnp.mean(to_tensor(test_case.arrs[0]), (-1, 0)) | |||
| @pytest.mark.level1 | |||
| @pytest.mark.platform_arm_ascend_training | |||
| @pytest.mark.platform_x86_ascend_training | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.platform_x86_cpu | |||
| @pytest.mark.env_onecard | |||
| def test_exception_amax(): | |||
| with pytest.raises(TypeError): | |||
| mnp.amax(mnp.array([[1, 2], [3, 4]]).astype(mnp.float32), initial=[1.0, 2.0]) | |||