Browse Source

!10097 Fix bugs for np.mean, np.asarray, np.concatenate, np.linspace

From: @yanglf1121
Reviewed-by: 
Signed-off-by:
tags/v1.1.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
893d59afd7
4 changed files with 118 additions and 23 deletions
  1. +34
    -6
      mindspore/numpy/array_ops.py
  2. +15
    -1
      mindspore/numpy/math_ops.py
  3. +39
    -0
      mindspore/numpy/utils.py
  4. +30
    -16
      tests/ut/python/numpy_native/test_array_ops.py

+ 34
- 6
mindspore/numpy/array_ops.py View File

@@ -25,11 +25,14 @@ from ..ops.primitive import constexpr


from .utils import _check_shape, _check_shape_compile, _check_dtype, _check_is_int, \ from .utils import _check_shape, _check_shape_compile, _check_dtype, _check_is_int, \
_check_axes_range, _check_start_normalize, _check_shape_contain_zero, _check_is_tensor, \ _check_axes_range, _check_start_normalize, _check_shape_contain_zero, _check_is_tensor, \
_check_input_for_asarray
_check_input_for_asarray, _deep_list, _deep_tensor_to_nparray, _check_is_list, \
_covert_list_tensor_to_tuple_tensor


DEFAULT_FLOAT_DTYPE = mstype.float32 DEFAULT_FLOAT_DTYPE = mstype.float32
DEFAULT_INT_DTYPE = mstype.int32 DEFAULT_INT_DTYPE = mstype.int32

# According to official numpy reference, the dimension of a numpy array must be less
# than 32
MAX_NUMPY_DIMS = 32


def array(obj, dtype=None, copy=True, ndmin=0): def array(obj, dtype=None, copy=True, ndmin=0):
""" """
@@ -115,6 +118,10 @@ def asarray(a, dtype=None):
dtype = mstype.bool_ dtype = mstype.bool_


if isinstance(a, (list, tuple)): if isinstance(a, (list, tuple)):
# Convert all tuple/nested tuples to lists
a = _deep_list(a)
# Convert all tensor sub-elements to numpy arrays
a = _deep_tensor_to_nparray(a)
a = onp.asarray(a) a = onp.asarray(a)
# If dtype is not specified, we keep consistent with numpy decision # If dtype is not specified, we keep consistent with numpy decision
# only exceptions are: we use int/float32 # only exceptions are: we use int/float32
@@ -175,6 +182,10 @@ def asfarray(a, dtype=DEFAULT_FLOAT_DTYPE):
dtype = DEFAULT_FLOAT_DTYPE dtype = DEFAULT_FLOAT_DTYPE


if isinstance(a, (list, tuple)): if isinstance(a, (list, tuple)):
# Convert all tuple/nested tuples to lists
a = _deep_list(a)
# Convert all tensor sub-elements to numpy arrays
a = _deep_tensor_to_nparray(a)
a = onp.asarray(a) a = onp.asarray(a)


if isinstance(a, onp.ndarray): if isinstance(a, onp.ndarray):
@@ -317,8 +328,10 @@ def arange(*args, **kwargs):
implementation. implementation.


Args: Args:
start(Union[int, float], optional): Start of interval. The interval includes
this value. Default is 0.
start(Union[int, float]): Start of interval. The interval includes this value.
When stop is provided as a position argument, start must be given, when stop
is a normal argument, start can be optional, and default is 0.
Please see additional examples below.
stop(Union[int, float], optional): End of interval. The interval does not stop(Union[int, float], optional): End of interval. The interval does not
include this value, except in some cases where step is not an integer include this value, except in some cases where step is not an integer
and floating point round-off affects the length of out. and floating point round-off affects the length of out.
@@ -340,6 +353,13 @@ def arange(*args, **kwargs):
>>> import mindspore.numpy as np >>> import mindspore.numpy as np
>>> print(np.arange(0, 5, 1)) >>> print(np.arange(0, 5, 1))
[0 1 2 3 4] [0 1 2 3 4]
>>> print(np.arange(3))
[0 1 2]
>>> print(np.arange(start=0, stop=3))
[0 1 2]
>>> print(np.arange(0, stop=3, step=0.5))
[0. 0.5 1. 1.5 2. 2.5]
>>> print(np.arange(stop=3)) # This will lead to TypeError
""" """
# infer the dtype, if either of start, end, step is float, default dtype is # infer the dtype, if either of start, end, step is float, default dtype is
# float32, else int32. # float32, else int32.
@@ -419,6 +439,9 @@ def linspace(start, stop, num=50, endpoint=True, retstep=False, dtype=None, axis
if isinstance(stop, Tensor): if isinstance(stop, Tensor):
stop = stop.asnumpy() stop = stop.asnumpy()


if not isinstance(num, int):
raise TypeError(f"num should be an integer, but got {type(num)}")

final_dtype = None final_dtype = None
if dtype is not None: if dtype is not None:
final_dtype = _check_dtype(dtype) final_dtype = _check_dtype(dtype)
@@ -990,7 +1013,7 @@ def concatenate(arrays, axis=0):
# if only one tensor is provided, it is treated as a tuple along the # if only one tensor is provided, it is treated as a tuple along the
# first dimension. For example, a tensor of shape (3,4,5) will be treated # first dimension. For example, a tensor of shape (3,4,5) will be treated
# as: tuple(tensor_1(4,5), tensor_2(4,5), tensor_3(4,5)) # as: tuple(tensor_1(4,5), tensor_2(4,5), tensor_3(4,5))
if axis is None:
if axis is None or axis >= MAX_NUMPY_DIMS:
return ravel(arrays) return ravel(arrays)
arr_shape = F.shape(arrays) arr_shape = F.shape(arrays)
_check_axes_range((axis,), len(arr_shape)) _check_axes_range((axis,), len(arr_shape))
@@ -1002,11 +1025,16 @@ def concatenate(arrays, axis=0):
return arrays return arrays


flattened_arrays = () flattened_arrays = ()
if axis is None:
if axis is None or axis >= MAX_NUMPY_DIMS:
for arr in arrays: for arr in arrays:
flattened_arrays += (ravel(arr),) flattened_arrays += (ravel(arr),)
axis = -1 axis = -1
return P.Concat(axis)(flattened_arrays) return P.Concat(axis)(flattened_arrays)

# convert a list of tensor to a tuple of tensor
if _check_is_list(array_type):
arrays = _covert_list_tensor_to_tuple_tensor(arrays)

arr_shape = F.shape(arrays[0]) arr_shape = F.shape(arrays[0])
_check_axes_range((axis,), len(arr_shape)) _check_axes_range((axis,), len(arr_shape))




+ 15
- 1
mindspore/numpy/math_ops.py View File

@@ -15,7 +15,8 @@
"""math operations, the function docs are adapted from Numpy API.""" """math operations, the function docs are adapted from Numpy API."""
from ..ops import operations as P from ..ops import operations as P
from ..ops import functional as F from ..ops import functional as F
from .array_ops import squeeze
from ..ops.primitive import constexpr
from .array_ops import squeeze, asarray
from .utils import _infer_out_shape, _is_scalar, _check_axis_valid, _get_device_compile, \ from .utils import _infer_out_shape, _is_scalar, _check_axis_valid, _get_device_compile, \
_check_shape_aligned _check_shape_aligned


@@ -63,6 +64,8 @@ def mean(a, axis=None, keepdims=False):
2.5 2.5
""" """
axis = _check_axis_valid(axis, P.Rank()(a)) axis = _check_axis_valid(axis, P.Rank()(a))
if _is_empty(F.shape(a)):
return _nan()
if _is_scalar(a.shape): if _is_scalar(a.shape):
if keepdims: if keepdims:
return a return a
@@ -140,6 +143,17 @@ def inner(a, b):
return res return res




@constexpr
def _nan():
"""Returns a Tensor with nan value"""
return asarray(float('nan'))


def _is_empty(shape):
"""Checks if the shape is empty"""
return F.shape_mul(shape) == 0


def _expand(x, ndim): def _expand(x, ndim):
"""Expand x to ndim""" """Expand x to ndim"""
while P.Rank()(x) < ndim: while P.Rank()(x) < ndim:


+ 39
- 0
mindspore/numpy/utils.py View File

@@ -135,6 +135,37 @@ def _check_dtype(dtype):
return dtype return dtype




def _deep_list(array_like):
"""convert nested tuple/list mixtures to pure nested list"""
if isinstance(array_like, (list, tuple)):
return list(map(_deep_list, array_like))
return array_like


def _deep_tensor_to_nparray(array_like):
"""
convert a nested list of tensor to nested list of np_array.

Args:
array_like(list(tensor)): In any format of nested lists that may contain
tensors.

Returns:
array_like(list(np_array)): Formatted array that can be directly processed
by numpy.array(), with all tensor elements converted to numpy_array.
"""
# Recursively check whether each element is a tensor or not, if is tensor,
# convert it to a numpy array in place
if isinstance(array_like, Tensor):
return array_like.asnumpy()

if isinstance(array_like, list):
for idx, value in enumerate(array_like):
array_like[idx] = _deep_tensor_to_nparray(value)

return array_like


def _check_input_for_asarray(array_like): def _check_input_for_asarray(array_like):
"""check whether array_like argument is a valid type for np.asarray conversion""" """check whether array_like argument is a valid type for np.asarray conversion"""
if isinstance(array_like, (Tensor, list, tuple, int, float, bool, onp.ndarray)): if isinstance(array_like, (Tensor, list, tuple, int, float, bool, onp.ndarray)):
@@ -166,6 +197,14 @@ def _get_device():
return context.get_context('device_target') return context.get_context('device_target')




def _covert_list_tensor_to_tuple_tensor(list_of_tensor):
"""Convert a list of tensor to a tuple of tensor"""
tuple_of_tensor = ()
for tensor in list_of_tensor:
tuple_of_tensor += (tensor,)
return tuple_of_tensor


def _get_mode(): def _get_mode():
"""Get the current mode (0 is Graph mode, 1 is PyNative mode)""" """Get the current mode (0 is Graph mode, 1 is PyNative mode)"""
return context.get_context('mode') return context.get_context('mode')


+ 30
- 16
tests/ut/python/numpy_native/test_array_ops.py View File

@@ -85,6 +85,14 @@ def test_asarray():
expected = mnp.asarray(array, test_case.mnp_dtypes[i]).asnumpy() expected = mnp.asarray(array, test_case.mnp_dtypes[i]).asnumpy()
match_array(actual, expected, error=7) match_array(actual, expected, error=7)


# Additional tests for nested tensor/numpy_array mixture
mnp_input = [(onp.ones(3,), mnp.ones(3)), [[1, 1, 1], (1, 1, 1)]]
onp_input = [(onp.ones(3,), onp.ones(3)), [[1, 1, 1], (1, 1, 1)]]

actual = onp.asarray(onp_input)
expected = mnp.asarray(mnp_input).asnumpy()
match_array(actual, expected, error=7)



def test_array(): def test_array():
# array's function is very similar to asarray, so we mainly test the # array's function is very similar to asarray, so we mainly test the
@@ -100,6 +108,14 @@ def test_array():
assert arr1 is not arr3 assert arr1 is not arr3
assert arr4 is arr5 assert arr4 is arr5


# Additional tests for nested tensor/numpy_array mixture
mnp_input = [(onp.ones(3,), mnp.ones(3)), [[1, 1, 1], (1, 1, 1)]]
onp_input = [(onp.ones(3,), onp.ones(3)), [[1, 1, 1], (1, 1, 1)]]

actual = onp.asarray(onp_input)
expected = mnp.asarray(mnp_input).asnumpy()
match_array(actual, expected, error=7)



def test_asfarray(): def test_asfarray():
test_case = Cases() test_case = Cases()
@@ -120,6 +136,14 @@ def test_asfarray():
expected = mnp.asfarray(array, test_case.mnp_dtypes[i]).asnumpy() expected = mnp.asfarray(array, test_case.mnp_dtypes[i]).asnumpy()
match_array(actual, expected, error=7) match_array(actual, expected, error=7)


# Additional tests for nested tensor/numpy_array mixture
mnp_input = [(onp.ones(3,), mnp.ones(3)), [[1, 1, 1], (1, 1, 1)]]
onp_input = [(onp.ones(3,), onp.ones(3)), [[1, 1, 1], (1, 1, 1)]]

actual = onp.asarray(onp_input)
expected = mnp.asarray(mnp_input).asnumpy()
match_array(actual, expected, error=7)



def test_zeros(): def test_zeros():
test_case = Cases() test_case = Cases()
@@ -547,20 +571,9 @@ def test_exec():
return test_exec_case return test_exec_case




raise_set = [
('Expand_dims_Error', {
'block': (lambda x: mnp.expand_dims, {'exception': ValueError}),
'desc_inputs': [mnp.ones((2, 3, 4)), 0]}),
]


def expand_dims_exception(input_tensor):
return mnp.expand_dims(input_tensor, 1.2)


def test_expand_dims_exception(): def test_expand_dims_exception():
with pytest.raises(TypeError): with pytest.raises(TypeError):
expand_dims_exception(mnp.ones((3, 3)))
mnp.expand_dims(mnp.ones((3, 3)), 1.2)




def test_asarray_exception(): def test_asarray_exception():
@@ -568,10 +581,11 @@ def test_asarray_exception():
mnp.asarray({1, 2, 3}) mnp.asarray({1, 2, 3})




def swapaxes_exception(input_tensor):
return mnp.swapaxes(input_tensor, 1, 10)
def test_swapaxes_exception():
with pytest.raises(TypeError):
mnp.swapaxes(mnp.ones((3, 3)), 1, 10)




def test_swapaxes_exception():
def test_linspace_exception():
with pytest.raises(TypeError): with pytest.raises(TypeError):
swapaxes_exception(mnp.ones((3, 3)))
mnp.linspace(0, 1, num=2.5)

Loading…
Cancel
Save