From f34756c3bd1c1f1c323f3887268da5ad0299e38c Mon Sep 17 00:00:00 2001 From: huangmengxi Date: Tue, 2 Mar 2021 19:26:40 +0800 Subject: [PATCH] additional typecheck --- mindspore/numpy/array_ops.py | 7 ++++++- mindspore/numpy/math_ops.py | 4 +++- mindspore/numpy/utils_const.py | 15 +++++++-------- 3 files changed, 16 insertions(+), 10 deletions(-) diff --git a/mindspore/numpy/array_ops.py b/mindspore/numpy/array_ops.py index ec4d14bb54..17b5d0e205 100644 --- a/mindspore/numpy/array_ops.py +++ b/mindspore/numpy/array_ops.py @@ -1200,6 +1200,8 @@ def moveaxis(a, source, destination): ndim = F.rank(a) source = _check_axis_valid(source, ndim) destination = _check_axis_valid(destination, ndim) + if len(source) != len(destination): + _raise_value_error('`source` and `destination` arguments must have the same number of elements') perm = _get_moved_perm(ndim, source, destination) shape = F.shape(a) @@ -1305,7 +1307,7 @@ def broadcast_to(array, shape): """ shape_a = F.shape(array) if not _check_can_broadcast_to(shape_a, shape): - return _raise_value_error('cannot broadcaast with {shape_a} {shape}') + return _raise_value_error('cannot broadcast with ', shape) return _broadcast_to_shape(array, shape) @@ -1386,6 +1388,7 @@ def split(x, indices_or_sections, axis=0): Tensor(shape=[3], dtype=Float32, value= [ 6.00000000e+00, 7.00000000e+00, 8.00000000e+00])) """ + _check_input_tensor(x) _ = _check_axis_type(axis, True, False, False) axis = _canonicalize_axis(axis, x.ndim) res = None @@ -1827,6 +1830,8 @@ def take(a, indices, axis=None, mode='raise'): [5 7]] """ _check_input_tensor(a, indices) + if mode not in ('raise', 'wrap', 'clip'): + _raise_value_error('raise should be one of "raise", "wrap", or "clip"') if axis is None: a = ravel(a) axis = 0 diff --git a/mindspore/numpy/math_ops.py b/mindspore/numpy/math_ops.py index 4bf0d331e9..46f54389fd 100644 --- a/mindspore/numpy/math_ops.py +++ b/mindspore/numpy/math_ops.py @@ -1149,11 +1149,13 @@ def ptp(x, axis=None, out=None, keepdims=False): [2. 0. 5. 2.] """ _check_input_tensor(x) + if not isinstance(keepdims, bool): + _raise_type_error('keepdims should be boolean') if axis is None: axis = () else: _check_axis_type(axis, True, True, False) - axis = _canonicalize_axis(axis, x.ndim) + axis = _check_axis_valid(axis, x.ndim) if keepdims: x_min = _reduce_min_keepdims(x, axis) diff --git a/mindspore/numpy/utils_const.py b/mindspore/numpy/utils_const.py index 35e96ebcc6..872aa6399b 100644 --- a/mindspore/numpy/utils_const.py +++ b/mindspore/numpy/utils_const.py @@ -165,19 +165,18 @@ def _check_axis_valid(axes, ndim): Checks axes are valid given ndim, and returns axes that can be passed to the built-in operator (non-negative, int or tuple) """ - if isinstance(axes, int): - _check_axis_in_range(axes, ndim) - return (axes % ndim,) + if axes is None: + axes = F.make_range(ndim) + return axes if isinstance(axes, (tuple, list)): for axis in axes: _check_axis_in_range(axis, ndim) axes = tuple(map(lambda x: x % ndim, axes)) - if all(axes.count(el) <= 1 for el in axes): - return axes - if axes is None: - axes = F.make_range(ndim) + if any(axes.count(el) > 1 for el in axes): + raise ValueError('duplicate value in "axis"') return axes - raise ValueError('duplicate value in "axis"') + _check_axis_in_range(axes, ndim) + return (axes % ndim,) @constexpr