From: @jachua Reviewed-by: Signed-off-by:tags/v1.2.0-rc1
| @@ -35,7 +35,7 @@ from .array_creations import copy_ as copy | |||||
| from .array_creations import (array, asarray, asfarray, ones, zeros, full, arange, | from .array_creations import (array, asarray, asfarray, ones, zeros, full, arange, | ||||
| linspace, logspace, eye, identity, empty, empty_like, | linspace, logspace, eye, identity, empty, empty_like, | ||||
| ones_like, zeros_like, full_like, diagonal, tril, triu, | ones_like, zeros_like, full_like, diagonal, tril, triu, | ||||
| tri, trace, cumsum, meshgrid, mgrid, ogrid, diagflat, | |||||
| tri, trace, meshgrid, mgrid, ogrid, diagflat, | |||||
| diag, diag_indices, ix_) | diag, diag_indices, ix_) | ||||
| from .dtypes import (int_, int8, int16, int32, int64, uint, uint8, uint16, | from .dtypes import (int_, int8, int16, int32, int64, uint, uint8, uint16, | ||||
| uint32, uint64, float_, float16, float32, float64, bool_, inf, nan, | uint32, uint64, float_, float16, float32, float64, bool_, inf, nan, | ||||
| @@ -45,7 +45,7 @@ from .math_ops import (mean, inner, add, subtract, multiply, divide, true_divide | |||||
| matmul, square, sqrt, reciprocal, log, maximum, heaviside, amax, amin, | matmul, square, sqrt, reciprocal, log, maximum, heaviside, amax, amin, | ||||
| hypot, float_power, floor, ptp, deg2rad, rad2deg, count_nonzero, | hypot, float_power, floor, ptp, deg2rad, rad2deg, count_nonzero, | ||||
| positive, negative, clip, floor_divide, remainder, fix, fmod, trunc, | positive, negative, clip, floor_divide, remainder, fix, fmod, trunc, | ||||
| exp, expm1) | |||||
| exp, expm1, cumsum) | |||||
| from .logic_ops import (not_equal, less_equal, less, greater_equal, greater, equal, isfinite, | from .logic_ops import (not_equal, less_equal, less, greater_equal, greater, equal, isfinite, | ||||
| isnan, isinf, isposinf, isneginf, isscalar) | isnan, isinf, isposinf, isneginf, isscalar) | ||||
| @@ -70,7 +70,7 @@ math_module = ['mean', 'inner', 'add', 'subtract', 'multiply', 'divide', 'true_d | |||||
| 'minimum', 'matmul', 'square', 'sqrt', 'reciprocal', 'log', 'maximum', | 'minimum', 'matmul', 'square', 'sqrt', 'reciprocal', 'log', 'maximum', | ||||
| 'heaviside', 'amax', 'amin', 'hypot', 'float_power', 'floor', 'ptp', 'deg2rad', | 'heaviside', 'amax', 'amin', 'hypot', 'float_power', 'floor', 'ptp', 'deg2rad', | ||||
| 'rad2deg', 'count_nonzero', 'positive', 'negative', 'clip', 'floor_divide', | 'rad2deg', 'count_nonzero', 'positive', 'negative', 'clip', 'floor_divide', | ||||
| 'remainder', 'mod', 'fix', 'fmod', 'trunc', 'exp', 'expm1', 'fabs'] | |||||
| 'remainder', 'mod', 'fix', 'fmod', 'trunc', 'exp', 'expm1', 'fabs', 'cumsum'] | |||||
| logic_module = ['not_equal', 'less_equal', 'less', 'greater_equal', 'greater', 'equal', 'isfinite', | logic_module = ['not_equal', 'less_equal', 'less', 'greater_equal', 'greater', 'equal', 'isfinite', | ||||
| 'isnan', 'isinf', 'isposinf', 'isneginf', 'isscalar'] | 'isnan', 'isinf', 'isposinf', 'isneginf', 'isscalar'] | ||||
| @@ -20,7 +20,6 @@ import numpy as onp | |||||
| from ..common import Tensor | from ..common import Tensor | ||||
| from ..common import dtype as mstype | from ..common import dtype as mstype | ||||
| from ..ops import functional as F | from ..ops import functional as F | ||||
| from ..ops import operations as P | |||||
| from ..ops.primitive import constexpr | from ..ops.primitive import constexpr | ||||
| from ..nn.layer.basic import tril as nn_tril | from ..nn.layer.basic import tril as nn_tril | ||||
| from ..nn.layer.basic import triu as nn_triu | from ..nn.layer.basic import triu as nn_triu | ||||
| @@ -31,7 +30,7 @@ from .utils import _check_input_for_asarray, _deep_list, _deep_tensor_to_nparray | |||||
| _expand, _broadcast_to_shape, _check_input_tensor, _convert_64_to_32, _get_dtype_from_scalar | _expand, _broadcast_to_shape, _check_input_tensor, _convert_64_to_32, _get_dtype_from_scalar | ||||
| from .utils_const import _raise_value_error, _empty, _check_axis_valid, _max, _min, \ | from .utils_const import _raise_value_error, _empty, _check_axis_valid, _max, _min, \ | ||||
| _check_same_type, _is_shape_empty, _check_shape, _check_dtype, _tile_size, _abs, \ | _check_same_type, _is_shape_empty, _check_shape, _check_dtype, _tile_size, _abs, \ | ||||
| _raise_type_error, _expanded_shape, _check_axis_in_range, _check_is_float, _iota, \ | |||||
| _raise_type_error, _expanded_shape, _check_is_float, _iota, \ | |||||
| _type_convert, _canonicalize_axis, _list_comprehensions, _ceil | _type_convert, _canonicalize_axis, _list_comprehensions, _ceil | ||||
| from .array_ops import transpose, ravel, concatenate, broadcast_arrays, reshape | from .array_ops import transpose, ravel, concatenate, broadcast_arrays, reshape | ||||
| from .dtypes import nan | from .dtypes import nan | ||||
| @@ -41,7 +40,6 @@ from .dtypes import nan | |||||
| MAX_NUMPY_DIMS = 32 | MAX_NUMPY_DIMS = 32 | ||||
| # All types that can be accepted as "array_like" parameters in graph mode. | # All types that can be accepted as "array_like" parameters in graph mode. | ||||
| ARRAY_TYPES = (int, float, bool, list, tuple, Tensor) | ARRAY_TYPES = (int, float, bool, list, tuple, Tensor) | ||||
| _cumsum_default = P.CumSum() | |||||
| def array(obj, dtype=None, copy=True, ndmin=0): | def array(obj, dtype=None, copy=True, ndmin=0): | ||||
| @@ -1172,53 +1170,6 @@ def trace(a, offset=0, axis1=0, axis2=1, dtype=None): | |||||
| return res | return res | ||||
| def cumsum(a, axis=None, dtype=None): | |||||
| """ | |||||
| Returns the cumulative sum of the elements along a given axis. | |||||
| Args: | |||||
| a (Tensor): Input tensor. | |||||
| axis (int, optional): Axis along which the cumulative sum is computed. The | |||||
| default (None) is to compute the cumsum over the flattened array. | |||||
| dtype (:class:`mindspore.dtype`, optional): If not specified, stay the same as `a`, | |||||
| unless `a` has an integer dtype with a precision less than that of the | |||||
| default platform integer. In that case, the default platform integer | |||||
| is used. | |||||
| Returns: | |||||
| Tensor. | |||||
| Raises: | |||||
| TypeError: If input arguments have types not specified above. | |||||
| ValueError: If axis is out of range. | |||||
| Supported Platforms: | |||||
| ``Ascend`` ``GPU`` ``CPU`` | |||||
| Examples: | |||||
| >>> output = np.cumsum(np.ones((3,3)), axis=0) | |||||
| >>> print(output) | |||||
| [[1. 1. 1.] | |||||
| [2. 2. 2.] | |||||
| [3. 3. 3.]] | |||||
| """ | |||||
| _check_input_tensor(a) | |||||
| original_dtype = F.dtype(a) | |||||
| # If original array is int, and has precision less then int32, convert to int32 | |||||
| if _check_same_type(original_dtype, mstype.bool_) or \ | |||||
| _check_same_type(original_dtype, mstype.int8) or \ | |||||
| _check_same_type(original_dtype, mstype.int16): | |||||
| original_dtype = mstype.int32 | |||||
| a = a.astype(mstype.float32) | |||||
| if axis is None: | |||||
| a = a.ravel() | |||||
| axis = 0 | |||||
| _check_axis_in_range(axis, a.ndim) | |||||
| if dtype is not None and not _check_same_type(original_dtype, dtype): | |||||
| return _cumsum_default(a, axis).astype(dtype, copy=False) | |||||
| return _cumsum_default(a, axis).astype(original_dtype, copy=False) | |||||
| def _index(i, size, Cartesian=True): | def _index(i, size, Cartesian=True): | ||||
| """If Cartesian=True, index 0 is swapped with index 1.""" | """If Cartesian=True, index 0 is swapped with index 1.""" | ||||
| if Cartesian: | if Cartesian: | ||||
| @@ -1905,8 +1905,11 @@ def repeat(a, repeats, axis=None): | |||||
| if repeats == 0: | if repeats == 0: | ||||
| return _empty(F.dtype(a), (0,)) | return _empty(F.dtype(a), (0,)) | ||||
| return C.repeat_elements(a, repeats, axis) | return C.repeat_elements(a, repeats, axis) | ||||
| shape = F.shape(a) | shape = F.shape(a) | ||||
| size = shape[axis] | size = shape[axis] | ||||
| if len(repeats) != size: | |||||
| _raise_value_error('operands could not be broadcast together') | |||||
| subs = split(a, size, axis) | subs = split(a, size, axis) | ||||
| repeated_subs = [] | repeated_subs = [] | ||||
| for sub, rep in zip(subs, repeats): | for sub, rep in zip(subs, repeats): | ||||
| @@ -361,7 +361,7 @@ def isnan(x, out=None, where=True, dtype=None): | |||||
| When `where` is provided, `out` must have a tensor value. `out` is not supported | When `where` is provided, `out` must have a tensor value. `out` is not supported | ||||
| for storing the result, however it can be used in combination with `where` to set | for storing the result, however it can be used in combination with `where` to set | ||||
| the value at indices for which `where` is set to False. | the value at indices for which `where` is set to False. | ||||
| On GPU, the supported dtypes are np.float16, and np.float32. | |||||
| Only np.float32 is currently supported. | |||||
| Args: | Args: | ||||
| x (Tensor): Input values. | x (Tensor): Input values. | ||||
| @@ -422,7 +422,7 @@ def isinf(x, out=None, where=True, dtype=None): | |||||
| When `where` is provided, `out` must have a tensor value. `out` is not supported | When `where` is provided, `out` must have a tensor value. `out` is not supported | ||||
| for storing the result, however it can be used in combination with `where` to set | for storing the result, however it can be used in combination with `where` to set | ||||
| the value at indices for which `where` is set to False. | the value at indices for which `where` is set to False. | ||||
| On GPU, the supported dtypes are np.float16, and np.float32. | |||||
| Only np.float32 is currently supported. | |||||
| Args: | Args: | ||||
| x (Tensor): Input values. | x (Tensor): Input values. | ||||
| @@ -477,7 +477,7 @@ def isposinf(x): | |||||
| Note: | Note: | ||||
| Numpy argument `out` is not supported. | Numpy argument `out` is not supported. | ||||
| On GPU, the supported dtypes are np.float16, and np.float32. | |||||
| Only np.float32 is currently supported. | |||||
| Args: | Args: | ||||
| x (Tensor): Input values. | x (Tensor): Input values. | ||||
| @@ -507,7 +507,7 @@ def isneginf(x): | |||||
| Note: | Note: | ||||
| Numpy argument `out` is not supported. | Numpy argument `out` is not supported. | ||||
| On GPU, the supported dtypes are np.float16, and np.float32. | |||||
| Only np.float32 is currently supported. | |||||
| Args: | Args: | ||||
| x (Tensor): Input values. | x (Tensor): Input values. | ||||
| @@ -32,7 +32,7 @@ from .array_ops import ravel, expand_dims | |||||
| from .utils_const import _infer_out_shape, _check_axis_valid, _get_device, \ | from .utils_const import _infer_out_shape, _check_axis_valid, _get_device, \ | ||||
| _check_shape_aligned, _raise_type_error, _check_same_type, _check_is_float, \ | _check_shape_aligned, _raise_type_error, _check_same_type, _check_is_float, \ | ||||
| _raise_value_error, _check_matmul_shapes, _promote, _check_axis_type, _canonicalize_axis, \ | _raise_value_error, _check_matmul_shapes, _promote, _check_axis_type, _canonicalize_axis, \ | ||||
| _max, _is_shape_empty, _check_is_int, _expanded_shape | |||||
| _max, _is_shape_empty, _check_is_int, _expanded_shape, _check_axis_in_range | |||||
| from .utils import _is_scalar, _expand, _broadcast_to, _broadcast_to_shape, _get_size, \ | from .utils import _is_scalar, _expand, _broadcast_to, _broadcast_to_shape, _get_size, \ | ||||
| _check_input_tensor | _check_input_tensor | ||||
| @@ -50,6 +50,7 @@ _reduce_min_default = P.ReduceMin() | |||||
| _reduce_min_keepdims = P.ReduceMin(True) | _reduce_min_keepdims = P.ReduceMin(True) | ||||
| _reduce_max_default = P.ReduceMax() | _reduce_max_default = P.ReduceMax() | ||||
| _reduce_max_keepdims = P.ReduceMax(True) | _reduce_max_keepdims = P.ReduceMax(True) | ||||
| _cumsum_default = P.CumSum() | |||||
| def absolute(x, out=None, where=True, dtype=None): | def absolute(x, out=None, where=True, dtype=None): | ||||
| """ | """ | ||||
| @@ -2385,6 +2386,53 @@ def negative(a, out=None, where=True, dtype=None): | |||||
| return _apply_tensor_op(F.neg_tensor, a, out=out, where=where, dtype=dtype) | return _apply_tensor_op(F.neg_tensor, a, out=out, where=where, dtype=dtype) | ||||
| def cumsum(a, axis=None, dtype=None): | |||||
| """ | |||||
| Returns the cumulative sum of the elements along a given axis. | |||||
| Args: | |||||
| a (Tensor): Input tensor. | |||||
| axis (int, optional): Axis along which the cumulative sum is computed. The | |||||
| default (None) is to compute the cumsum over the flattened array. | |||||
| dtype (:class:`mindspore.dtype`, optional): If not specified, stay the same as `a`, | |||||
| unless `a` has an integer dtype with a precision less than that of the | |||||
| default platform integer. In that case, the default platform integer | |||||
| is used. | |||||
| Returns: | |||||
| Tensor. | |||||
| Raises: | |||||
| TypeError: If input arguments have types not specified above. | |||||
| ValueError: If axis is out of range. | |||||
| Supported Platforms: | |||||
| ``Ascend`` ``GPU`` ``CPU`` | |||||
| Examples: | |||||
| >>> output = np.cumsum(np.ones((3,3)), axis=0) | |||||
| >>> print(output) | |||||
| [[1. 1. 1.] | |||||
| [2. 2. 2.] | |||||
| [3. 3. 3.]] | |||||
| """ | |||||
| _check_input_tensor(a) | |||||
| original_dtype = F.dtype(a) | |||||
| # If original array is int, and has precision less then int32, convert to int32 | |||||
| if _check_same_type(original_dtype, mstype.bool_) or \ | |||||
| _check_same_type(original_dtype, mstype.int8) or \ | |||||
| _check_same_type(original_dtype, mstype.int16): | |||||
| original_dtype = mstype.int32 | |||||
| a = a.astype(mstype.float32) | |||||
| if axis is None: | |||||
| a = a.ravel() | |||||
| axis = 0 | |||||
| _check_axis_in_range(axis, a.ndim) | |||||
| if dtype is not None and not _check_same_type(original_dtype, dtype): | |||||
| return _cumsum_default(a, axis).astype(dtype, copy=False) | |||||
| return _cumsum_default(a, axis).astype(original_dtype, copy=False) | |||||
| def _apply_tensor_op(fn, *args, out=None, where=True, dtype=None): | def _apply_tensor_op(fn, *args, out=None, where=True, dtype=None): | ||||
| """Applies tensor operations based on fn""" | """Applies tensor operations based on fn""" | ||||
| _check_input_tensor(*args) | _check_input_tensor(*args) | ||||
| @@ -549,26 +549,6 @@ def test_tri_triu_tril(): | |||||
| match_array(mnp.tri(64, 64, -10).asnumpy(), onp.tri(64, 64, -10)) | match_array(mnp.tri(64, 64, -10).asnumpy(), onp.tri(64, 64, -10)) | ||||
| @pytest.mark.level1 | |||||
| @pytest.mark.platform_arm_ascend_training | |||||
| @pytest.mark.platform_x86_ascend_training | |||||
| @pytest.mark.platform_x86_gpu_training | |||||
| @pytest.mark.platform_x86_cpu | |||||
| @pytest.mark.env_onecard | |||||
| def test_cumsum(): | |||||
| x = mnp.ones((16, 16), dtype="bool") | |||||
| match_array(mnp.cumsum(x).asnumpy(), onp.cumsum(x.asnumpy())) | |||||
| match_array(mnp.cumsum(x, axis=0).asnumpy(), | |||||
| onp.cumsum(x.asnumpy(), axis=0)) | |||||
| match_meta(mnp.cumsum(x).asnumpy(), onp.cumsum(x.asnumpy())) | |||||
| x = rand_int(3, 4, 5) | |||||
| match_array(mnp.cumsum(mnp.asarray(x), dtype="bool").asnumpy(), | |||||
| onp.cumsum(x, dtype="bool")) | |||||
| match_array(mnp.cumsum(mnp.asarray(x), axis=-1).asnumpy(), | |||||
| onp.cumsum(x, axis=-1)) | |||||
| def mnp_diagonal(arr): | def mnp_diagonal(arr): | ||||
| return mnp.diagonal(arr, offset=2, axis1=-1, axis2=0) | return mnp.diagonal(arr, offset=2, axis1=-1, axis2=0) | ||||
| @@ -20,7 +20,7 @@ import numpy as onp | |||||
| import mindspore.numpy as mnp | import mindspore.numpy as mnp | ||||
| from .utils import rand_int, rand_bool, run_binop_test, run_unary_test, run_multi_test, \ | from .utils import rand_int, rand_bool, run_binop_test, run_unary_test, run_multi_test, \ | ||||
| run_single_test, match_res, match_array | |||||
| run_single_test, match_res, match_array, match_meta | |||||
| class Cases(): | class Cases(): | ||||
| def __init__(self): | def __init__(self): | ||||
| @@ -1138,6 +1138,26 @@ def test_negative(): | |||||
| match_array(mnp_neg.asnumpy(), onp_neg, 1e-5) | match_array(mnp_neg.asnumpy(), onp_neg, 1e-5) | ||||
| @pytest.mark.level1 | |||||
| @pytest.mark.platform_arm_ascend_training | |||||
| @pytest.mark.platform_x86_ascend_training | |||||
| @pytest.mark.platform_x86_gpu_training | |||||
| @pytest.mark.platform_x86_cpu | |||||
| @pytest.mark.env_onecard | |||||
| def test_cumsum(): | |||||
| x = mnp.ones((16, 16), dtype="bool") | |||||
| match_array(mnp.cumsum(x).asnumpy(), onp.cumsum(x.asnumpy())) | |||||
| match_array(mnp.cumsum(x, axis=0).asnumpy(), | |||||
| onp.cumsum(x.asnumpy(), axis=0)) | |||||
| match_meta(mnp.cumsum(x).asnumpy(), onp.cumsum(x.asnumpy())) | |||||
| x = rand_int(3, 4, 5) | |||||
| match_array(mnp.cumsum(mnp.asarray(x), dtype="bool").asnumpy(), | |||||
| onp.cumsum(x, dtype="bool")) | |||||
| match_array(mnp.cumsum(mnp.asarray(x), axis=-1).asnumpy(), | |||||
| onp.cumsum(x, axis=-1)) | |||||
| @pytest.mark.level1 | @pytest.mark.level1 | ||||
| @pytest.mark.platform_arm_ascend_training | @pytest.mark.platform_arm_ascend_training | ||||
| @pytest.mark.platform_x86_ascend_training | @pytest.mark.platform_x86_ascend_training | ||||