| @@ -1596,7 +1596,7 @@ def ix_(*args): | |||||
| Boolean masks are not supported. | Boolean masks are not supported. | ||||
| Args: | Args: | ||||
| *args (Tensor): 1-D, each sequence should be of integer type. | |||||
| *args (Tensor): 1-D sequences. | |||||
| Returns: | Returns: | ||||
| Tuple of Tensor, `N` arrays with `N` dimensions each, with `N` the | Tuple of Tensor, `N` arrays with `N` dimensions each, with `N` the | ||||
| @@ -1898,15 +1898,17 @@ def repeat(a, repeats, axis=None): | |||||
| [3 4]] | [3 4]] | ||||
| """ | """ | ||||
| _check_input_tensor(a) | _check_input_tensor(a) | ||||
| if not isinstance(repeats, (tuple, list)): | |||||
| repeats = (repeats,) | |||||
| _check_element_int(repeats) | |||||
| if axis is None: | if axis is None: | ||||
| a = ravel(a) | a = ravel(a) | ||||
| axis = 0 | axis = 0 | ||||
| ndim = F.rank(a) | ndim = F.rank(a) | ||||
| _check_axis_in_range(axis, ndim) | _check_axis_in_range(axis, ndim) | ||||
| axis = axis + ndim if axis < 0 else axis | axis = axis + ndim if axis < 0 else axis | ||||
| if isinstance(repeats, (tuple, list)) and len(repeats) == 1: | |||||
| if len(repeats) == 1: | |||||
| repeats = repeats[0] | repeats = repeats[0] | ||||
| if isinstance(repeats, int): | |||||
| if repeats == 0: | if repeats == 0: | ||||
| return _empty(F.dtype(a), (0,)) | return _empty(F.dtype(a), (0,)) | ||||
| return C.repeat_elements(a, repeats, axis) | return C.repeat_elements(a, repeats, axis) | ||||
| @@ -17,7 +17,9 @@ | |||||
| from .math_ops import _apply_tensor_op | from .math_ops import _apply_tensor_op | ||||
| from ..ops import functional as F | from ..ops import functional as F | ||||
| from ..ops.primitive import constexpr | |||||
| from ..common import dtype as mstype | from ..common import dtype as mstype | ||||
| from ..common import Tensor | |||||
| from .._c_expression import typing | from .._c_expression import typing | ||||
| from .array_creations import zeros, ones | from .array_creations import zeros, ones | ||||
| @@ -530,6 +532,13 @@ def isneginf(x): | |||||
| return _is_sign_inf(x, F.tensor_lt) | return _is_sign_inf(x, F.tensor_lt) | ||||
| @constexpr | |||||
| def _isscalar(x): | |||||
| """Returns True if x is a scalar type""" | |||||
| return isinstance(x, (typing.Number, typing.Int, typing.UInt, typing.Float, | |||||
| typing.Bool, typing.String)) | |||||
| def isscalar(element): | def isscalar(element): | ||||
| """ | """ | ||||
| Returns True if the type of element is a scalar type. | Returns True if the type of element is a scalar type. | ||||
| @@ -565,5 +574,5 @@ def isscalar(element): | |||||
| >>> print(output) | >>> print(output) | ||||
| True | True | ||||
| """ | """ | ||||
| return isinstance(F.typeof(element), (typing.Number, typing.Int, typing.UInt, | |||||
| typing.Float, typing.Bool, typing.String)) | |||||
| obj_type = F.typeof(element) | |||||
| return not isinstance(obj_type, Tensor) and _isscalar(obj_type) | |||||
| @@ -2237,6 +2237,10 @@ def _reduce(a, reduce_fn, cmp_fn, axis=None, keepdims=False, initial=None, where | |||||
| ndim = F.rank(a) | ndim = F.rank(a) | ||||
| dtype = F.dtype(a) | dtype = F.dtype(a) | ||||
| axes = _check_axis_valid(axis, ndim) | axes = _check_axis_valid(axis, ndim) | ||||
| if initial is not None: | |||||
| if ((isinstance(initial, Tensor) and F.rank(initial) > 0) or | |||||
| not isinstance(initial, (int, float, bool, Tensor))): | |||||
| _raise_type_error('initial should be scalar') | |||||
| if _is_shape_empty(shape): | if _is_shape_empty(shape): | ||||
| if not axes: | if not axes: | ||||