| @@ -22,36 +22,59 @@ Note: | |||
| - array_ops.py defines all the array operation interfaces. | |||
| - array_creations.py defines all the array generation interfaces. | |||
| - math_ops.py defines all the math operations on tensors. | |||
| - logic_ops.py defines all the logical operations on tensors. | |||
| - dtypes.py defines all the mindspore.numpy dtypes (mainly redirected from mindspore) | |||
| """ | |||
| from .array_ops import (transpose, expand_dims, squeeze, rollaxis, swapaxes, reshape, | |||
| ravel, concatenate, where, atleast_1d, atleast_2d, atleast_3d, | |||
| column_stack, hstack, dstack, vstack, stack, unique) | |||
| column_stack, hstack, dstack, vstack, stack, unique, moveaxis, | |||
| tile, broadcast_to, broadcast_arrays, roll, append, split, vsplit, | |||
| flip, flipud, fliplr, hsplit, dsplit, take_along_axis, take, repeat) | |||
| from .array_creations import copy_ as copy | |||
| from .array_creations import (array, asarray, asfarray, ones, zeros, full, arange, | |||
| linspace, logspace, eye, identity, empty, empty_like, | |||
| ones_like, zeros_like, full_like, diagonal, tril, triu, | |||
| tri, trace) | |||
| tri, trace, cumsum, meshgrid, mgrid, ogrid, diagflat, | |||
| diag, diag_indices, ix_) | |||
| from .dtypes import (int_, int8, int16, int32, int64, uint, uint8, uint16, | |||
| uint32, uint64, float_, float16, float32, float64, bool_, inf, | |||
| numeric_types) | |||
| from .math_ops import (mean, inner, add, subtract, multiply, divide, power, | |||
| dot, outer, tensordot, absolute) | |||
| uint32, uint64, float_, float16, float32, float64, bool_, inf, nan, | |||
| numeric_types, PINF, NINF) | |||
| from .math_ops import (mean, inner, add, subtract, multiply, divide, true_divide, power, | |||
| dot, outer, tensordot, absolute, std, var, average, minimum, | |||
| matmul, square, sqrt, reciprocal, log, maximum, heaviside, amax, amin, | |||
| hypot, float_power, floor, ptp, deg2rad, rad2deg, count_nonzero, | |||
| positive, negative, clip, floor_divide, remainder, fix, fmod, trunc, | |||
| exp, expm1) | |||
| from .logic_ops import (not_equal, less_equal, less, greater_equal, greater, equal, isfinite, | |||
| isnan, isinf, isposinf, isneginf, isscalar) | |||
| mod = remainder | |||
| fabs = absolute | |||
| array_ops_module = ['transpose', 'expand_dims', 'squeeze', 'rollaxis', 'swapaxes', 'reshape', | |||
| 'ravel', 'concatenate', 'where', 'atleast_1d', 'atleast_2d', 'atleast_3d', | |||
| 'column_stack', 'hstack', 'dstack', 'vstack', 'stack', 'unique'] | |||
| 'column_stack', 'hstack', 'dstack', 'vstack', 'stack', 'unique', 'moveaxis', | |||
| 'tile', 'broadcast_to', 'broadcast_arrays', 'append', 'roll', 'split', 'vsplit', | |||
| 'flip', 'flipud', 'fliplr', 'hsplit', 'dsplit', 'take_along_axis', 'take', | |||
| 'repeat'] | |||
| array_creations_module = ['array', 'asarray', 'asfarray', 'ones', 'zeros', 'full', 'arange', | |||
| 'linspace', 'logspace', 'eye', 'identity', 'empty', 'empty_like', | |||
| 'ones_like', 'zeros_like', 'full_like', 'diagonal', 'tril', 'triu', | |||
| 'tri', 'trace'] | |||
| 'tri', 'trace', 'meshgrid', 'mgrid', 'ogrid', 'diagflat', 'diag', | |||
| 'diag_indices', 'ix_', 'cumsum'] | |||
| math_module = ['mean', 'inner', 'add', 'subtract', 'multiply', 'divide', 'power', | |||
| 'dot', 'outer', 'tensordot', 'absolute'] | |||
| math_module = ['mean', 'inner', 'add', 'subtract', 'multiply', 'divide', 'true_divide', 'power', | |||
| 'dot', 'outer', 'tensordot', 'absolute', 'std', 'var', 'average', 'not_equal', | |||
| 'minimum', 'matmul', 'square', 'sqrt', 'reciprocal', 'log', 'maximum', | |||
| 'heaviside', 'amax', 'amin', 'hypot', 'float_power', 'floor', 'ptp', 'deg2rad', | |||
| 'rad2deg', 'count_nonzero', 'positive', 'negative', 'clip', 'floor_divide', | |||
| 'remainder', 'mod', 'fix', 'fmod', 'trunc', 'exp', 'expm1', 'fabs'] | |||
| __all__ = array_ops_module + array_creations_module + math_module + numeric_types | |||
| logic_module = ['not_equal', 'less_equal', 'less', 'greater_equal', 'greater', 'equal', 'isfinite', | |||
| 'isnan', 'isinf', 'isposinf', 'isneginf', 'isscalar'] | |||
| __all__ = array_ops_module + array_creations_module + math_module + logic_module + numeric_types | |||
| __all__.sort() | |||
| @@ -22,7 +22,12 @@ from ..common.dtype import (int8, int16, int32, int64, uint8, uint16, uint32, ui | |||
| # backend for now. | |||
| inf = float('inf') | |||
| PINF = float('inf') | |||
| NINF = float('-inf') | |||
| nan = float('nan') | |||
| # all three of inf, PINF, and NINF are defined in the original numpy, and as we aim for | |||
| # consistency same thing is done here | |||
| pi = 3.141592653589793 | |||
| int_ = int32 | |||
| uint = uint32 | |||
| @@ -0,0 +1,576 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """logical operations, the function docs are adapted from Numpy API.""" | |||
| from .math_ops import _apply_tensor_op | |||
| from ..ops import functional as F | |||
| from ..common import dtype as mstype | |||
| from .._c_expression import typing | |||
| from .array_creations import zeros, ones | |||
| from .utils import _check_input_tensor | |||
| def not_equal(x1, x2, out=None, where=True, dtype=None): | |||
| """ | |||
| Returns (x1 != x2) element-wise. | |||
| Args: | |||
| x1 (Tensor): First input tensor to be compared. | |||
| x2 (Tensor): Second input tensor to be compared. | |||
| out (Tensor or None, optional), default is None. | |||
| where (Tensor or None, optional): For any non-default value of type other | |||
| than :class:`Tensor` or :class:`None`, the output retains its original value. | |||
| This condition is broadcasted over the input. At locations where the | |||
| condition is `True`, the out array will be set to the ufunc result. | |||
| Elsewhere, the out array will retain its original value. Note that | |||
| if an uninitialized out array is created via the default ``out=None``, | |||
| locations within it where the condition is `False` will remain | |||
| uninitialized. | |||
| dtype (:class:`mindspore.dtype`, optional): defaults to None. Overrides the dtype of the | |||
| output Tensor. | |||
| Returns: | |||
| Tensor or scalar, element-wise comparison of `x1` and `x2`. Typically of type | |||
| bool, unless `dtype` is passed. This is a scalar if both `x1` and `x2` are | |||
| scalars. | |||
| Raises: | |||
| TypeError: If the input is not a tensor. | |||
| Supported Platforms: | |||
| ``Ascend`` ``GPU`` ``CPU`` | |||
| Examples: | |||
| >>> import mindspore.numpy as np | |||
| >>> a = np.asarray([1, 2]) | |||
| >>> b = np.asarray([[1, 3],[1, 4]]) | |||
| >>> print(np.not_equal(a, b)) | |||
| >>> [[False True] | |||
| [False True]] | |||
| """ | |||
| _check_input_tensor(x1, x2) | |||
| return _apply_tensor_op(F.not_equal, x1, x2, out=out, where=where, dtype=dtype) | |||
| def less_equal(x1, x2, out=None, where=True, dtype=None): | |||
| """ | |||
| Returns the truth value of ``(x1 <= x2)`` element-wise. | |||
| Note: | |||
| Numpy arguments `casting`, `order`, `dtype`, `subok`, `signature`, and `extobj` are | |||
| not supported. | |||
| When `where` is provided, `out` must have a tensor value. `out` is not supported | |||
| for storing the result, however it can be used in combination with `where` to set | |||
| the value at indices for which `where` is set to False. | |||
| Args: | |||
| x1 (Tensor): Input array. | |||
| x2 (Tensor): Input array. If ``x1.shape != x2.shape``, they must be | |||
| broadcastable to a common shape (which becomes the shape of the output). | |||
| out (Tensor or None, optional): defaults to None. | |||
| where (Tensor or None, optional): For any non-default value of type other | |||
| than :class:`Tensor` or :class:`None`, the output retains its original value. | |||
| This condition is broadcasted over the input. At locations where the | |||
| condition is `True`, the out array will be set to the ufunc result. | |||
| Elsewhere, the out array will retain its original value. Note that | |||
| if an uninitialized out array is created via the default ``out=None``, | |||
| locations within it where the condition is `False` will remain | |||
| uninitialized. | |||
| dtype (:class:`mindspore.dtype`, optional): defaults to None. Overrides the dtype of the | |||
| output Tensor. | |||
| Returns: | |||
| Tensor or scalar, element-wise comparison of `x1` and `x2`. Typically of type | |||
| bool, unless ``dtype=object`` is passed. This is a scalar if both `x1` and `x2` are | |||
| scalars. | |||
| Raises: | |||
| TypeError: if the input is not a tensor. | |||
| Supported Platforms: | |||
| ``Ascend`` ``GPU`` ``CPU`` | |||
| Examples: | |||
| >>> output = np.less_equal(np.array([4, 2, 1]), np.array([2, 2, 2])) | |||
| >>> print(output) | |||
| [False True True] | |||
| """ | |||
| _check_input_tensor(x1, x2) | |||
| return _apply_tensor_op(F.tensor_le, x1, x2, out=out, where=where, dtype=dtype) | |||
| def less(x1, x2, out=None, where=True, dtype=None): | |||
| """ | |||
| Returns the truth value of ``(x1 < x2)`` element-wise. | |||
| Note: | |||
| Numpy arguments `casting`, `order`, `dtype`, `subok`, `signature`, and `extobj` are | |||
| not supported. | |||
| When `where` is provided, `out` must have a tensor value. `out` is not supported | |||
| for storing the result, however it can be used in combination with `where` to set | |||
| the value at indices for which `where` is set to False. | |||
| Args: | |||
| x1 (Tensor): input array. | |||
| x2 (Tensor): Input array. If ``x1.shape != x2.shape``, they must be | |||
| broadcastable to a common shape (which becomes the shape of the output). | |||
| out (Tensor or None, optional): defaults to None. | |||
| where (Tensor or None, optional): For any non-default value of type other | |||
| than :class:`Tensor` or :class:`None`, the output retains its original value. | |||
| This condition is broadcasted over the input. At locations where the | |||
| condition is `True`, the out array will be set to the ufunc result. | |||
| Elsewhere, the out array will retain its original value. Note that | |||
| if an uninitialized out array is created via the default ``out=None``, | |||
| locations within it where the condition is `False` will remain | |||
| uninitialized. | |||
| dtype (:class:`mindspore.dtype`, optional): defaults to None. Overrides the dtype of the | |||
| output Tensor. | |||
| Returns: | |||
| Tensor or scalar, element-wise comparison of `x1` and `x2`. Typically of type | |||
| bool, unless ``dtype=object`` is passed. This is a scalar if both `x1` and `x2` are | |||
| scalars. | |||
| Raises: | |||
| TypeError: if the input is not a tensor. | |||
| Supported Platforms: | |||
| ``Ascend`` ``GPU`` ``CPU`` | |||
| Examples: | |||
| >>> output = np.less(np.array([1, 2]), np.array([2, 2])) | |||
| >>> print(output) | |||
| [ True False] | |||
| """ | |||
| return _apply_tensor_op(F.tensor_lt, x1, x2, out=out, where=where, dtype=dtype) | |||
| def greater_equal(x1, x2, out=None, where=True, dtype=None): | |||
| """ | |||
| Returns the truth value of ``(x1 >= x2)`` element-wise. | |||
| Note: | |||
| Numpy arguments `casting`, `order`, `dtype`, `subok`, `signature`, and `extobj` are | |||
| not supported. | |||
| When `where` is provided, `out` must have a tensor value. `out` is not supported | |||
| for storing the result, however it can be used in combination with `where` to set | |||
| the value at indices for which `where` is set to False. | |||
| Args: | |||
| x1 (Tensor): Input array. | |||
| x2 (Tensor): Input array. If ``x1.shape != x2.shape``, they must be | |||
| broadcastable to a common shape (which becomes the shape of the output). | |||
| out (Tensor or None, optional): defaults to None. | |||
| where (Tensor or None, optional): For any non-default value of type other | |||
| than :class:`Tensor` or :class:`None`, the output retains its original value. | |||
| This condition is broadcasted over the input. At locations where the | |||
| condition is `True`, the out array will be set to the ufunc result. | |||
| Elsewhere, the out array will retain its original value. Note that | |||
| if an uninitialized out array is created via the default ``out=None``, | |||
| locations within it where the condition is `False` will remain | |||
| uninitialized. | |||
| dtype (:class:`mindspore.dtype`, optional): defaults to None. Overrides the dtype of the | |||
| output Tensor. | |||
| Returns: | |||
| Tensor or scalar, element-wise comparison of `x1` and `x2`. Typically of type | |||
| bool, unless ``dtype=object`` is passed. This is a scalar if both `x1` and `x2` are | |||
| scalars. | |||
| Raises: | |||
| TypeError: if the input is not a tensor. | |||
| Supported Platforms: | |||
| ``Ascend`` ``GPU`` ``CPU`` | |||
| Examples: | |||
| >>> output = np.greater_equal(np.array([4, 2, 1]), np.array([2, 2, 2])) | |||
| >>> print(output) | |||
| [ True True False] | |||
| """ | |||
| return _apply_tensor_op(F.tensor_ge, x1, x2, out=out, where=where, dtype=dtype) | |||
| def greater(x1, x2, out=None, where=True, dtype=None): | |||
| """ | |||
| Returns the truth value of ``(x1 > x2)`` element-wise. | |||
| Note: | |||
| Numpy arguments `casting`, `order`, `dtype`, `subok`, `signature`, and `extobj` are | |||
| not supported. | |||
| When `where` is provided, `out` must have a tensor value. `out` is not supported | |||
| for storing the result, however it can be used in combination with `where` to set | |||
| the value at indices for which `where` is set to False. | |||
| Args: | |||
| x1 (Tensor): Input array. | |||
| x2 (Tensor): Input array. If ``x1.shape != x2.shape``, they must be | |||
| broadcastable to a common shape (which becomes the shape of the output). | |||
| out (Tensor or None, optional): defaults to None. | |||
| where (Tensor or None, optional): For any non-default value of type other | |||
| than :class:`Tensor` or :class:`None`, the output retains its original value. | |||
| This condition is broadcasted over the input. At locations where the | |||
| condition is `True`, the out array will be set to the ufunc result. | |||
| Elsewhere, the out array will retain its original value. Note that | |||
| if an uninitialized out array is created via the default ``out=None``, | |||
| locations within it where the condition is `False` will remain | |||
| uninitialized. | |||
| dtype (:class:`mindspore.dtype`, optional): defaults to None. Overrides the dtype of the | |||
| output Tensor. | |||
| Returns: | |||
| Tensor or scalar, element-wise comparison of `x1` and `x2`. Typically of type | |||
| bool, unless ``dtype=object`` is passed. This is a scalar if both `x1` and `x2` are | |||
| scalars. | |||
| Raises: | |||
| TypeError: if the input is not a tensor. | |||
| Supported Platforms: | |||
| ``Ascend`` ``GPU`` ``CPU`` | |||
| Examples: | |||
| >>> output = np.greater(np.array([4, 2]), np.array([2, 2])) | |||
| >>> print(output) | |||
| [ True False] | |||
| """ | |||
| return _apply_tensor_op(F.tensor_gt, x1, x2, out=out, where=where, dtype=dtype) | |||
| def equal(x1, x2, out=None, where=True, dtype=None): | |||
| """ | |||
| Returns the truth value of ``(x1 == x2)`` element-wise. | |||
| Note: | |||
| Numpy arguments `casting`, `order`, `dtype`, `subok`, `signature`, and `extobj` are | |||
| not supported. | |||
| When `where` is provided, `out` must have a tensor value. `out` is not supported | |||
| for storing the result, however it can be used in combination with `where` to set | |||
| the value at indices for which `where` is set to False. | |||
| Args: | |||
| x1 (Tensor): Input array. | |||
| x2 (Tensor): Input array. If ``x1.shape != x2.shape``, they must be | |||
| broadcastable to a common shape (which becomes the shape of the output). | |||
| out (Tensor or None, optional): defaults to None. | |||
| where (Tensor or None, optional): For any non-default value of type other | |||
| than :class:`Tensor` or :class:`None`, the output retains its original value. | |||
| This condition is broadcasted over the input. At locations where the | |||
| condition is `True`, the out array will be set to the ufunc result. | |||
| Elsewhere, the out array will retain its original value. Note that | |||
| if an uninitialized out array is created via the default ``out=None``, | |||
| locations within it where the condition is `False` will remain | |||
| uninitialized. | |||
| dtype (:class:`mindspore.dtype`, optional): defaults to None. Overrides the dtype of the | |||
| output Tensor. | |||
| Returns: | |||
| Tensor or scalar, element-wise comparison of `x1` and `x2`. Typically of type | |||
| bool, unless ``dtype=object`` is passed. This is a scalar if both `x1` and `x2` are | |||
| scalars. | |||
| Raises: | |||
| TypeError: if the input is not a tensor. | |||
| Supported Platforms: | |||
| ``Ascend`` ``GPU`` ``CPU`` | |||
| Examples: | |||
| >>> output = np.equal(np.array([0, 1, 3]), np.arange(3)) | |||
| >>> print(output) | |||
| [ True True False] | |||
| """ | |||
| return _apply_tensor_op(F.equal, x1, x2, out=out, where=where, dtype=dtype) | |||
| def isfinite(x, out=None, where=True, dtype=None): | |||
| """ | |||
| Tests element-wise for finiteness (not infinity or not Not a Number). | |||
| The result is returned as a boolean array. | |||
| Note: | |||
| Numpy arguments `casting`, `order`, `dtype`, `subok`, `signature`, and `extobj` are | |||
| not supported. | |||
| When `where` is provided, `out` must have a tensor value. `out` is not supported | |||
| for storing the result, however it can be used in combination with `where` to set | |||
| the value at indices for which `where` is set to False. | |||
| On GPU, the supported dtypes are np.float16, and np.float32. | |||
| Args: | |||
| x (Tensor): Input values. | |||
| out (Tensor or None, optional): defaults to None. | |||
| where (Tensor or None, optional): For any non-default value of type other | |||
| than :class:`Tensor` or :class:`None`, the output retains its original value. | |||
| This condition is broadcasted over the input. At locations where the | |||
| condition is `True`, the out array will be set to the ufunc result. | |||
| Elsewhere, the out array will retain its original value. Note that | |||
| if an uninitialized out array is created via the default ``out=None``, | |||
| locations within it where the condition is `False` will remain | |||
| uninitialized. | |||
| dtype (:class:`mindspore.dtype`, optional): defaults to None. Overrides the dtype of the | |||
| output Tensor. | |||
| Returns: | |||
| Tensor or scalar, true where `x` is not positive infinity, negative infinity, | |||
| or NaN; false otherwise. This is a scalar if `x` is a scalar. | |||
| Raises: | |||
| TypeError: if the input is not a tensor. | |||
| Supported Platforms: | |||
| ``Ascend`` ``GPU`` ``CPU`` | |||
| Examples: | |||
| >>> output = np.isfinite(np.array([np.inf, 1., np.nan]).astype('float32')) | |||
| >>> print(output) | |||
| [False True False] | |||
| >>> output = np.isfinite(np.log(np.array(-1.).astype('float32'))) | |||
| >>> print(output) | |||
| False | |||
| """ | |||
| return _apply_tensor_op(F.isfinite, x, out=out, where=where, dtype=dtype) | |||
| def _isnan(x): | |||
| """Compures isnan without applying keyword arguments.""" | |||
| shape = F.shape(x) | |||
| zeros_tensor = zeros(shape, mstype.float32) | |||
| ones_tensor = ones(shape, mstype.float32) | |||
| non_neg = F.tensor_ge(x, zeros_tensor) | |||
| non_pos = F.tensor_le(x, zeros_tensor) | |||
| res = F.select(non_neg, zeros_tensor, ones_tensor) | |||
| res = F.select(non_pos, zeros_tensor, res) | |||
| return F.cast(res, mstype.bool_) | |||
| def isnan(x, out=None, where=True, dtype=None): | |||
| """ | |||
| Tests element-wise for NaN and return result as a boolean array. | |||
| Note: | |||
| Numpy arguments `casting`, `order`, `dtype`, `subok`, `signature`, and `extobj` are | |||
| not supported. | |||
| When `where` is provided, `out` must have a tensor value. `out` is not supported | |||
| for storing the result, however it can be used in combination with `where` to set | |||
| the value at indices for which `where` is set to False. | |||
| On GPU, the supported dtypes are np.float16, and np.float32. | |||
| Args: | |||
| x (Tensor): Input values. | |||
| out (Tensor or None, optional): defaults to None. | |||
| where (Tensor or None, optional): For any non-default value of type other | |||
| than :class:`Tensor` or :class:`None`, the output retains its original value. | |||
| This condition is broadcasted over the input. At locations where the | |||
| condition is `True`, the out array will be set to the ufunc result. | |||
| Elsewhere, the out array will retain its original value. Note that | |||
| if an uninitialized out array is created via the default ``out=None``, | |||
| locations within it where the condition is `False` will remain | |||
| uninitialized. | |||
| dtype (:class:`mindspore.dtype`, optional): defaults to None. Overrides the dtype of the | |||
| output Tensor. | |||
| Returns: | |||
| Tensor or scalar, true where `x` is NaN, false otherwise. This is a scalar if | |||
| `x` is a scalar. | |||
| Raises: | |||
| TypeError: if the input is not a tensor. | |||
| Supported Platforms: | |||
| ``Ascend`` ``GPU`` ``CPU`` | |||
| Examples: | |||
| >>> output = np.isnan(np.array(np.nan, np.float32)) | |||
| >>> print(output) | |||
| True | |||
| >>> output = np.isnan(np.array(np.inf, np.float32)) | |||
| >>> print(output) | |||
| False | |||
| """ | |||
| return _apply_tensor_op(_isnan, x, out=out, where=where, dtype=dtype) | |||
| def _isinf(x): | |||
| """Computes isinf without applying keyword arguments.""" | |||
| shape = F.shape(x) | |||
| zeros_tensor = zeros(shape, mstype.float32) | |||
| ones_tensor = ones(shape, mstype.float32) | |||
| not_inf = F.isfinite(x) | |||
| is_nan = _isnan(x) | |||
| res = F.select(not_inf, zeros_tensor, ones_tensor) | |||
| res = F.select(is_nan, zeros_tensor, res) | |||
| return F.cast(res, mstype.bool_) | |||
| def isinf(x, out=None, where=True, dtype=None): | |||
| """ | |||
| Tests element-wise for positive or negative infinity. | |||
| Returns a boolean array of the same shape as `x`, True where ``x == +/-inf``, otherwise False. | |||
| Note: | |||
| Numpy arguments `casting`, `order`, `dtype`, `subok`, `signature`, and `extobj` are | |||
| not supported. | |||
| When `where` is provided, `out` must have a tensor value. `out` is not supported | |||
| for storing the result, however it can be used in combination with `where` to set | |||
| the value at indices for which `where` is set to False. | |||
| On GPU, the supported dtypes are np.float16, and np.float32. | |||
| Args: | |||
| x (Tensor): Input values. | |||
| out (Tensor or None, optional): defaults to None. | |||
| where (Tensor or None, optional): For any non-default value of type other | |||
| than :class:`Tensor` or :class:`None`, the output retains its original value. | |||
| This condition is broadcasted over the input. At locations where the | |||
| condition is `True`, the out array will be set to the ufunc result. | |||
| Elsewhere, the out array will retain its original value. Note that | |||
| if an uninitialized out array is created via the default ``out=None``, | |||
| locations within it where the condition is `False` will remain | |||
| uninitialized. | |||
| dtype (:class:`mindspore.dtype`, optional): defaults to None. Overrides the dtype of the | |||
| output Tensor. | |||
| Returns: | |||
| Tensor or scalar, true where `x` is positive or negative infinity, false | |||
| otherwise. This is a scalar if `x` is a scalar. | |||
| Raises: | |||
| TypeError: if the input is not a tensor. | |||
| Supported Platforms: | |||
| ``Ascend`` ``GPU`` ``CPU`` | |||
| Examples: | |||
| >>> output = np.isinf(np.array(np.inf, np.float32)) | |||
| >>> print(output) | |||
| True | |||
| >>> output = np.isinf(np.array([np.inf, -np.inf, 1.0, np.nan], np.float32)) | |||
| >>> print(output) | |||
| [ True True False False] | |||
| """ | |||
| return _apply_tensor_op(_isinf, x, out=out, where=where, dtype=dtype) | |||
| def _is_sign_inf(x, fn): | |||
| """Tests element-wise for inifinity with sign.""" | |||
| shape = F.shape(x) | |||
| zeros_tensor = zeros(shape, mstype.float32) | |||
| ones_tensor = ones(shape, mstype.float32) | |||
| not_inf = F.isfinite(x) | |||
| is_sign = fn(x, zeros_tensor) | |||
| res = F.select(not_inf, zeros_tensor, ones_tensor) | |||
| res = F.select(is_sign, res, zeros_tensor) | |||
| return F.cast(res, mstype.bool_) | |||
| def isposinf(x): | |||
| """ | |||
| Tests element-wise for positive infinity, returns result as bool array. | |||
| Note: | |||
| Numpy argument `out` is not supported. | |||
| On GPU, the supported dtypes are np.float16, and np.float32. | |||
| Args: | |||
| x (Tensor): Input values. | |||
| Returns: | |||
| Tensor or scalar, true where `x` is positive infinity, false otherwise. | |||
| This is a scalar if `x` is a scalar. | |||
| Raises: | |||
| TypeError: if the input is not a tensor. | |||
| Supported Platforms: | |||
| ``Ascend`` ``GPU`` ``CPU`` | |||
| Examples: | |||
| >>> output = np.isposinf(np.array([-np.inf, 0., np.inf], np.float32)) | |||
| >>> print(output) | |||
| [False False True] | |||
| """ | |||
| _check_input_tensor(x) | |||
| return _is_sign_inf(x, F.tensor_gt) | |||
| def isneginf(x): | |||
| """ | |||
| Tests element-wise for negative infinity, returns result as bool array. | |||
| Note: | |||
| Numpy argument `out` is not supported. | |||
| On GPU, the supported dtypes are np.float16, and np.float32. | |||
| Args: | |||
| x (Tensor): Input values. | |||
| Returns: | |||
| Tensor or scalar, true where `x` is negative infinity, false otherwise. | |||
| This is a scalar if `x` is a scalar. | |||
| Raises: | |||
| TypeError: if the input is not a tensor. | |||
| Supported Platforms: | |||
| ``Ascend`` ``GPU`` ``CPU`` | |||
| Examples: | |||
| >>> output = np.isneginf(np.array([-np.inf, 0., np.inf], np.float32)) | |||
| >>> print(output) | |||
| [ True False False] | |||
| """ | |||
| return _is_sign_inf(x, F.tensor_lt) | |||
| def isscalar(element): | |||
| """ | |||
| Returns True if the type of element is a scalar type. | |||
| Note: | |||
| Only object types recognized by the mindspore parser are supported, | |||
| which includes objects, types, methods and functions defined within | |||
| the scope of mindspore. Other built-in types are not supported. | |||
| Args: | |||
| element (any): Input argument, can be of any type and shape. | |||
| Returns: | |||
| Boolean, True if `element` is a scalar type, False if it is not. | |||
| Raises: | |||
| TypeError: if the type of `element` is not supported by mindspore parser. | |||
| Supported Platforms: | |||
| ``Ascend`` ``GPU`` ``CPU`` | |||
| Examples: | |||
| >>> output = np.isscalar(3.1) | |||
| >>> print(output) | |||
| True | |||
| >>> output = np.isscalar(np.array(3.1)) | |||
| >>> print(output) | |||
| False | |||
| >>> output = np.isscalar(False) | |||
| >>> print(output) | |||
| True | |||
| >>> output = np.isscalar('numpy') | |||
| >>> print(output) | |||
| True | |||
| """ | |||
| return isinstance(F.typeof(element), (typing.Number, typing.Int, typing.UInt, | |||
| typing.Float, typing.Bool, typing.String)) | |||
| @@ -16,11 +16,11 @@ | |||
| import numpy as onp | |||
| import mindspore.context as context | |||
| from ..common import Tensor | |||
| from ..ops import functional as F | |||
| from ..common import dtype as mstype | |||
| from .utils_const import _tile_size | |||
| from .utils_const import _tile_size, _add_unit_axes, _raise_type_error | |||
| def _deep_list(array_like): | |||
| @@ -56,10 +56,9 @@ def _deep_tensor_to_nparray(array_like): | |||
| def _check_input_for_asarray(array_like): | |||
| """check whether array_like argument is a valid type for np.asarray conversion""" | |||
| if isinstance(array_like, (Tensor, list, tuple, int, float, bool, onp.ndarray)): | |||
| return True | |||
| raise TypeError("input data must be `int`, `float`, `bool`, `Tensor`, `list`, `tuple`" + \ | |||
| f"or numpy.ndarray, but got {type(array_like)}") | |||
| if not isinstance(array_like, (Tensor, list, tuple, int, float, bool, onp.ndarray)): | |||
| _raise_type_error("input data must be `int`, `float`, `bool`, `Tensor`, `list`, `tuple`" + \ | |||
| "or numpy.ndarray, but got ", array_like) | |||
| def _is_scalar(shape): | |||
| @@ -67,16 +66,6 @@ def _is_scalar(shape): | |||
| return F.shape_mul(shape) == 1 | |||
| def _is_empty(shape): | |||
| """Checks if the shape is empty""" | |||
| return F.shape_mul(shape) == 0 | |||
| def _get_device(): | |||
| """Get the current device (`GPU`, `CPU`, `Ascend`)""" | |||
| return context.get_context('device_target') | |||
| def _convert_list_tensor_to_tuple_tensor(list_of_tensor): | |||
| """Convert a list of tensor to a tuple of tensor""" | |||
| if isinstance(list_of_tensor, list): | |||
| @@ -87,19 +76,66 @@ def _convert_list_tensor_to_tuple_tensor(list_of_tensor): | |||
| return list_of_tensor | |||
| def _get_mode(): | |||
| """Get the current mode (0 is Graph mode, 1 is PyNative mode)""" | |||
| return context.get_context('mode') | |||
| def _expand(x, ndim, axis=0): | |||
| """Expand x to ndim.""" | |||
| while F.rank(x) < ndim: | |||
| x = F.expand_dims(x, axis) | |||
| return x | |||
| """Expand x to ndim from axis, which can be 0 or -1.""" | |||
| shape = _add_unit_axes(F.shape(x), ndim, axis == -1) | |||
| return F.reshape(x, shape) | |||
| def _broadcast_to(x, shape_cur, shape_to, ndim_to): | |||
| """Broadcasts x from shape_cur to shape_to.""" | |||
| size = _tile_size(shape_cur, shape_to, ndim_to) | |||
| return F.tile(x, size) | |||
| def _broadcast_to_shape(x, shape): | |||
| """Broadcasts x from current shape to shape""" | |||
| ndim_to = len(shape) | |||
| x = _expand(x, ndim_to) | |||
| return _broadcast_to(x, F.shape(x), shape, ndim_to) | |||
| def _get_size(x, axis=None): | |||
| """Get the number of elements along the given axis of tensor x.""" | |||
| if axis is None or F.tuple_len(axis) == 0: | |||
| axis = F.make_range(x.ndim) | |||
| nums = 1 | |||
| for ax in axis: | |||
| nums *= x.shape[ax] | |||
| return nums | |||
| def _check_input_tensor(*tensors): | |||
| for tensor in tensors: | |||
| if not isinstance(tensor, Tensor): | |||
| _raise_type_error('expect Tensor, but got ', F.typeof(tensor)) | |||
| return True | |||
| def _convert_64_to_32(tensor): | |||
| """Convert tensor with float64/int64 types to float32/int32.""" | |||
| if tensor.dtype == mstype.float64: | |||
| return tensor.astype("float32") | |||
| if tensor.dtype == mstype.int64: | |||
| return tensor.astype("int32") | |||
| return tensor | |||
| def _get_dtype_from_scalar(*input_numbers): | |||
| """ | |||
| Get the final dtype from series of input numbers, compared with F.typeof, we | |||
| return int32/float32 for python int/float instead. | |||
| """ | |||
| bool_flag = True | |||
| int_flag = True | |||
| for number in input_numbers: | |||
| if number is not None: | |||
| if not isinstance(number, bool): | |||
| bool_flag = False | |||
| if not isinstance(number, int): | |||
| int_flag = False | |||
| if bool_flag: | |||
| return mstype.bool_ | |||
| if int_flag: | |||
| return mstype.int32 | |||
| return mstype.float32 | |||
| @@ -13,14 +13,16 @@ | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """internal graph-compatible utility functions""" | |||
| import math | |||
| from functools import partial | |||
| import mindspore.context as context | |||
| from ..ops import functional as F | |||
| from ..ops.primitive import constexpr | |||
| from ..common import dtype as mstype | |||
| from ..common import Tensor | |||
| from .._c_expression import Tensor as Tensor_ | |||
| from .._c_expression.typing import Tuple, List | |||
| from .._c_expression import typing | |||
| from .dtypes import promotion_rule, dtype_tuple, all_types, dtype_map | |||
| @@ -28,12 +30,17 @@ from .dtypes import promotion_rule, dtype_tuple, all_types, dtype_map | |||
| @constexpr | |||
| def _check_shape(shape): | |||
| """check the shape param to match the numpy style""" | |||
| if not isinstance(shape, (int, tuple, list, Tuple, List)): | |||
| if not isinstance(shape, (int, tuple, list, typing.Tuple, typing.List)): | |||
| raise TypeError(f"only int, tuple and list are allowed for shape, but got {type(shape)}") | |||
| if isinstance(shape, int): | |||
| shape = (shape,) | |||
| if isinstance(shape, (list, List)): | |||
| if isinstance(shape, (list, typing.List)): | |||
| shape = tuple(shape) | |||
| for s in shape: | |||
| if not isinstance(s, int): | |||
| raise TypeError("each entry in shape should be int.") | |||
| if s < 0: | |||
| raise ValueError("each entry in shape should no less than 0.") | |||
| return shape | |||
| @@ -57,7 +64,7 @@ def _check_dtype(dtype): | |||
| @constexpr | |||
| def _check_shape_contain_zero(shp): | |||
| def _is_shape_empty(shp): | |||
| """Check whether shape contains zero""" | |||
| if isinstance(shp, int): | |||
| return shp == 0 | |||
| @@ -77,35 +84,28 @@ def _check_start_normalize(start, ndim): | |||
| @constexpr | |||
| def _check_axes_range(axes, ndim): | |||
| """ | |||
| Check axes are within the number of dimensions of tensor x and normalize the negative axes. | |||
| Check axes type and normalize the negative axes. | |||
| Args: | |||
| axes (Union[int, tuple(int), list(int)]): Axes of the tensor. | |||
| axes: Axes of the tensor. | |||
| ndim (int): The number of dimensions of the tensor. | |||
| Return: | |||
| Axes (Union[int, tuple(int)]). If input is integer, return integer, else tuple. | |||
| Raises: | |||
| TypeError: If the axes are not integer, tuple(int) or list(int). | |||
| ValueError: If duplicate axes exists or some axis is out of bounds. | |||
| """ | |||
| if not isinstance(axes, int) and not isinstance(axes, tuple) and not isinstance(axes, list): | |||
| raise TypeError(f"int, tuple(int) or list(int) expected, but got {type(axes)}.") | |||
| low = -ndim | |||
| up = ndim - 1 | |||
| if low > up: | |||
| raise ValueError(f"Lower bound {low} and upper bound {up} of axes are not allowed.") | |||
| if isinstance(axes, int): | |||
| if axes < low or axes > up: | |||
| raise ValueError(f"axis {axes} is out of bounds for tensor of dimension {ndim}.") | |||
| return axes if axes >= 0 else axes + ndim | |||
| new_axes = [] | |||
| for item in axes: | |||
| if not isinstance(item, int): | |||
| raise TypeError(f"int in tuple or list expected, but got {type(item)}.") | |||
| if item < low or item > up: | |||
| raise ValueError(f"axis {item} in {axes} is out of bounds for tensor of dimension {ndim}.") | |||
| new_axes.append(item if item >= 0 else item + ndim) | |||
| return tuple(new_axes) | |||
| _check_axis_type(axes, True, True, True) | |||
| if isinstance(axes, (list, tuple)): | |||
| _check_element_int(axes) | |||
| axes = _canonicalize_axis(axes, ndim) | |||
| return axes | |||
| @constexpr | |||
| def _get_device_compile(): | |||
| def _get_device(): | |||
| """Get the current device (`GPU`, `CPU`, `Ascend`)""" | |||
| return context.get_context('device_target') | |||
| @@ -153,9 +153,10 @@ def _infer_out_shape(*shapes): | |||
| @constexpr | |||
| def _check_axis_in_range(axis, ndim): | |||
| """Checks axes are with the bounds of ndim""" | |||
| if -ndim <= axis < ndim: | |||
| return True | |||
| raise ValueError(f'axis {axis} is out of bounds for array of dimension {ndim}') | |||
| if not isinstance(axis, int): | |||
| raise TypeError(f'axes should be integers, not {type(axis)}') | |||
| if not -ndim <= axis < ndim: | |||
| raise ValueError(f'axis {axis} is out of bounds for array of dimension {ndim}') | |||
| @constexpr | |||
| @@ -165,26 +166,25 @@ def _check_axis_valid(axes, ndim): | |||
| to the built-in operator (non-negative, int or tuple) | |||
| """ | |||
| if isinstance(axes, int): | |||
| _ = _check_axis_in_range(axes, ndim) | |||
| _check_axis_in_range(axes, ndim) | |||
| return (axes % ndim,) | |||
| if isinstance(axes, tuple): | |||
| if isinstance(axes, (tuple, list)): | |||
| for axis in axes: | |||
| _ = _check_axis_in_range(axis, ndim) | |||
| _check_axis_in_range(axis, ndim) | |||
| axes = tuple(map(lambda x: x % ndim, axes)) | |||
| if all(axes.count(el) <= 1 for el in axes): | |||
| return axes | |||
| if axes is None: | |||
| axes = F.make_range(ndim) | |||
| return axes | |||
| raise ValueError('duplicate value in \'axis\'') | |||
| raise ValueError('duplicate value in "axis"') | |||
| @constexpr | |||
| def _check_shape_aligned(shape1, shape2): | |||
| """Checks shape1 and shape2 are valid shapes to perform inner product""" | |||
| if shape1[-1] == shape2[-1]: | |||
| return True | |||
| raise ValueError(f'shapes {shape1} {shape2} not aligned: {shape1[-1]} (dim 0) != {shape2[-1]} (dim 0)') | |||
| if shape1[-1] != shape2[-1]: | |||
| raise ValueError(f'shapes {shape1} {shape2} not aligned: {shape1[-1]} (dim 0) != {shape2[-1]} (dim 0)') | |||
| @constexpr | |||
| @@ -197,30 +197,6 @@ def _tile_size(shape, out_shape, ndim): | |||
| return tuple(size) | |||
| @constexpr | |||
| def _check_is_int(obj): | |||
| """Check whether obj is an integer.""" | |||
| return isinstance(obj, int) | |||
| @constexpr | |||
| def _check_is_tuple(obj): | |||
| """Check whether obj is a tuple""" | |||
| return isinstance(obj, (tuple, Tuple)) | |||
| @constexpr | |||
| def _check_is_list(obj): | |||
| """Check whether obj is a list""" | |||
| return isinstance(obj, (list, List)) | |||
| @constexpr | |||
| def _check_is_tensor(obj): | |||
| """Check whether obj is a tensor""" | |||
| return isinstance(obj, mstype.tensor_type) | |||
| @constexpr | |||
| def _raise_type_error(info, param=None): | |||
| """ | |||
| @@ -298,6 +274,177 @@ def _check_is_float(dtype): | |||
| @constexpr | |||
| def _check_input_tensor(input_type): | |||
| if not _check_is_tensor(input_type): | |||
| raise TypeError(f'expect Tensor, but got {input_type}') | |||
| def _check_is_int(dtype): | |||
| return isinstance(dtype, typing.Int) | |||
| @constexpr | |||
| def _check_matmul_shapes(shape1, shape2): | |||
| """Checks shape1 and shape2 are valid shapes to perform matmul""" | |||
| ndim1, ndim2 = len(shape1), len(shape2) | |||
| if ndim1 < 1 or ndim2 < 1: | |||
| raise ValueError('input operands must have at least 1 dimension') | |||
| if ndim2 >= 2 and shape1[-1] != shape2[-2]: | |||
| raise ValueError(f'mismatch in core dimension of input operands (size ' | |||
| f'{shape1[-1]} is different from {shape2[-2]})') | |||
| @constexpr | |||
| def _check_axis_type(axis, type_int=True, type_tuple=True, type_list=True): | |||
| """Check axis argument type.""" | |||
| if type_int and isinstance(axis, int): | |||
| return True | |||
| if (type_tuple and isinstance(axis, tuple)) or (type_list and isinstance(axis, list)): | |||
| for ax in axis: | |||
| if not isinstance(ax, int): | |||
| raise TypeError(f"Each axis should be integer, but got {type(ax)} in {axis}.") | |||
| return True | |||
| type_str = "" | |||
| if type_int: type_str += "int, " | |||
| if type_tuple: type_str += "tuple, " | |||
| if type_list: type_str += "list, " | |||
| raise TypeError(f"Axis should be {type_str}but got {type(axis)}.") | |||
| @constexpr | |||
| def _canonicalize_axis(axis, ndim): | |||
| """ | |||
| Check axes are within the number of dimensions of tensor x and normalize the negative axes. | |||
| Args: | |||
| axis (Union[int, tuple(int), list(int)]): Axes of the tensor. | |||
| ndim (int): The number of dimensions of the tensor. | |||
| Return: | |||
| Axis (Union[int, tuple(int)]). If input is integer, return integer, else tuple. | |||
| """ | |||
| if isinstance(axis, int): | |||
| axis = [axis] | |||
| for ax in axis: | |||
| _check_axis_in_range(ax, ndim) | |||
| def canonicalizer(ax): | |||
| return ax + ndim if ax < 0 else ax | |||
| axis = tuple([canonicalizer(axis) for axis in axis]) | |||
| if all(axis.count(el) <= 1 for el in axis): | |||
| return axis if len(axis) > 1 else axis[0] | |||
| raise ValueError(f"duplicate axes in {axis}.") | |||
| @constexpr | |||
| def _broadcast_tuples(tup1, tup2): | |||
| """ | |||
| Broadcast two 1D tuples to the same length, if inputs are ints, convert to | |||
| tuples first. | |||
| """ | |||
| tup1 = (tup1,) if isinstance(tup1, int) else tup1 | |||
| tup2 = (tup2,) if isinstance(tup2, int) else tup2 | |||
| if not isinstance(tup1, (tuple, list)) or not isinstance(tup2, (tuple, list)): | |||
| raise TypeError("input shift and axis must be tuple or list or int.") | |||
| if len(tup1) == len(tup2): | |||
| return tup1, tup2 | |||
| if len(tup1) == 1: | |||
| tup1 *= len(tup2) | |||
| elif len(tup2) == 1: | |||
| tup2 *= len(tup1) | |||
| else: | |||
| raise ValueError("shape mismatch: objects cannot be broadcast to a single shape") | |||
| return tup1, tup2 | |||
| @constexpr | |||
| def _expanded_shape(ndim, axis_size, axis): | |||
| """ | |||
| Returns a shape with size = 1 for all dimensions | |||
| except at axis. | |||
| """ | |||
| return tuple([axis_size if i == axis else 1 for i in range(ndim)]) | |||
| @constexpr | |||
| def _add_unit_axes(shape, ndim, append=False): | |||
| """ | |||
| Prepends shape with 1s so that it has the number of dimensions ndim. | |||
| If append is set to True, returns shape appended with 1s instead. | |||
| """ | |||
| if isinstance(shape, int): | |||
| shape = (shape,) | |||
| ndim_diff = ndim - len(shape) | |||
| if ndim_diff > 0: | |||
| if append: | |||
| shape = [i for i in shape] + [1]*ndim_diff | |||
| else: | |||
| shape = [1]*ndim_diff + [i for i in shape] | |||
| return tuple(shape) | |||
| @constexpr | |||
| def _check_element_int(lst): | |||
| """ | |||
| Check whether each element in `lst` is an integer. | |||
| """ | |||
| for item in lst: | |||
| if not isinstance(item, int): | |||
| raise TypeError(f"Each element in {lst} should be integer, but got {type(item)}.") | |||
| return True | |||
| @constexpr | |||
| def _type_convert(force, obj): | |||
| """ | |||
| Convert type of `obj` to `force`. | |||
| """ | |||
| return force(obj) | |||
| @constexpr | |||
| def _list_comprehensions(obj, item=None, return_tuple=False): | |||
| """ | |||
| Generates a new list/tuple by list comprehension. | |||
| Args: | |||
| obj (Union[int, list, tuple]): | |||
| If integer, it will be the length of the returned tuple/list. | |||
| item: The value to be filled. Default: None. | |||
| If None, the values in the new list/tuple are the same as obj | |||
| or range(obj) when obj is integer. | |||
| return_tuple(bool): If true, returns tuple, else returns list. | |||
| Returns: | |||
| List or tuple. | |||
| """ | |||
| res = [] | |||
| lst = obj | |||
| if isinstance(obj, int): | |||
| lst = range(obj) | |||
| if item is None: | |||
| res = [i for i in lst] | |||
| else: | |||
| res = [item for i in lst] | |||
| if return_tuple: | |||
| return tuple(res) | |||
| return res | |||
| @constexpr | |||
| def _tuple_getitem(tup, idx, startswith=True): | |||
| """ | |||
| Returns a slice from tup starting with idx. If startswith is False, | |||
| returns a lice from tup ending with idx instead. | |||
| """ | |||
| if startswith: | |||
| return tup[idx:] | |||
| return tup[:idx] | |||
| @constexpr | |||
| def _iota(dtype, num): | |||
| """Creates a 1-D tensor with value: [0,1,...num-1] and dtype.""" | |||
| # TODO: Change to P.Linspace when the kernel is implemented on CPU. | |||
| return Tensor(list(range(int(num))), dtype) | |||
| @constexpr | |||
| def _ceil(number): | |||
| """Ceils the number in graph mode.""" | |||
| return math.ceil(number) | |||
| @@ -59,18 +59,25 @@ tensor_div = P.RealDiv() | |||
| tensor_floordiv = P.FloorDiv() | |||
| tensor_pow = P.Pow() | |||
| tensor_mod = P.FloorMod() | |||
| tensor_exp = P.Exp() | |||
| tensor_expm1 = P.Expm1() | |||
| strided_slice = P.StridedSlice() | |||
| same_type_shape = P.SameTypeShape() | |||
| check_bprop = P.CheckBprop() | |||
| equal = P.Equal() | |||
| not_equal = P.NotEqual() | |||
| isfinite = P.IsFinite() | |||
| assign_sub = P.AssignSub() | |||
| assign_add = P.AssignAdd() | |||
| assign = P.Assign() | |||
| square = P.Square() | |||
| sqrt = P.Sqrt() | |||
| log = P.Log() | |||
| reduce_sum = P.ReduceSum() | |||
| tensor_slice = P.Slice() | |||
| maximum = P.Maximum() | |||
| minimum = P.Minimum() | |||
| floor = P.Floor() | |||
| scalar_to_array = P.ScalarToArray() | |||
| scalar_to_tensor = P.ScalarToTensor() | |||
| @@ -82,6 +89,7 @@ transpose = P.Transpose() | |||
| squeeze = P.Squeeze() | |||
| scatter_nd = P.ScatterNd() | |||
| gather = P.Gather() | |||
| gather_d = P.GatherD() | |||
| gather_nd = P.GatherNd() | |||
| scatter_update = P.ScatterUpdate() | |||
| scatter_nd_update = P.ScatterNdUpdate() | |||
| @@ -14,15 +14,13 @@ | |||
| # ============================================================================ | |||
| """unit tests for numpy array operations""" | |||
| import functools | |||
| import pytest | |||
| import numpy as onp | |||
| import mindspore.context as context | |||
| import mindspore.numpy as mnp | |||
| context.set_context(mode=context.GRAPH_MODE, device_target='CPU') | |||
| from .utils import rand_int, rand_bool, match_array, match_res, match_meta, \ | |||
| match_all_arrays | |||
| class Cases(): | |||
| @@ -97,10 +95,10 @@ class Cases(): | |||
| self.mnp_prototypes = [ | |||
| mnp.ones((2, 3, 4)), | |||
| mnp.ones((0, 3, 0, 2, 5)), | |||
| onp.ones((2, 7, 0)), | |||
| onp.ones(()), | |||
| [mnp.ones(3), (1, 2, 3), onp.ones(3), [4, 5, 6]], | |||
| ([(1, 2), mnp.ones(2)], (onp.ones(2), [3, 4])), | |||
| mnp.ones((2, 7, 0)), | |||
| mnp.ones(()), | |||
| [mnp.ones(3), (1, 2, 3), mnp.ones(3), [4, 5, 6]], | |||
| ([(1, 2), mnp.ones(2)], (mnp.ones(2), [3, 4])), | |||
| ] | |||
| self.onp_prototypes = [ | |||
| @@ -113,97 +111,6 @@ class Cases(): | |||
| ] | |||
| def match_array(actual, expected, error=0): | |||
| if error > 0: | |||
| onp.testing.assert_almost_equal(actual.tolist(), expected.tolist(), | |||
| decimal=error) | |||
| else: | |||
| onp.testing.assert_equal(actual.tolist(), expected.tolist()) | |||
| def check_all_results(onp_results, mnp_results, error=0): | |||
| """Check all results from numpy and mindspore.numpy""" | |||
| for i, _ in enumerate(onp_results): | |||
| match_array(onp_results[i], mnp_results[i].asnumpy()) | |||
| def check_all_unique_results(onp_results, mnp_results): | |||
| """ | |||
| Check all results from numpy and mindspore.numpy. | |||
| Args: | |||
| onp_results (Union[tuple of numpy.arrays, numpy.array]) | |||
| mnp_results (Union[tuple of Tensors, Tensor]) | |||
| """ | |||
| for i, _ in enumerate(onp_results): | |||
| if isinstance(onp_results[i], tuple): | |||
| for j in range(len(onp_results[i])): | |||
| match_array(onp_results[i][j], | |||
| mnp_results[i][j].asnumpy(), error=7) | |||
| else: | |||
| match_array(onp_results[i], mnp_results[i].asnumpy(), error=7) | |||
| def run_non_kw_test(mnp_fn, onp_fn): | |||
| """Run tests on functions with non keyword arguments""" | |||
| test_case = Cases() | |||
| for i in range(len(test_case.arrs)): | |||
| arrs = test_case.arrs[:i] | |||
| match_res(mnp_fn, onp_fn, *arrs) | |||
| for i in range(len(test_case.scalars)): | |||
| arrs = test_case.scalars[:i] | |||
| match_res(mnp_fn, onp_fn, *arrs) | |||
| for i in range(len(test_case.expanded_arrs)): | |||
| arrs = test_case.expanded_arrs[:i] | |||
| match_res(mnp_fn, onp_fn, *arrs) | |||
| for i in range(len(test_case.nested_arrs)): | |||
| arrs = test_case.nested_arrs[:i] | |||
| match_res(mnp_fn, onp_fn, *arrs) | |||
| def rand_int(*shape): | |||
| """return an random integer array with parameter shape""" | |||
| res = onp.random.randint(low=1, high=5, size=shape) | |||
| if isinstance(res, onp.ndarray): | |||
| return res.astype(onp.float32) | |||
| return float(res) | |||
| # return an random boolean array | |||
| def rand_bool(*shape): | |||
| return onp.random.rand(*shape) > 0.5 | |||
| def match_res(mnp_fn, onp_fn, *arrs, **kwargs): | |||
| """Checks results from applying mnp_fn and onp_fn on arrs respectively""" | |||
| mnp_arrs = map(functools.partial(mnp.asarray, dtype='float32'), arrs) | |||
| mnp_res = mnp_fn(*mnp_arrs, **kwargs) | |||
| onp_res = onp_fn(*arrs, **kwargs) | |||
| match_all_arrays(mnp_res, onp_res) | |||
| def match_all_arrays(mnp_res, onp_res, error=0): | |||
| if isinstance(mnp_res, (tuple, list)): | |||
| for actual, expected in zip(mnp_res, onp_res): | |||
| match_array(actual.asnumpy(), expected, error) | |||
| else: | |||
| match_array(mnp_res.asnumpy(), onp_res, error) | |||
| def match_meta(actual, expected): | |||
| # float64 and int64 are not supported, and the default type for | |||
| # float and int are float32 and int32, respectively | |||
| if expected.dtype == onp.float64: | |||
| expected = expected.astype(onp.float32) | |||
| elif expected.dtype == onp.int64: | |||
| expected = expected.astype(onp.int32) | |||
| assert actual.shape == expected.shape | |||
| assert actual.dtype == expected.dtype | |||
| @pytest.mark.level1 | |||
| @pytest.mark.platform_arm_ascend_training | |||
| @pytest.mark.platform_x86_ascend_training | |||
| @@ -440,27 +347,50 @@ def test_arange(): | |||
| def test_linspace(): | |||
| actual = onp.linspace(2.0, 3.0, dtype=onp.float32) | |||
| expected = mnp.linspace(2.0, 3.0).asnumpy() | |||
| match_array(actual, expected, error=7) | |||
| match_array(actual, expected, error=6) | |||
| actual = onp.linspace(2.0, 3.0, num=5, dtype=onp.float32) | |||
| expected = mnp.linspace(2.0, 3.0, num=5).asnumpy() | |||
| match_array(actual, expected, error=7) | |||
| match_array(actual, expected, error=6) | |||
| actual = onp.linspace( | |||
| 2.0, 3.0, num=5, endpoint=False, dtype=onp.float32) | |||
| expected = mnp.linspace(2.0, 3.0, num=5, endpoint=False).asnumpy() | |||
| match_array(actual, expected, error=7) | |||
| match_array(actual, expected, error=6) | |||
| actual = onp.linspace(2.0, 3.0, num=5, retstep=True, dtype=onp.float32) | |||
| expected = mnp.linspace(2.0, 3.0, num=5, retstep=True) | |||
| match_array(actual[0], expected[0].asnumpy()) | |||
| assert actual[1] == expected[1] | |||
| assert actual[1] == expected[1].asnumpy() | |||
| actual = onp.linspace(2.0, [3, 4, 5], num=5, | |||
| endpoint=False, dtype=onp.float32) | |||
| expected = mnp.linspace( | |||
| 2.0, [3, 4, 5], num=5, endpoint=False).asnumpy() | |||
| match_array(actual, expected) | |||
| match_array(actual, expected, error=6) | |||
| start = onp.random.random([2, 1, 4]) | |||
| stop = onp.random.random([1, 5, 1]) | |||
| actual = onp.linspace(start, stop, num=20, retstep=True, | |||
| endpoint=False, dtype=onp.float32) | |||
| expected = mnp.linspace(mnp.asarray(start), mnp.asarray(stop), num=20, | |||
| retstep=True, endpoint=False) | |||
| match_array(actual[0], expected[0].asnumpy(), error=6) | |||
| match_array(actual[1], expected[1].asnumpy(), error=6) | |||
| actual = onp.linspace(start, stop, num=20, retstep=True, | |||
| endpoint=False, dtype=onp.int16) | |||
| expected = mnp.linspace(mnp.asarray(start), mnp.asarray(stop), num=20, | |||
| retstep=True, endpoint=False, dtype=mnp.int16) | |||
| match_array(actual[0], expected[0].asnumpy(), error=6) | |||
| match_array(actual[1], expected[1].asnumpy(), error=6) | |||
| for axis in range(2): | |||
| actual = onp.linspace(start, stop, num=20, retstep=False, | |||
| endpoint=False, dtype=onp.float32, axis=axis) | |||
| expected = mnp.linspace(mnp.asarray(start), mnp.asarray(stop), num=20, | |||
| retstep=False, endpoint=False, dtype=mnp.float32, axis=axis) | |||
| match_array(actual, expected.asnumpy(), error=6) | |||
| @pytest.mark.level1 | |||
| @@ -472,22 +402,22 @@ def test_linspace(): | |||
| def test_logspace(): | |||
| actual = onp.logspace(2.0, 3.0, dtype=onp.float32) | |||
| expected = mnp.logspace(2.0, 3.0).asnumpy() | |||
| match_array(actual, expected) | |||
| match_array(actual, expected, error=3) | |||
| actual = onp.logspace(2.0, 3.0, num=5, dtype=onp.float32) | |||
| expected = mnp.logspace(2.0, 3.0, num=5).asnumpy() | |||
| match_array(actual, expected) | |||
| match_array(actual, expected, error=3) | |||
| actual = onp.logspace( | |||
| 2.0, 3.0, num=5, endpoint=False, dtype=onp.float32) | |||
| expected = mnp.logspace(2.0, 3.0, num=5, endpoint=False).asnumpy() | |||
| match_array(actual, expected) | |||
| match_array(actual, expected, error=3) | |||
| actual = onp.logspace(2.0, [3, 4, 5], num=5, | |||
| actual = onp.logspace(2.0, [3, 4, 5], num=5, base=2, | |||
| endpoint=False, dtype=onp.float32) | |||
| expected = mnp.logspace( | |||
| 2.0, [3, 4, 5], num=5, endpoint=False).asnumpy() | |||
| match_array(actual, expected) | |||
| 2.0, [3, 4, 5], num=5, base=2, endpoint=False).asnumpy() | |||
| match_array(actual, expected, error=3) | |||
| @pytest.mark.level1 | |||
| @@ -537,7 +467,6 @@ def run_x_like(mnp_fn, onp_fn): | |||
| actual = mnp_fn(mnp_proto, shape=shape).asnumpy() | |||
| expected = onp_fn(onp_proto, shape=shape) | |||
| match_array(actual, expected) | |||
| for mnp_dtype, onp_dtype in zip(test_case.mnp_dtypes, | |||
| test_case.onp_dtypes): | |||
| actual = mnp_fn(mnp_proto, dtype=mnp_dtype).asnumpy() | |||
| @@ -581,18 +510,18 @@ def test_full_like(): | |||
| for mnp_proto, onp_proto in zip(test_case.mnp_prototypes, test_case.onp_prototypes): | |||
| shape = onp.zeros_like(onp_proto).shape | |||
| fill_value = rand_int() | |||
| actual = mnp.full_like(mnp_proto, fill_value).asnumpy() | |||
| actual = mnp.full_like(mnp_proto, mnp.array(fill_value)).asnumpy() | |||
| expected = onp.full_like(onp_proto, fill_value) | |||
| match_array(actual, expected) | |||
| for i in range(len(shape) - 1, 0, -1): | |||
| fill_value = rand_int(*shape[i:]) | |||
| actual = mnp.full_like(mnp_proto, fill_value).asnumpy() | |||
| actual = mnp.full_like(mnp_proto, mnp.array(fill_value)).asnumpy() | |||
| expected = onp.full_like(onp_proto, fill_value) | |||
| match_array(actual, expected) | |||
| fill_value = rand_int(1, *shape[i + 1:]) | |||
| actual = mnp.full_like(mnp_proto, fill_value).asnumpy() | |||
| actual = mnp.full_like(mnp_proto, mnp.array(fill_value)).asnumpy() | |||
| expected = onp.full_like(onp_proto, fill_value) | |||
| match_array(actual, expected) | |||
| @@ -620,6 +549,26 @@ def test_tri_triu_tril(): | |||
| match_array(mnp.tri(64, 64, -10).asnumpy(), onp.tri(64, 64, -10)) | |||
| @pytest.mark.level1 | |||
| @pytest.mark.platform_arm_ascend_training | |||
| @pytest.mark.platform_x86_ascend_training | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.platform_x86_cpu | |||
| @pytest.mark.env_onecard | |||
| def test_cumsum(): | |||
| x = mnp.ones((16, 16), dtype="bool") | |||
| match_array(mnp.cumsum(x).asnumpy(), onp.cumsum(x.asnumpy())) | |||
| match_array(mnp.cumsum(x, axis=0).asnumpy(), | |||
| onp.cumsum(x.asnumpy(), axis=0)) | |||
| match_meta(mnp.cumsum(x).asnumpy(), onp.cumsum(x.asnumpy())) | |||
| x = rand_int(3, 4, 5) | |||
| match_array(mnp.cumsum(mnp.asarray(x), dtype="bool").asnumpy(), | |||
| onp.cumsum(x, dtype="bool")) | |||
| match_array(mnp.cumsum(mnp.asarray(x), axis=-1).asnumpy(), | |||
| onp.cumsum(x, axis=-1)) | |||
| def mnp_diagonal(arr): | |||
| return mnp.diagonal(arr, offset=2, axis1=-1, axis2=0) | |||
| @@ -697,6 +646,138 @@ def test_trace(): | |||
| match_res(mnp.trace, onp.trace, arr, offset=i, axis1=2, axis2=-1) | |||
| def mnp_meshgrid(*xi): | |||
| a = mnp.meshgrid(*xi) | |||
| b = mnp.meshgrid(*xi, sparse=True) | |||
| c = mnp.meshgrid(*xi, indexing='ij') | |||
| d = mnp.meshgrid(*xi, sparse=False, indexing='ij') | |||
| return a, b, c, d | |||
| def onp_meshgrid(*xi): | |||
| a = onp.meshgrid(*xi) | |||
| b = onp.meshgrid(*xi, sparse=True) | |||
| c = onp.meshgrid(*xi, indexing='ij') | |||
| d = onp.meshgrid(*xi, sparse=False, indexing='ij') | |||
| return a, b, c, d | |||
| @pytest.mark.level1 | |||
| @pytest.mark.platform_arm_ascend_training | |||
| @pytest.mark.platform_x86_ascend_training | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.platform_x86_cpu | |||
| @pytest.mark.env_onecard | |||
| def test_meshgrid(): | |||
| xi = (onp.full(3, 2), onp.full(1, 5), onp.full( | |||
| (2, 3), 9), onp.full((4, 5, 6), 7)) | |||
| for i in range(len(xi)): | |||
| arrs = xi[i:] | |||
| mnp_arrs = map(mnp.asarray, arrs) | |||
| for mnp_res, onp_res in zip(mnp_meshgrid(*mnp_arrs), onp_meshgrid(*arrs)): | |||
| match_all_arrays(mnp_res, onp_res) | |||
| @pytest.mark.level1 | |||
| @pytest.mark.platform_arm_ascend_training | |||
| @pytest.mark.platform_x86_ascend_training | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.platform_x86_cpu | |||
| @pytest.mark.env_onecard | |||
| def test_mgrid(): | |||
| mnp_res = mnp.mgrid[0:5] | |||
| onp_res = onp.mgrid[0:5] | |||
| match_all_arrays(mnp_res, onp_res, error=5) | |||
| mnp_res = mnp.mgrid[2:30:4j, -10:20:7, 2:5:0.5] | |||
| onp_res = onp.mgrid[2:30:4j, -10:20:7, 2:5:0.5] | |||
| match_all_arrays(mnp_res, onp_res, error=5) | |||
| @pytest.mark.level1 | |||
| @pytest.mark.platform_arm_ascend_training | |||
| @pytest.mark.platform_x86_ascend_training | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.platform_x86_cpu | |||
| @pytest.mark.env_onecard | |||
| def test_ogrid(): | |||
| mnp_res = mnp.ogrid[0:5] | |||
| onp_res = onp.ogrid[0:5] | |||
| match_all_arrays(mnp_res, onp_res, error=5) | |||
| mnp_res = mnp.ogrid[2:30:4j, -10:20:7, 2:5:0.5] | |||
| onp_res = onp.ogrid[2:30:4j, -10:20:7, 2:5:0.5] | |||
| match_all_arrays(mnp_res, onp_res, error=5) | |||
| @pytest.mark.level1 | |||
| @pytest.mark.platform_arm_ascend_training | |||
| @pytest.mark.platform_x86_ascend_training | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.platform_x86_cpu | |||
| @pytest.mark.env_onecard | |||
| def test_diagflat(): | |||
| arrs = [rand_int(0), rand_int(2, 3), rand_int(3, 5, 0)] | |||
| for arr in arrs: | |||
| for i in [-2, 0, 7]: | |||
| match_res(mnp.diagflat, onp.diagflat, arr, k=i) | |||
| @pytest.mark.level1 | |||
| @pytest.mark.platform_arm_ascend_training | |||
| @pytest.mark.platform_x86_ascend_training | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.platform_x86_cpu | |||
| @pytest.mark.env_onecard | |||
| def test_diag(): | |||
| arrs = [rand_int(0), rand_int(0, 0), rand_int(7), rand_int(5, 5), | |||
| rand_int(3, 8), rand_int(9, 6)] | |||
| for arr in arrs: | |||
| for i in [-10, -5, -1, 0, 2, 5, 6, 10]: | |||
| match_res(mnp.diag, onp.diag, arr, k=i) | |||
| @pytest.mark.level1 | |||
| @pytest.mark.platform_arm_ascend_training | |||
| @pytest.mark.platform_x86_ascend_training | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.platform_x86_cpu | |||
| @pytest.mark.env_onecard | |||
| def test_diag_indices(): | |||
| mnp_res = mnp.diag_indices(0) | |||
| onp_res = onp.diag_indices(0) | |||
| match_all_arrays(mnp_res, onp_res) | |||
| mnp_res = mnp.diag_indices(3, 0) | |||
| onp_res = onp.diag_indices(3, 0) | |||
| match_all_arrays(mnp_res, onp_res) | |||
| mnp_res = mnp.diag_indices(5, 7) | |||
| onp_res = onp.diag_indices(5, 7) | |||
| match_all_arrays(mnp_res, onp_res) | |||
| def mnp_ix_(*args): | |||
| return mnp.ix_(*args) | |||
| def onp_ix_(*args): | |||
| return onp.ix_(*args) | |||
| @pytest.mark.level1 | |||
| @pytest.mark.platform_arm_ascend_training | |||
| @pytest.mark.platform_x86_ascend_training | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.platform_x86_cpu | |||
| @pytest.mark.env_onecard | |||
| def test_ix_(): | |||
| arrs = [rand_int(i + 1) for i in range(10)] | |||
| for i in range(10): | |||
| test_arrs = arrs[:i + 1] | |||
| match_res(mnp_ix_, onp_ix_, *test_arrs) | |||
| @pytest.mark.level1 | |||
| @pytest.mark.platform_arm_ascend_training | |||
| @pytest.mark.platform_x86_ascend_training | |||
| @@ -22,6 +22,9 @@ import numpy as onp | |||
| import mindspore.numpy as mnp | |||
| from mindspore.nn import Cell | |||
| from .utils import rand_int, run_non_kw_test, check_all_results, match_array, \ | |||
| rand_bool, match_res, run_multi_test | |||
| class Cases(): | |||
| def __init__(self): | |||
| @@ -111,81 +114,6 @@ class Cases(): | |||
| ] | |||
| def match_array(actual, expected, error=0): | |||
| if error > 0: | |||
| onp.testing.assert_almost_equal(actual.tolist(), expected.tolist(), | |||
| decimal=error) | |||
| else: | |||
| onp.testing.assert_equal(actual.tolist(), expected.tolist()) | |||
| def check_all_results(onp_results, mnp_results, error=0): | |||
| """Check all results from numpy and mindspore.numpy""" | |||
| for i, _ in enumerate(onp_results): | |||
| match_array(onp_results[i], mnp_results[i].asnumpy()) | |||
| def run_non_kw_test(mnp_fn, onp_fn): | |||
| """Run tests on functions with non keyword arguments""" | |||
| test_case = Cases() | |||
| for i in range(len(test_case.arrs)): | |||
| arrs = test_case.arrs[:i] | |||
| match_res(mnp_fn, onp_fn, *arrs) | |||
| for i in range(len(test_case.scalars)): | |||
| arrs = test_case.scalars[:i] | |||
| match_res(mnp_fn, onp_fn, *arrs) | |||
| for i in range(len(test_case.expanded_arrs)): | |||
| arrs = test_case.expanded_arrs[:i] | |||
| match_res(mnp_fn, onp_fn, *arrs) | |||
| for i in range(len(test_case.nested_arrs)): | |||
| arrs = test_case.nested_arrs[:i] | |||
| match_res(mnp_fn, onp_fn, *arrs) | |||
| def rand_int(*shape): | |||
| """return an random integer array with parameter shape""" | |||
| res = onp.random.randint(low=1, high=5, size=shape) | |||
| if isinstance(res, onp.ndarray): | |||
| return res.astype(onp.float32) | |||
| return float(res) | |||
| # return an random boolean array | |||
| def rand_bool(*shape): | |||
| return onp.random.rand(*shape) > 0.5 | |||
| def match_res(mnp_fn, onp_fn, *arrs, **kwargs): | |||
| """Checks results from applying mnp_fn and onp_fn on arrs respectively""" | |||
| mnp_arrs = map(functools.partial(mnp.asarray, dtype='float32'), arrs) | |||
| mnp_res = mnp_fn(*mnp_arrs, **kwargs) | |||
| onp_res = onp_fn(*arrs, **kwargs) | |||
| match_all_arrays(mnp_res, onp_res) | |||
| def match_all_arrays(mnp_res, onp_res, error=0): | |||
| if isinstance(mnp_res, (tuple, list)): | |||
| assert len(mnp_res) == len(onp_res) | |||
| for actual, expected in zip(mnp_res, onp_res): | |||
| match_array(actual.asnumpy(), expected, error) | |||
| else: | |||
| match_array(mnp_res.asnumpy(), onp_res, error) | |||
| def match_meta(actual, expected): | |||
| # float64 and int64 are not supported, and the default type for | |||
| # float and int are float32 and int32, respectively | |||
| if expected.dtype == onp.float64: | |||
| expected = expected.astype(onp.float32) | |||
| elif expected.dtype == onp.int64: | |||
| expected = expected.astype(onp.int32) | |||
| assert actual.shape == expected.shape | |||
| assert actual.dtype == expected.dtype | |||
| # Test np.transpose and np.ndarray.transpose | |||
| def mnp_transpose(input_tensor): | |||
| a = mnp.transpose(input_tensor, (0, 2, 1)) | |||
| @@ -458,6 +386,34 @@ def test_concatenate(): | |||
| check_all_results(o_concatenate, m_concatenate) | |||
| def mnp_append(arr1, arr2): | |||
| a = mnp.append(arr1, arr2) | |||
| b = mnp.append(arr1, arr2, axis=0) | |||
| c = mnp.append(arr1, arr2, axis=-1) | |||
| return a, b, c | |||
| def onp_append(arr1, arr2): | |||
| a = onp.append(arr1, arr2) | |||
| b = onp.append(arr1, arr2, axis=0) | |||
| c = onp.append(arr1, arr2, axis=-1) | |||
| return a, b, c | |||
| @pytest.mark.level1 | |||
| @pytest.mark.platform_arm_ascend_training | |||
| @pytest.mark.platform_x86_ascend_training | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.platform_x86_cpu | |||
| @pytest.mark.env_onecard | |||
| def test_append(): | |||
| onp_array = onp.random.random((4, 3, 2)).astype('float32') | |||
| onp_value = onp.random.random((4, 3, 2)).astype('float32') | |||
| mnp_array = mnp.asarray(onp_array) | |||
| mnp_value = mnp.asarray(onp_value) | |||
| onp_res = onp_append(onp_array, onp_value) | |||
| mnp_res = mnp_append(mnp_array, mnp_value) | |||
| check_all_results(onp_res, mnp_res) | |||
| def construct_arrays(n=1, ndim=1, axis=None, low=1, high=5): | |||
| onp_array_lst = [] | |||
| mnp_array_lst = [] | |||
| @@ -629,7 +585,7 @@ def onp_atleast3d(*arys): | |||
| @pytest.mark.platform_x86_cpu | |||
| @pytest.mark.env_onecard | |||
| def test_atleast1d(): | |||
| run_non_kw_test(mnp_atleast1d, onp_atleast1d) | |||
| run_non_kw_test(mnp_atleast1d, onp_atleast1d, Cases()) | |||
| @pytest.mark.level1 | |||
| @@ -639,7 +595,7 @@ def test_atleast1d(): | |||
| @pytest.mark.platform_x86_cpu | |||
| @pytest.mark.env_onecard | |||
| def test_atleast2d(): | |||
| run_non_kw_test(mnp_atleast2d, onp_atleast2d) | |||
| run_non_kw_test(mnp_atleast2d, onp_atleast2d, Cases()) | |||
| @pytest.mark.level1 | |||
| @@ -649,7 +605,7 @@ def test_atleast2d(): | |||
| @pytest.mark.platform_x86_cpu | |||
| @pytest.mark.env_onecard | |||
| def test_atleast3d(): | |||
| run_non_kw_test(mnp_atleast3d, onp_atleast3d) | |||
| run_non_kw_test(mnp_atleast3d, onp_atleast3d, Cases()) | |||
| # Test np.where | |||
| @@ -858,6 +814,444 @@ def test_stack(): | |||
| match_res(mnp.stack, onp.stack, arrs, axis=i) | |||
| def mnp_roll(input_tensor): | |||
| a = mnp.roll(input_tensor, -3) | |||
| b = mnp.roll(input_tensor, [-2, -3], 1) | |||
| c = mnp.roll(input_tensor, (3, 0, -5), (-1, -2, 0)) | |||
| d = mnp.roll(input_tensor, (4,), [0, 0, 1]) | |||
| return a, b, c, d | |||
| def onp_roll(input_array): | |||
| a = onp.roll(input_array, -3) | |||
| b = onp.roll(input_array, [-2, -3], 1) | |||
| c = onp.roll(input_array, (3, 0, -5), (-1, -2, 0)) | |||
| d = onp.roll(input_array, (4,), [0, 0, 1]) | |||
| return a, b, c, d | |||
| @pytest.mark.level1 | |||
| @pytest.mark.platform_arm_ascend_training | |||
| @pytest.mark.platform_x86_ascend_training | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.platform_x86_cpu | |||
| @pytest.mark.env_onecard | |||
| def test_roll(): | |||
| arr = rand_int(3, 4, 5) | |||
| match_res(mnp_roll, onp_roll, arr) | |||
| arr = rand_int(1, 4, 6).astype("int64") | |||
| match_res(mnp_roll, onp_roll, arr) | |||
| def mnp_moveaxis(a): | |||
| a = mnp.moveaxis(a, 3, 3) | |||
| b = mnp.moveaxis(a, -1, 4) | |||
| c = mnp.moveaxis(a, (2, 1, 4), (0, 3, 2)) | |||
| d = mnp.moveaxis(a, [-2, -5], [2, -4]) | |||
| return a, b, c, d | |||
| def onp_moveaxis(a): | |||
| a = onp.moveaxis(a, 3, 3) | |||
| b = onp.moveaxis(a, -1, 4) | |||
| c = onp.moveaxis(a, (2, 1, 4), (0, 3, 2)) | |||
| d = onp.moveaxis(a, [-2, -5], [2, -4]) | |||
| return a, b, c, d | |||
| @pytest.mark.level1 | |||
| @pytest.mark.platform_arm_ascend_training | |||
| @pytest.mark.platform_x86_ascend_training | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.platform_x86_cpu | |||
| @pytest.mark.env_onecard | |||
| def test_moveaxis(): | |||
| a = rand_int(2, 4, 5, 9, 6) | |||
| match_res(mnp_moveaxis, onp_moveaxis, a) | |||
| a = rand_int(2, 4, 5, 0, 6, 7, 1, 3, 8) | |||
| match_res(mnp_moveaxis, onp_moveaxis, a) | |||
| def mnp_tile(x): | |||
| a = mnp.tile(x, 0) | |||
| b = mnp.tile(x, 1) | |||
| c = mnp.tile(x, 3) | |||
| d = mnp.tile(x, [5, 1]) | |||
| e = mnp.tile(x, (3, 1, 0)) | |||
| f = mnp.tile(x, [5, 1, 2, 3, 7]) | |||
| return a, b, c, d, e, f | |||
| def onp_tile(x): | |||
| a = onp.tile(x, 0) | |||
| b = onp.tile(x, 1) | |||
| c = onp.tile(x, 3) | |||
| d = onp.tile(x, [5, 1]) | |||
| e = onp.tile(x, (3, 1, 0)) | |||
| f = onp.tile(x, [5, 1, 2, 3, 7]) | |||
| return a, b, c, d, e, f | |||
| @pytest.mark.level1 | |||
| @pytest.mark.platform_arm_ascend_training | |||
| @pytest.mark.platform_x86_ascend_training | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.platform_x86_cpu | |||
| @pytest.mark.env_onecard | |||
| def test_tile(): | |||
| a = rand_int(2, 3, 4) | |||
| match_res(mnp_tile, onp_tile, a) | |||
| b = rand_int(5, 0, 8) | |||
| match_res(mnp_tile, onp_tile, b) | |||
| def mnp_broadcast_to(x): | |||
| a = mnp.broadcast_to(x, (2, 3)) | |||
| b = mnp.broadcast_to(x, (8, 1, 3)) | |||
| return a, b | |||
| def onp_broadcast_to(x): | |||
| a = onp.broadcast_to(x, (2, 3)) | |||
| b = onp.broadcast_to(x, (8, 1, 3)) | |||
| return a, b | |||
| @pytest.mark.level1 | |||
| @pytest.mark.platform_arm_ascend_training | |||
| @pytest.mark.platform_x86_ascend_training | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.platform_x86_cpu | |||
| @pytest.mark.env_onecard | |||
| def test_broadcast_to(): | |||
| x = rand_int() | |||
| match_res(mnp_broadcast_to, onp_broadcast_to, x) | |||
| x = rand_int(3) | |||
| match_res(mnp_broadcast_to, onp_broadcast_to, x) | |||
| x = rand_int(1, 3) | |||
| match_res(mnp_broadcast_to, onp_broadcast_to, x) | |||
| def mnp_broadcast_arrays(*args): | |||
| return mnp.broadcast_arrays(*args) | |||
| def onp_broadcast_arrays(*args): | |||
| return onp.broadcast_arrays(*args) | |||
| @pytest.mark.level1 | |||
| @pytest.mark.platform_arm_ascend_training | |||
| @pytest.mark.platform_x86_ascend_training | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.platform_x86_cpu | |||
| @pytest.mark.env_onecard | |||
| def test_broadcast_arrays(): | |||
| test_case = Cases() | |||
| broadcastables = test_case.broadcastables | |||
| for i in range(len(broadcastables)): | |||
| arrs = broadcastables[i:] | |||
| match_res(mnp_broadcast_arrays, onp_broadcast_arrays, *arrs) | |||
| def mnp_flip(x): | |||
| a = mnp.flip(x) | |||
| b = mnp.flip(x, 0) | |||
| c = mnp.flip(x, 1) | |||
| d = mnp.flip(x, (-3, -1)) | |||
| return a, b, c, d | |||
| def onp_flip(x): | |||
| a = onp.flip(x) | |||
| b = onp.flip(x, 0) | |||
| c = onp.flip(x, 1) | |||
| d = onp.flip(x, (-3, -1)) | |||
| return a, b, c, d | |||
| @pytest.mark.level1 | |||
| @pytest.mark.platform_arm_ascend_training | |||
| @pytest.mark.platform_x86_ascend_training | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.platform_x86_cpu | |||
| @pytest.mark.env_onecard | |||
| def test_flip(): | |||
| x = rand_int(2, 3, 4) | |||
| run_multi_test(mnp_flip, onp_flip, (x,)) | |||
| def mnp_flipud(x): | |||
| return mnp.flipud(x) | |||
| def onp_flipud(x): | |||
| return onp.flipud(x) | |||
| @pytest.mark.level1 | |||
| @pytest.mark.platform_arm_ascend_training | |||
| @pytest.mark.platform_x86_ascend_training | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.platform_x86_cpu | |||
| @pytest.mark.env_onecard | |||
| def test_flipud(): | |||
| x = rand_int(2, 3, 4) | |||
| run_multi_test(mnp_flipud, onp_flipud, (x,)) | |||
| def mnp_fliplr(x): | |||
| return mnp.fliplr(x) | |||
| def onp_fliplr(x): | |||
| return onp.fliplr(x) | |||
| @pytest.mark.level1 | |||
| @pytest.mark.platform_arm_ascend_training | |||
| @pytest.mark.platform_x86_ascend_training | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.platform_x86_cpu | |||
| @pytest.mark.env_onecard | |||
| def test_fliplr(): | |||
| x = rand_int(2, 3, 4) | |||
| run_multi_test(mnp_fliplr, onp_fliplr, (x,)) | |||
| def mnp_split(input_tensor): | |||
| a = mnp.split(input_tensor, indices_or_sections=1) | |||
| b = mnp.split(input_tensor, indices_or_sections=3) | |||
| c = mnp.split(input_tensor, indices_or_sections=(-9, -8, 6)) | |||
| d = mnp.split(input_tensor, indices_or_sections=(3, 2, 1)) | |||
| e = mnp.split(input_tensor, indices_or_sections=(-10, -4, 5, 10)) | |||
| f = mnp.split(input_tensor, indices_or_sections=[0, 2], axis=1) | |||
| return a, b, c, d, e, f | |||
| def onp_split(input_array): | |||
| a = onp.split(input_array, indices_or_sections=1) | |||
| b = onp.split(input_array, indices_or_sections=3) | |||
| c = onp.split(input_array, indices_or_sections=(-9, -8, 6)) | |||
| d = onp.split(input_array, indices_or_sections=(3, 2, 1)) | |||
| e = onp.split(input_array, indices_or_sections=(-10, -4, 5, 10)) | |||
| f = onp.split(input_array, indices_or_sections=[0, 2], axis=1) | |||
| return a, b, c, d, e, f | |||
| @pytest.mark.level1 | |||
| @pytest.mark.platform_arm_ascend_training | |||
| @pytest.mark.platform_x86_ascend_training | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.platform_x86_cpu | |||
| @pytest.mark.env_onecard | |||
| def test_split(): | |||
| onp_arrs = [ | |||
| onp.random.randint(1, 5, size=(9, 4, 5)).astype('float32'), | |||
| onp.random.randint(1, 5, size=(9, 4, 5)).astype('float64') | |||
| ] | |||
| mnp_arrs = [mnp.asarray(arr) for arr in onp_arrs] | |||
| for onp_arr, mnp_arr in zip(onp_arrs, mnp_arrs): | |||
| o_split = onp_split(onp_arr) | |||
| m_split = mnp_split(mnp_arr) | |||
| for expect_lst, actual_lst in zip(o_split, m_split): | |||
| for expect, actual in zip(expect_lst, actual_lst): | |||
| match_array(expect, actual.asnumpy()) | |||
| def mnp_vsplit(input_tensor): | |||
| a = mnp.vsplit(input_tensor, indices_or_sections=3) | |||
| b = mnp.vsplit(input_tensor, indices_or_sections=(-10, -4, 5, 10)) | |||
| c = mnp.vsplit(input_tensor, indices_or_sections=[0, 2]) | |||
| return a, b, c | |||
| def onp_vsplit(input_array): | |||
| a = onp.vsplit(input_array, indices_or_sections=3) | |||
| b = onp.vsplit(input_array, indices_or_sections=(-10, -4, 5, 10)) | |||
| c = onp.vsplit(input_array, indices_or_sections=[0, 2]) | |||
| return a, b, c | |||
| @pytest.mark.level1 | |||
| @pytest.mark.platform_arm_ascend_training | |||
| @pytest.mark.platform_x86_ascend_training | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.platform_x86_cpu | |||
| @pytest.mark.env_onecard | |||
| def test_vsplit(): | |||
| onp_arrs = [ | |||
| onp.random.randint(1, 5, size=(9, 4, 5)).astype('float32'), | |||
| onp.random.randint(1, 5, size=(9, 4, 5)).astype('float64') | |||
| ] | |||
| mnp_arrs = [mnp.asarray(arr) for arr in onp_arrs] | |||
| for onp_arr, mnp_arr in zip(onp_arrs, mnp_arrs): | |||
| o_vsplit = onp_vsplit(onp_arr) | |||
| m_vsplit = mnp_vsplit(mnp_arr) | |||
| for expect_lst, actual_lst in zip(o_vsplit, m_vsplit): | |||
| for expect, actual in zip(expect_lst, actual_lst): | |||
| match_array(expect, actual.asnumpy()) | |||
| def mnp_hsplit(input_tensor): | |||
| a = mnp.hsplit(input_tensor, indices_or_sections=3) | |||
| b = mnp.hsplit(input_tensor, indices_or_sections=(-10, -4, 5, 10)) | |||
| c = mnp.hsplit(input_tensor, indices_or_sections=[0, 2]) | |||
| return a, b, c | |||
| def onp_hsplit(input_array): | |||
| a = onp.hsplit(input_array, indices_or_sections=3) | |||
| b = onp.hsplit(input_array, indices_or_sections=(-10, -4, 5, 10)) | |||
| c = onp.hsplit(input_array, indices_or_sections=[0, 2]) | |||
| return a, b, c | |||
| @pytest.mark.level1 | |||
| @pytest.mark.platform_arm_ascend_training | |||
| @pytest.mark.platform_x86_ascend_training | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.platform_x86_cpu | |||
| @pytest.mark.env_onecard | |||
| def test_hsplit(): | |||
| onp_arrs = [ | |||
| onp.random.randint(1, 5, size=(4, 9, 5)).astype('float32'), | |||
| onp.random.randint(1, 5, size=(4, 9, 5)).astype('float64') | |||
| ] | |||
| mnp_arrs = [mnp.asarray(arr) for arr in onp_arrs] | |||
| for onp_arr, mnp_arr in zip(onp_arrs, mnp_arrs): | |||
| o_hsplit = onp_hsplit(onp_arr) | |||
| m_hsplit = mnp_hsplit(mnp_arr) | |||
| for expect_lst, actual_lst in zip(o_hsplit, m_hsplit): | |||
| for expect, actual in zip(expect_lst, actual_lst): | |||
| match_array(expect, actual.asnumpy()) | |||
| def mnp_dsplit(input_tensor): | |||
| a = mnp.dsplit(input_tensor, indices_or_sections=3) | |||
| b = mnp.dsplit(input_tensor, indices_or_sections=(-10, -4, 5, 10)) | |||
| c = mnp.dsplit(input_tensor, indices_or_sections=[0, 2]) | |||
| return a, b, c | |||
| def onp_dsplit(input_array): | |||
| a = onp.dsplit(input_array, indices_or_sections=3) | |||
| b = onp.dsplit(input_array, indices_or_sections=(-10, -4, 5, 10)) | |||
| c = onp.dsplit(input_array, indices_or_sections=[0, 2]) | |||
| return a, b, c | |||
| @pytest.mark.level1 | |||
| @pytest.mark.platform_arm_ascend_training | |||
| @pytest.mark.platform_x86_ascend_training | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.platform_x86_cpu | |||
| @pytest.mark.env_onecard | |||
| def test_dsplit(): | |||
| onp_arrs = [ | |||
| onp.random.randint(1, 5, size=(5, 4, 9)).astype('float32'), | |||
| onp.random.randint(1, 5, size=(5, 4, 9)).astype('float64') | |||
| ] | |||
| mnp_arrs = [mnp.asarray(arr) for arr in onp_arrs] | |||
| for onp_arr, mnp_arr in zip(onp_arrs, mnp_arrs): | |||
| o_dsplit = onp_dsplit(onp_arr) | |||
| m_dsplit = mnp_dsplit(mnp_arr) | |||
| for expect_lst, actual_lst in zip(o_dsplit, m_dsplit): | |||
| for expect, actual in zip(expect_lst, actual_lst): | |||
| match_array(expect, actual.asnumpy()) | |||
| def mnp_take_along_axis(*arrs): | |||
| x = arrs[0] | |||
| a = mnp.take_along_axis(x, arrs[1], axis=None) | |||
| b = mnp.take_along_axis(x, arrs[2], axis=1) | |||
| c = mnp.take_along_axis(x, arrs[3], axis=-1) | |||
| d = mnp.take_along_axis(x, arrs[4], axis=0) | |||
| return a, b, c, d | |||
| def onp_take_along_axis(*arrs): | |||
| x = arrs[0] | |||
| a = onp.take_along_axis(x, arrs[1], axis=None) | |||
| b = onp.take_along_axis(x, arrs[2], axis=1) | |||
| c = onp.take_along_axis(x, arrs[3], axis=-1) | |||
| d = onp.take_along_axis(x, arrs[4], axis=0) | |||
| return a, b, c, d | |||
| @pytest.mark.level1 | |||
| @pytest.mark.platform_arm_ascend_training | |||
| @pytest.mark.platform_x86_ascend_training | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.platform_x86_cpu | |||
| @pytest.mark.env_onecard | |||
| def test_take_along_axis(): | |||
| x = rand_int(6, 7, 8, 9) | |||
| indices1 = rand_int(2).astype(onp.int32) | |||
| indices2 = rand_int(6, 3, 8, 1).astype(onp.int32) | |||
| indices3 = rand_int(6, 1, 8, 5).astype(onp.int32) | |||
| indices4 = rand_int(4, 1, 1, 1).astype(onp.int32) | |||
| run_multi_test(mnp_take_along_axis, onp_take_along_axis, | |||
| (x, indices1, indices2, indices3, indices4)) | |||
| def mnp_take(x, indices): | |||
| a = mnp.take(x, indices) | |||
| b = mnp.take(x, indices, axis=-1) | |||
| c = mnp.take(x, indices, axis=0, mode='wrap') | |||
| d = mnp.take(x, indices, axis=1, mode='clip') | |||
| return a, b, c, d | |||
| def onp_take(x, indices): | |||
| a = onp.take(x, indices) | |||
| b = onp.take(x, indices, axis=-1) | |||
| c = onp.take(x, indices, axis=0, mode='wrap') | |||
| d = onp.take(x, indices, axis=1, mode='clip') | |||
| return a, b, c, d | |||
| @pytest.mark.level1 | |||
| @pytest.mark.platform_arm_ascend_training | |||
| @pytest.mark.platform_x86_ascend_training | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.platform_x86_cpu | |||
| @pytest.mark.env_onecard | |||
| def test_take(): | |||
| x = rand_int(2, 3, 4, 5) | |||
| indices = rand_int(2, 3).astype(onp.int32) | |||
| run_multi_test(mnp_take, onp_take, (x, indices)) | |||
| def mnp_repeat(x): | |||
| a = mnp.repeat(x, 2) | |||
| b = mnp.repeat(x, 3, axis=0) | |||
| c = mnp.repeat(x, (4, 1, 5), axis=1) | |||
| d = mnp.repeat(x, (3, 2, 1, 0, 4), axis=-1) | |||
| e = mnp.repeat(x, 0) | |||
| return a, b, c, d, e | |||
| def onp_repeat(x): | |||
| a = onp.repeat(x, 2) | |||
| b = onp.repeat(x, 3, axis=0) | |||
| c = onp.repeat(x, (4, 1, 5), axis=1) | |||
| d = onp.repeat(x, (3, 2, 1, 0, 4), axis=-1) | |||
| e = onp.repeat(x, 0) | |||
| return a, b, c, d, e | |||
| @pytest.mark.level1 | |||
| @pytest.mark.platform_arm_ascend_training | |||
| @pytest.mark.platform_x86_ascend_training | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.platform_x86_cpu | |||
| @pytest.mark.env_onecard | |||
| def test_repeat(): | |||
| x = rand_int(2, 3, 4, 5) | |||
| run_multi_test(mnp_repeat, onp_repeat, (x,)) | |||
| class ReshapeExpandSqueeze(Cell): | |||
| def __init__(self): | |||
| super(ReshapeExpandSqueeze, self).__init__() | |||
| @@ -0,0 +1,263 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """unit tests for numpy logical operations""" | |||
| import pytest | |||
| import numpy as onp | |||
| import mindspore.numpy as mnp | |||
| from .utils import rand_int, run_binop_test, match_res | |||
| class Cases(): | |||
| def __init__(self): | |||
| self.arrs = [ | |||
| rand_int(2), | |||
| rand_int(2, 3), | |||
| rand_int(2, 3, 4), | |||
| rand_int(2, 3, 4, 5), | |||
| ] | |||
| # scalars expanded across the 0th dimension | |||
| self.scalars = [ | |||
| rand_int(), | |||
| rand_int(1), | |||
| rand_int(1, 1), | |||
| rand_int(1, 1, 1, 1), | |||
| ] | |||
| # arrays of the same size expanded across the 0th dimension | |||
| self.expanded_arrs = [ | |||
| rand_int(2, 3), | |||
| rand_int(1, 2, 3), | |||
| rand_int(1, 1, 2, 3), | |||
| rand_int(1, 1, 1, 2, 3), | |||
| ] | |||
| # arrays which can be broadcast | |||
| self.broadcastables = [ | |||
| rand_int(5), | |||
| rand_int(6, 1), | |||
| rand_int(7, 1, 5), | |||
| rand_int(8, 1, 6, 1) | |||
| ] | |||
| # array which contains infs and nans | |||
| self.infs = onp.array([[1.0, onp.nan], [onp.inf, onp.NINF], [2.3, -4.5], [onp.nan, 0.0]]) | |||
| test_case = Cases() | |||
| def mnp_not_equal(a, b): | |||
| return mnp.not_equal(a, b) | |||
| def onp_not_equal(a, b): | |||
| return onp.not_equal(a, b) | |||
| @pytest.mark.level1 | |||
| @pytest.mark.platform_arm_ascend_training | |||
| @pytest.mark.platform_x86_ascend_training | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.platform_x86_cpu | |||
| @pytest.mark.env_onecard | |||
| def test_not_equal(): | |||
| run_binop_test(mnp_not_equal, onp_not_equal, test_case) | |||
| def mnp_less_equal(a, b): | |||
| return mnp.less_equal(a, b) | |||
| def onp_less_equal(a, b): | |||
| return onp.less_equal(a, b) | |||
| @pytest.mark.level1 | |||
| @pytest.mark.platform_arm_ascend_training | |||
| @pytest.mark.platform_x86_ascend_training | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.platform_x86_cpu | |||
| @pytest.mark.env_onecard | |||
| def test_less_equal(): | |||
| run_binop_test(mnp_less_equal, onp_less_equal, test_case) | |||
| def mnp_less(a, b): | |||
| return mnp.less(a, b) | |||
| def onp_less(a, b): | |||
| return onp.less(a, b) | |||
| @pytest.mark.level1 | |||
| @pytest.mark.platform_arm_ascend_training | |||
| @pytest.mark.platform_x86_ascend_training | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.platform_x86_cpu | |||
| @pytest.mark.env_onecard | |||
| def test_less(): | |||
| run_binop_test(mnp_less, onp_less, test_case) | |||
| def mnp_greater_equal(a, b): | |||
| return mnp.greater_equal(a, b) | |||
| def onp_greater_equal(a, b): | |||
| return onp.greater_equal(a, b) | |||
| @pytest.mark.level1 | |||
| @pytest.mark.platform_arm_ascend_training | |||
| @pytest.mark.platform_x86_ascend_training | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.platform_x86_cpu | |||
| @pytest.mark.env_onecard | |||
| def test_greater_equal(): | |||
| run_binop_test(mnp_greater_equal, onp_greater_equal, test_case) | |||
| def mnp_greater(a, b): | |||
| return mnp.greater(a, b) | |||
| def onp_greater(a, b): | |||
| return onp.greater(a, b) | |||
| @pytest.mark.level1 | |||
| @pytest.mark.platform_arm_ascend_training | |||
| @pytest.mark.platform_x86_ascend_training | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.platform_x86_cpu | |||
| @pytest.mark.env_onecard | |||
| def test_greater(): | |||
| run_binop_test(mnp_greater, onp_greater, test_case) | |||
| def mnp_equal(a, b): | |||
| return mnp.equal(a, b) | |||
| def onp_equal(a, b): | |||
| return onp.equal(a, b) | |||
| def test_equal(): | |||
| run_binop_test(mnp_equal, onp_equal, test_case) | |||
| def mnp_isfinite(x): | |||
| return mnp.isfinite(x) | |||
| def onp_isfinite(x): | |||
| return onp.isfinite(x) | |||
| @pytest.mark.level1 | |||
| @pytest.mark.platform_arm_ascend_training | |||
| @pytest.mark.platform_x86_ascend_training | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.platform_x86_cpu | |||
| @pytest.mark.env_onecard | |||
| def test_isfinite(): | |||
| match_res(mnp_isfinite, onp_isfinite, test_case.infs) | |||
| def mnp_isnan(x): | |||
| return mnp.isnan(x) | |||
| def onp_isnan(x): | |||
| return onp.isnan(x) | |||
| @pytest.mark.level1 | |||
| @pytest.mark.platform_arm_ascend_training | |||
| @pytest.mark.platform_x86_ascend_training | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.platform_x86_cpu | |||
| @pytest.mark.env_onecard | |||
| def test_isnan(): | |||
| match_res(mnp_isnan, onp_isnan, test_case.infs) | |||
| def mnp_isinf(x): | |||
| return mnp.isinf(x) | |||
| def onp_isinf(x): | |||
| return onp.isinf(x) | |||
| @pytest.mark.level1 | |||
| @pytest.mark.platform_arm_ascend_training | |||
| @pytest.mark.platform_x86_ascend_training | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.platform_x86_cpu | |||
| @pytest.mark.env_onecard | |||
| def test_isinf(): | |||
| match_res(mnp_isinf, onp_isinf, test_case.infs) | |||
| def mnp_isposinf(x): | |||
| return mnp.isposinf(x) | |||
| def onp_isposinf(x): | |||
| return onp.isposinf(x) | |||
| @pytest.mark.level1 | |||
| @pytest.mark.platform_arm_ascend_training | |||
| @pytest.mark.platform_x86_ascend_training | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.platform_x86_cpu | |||
| @pytest.mark.env_onecard | |||
| def test_isposinf(): | |||
| match_res(mnp_isposinf, onp_isposinf, test_case.infs) | |||
| def mnp_isneginf(x): | |||
| return mnp.isneginf(x) | |||
| def onp_isneginf(x): | |||
| return onp.isneginf(x) | |||
| @pytest.mark.level1 | |||
| @pytest.mark.platform_arm_ascend_training | |||
| @pytest.mark.platform_x86_ascend_training | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.platform_x86_cpu | |||
| @pytest.mark.env_onecard | |||
| def test_isneginf(): | |||
| match_res(mnp_isneginf, onp_isneginf, test_case.infs) | |||
| def test_isscalar(): | |||
| assert mnp.isscalar(1) == onp.isscalar(1) | |||
| assert mnp.isscalar(2.3) == onp.isscalar(2.3) | |||
| assert mnp.isscalar([4.5]) == onp.isscalar([4.5]) | |||
| assert mnp.isscalar(False) == onp.isscalar(False) | |||
| assert mnp.isscalar(mnp.array(True)) == onp.isscalar(onp.array(True)) | |||
| assert mnp.isscalar('numpy') == onp.isscalar('numpy') | |||
| @@ -0,0 +1,165 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """utility functions for mindspore.numpy st tests""" | |||
| import functools | |||
| import numpy as onp | |||
| import mindspore.numpy as mnp | |||
| def match_array(actual, expected, error=0): | |||
| if isinstance(actual, int): | |||
| actual = onp.asarray(actual) | |||
| if isinstance(expected, int): | |||
| expected = onp.asarray(expected) | |||
| if error > 0: | |||
| onp.testing.assert_almost_equal(actual.tolist(), expected.tolist(), | |||
| decimal=error) | |||
| else: | |||
| onp.testing.assert_equal(actual.tolist(), expected.tolist()) | |||
| def check_all_results(onp_results, mnp_results, error=0): | |||
| """Check all results from numpy and mindspore.numpy""" | |||
| for i, _ in enumerate(onp_results): | |||
| match_array(onp_results[i], mnp_results[i].asnumpy()) | |||
| def check_all_unique_results(onp_results, mnp_results): | |||
| """ | |||
| Check all results from numpy and mindspore.numpy. | |||
| Args: | |||
| onp_results (Union[tuple of numpy.arrays, numpy.array]) | |||
| mnp_results (Union[tuple of Tensors, Tensor]) | |||
| """ | |||
| for i, _ in enumerate(onp_results): | |||
| if isinstance(onp_results[i], tuple): | |||
| for j in range(len(onp_results[i])): | |||
| match_array(onp_results[i][j], | |||
| mnp_results[i][j].asnumpy(), error=7) | |||
| else: | |||
| match_array(onp_results[i], mnp_results[i].asnumpy(), error=7) | |||
| def run_non_kw_test(mnp_fn, onp_fn, test_case): | |||
| """Run tests on functions with non keyword arguments""" | |||
| for i in range(len(test_case.arrs)): | |||
| arrs = test_case.arrs[:i] | |||
| match_res(mnp_fn, onp_fn, *arrs) | |||
| for i in range(len(test_case.scalars)): | |||
| arrs = test_case.scalars[:i] | |||
| match_res(mnp_fn, onp_fn, *arrs) | |||
| for i in range(len(test_case.expanded_arrs)): | |||
| arrs = test_case.expanded_arrs[:i] | |||
| match_res(mnp_fn, onp_fn, *arrs) | |||
| for i in range(len(test_case.nested_arrs)): | |||
| arrs = test_case.nested_arrs[:i] | |||
| match_res(mnp_fn, onp_fn, *arrs) | |||
| def rand_int(*shape): | |||
| """return an random integer array with parameter shape""" | |||
| res = onp.random.randint(low=1, high=5, size=shape) | |||
| if isinstance(res, onp.ndarray): | |||
| return res.astype(onp.float32) | |||
| return float(res) | |||
| # return an random boolean array | |||
| def rand_bool(*shape): | |||
| return onp.random.rand(*shape) > 0.5 | |||
| def match_res(mnp_fn, onp_fn, *arrs, **kwargs): | |||
| """Checks results from applying mnp_fn and onp_fn on arrs respectively""" | |||
| mnp_arrs = map(functools.partial(mnp.asarray, dtype='float32'), arrs) | |||
| error = kwargs.get('error', 0) | |||
| kwargs.pop('error', None) | |||
| mnp_res = mnp_fn(*mnp_arrs, **kwargs) | |||
| onp_res = onp_fn(*arrs, **kwargs) | |||
| match_all_arrays(mnp_res, onp_res, error=error) | |||
| def match_all_arrays(mnp_res, onp_res, error=0): | |||
| if isinstance(mnp_res, (tuple, list)): | |||
| assert len(mnp_res) == len(onp_res) | |||
| for actual, expected in zip(mnp_res, onp_res): | |||
| match_array(actual.asnumpy(), expected, error) | |||
| else: | |||
| match_array(mnp_res.asnumpy(), onp_res, error) | |||
| def match_meta(actual, expected): | |||
| # float64 and int64 are not supported, and the default type for | |||
| # float and int are float32 and int32, respectively | |||
| if expected.dtype == onp.float64: | |||
| expected = expected.astype(onp.float32) | |||
| elif expected.dtype == onp.int64: | |||
| expected = expected.astype(onp.int32) | |||
| assert actual.shape == expected.shape | |||
| assert actual.dtype == expected.dtype | |||
| def run_binop_test(mnp_fn, onp_fn, test_case): | |||
| for arr in test_case.arrs: | |||
| match_res(mnp_fn, onp_fn, arr, arr) | |||
| for scalar in test_case.scalars: | |||
| match_res(mnp_fn, onp_fn, arr, scalar) | |||
| match_res(mnp_fn, onp_fn, scalar, arr) | |||
| for scalar1 in test_case.scalars: | |||
| for scalar2 in test_case.scalars: | |||
| match_res(mnp_fn, onp_fn, scalar1, scalar2) | |||
| for expanded_arr1 in test_case.expanded_arrs: | |||
| for expanded_arr2 in test_case.expanded_arrs: | |||
| match_res(mnp_fn, onp_fn, expanded_arr1, expanded_arr2) | |||
| for broadcastable1 in test_case.broadcastables: | |||
| for broadcastable2 in test_case.broadcastables: | |||
| match_res(mnp_fn, onp_fn, broadcastable1, broadcastable2) | |||
| def run_unary_test(mnp_fn, onp_fn, test_case, error=0): | |||
| for arr in test_case.arrs: | |||
| match_res(mnp_fn, onp_fn, arr, error=error) | |||
| for arr in test_case.scalars: | |||
| match_res(mnp_fn, onp_fn, arr, error=error) | |||
| for arr in test_case.expanded_arrs: | |||
| match_res(mnp_fn, onp_fn, arr, error=error) | |||
| def run_multi_test(mnp_fn, onp_fn, arrs, error=0): | |||
| mnp_arrs = map(mnp.asarray, arrs) | |||
| for actual, expected in zip(mnp_fn(*mnp_arrs), onp_fn(*arrs)): | |||
| match_array(actual.asnumpy(), expected, error) | |||
| def run_single_test(mnp_fn, onp_fn, arr, error=0): | |||
| mnp_arr = mnp.asarray(arr) | |||
| for actual, expected in zip(mnp_fn(mnp_arr), onp_fn(arr)): | |||
| if isinstance(expected, tuple): | |||
| for actual_arr, expected_arr in zip(actual, expected): | |||
| match_array(actual_arr.asnumpy(), expected_arr, error) | |||
| match_array(actual.asnumpy(), expected, error) | |||