From 59e5afd2bb210a3fa1c170a16dd18ca7407b3b15 Mon Sep 17 00:00:00 2001 From: yanglf1121 Date: Mon, 23 Nov 2020 11:20:14 +0800 Subject: [PATCH] Add mindnumpy to mindspore --- mindspore/_extends/parse/standard_method.py | 27 +- mindspore/numpy/__init__.py | 44 + mindspore/numpy/array_ops.py | 1022 +++++++++++++++++ mindspore/numpy/dtypes.py | 96 ++ mindspore/numpy/math_ops.py | 160 +++ mindspore/numpy/utils.py | 316 +++++ tests/ut/python/numpy_native/__init__.py | 21 + .../ut/python/numpy_native/test_array_ops.py | 577 ++++++++++ tests/ut/python/numpy_native/test_math_ops.py | 88 ++ 9 files changed, 2346 insertions(+), 5 deletions(-) create mode 100644 mindspore/numpy/__init__.py create mode 100644 mindspore/numpy/array_ops.py create mode 100644 mindspore/numpy/dtypes.py create mode 100644 mindspore/numpy/math_ops.py create mode 100644 mindspore/numpy/utils.py create mode 100644 tests/ut/python/numpy_native/__init__.py create mode 100644 tests/ut/python/numpy_native/test_array_ops.py create mode 100644 tests/ut/python/numpy_native/test_math_ops.py diff --git a/mindspore/_extends/parse/standard_method.py b/mindspore/_extends/parse/standard_method.py index 7f80f65305..3ca2de17c0 100644 --- a/mindspore/_extends/parse/standard_method.py +++ b/mindspore/_extends/parse/standard_method.py @@ -68,13 +68,22 @@ def any_(x, axis=(), keep_dims=False): return reduce_any(x, axis) -def transpose(x): +def transpose(x, *axis): """Implementation of `transpose`.""" + new_order = None shape = F.shape(x) length = F.tuple_len(shape) - perm = F.make_range(0, length) - revert_perm = F.tuple_reversed(perm) - out = trans(x, revert_perm) + if not axis: + perm = F.make_range(0, length) + new_order = F.tuple_reversed(perm) + + elif len(axis) == 1: + new_order = convert_list_to_tuple(axis[0]) + + elif len(axis) == length: + new_order = axis + + out = trans(x, new_order) return out @@ -194,7 +203,7 @@ def check_type_same(x_type, base_type): @constexpr def check_is_tensor(x): - """check whether x is list or tuple.""" + """check whether x is tensor.""" if isinstance(x, mstype.tensor_type): return True return False @@ -250,6 +259,14 @@ def check_view_shape(x): x = x[0] return x +@constexpr +def convert_list_to_tuple(shp): + """Check the type of the shape, if is list, convert to tuple""" + if not isinstance(shp, (list, tuple)): + raise ValueError(f"The shape variable should be a list or tuple, but got {type(shp)}") + if isinstance(shp, list): + shp = tuple(shp) + return shp def tensor_bool(x): """tensor as conditon, if is constant, return immediate bool value""" diff --git a/mindspore/numpy/__init__.py b/mindspore/numpy/__init__.py new file mode 100644 index 0000000000..04a2414199 --- /dev/null +++ b/mindspore/numpy/__init__.py @@ -0,0 +1,44 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +Numpy-like interfaces in mindspore. + +Examples: + >>> import mindspore.numpy as np + +Note: + - array_ops.py define all the array generation and operation interfaces. + - math_ops.py define all the math operations on tensors. + - dtypes.py define all the mindspore.numpy dtypes (mainly redirected from mindspore) + - random/ defines all the random operations. +""" + +from .array_ops import (array, asarray, asfarray, ones, zeros, full, arange, + linspace, logspace, eye, identity, transpose, expand_dims, + squeeze, rollaxis, swapaxes, reshape, ravel, concatenate) +from .array_ops import copy_ as copy +from .dtypes import (int_, int8, int16, int32, int64, uint, uint8, uint16, + uint32, uint64, float_, float16, float32, float64, bool_, inf, + numeric_types) +from .math_ops import mean, inner + + +array_ops_module = ['array', 'asarray', 'asfarray', 'copy', 'ones', 'zeros', 'arange', + 'linspace', 'logspace', 'eye', 'identity', 'transpose', 'expand_dims', + 'squeeze', 'rollaxis', 'swapaxes', 'reshape', 'ravel', 'concatenate'] + +math_module = ['mean', 'inner'] + +__all__ = array_ops_module + math_module + numeric_types diff --git a/mindspore/numpy/array_ops.py b/mindspore/numpy/array_ops.py new file mode 100644 index 0000000000..e2640be470 --- /dev/null +++ b/mindspore/numpy/array_ops.py @@ -0,0 +1,1022 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""array operations, the function docs are adapted from Numpy API.""" +from copy import copy as py_copy + +import numpy as onp + +import mindspore +from mindspore import Tensor +from mindspore.ops import operations as P +from mindspore.ops import functional as F +from mindspore.ops.primitive import constexpr + +from .utils import _check_shape, _check_shape_compile, _check_dtype, _check_is_int, \ + _check_axes_range, _check_start_normalize, _check_shape_contain_zero, _check_is_tensor, \ + _check_input_for_asarray + +DEFAULT_FLOAT_DTYPE = mindspore.float32 +DEFAULT_INT_DTYPE = mindspore.int32 + + +def array(obj, dtype=None, copy=True, ndmin=0): + """ + Create a tensor. + + This function creat tensors from an array-like object. + + Args: + obj (Union[int, float, bool, list, tuple, numpy.ndarray]): Input data, in + any form that can be converted to an array. This includes lists, lists of + tuples, tuples, tuples of tuples, tuples of lists and ndarrays. + dtype (Union[mindspore.dtype, str], optional): Designated array dtype, can + be in format of np.int32, or `int32`. If dtype is None, the data type + of the new tensor will be inferred from obj. Default is None. + copy (bool): If true, then the object is copied. Otherwise, a copy will + only be made if necessary. Default: True. + ndmin (int): Specifies the minimum number of dimensions that the resulting + array should have. Ones will be pre-pended to the shape as needed to + meet this requirement. Default: 0 + + Returns: + Tensor, generated tensor with the specified dtype. + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Examples: + >>> import mindspore.numpy as np + >>> print(np.array([1,2,3])) + [1 2 3] + """ + if ndmin > 0: + # Fall back to original numpy creation. + if isinstance(obj, Tensor): + obj = obj.asnumpy() + return asarray(onp.array(obj, dtype, copy=copy, ndmin=ndmin)) + + if not copy: + return asarray(obj, dtype=dtype) + + obj = py_copy(obj) + return asarray(obj, dtype=dtype) + + +def asarray(a, dtype=None): + """ + Convert the input to tensor. + + This function convert tensors from an array-like object. + + Args: + a (Union[int, float, bool, list, tuple, numpy.ndarray]): Input data, in + any form that can be converted to an array. This includes lists, lists of + tuples, tuples, tuples of tuples, tuples of lists and ndarrays. + dtype (Union[mindspore.dtype, str], optional): Designated array dtype, can + be in format of np.int32, or `int32`. If dtype is None, the data type + of the new tensor will be inferred from a. Default is None. + + Returns: + Tensor, generated tensor with the specified dtype. + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Examples: + >>> import mindspore.numpy as np + >>> print(np.asarray([1,2,3])) + [1 2 3] + """ + + if dtype is not None: + dtype = _check_dtype(dtype) + + _ = _check_input_for_asarray(a) + + if isinstance(a, float) and (dtype is None): + dtype = DEFAULT_FLOAT_DTYPE + + if isinstance(a, int) and not isinstance(a, bool) and (dtype is None): + dtype = DEFAULT_INT_DTYPE + + if isinstance(a, bool) and (dtype is None): + dtype = mindspore.bool_ + + if isinstance(a, (list, tuple)): + a = onp.asarray(a) + # If dtype is not specified, we keep consistent with numpy decision + # only exceptions are: we use int/float32 + if dtype is None: + if a.dtype is onp.dtype('int64'): + dtype = DEFAULT_INT_DTYPE + elif a.dtype is onp.dtype('float64'): + dtype = DEFAULT_FLOAT_DTYPE + + if isinstance(a, onp.ndarray) and dtype is None: + if a.dtype is onp.dtype('bool'): + dtype = mindspore.bool_ + elif a.dtype is onp.dtype('int'): + dtype = DEFAULT_INT_DTYPE + elif a.dtype is onp.dtype('float'): + dtype = DEFAULT_FLOAT_DTYPE + a = Tensor.from_numpy(a) + + # If a is already an tensor and we don't need to cast dtype, return a + if isinstance(a, Tensor): + if dtype is None: + return a + dtype = _check_dtype(dtype) + if dtype == a.dtype: + return a + + return Tensor(a, dtype=dtype) + + +def asfarray(a, dtype=DEFAULT_FLOAT_DTYPE): + """ + Similar to asarray, convert the input to an float array. + + If non-float dtype is defined, this function will return a float32 Tensor instead. + + Args: + a (Union[int, float, bool, list, tuple, numpy.ndarray]): Input data, in + any form that can be converted to an array. This includes lists, lists of + tuples, tuples, tuples of tuples, tuples of lists and ndarrays. + dtype (Union[mindspore.dtype, str], optional): Designated array dtype, can + be in format of np.float32, or `float32`. Default is mindspore.float32. + + Returns: + Tensor, generated tensor with the specified float dtype. + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Examples: + >>> import mindspore.numpy as np + >>> print(np.asfarray([1,2,3])) + [1. 2. 3.] + """ + dtype = _check_dtype(dtype) + _ = _check_input_for_asarray(a) + + if dtype not in (mindspore.float16, mindspore.float32, mindspore.float64): + dtype = DEFAULT_FLOAT_DTYPE + + if isinstance(a, (list, tuple)): + a = onp.asarray(a) + + if isinstance(a, onp.ndarray): + a = Tensor.from_numpy(a) + + return Tensor(a, dtype) + + +def copy_(a): + """ + Return an tensor copy of the given object. + + Args: + a (Tensor): Input tensor. + + Returns: + Tensor, has the same data as `a`. + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Examples: + >>> import mindspore.numpy as np + >>> print(np.copy([1,2,3])) + [1. 2. 3.] + """ + return py_copy(a) + + +def ones(shape, dtype=DEFAULT_FLOAT_DTYPE): + """ + Return a new array of given shape and type, filled with ones. + + Args: + shape (Union[int, tuple, list]): the shape of the new array. + dtype (Union[mindspore.dtype, str], optional): Designated array dtype, can + be in format of np.float32, or `float32`. Default is mindspore.float32. + + Returns: + Tensor, with the designated shape and dtype, filled with ones. + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Examples: + >>> import mindspore.numpy as np + >>> print(np.ones((2,2))) + [[1. 1.] + [1. 1.]] + """ + if _check_shape_contain_zero(shape): + return asarray(onp.ones(shape), dtype=dtype) + shape = _check_shape(shape) + dtype = _check_dtype(dtype) + fill = P.Fill() + output = fill(dtype, shape, 1) + return Tensor(output, dtype=dtype) + + +def zeros(shape, dtype=DEFAULT_FLOAT_DTYPE): + """ + Return a new array of given shape and type, filled with zeros. + + Args: + shape (Union[int, tuple, list]): the shape of the new array. + dtype (Union[mindspore.dtype, str], optional): Designated array dtype, can + be in format of np.float32, or `float32`. Default is mindspore.float32. + + Returns: + Tensor, with the designated shape and dtype, filled with zeros. + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Examples: + >>> import mindspore.numpy as np + >>> print(np.zeros((2,2))) + [[0. 0.] + [0. 0.]] + """ + if _check_shape_contain_zero(shape): + return asarray(onp.zeros(shape), dtype=dtype) + shape = _check_shape(shape) + dtype = _check_dtype(dtype) + fill = P.Fill() + output = fill(dtype, shape, 0) + return Tensor(output, dtype=dtype) + + +def full(shape, fill_value, dtype=None): + """ + Return a new array of given shape and type, filled with fill_value. + + Args: + shape (Union[int, tuple(int), list(int)]): Shape of the new array, e.g., + (2, 3) or 2. + fill_value (Union[int, float, bool, list, tuple]): scalar or array_like + fill value. + dtype (Union[mindspore.dtype, str], optional): Designated array dtype, can + be in format of np.float32, or `float32`, if dtype is None, the data type + of the new tensor will be inferred from fill_value. Default is None. + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Returns: + Tensor, with the designated shape and dtype, filled with `fill_value`. + + Examples: + >>> import mindspore.numpy as np + >>> print(np.full((2,2), True)) + [[True True] + [True True]] + """ + if dtype is None: + dtype = array(fill_value).dtype + + shape = _check_shape(shape) + _ = _check_input_for_asarray(fill_value) + dtype = _check_dtype(dtype) + + if isinstance(fill_value, (int, float, bool)) and not _check_shape_contain_zero(shape): + return P.Fill()(dtype, shape, fill_value) + + # if fill_value is array_like or shape contains zero. fall back to original + # numpy creation + return Tensor(onp.full(shape, fill_value, mindspore.dtype_to_nptype(dtype))) + + +def arange(*args, **kwargs): + """ + Return evenly spaced values within a given interval. + + Returns `num` evenly spaced samples, calculated over the interval [`start`, `stop`]. + The endpoint of the interval can optionally be excluded. + The current implementation is a direct wrapper on top of numpy.arange, except + the default dtype is float32 and int32, compare to float64 and int64 for numpy + implementation. + + Args: + start(Union[int, float], optional): Start of interval. The interval includes + this value. Default is 0. + stop(Union[int, float], optional): End of interval. The interval does not + include this value, except in some cases where step is not an integer + and floating point round-off affects the length of out. + step(Union[int, float], optional): Spacing between values. For any output + out, this is the distance between two adjacent values, out[i+1] - out[i]. + The default step size is 1. If step is specified as a position argument, + start must also be given. + dtype (Union[mindspore.dtype, str], optional): Designated array dtype, can + be in format of np.float32, or `float32`. If dtype is None, the data type + of the new tensor will be inferred from start, stop and step. Default is None. + + Returns: + arangend Tensor, array of evenly spaced values. + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Examples: + >>> import mindspore.numpy as np + >>> print(np.arange(0, 5, 1)) + [0 1 2 3 4] + """ + # infer the dtype, if either of start, end, step is float, default dtype is + # float32, else int32. + int_flag = True + final_dtype = None + + if args: + for item in args: + if isinstance(item, float): + int_flag = False + if kwargs: + if ('start' in kwargs and isinstance(kwargs['start'], float)) or \ + ('stop' in kwargs and isinstance(kwargs['stop'], float)) or \ + ('step' in kwargs and isinstance(kwargs['step'], float)): + int_flag = False + + if int_flag: + final_dtype = onp.int32 + else: + final_dtype = onp.float32 + + if 'dtype' in kwargs and kwargs['dtype'] is not None: + final_dtype = _check_dtype(kwargs['dtype']) + final_dtype = mindspore.dtype_to_nptype(final_dtype) + kwargs['dtype'] = final_dtype + out = onp.arange(*args, **kwargs) + out = Tensor.from_numpy(out) + return Tensor(out) + + +def linspace(start, stop, num=50, endpoint=True, retstep=False, dtype=None, axis=0): + """ + Return evenly spaced values within a given interval. + + The current implementation is a direct wrapper on top of numpy.linspace, except + the default dtype is float32, compare to float64 for numpy, + + Args: + start (Union[int, list(int), tuple(int),tensor]):The starting value of the sequence. + stop (Union[int, list(int), tuple(int),tensor]):The end value of the sequence, + unless `endpoint` is set to False. In that case, the sequence consists + of all but the last of ``num + 1` evenly spaced samples, so that `stop` + is excluded. Note that the step size changes when `endpoint` is False. + num (int, optional): Number of samples to generate. Default is 50. + endpoint (bool, optional): If True, `stop` is the last sample. Otherwise, it is + not included. Default is True. + retstep (bool, optional): If True, return (`samples`, `step`), where `step` is + the spacing between samples. + dtype (Union[mindspore.dtype, str], optional): Designated array dtype, can + be in format of np.float32, or `float32`.If `dtype` is None, infer the data + type from other input arguments. Default is None. + axis (int, optional): The axis in the result to store the samples. Relevant + only if start or stop are array-like. By default (0), the samples will + be along a new axis inserted at the beginning. Use -1 to get an axis at the end. + Default is 0. + + Returns: + samples (Tensor): There are `num` equally spaced samples in the closed interval + ``[start, stop]`` or the half-open interval ``[start, stop)`` + (depending on whether `endpoint` is True or False). + + step (float, optional): Only returned if `retstep` is True. + Size of spacing between samples. + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Examples: + >>> import mindspore.numpy as np + >>> print(np.linspace(0, 5, 6)) + [0. 1. 2. 3. 4. 5.] + """ + + if isinstance(start, Tensor): + start = start.asnumpy() + + if isinstance(stop, Tensor): + stop = stop.asnumpy() + + final_dtype = None + if dtype is not None: + final_dtype = _check_dtype(dtype) + final_dtype = mindspore.dtype_to_nptype(final_dtype) + else: + final_dtype = onp.float32 + + dtype = final_dtype + out = onp.linspace(start, stop, num, endpoint, retstep, dtype, axis) + + if retstep: + array_out, step_out = out[0], out[1] + tensor_out = Tensor.from_numpy(array_out) + return tensor_out, step_out + + tensor_out = Tensor.from_numpy(out) + return tensor_out + + +def logspace(start, stop, num=50, endpoint=True, base=10.0, dtype=None, axis=0): + """ + Return numbers spaced evenly on a log scale. + + In linear space, the sequence starts at base ** start (base to the power of + start) and ends with base ** stop (see endpoint below). + The current implementation is a direct wrapper on top of numpy.logspace, except + the default dtype is float32, compare to float64 for numpy, + + Args: + start (Union[int, list(int), tuple(int), tensor]):The starting value of the sequence. + stop (Union[int, list(int), tuple(int), tensor]):The end value of the sequence, + unless `endpoint` is set to False. In that case, the sequence consists + of all but the last of ``num + 1` evenly spaced samples, so that `stop` + is excluded. Note that the step size changes when `endpoint` is False. + num (int, optional): Number of samples to generate. Default is 50. + endpoint (bool, optional): If True, `stop` is the last sample. Otherwise, it is + not included. Default is True. + base (Union[int, float], optional): The base of the log space. The step size + between the elements in ln(samples) / ln(base) (or log_base(samples)) + is uniform. Default is 10.0. + dtype (Union[mindspore.dtype, str], optional): Designated array dtype, can + be in format of np.float32, or `float32`.If `dtype` is None, infer the data + type from other input arguments. Default is None. + axis (int, optional): The axis in the result to store the samples. Relevant + only if start or stop are array-like. By default (0), the samples will + be along a new axis inserted at the beginning. Use -1 to get an axis at the end. + Default is 0. + + Returns: + samples (Tensor): num samples, equally spaced on a log scale. + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Examples: + >>> import mindspore.numpy as np + >>> print(np.logspace(0, 5, 6, base=2.0)) + [ 1. 2. 4. 8. 16. 32.] + """ + + if isinstance(start, Tensor): + start = start.asnumpy() + + if isinstance(stop, Tensor): + stop = stop.asnumpy() + + final_dtype = None + if dtype is not None: + final_dtype = _check_dtype(dtype) + final_dtype = mindspore.dtype_to_nptype(final_dtype) + else: + final_dtype = onp.float32 + + dtype = final_dtype + out = onp.logspace(start, stop, num, endpoint, base, dtype, axis) + + tensor_out = Tensor.from_numpy(out) + return tensor_out + + +def eye(N, M=None, k=0, dtype=DEFAULT_FLOAT_DTYPE): + """ + Return a 2-D array with ones on the diagnoal and zeros elsewhere. + + Args: + N (int): Number of rows in the output, must be larger than 0. + M (int, optional): Number of columns in the output. If None, defaults to N, + if defined, must be larger than 0. Deault is None. + k (int, optional): Index of the diagonal: 0 (the default) refers to the main + diagonal, a positive value refers to an upper diagonal, and a negative value + to a lower diagonal. Default is 0. + dtype (Union[mindspore.dtype, str], optional): Designated array dtype, can + be in format of np.float32, or `float32`. Default is mindspore.float32. + + Returns: + result (Tensor): A tensor array of shape (N,M). An array where all elements + are equal to zero, except for the k-th diagonal, whose values are equal to one. + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Examples: + >>> import mindspore.numpy as np + >>> print(np.eye(2, 2)) + [[1. 0.] + [0. 1.]] + """ + dtype = _check_dtype(dtype) + make_eye = P.Eye() + if M is None: + M = N + M = int(M) + N = int(N) + k = int(k) + out = None + if k != 0 or N == 0 or M == 0: + # Fall back to original numpy creation method + out = onp.eye(N, M, k) + else: + out = make_eye(N, M, dtype) + return asarray(out, dtype=dtype) + + +def identity(n, dtype=DEFAULT_FLOAT_DTYPE): + """ + Return the identity array. + + Args: + n (int): Number of rows and columns in the output, must be larger than 0. + dtype (Union[mindspore.dtype, str], optional): Designated array dtype, can + be in format of np.float32, or `float32`. Default is mindspore.float32. + + Returns: + result (Tensor): A tensor array of shape (n,n). An array where all elements + are equal to zero, except for the diagonal, whose values are equal to one. + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Examples: + >>> import mindspore.numpy as np + >>> print(np.identity(2)) + [[1. 0.] + [0. 1.]] + """ + dtype = _check_dtype(dtype) + return eye(n, dtype=dtype) + + +@constexpr +def _prepare_shape_for_expand_dims(shape, axes): + """ + Creat the expanded new shape based on the shape and given axes + + Args: + shape (tuple): the shape of the tensor + axes Union(int, tuple(int), list(int)): the axes with dimensions expanded. + + Returns: + new_shape(tuple): the shape with dimensions expanded. + """ + + new_shape = [] + shape_idx = 0 + new_shape_length = len(shape) + + # Convert to set + if isinstance(axes, int): + new_shape_length += 1 + if axes >= new_shape_length or axes < -new_shape_length: + raise ValueError( + f"axis {axes} is out of bounds for array of dimension {new_shape_length}") + axes = {axes} + + elif isinstance(axes, (list, tuple)): + new_shape_length += len(axes) + for axis in axes: + if axis >= new_shape_length or axis < -new_shape_length: + raise ValueError( + f"axis {axis} is out of bounds for array of dimension {new_shape_length}") + axes = set(axes) + + else: + raise TypeError( + f"only int, tuple and list are allowed for axes, but got {type(axes)}") + + for new_shape_idx in range(new_shape_length): + if new_shape_idx in axes or new_shape_idx - new_shape_length in axes: + new_shape.append(1) + else: + new_shape.append(shape[shape_idx]) + shape_idx += 1 + return tuple(new_shape) + + +def expand_dims(a, axis): + """ + Expand the shape of an array. + + Insert a new axis that will appear at the axis position in the expanded array shape. + + Args: + a (Tensor): Input tensor array. + axis Union[int, list(int), tuple(int)]: Position in the expanded axes where + the new axis is placed, + + Returns: + Tensor, view of a tensor with the number of dimensions increased. + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Examples: + >>> import mindspore.numpy as np + >>> x = np.ones((2,2)) + >>> x = np.expand_dims(x,0) + >>> print(x,shape) + (2,2,1) + """ + shape = F.shape(a) + # yield expanded shape based on the axes + new_shape = _prepare_shape_for_expand_dims(shape, axis) + return P.Reshape()(a, new_shape) + + +@constexpr +def _prepare_shape_for_squeeze(shape, axes): + """ + Creat the squeezed new shape based on the tensor and given axes. + + Args: + shape (tuple): the shape of the tensor + axes Union(None, int, tuple(int), list(int)): the axes with dimensions squeezed. + + Returns: + new_shape(tuple): the shape with dimensions squeezed. + """ + new_shape = [] + ndim = len(shape) + + # Convert to set + if isinstance(axes, int): + if axes >= ndim or axes < -ndim: + raise ValueError( + f"axis {axes} is out of bounds for array of dimension {ndim}") + axes = {axes} + + elif isinstance(axes, (list, tuple)): + for axis in axes: + if axis >= ndim or axis < -ndim: + raise ValueError( + f"axis {axis} is out of bounds for array of dimension {ndim}") + axes = set(axes) + + elif axes is not None: + raise TypeError( + f"only int, tuple and list are allowed for axes, but got {type(axes)}") + + if axes is None: + new_shape = [s for s in shape if s != 1] + else: + for idx, s in enumerate(shape): + if s != 1 or (idx not in axes) and (idx - ndim not in axes): + new_shape.append(s) + # if an axis is selected with shape entry greater than one, an error is raised. + if s != 1 and ((idx in axes) or (idx - ndim in axes)): + raise ValueError( + f"axis {axes} has shape entry {s} > 1, cannot be squeezed.") + return tuple(new_shape) + + +def squeeze(a, axis=None): + """ + Remove single-dimensional entries from the shape of an array. + + This is a temporary solution to support CPU backend. Will be changed + once CPU backend supports P.Squeeze(). + + Args: + a (Tensor): Input tensor array. + axis: Union[None, int, list(int), tuple(list)]. Default is None. + + Returns: + Tensor, with all or a subset of the dimensions of length 1 removed. + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Examples: + >>> import mindspore.numpy as np + >>> x = np.ones((1,2,2,1)) + >>> x = np.squeeze(x) + >>> print(x,shape) + (2,2) + """ + shape = F.shape(a) + # yield squeezed shape based on the axes + new_shape = _prepare_shape_for_squeeze(shape, axis) + return P.Reshape()(a, new_shape) + + +def transpose(a, axes=None): + """ + Reverse or permute the axes of an array; returns the modified array. + + Args: + a (Tensor): a tensor to be transposed + axes Union[None, tuple, list]: the axes order, if axes is None, transpose + the entire tensor. Default is None. + + Returns: + Tensor, the transposed tensor array. + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Examples: + >>> import mindspore.numpy as np + >>> x = np.ones((1,2,3)) + >>> x = np.transpose(x) + >>> print(x,shape) + (3,2,1) + """ + if axes is None: + shape = F.shape(a) + length = F.tuple_len(shape) + perm = F.make_range(0, length) + new_order = F.tuple_reversed(perm) + return P.Transpose()(a, new_order) + + axes = _check_shape_compile(axes) + return P.Transpose()(a, axes) + + +def rollaxis(x, axis, start=0): + """ + Roll the specified axis backwards, until it lies in the given position. + The positions of the other axes do not change relative to one another. + + Args: + x (Tensor): A Tensor to be transposed. + axis (int): The axis to be rolled. + start (int): + - When start >= 0: + - When start <= axis: the axis is rolled back until it lies in this position (start). + - When start > axis: the axis is rolled until it lies before this position (start). + - When start < 0: the start will be normalized as follows: + start ........... Normalized start + -(x.ndim+1) raise ValueError + -x.ndim 0 + ... ... + -1 x.ndim-1 + 0 0 + ... ... + x.ndim x.ndim + x.ndim+1 raise ValueError + + Returns: + Transposed Tensor. Has the same data type as the original tensor x. + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Raises: + TypeError: If axis or start is not integer. + ValueError: If axis is not in the range from -ndim to ndim-1 or + start is not in the range from -ndim to ndim. + + Examples: + >>> import mindspore + >>> import mindspore.numpy as mnp + >>> from mindspore import Tensor + >>> import numpy as onp + >>> input_x = Tensor(onp.ones((2,3,4)), mindspore.float32) + >>> output = mnp.rollaxis(x, 0, 2) + >>> print(output.shape) + (3,2,4) + """ + _check_is_int(axis) + _check_is_int(start) + + shape = F.shape(x) + ndim = F.tuple_len(shape) + + axis = _check_axes_range(axis, ndim) + start = _check_start_normalize(start, ndim) + if start - axis >= 0 and start - axis <= 1: + return x + perm = F.make_range(0, ndim) + new_perm = None + if start < axis: + if axis + 1 < ndim: + new_perm = perm[0:start] + perm[axis:axis+1] + \ + perm[start:axis] + perm[axis+1:] + else: + new_perm = perm[0:start] + perm[axis:axis+1] + perm[start:axis] + if start > axis: + if start < ndim: + new_perm = perm[0:axis] + perm[axis+1:start] + \ + perm[axis:axis+1] + perm[start:] + else: + new_perm = perm[0:axis] + perm[axis+1:start] + \ + perm[axis:axis+1] + + return P.Transpose()(x, new_perm) + + +def swapaxes(x, axis1, axis2): + """ + Interchange two axes of a tensor. + + Args: + x (Tensor): A Tensor to be transposed. + axis1 (int): First axis. + axis2 (int): Second axis. + + Returns: + Transposed Tensor. Has the same data type as the original tensor x. + + Raises: + TypeError: If axis1 or axis2 is not integer. + ValueError: If axis1 or axis2 is not in the range from -ndim to ndim-1. + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Examples: + >>> import mindspore + >>> import mindspore.numpy as mnp + >>> from mindspore import Tensor + >>> import numpy as onp + >>> input_x = Tensor(onp.ones((2,3,4)), mindspore.float32) + >>> output = mnp.swapaxes(x, 0, 2) + >>> print(output.shape) + (4,3,2) + """ + _check_is_int(axis1) + _check_is_int(axis2) + + shape = F.shape(x) + ndim = F.tuple_len(shape) + + axes = _check_axes_range((axis1, axis2), ndim) + axis1, axis2 = axes[0], axes[1] + + if axis1 == axis2: + return x + if axis1 > axis2: + axis1, axis2 = axis2, axis1 + + perm = F.make_range(0, ndim) + new_perm = None + if axis2 + 1 < ndim: + new_perm = perm[0:axis1] + perm[axis2:axis2+1] + \ + perm[axis1+1:axis2] + perm[axis1:axis1+1] + perm[axis2+1:] + else: + new_perm = perm[0:axis1] + perm[axis2:axis2+1] + \ + perm[axis1+1:axis2] + perm[axis1:axis1+1] + + return P.Transpose()(x, new_perm) + + +def reshape(x, new_shape): + """ + Reshape a tensor without changing its data. + + Args: + x (Tensor): A Tensor to be reshaped. + new_shape (Union[int, list(int), tuple(int)]): The new shape should be + compatible with the original shape. If the tuple has only one element, + the result will be a 1-D tensor of that length. One shape dimension + can be -1. In this case, the value is inferred from the length of + the tensor and remaining dimensions. + + Returns: + Reshaped Tensor. Has the same data type as the original tensor x. + + Raises: + TypeError: If new_shape is not integer, list or tuple. + ValueError: If new_shape does not compatible with the original shape. + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Examples: + >>> x = Tensor(np.array([[-0.1, 0.3, 3.6], [0.4, 0.5, -3.2]]), mindspore.float32) + >>> reshape = mindspore.numpy.reshape() + >>> output = reshape(x, (3, 2)) + >>> print(output) + [[-0.1 0.3] + [ 3.6 0.4] + [ 0.5 -3.2]] + >>> output = reshape(x, (3, -1)) + >>> print(output) + [[-0.1 0.3] + [ 3.6 0.4] + [ 0.5 -3.2]] + >>> output = reshape(x, (6, )) + >>> print(output) + [-0.1 0.3 3.6 0.4 0.5 -3.2] + """ + new_shape = _check_shape_compile(new_shape) + return P.Reshape()(x, new_shape) + + +def ravel(x): + """ + Return a contiguous flattened tensor. + + A 1-D tensor, containing the elements of the input, is returned. + + Args: + x (Tensor): A tensor to be flattened. + + Returns: + Flattened Tensor. Has the same data type as the original tensor x. + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Examples: + >>> import mindspore + >>> import mindspore.numpy as mnp + >>> from mindspore import Tensor + >>> import numpy as onp + >>> input_x = Tensor(onp.ones((2,3,4)), mindspore.float32) + >>> output = mnp.ravel(x) + >>> print(output.shape) + (24,) + """ + return reshape(x, (-1,)) + + +@constexpr +def _move_axes_for_concatenate(arr_shape, axis): + """ + move axis 0 to the disiganated position, while keep other axes' relative + positions unchanged, only used if a single array is concatenated. + """ + + original_axes = tuple(range(len(arr_shape))) + new_axes = original_axes[1:axis+1] + (0,) + original_axes[axis+1:] + new_shape = arr_shape[1:axis+1] + (arr_shape[0] * arr_shape[axis+1],) + \ + arr_shape[axis+2:] + return new_axes, new_shape + + +def concatenate(arrays, axis=0): + """ + Join a sequence of arrays along an existing axis. + + Args: + arrays: Union[Tensor, tuple(Tensor), list(Tensor)], a Tensor or a list + of Tensor to be concatenated. + + axis (int, optional): The axis along which the arrays will be joined, + if axis is None, arrays are flattened before use. Default is 0. + + Returns: + Tensor, a Tensor concatenated from a Tensor or a list of Tensors. + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Examples: + >>> import mindspore.numpy as np + >>> x1 = np.ones((1,2,3)) + >>> x2 = np.ones((1,2,1)) + >>> x = np.concatenate((x1, x2), axis=-1) + >>> print(x,shape) + (1,2,4) + """ + array_type = F.typeof(arrays) + if _check_is_tensor(array_type): + # if the input is a single tensor + # if only one tensor is provided, it is treated as a tuple along the + # first dimension. For example, a tensor of shape (3,4,5) will be treated + # as: tuple(tensor_1(4,5), tensor_2(4,5), tensor_3(4,5)) + if axis is None: + return ravel(arrays) + arr_shape = F.shape(arrays) + _check_axes_range((axis,), len(arr_shape)) + # move axis 0 to the disiganated position, while keep other axes' relative + # positions unchanged + new_axes, new_shape = _move_axes_for_concatenate(arr_shape, axis) + arrays = transpose(arrays, new_axes) + arrays = reshape(arrays, new_shape) + return arrays + + flattened_arrays = () + if axis is None: + for arr in arrays: + flattened_arrays += (ravel(arr),) + axis = -1 + return P.Concat(axis)(flattened_arrays) + arr_shape = F.shape(arrays[0]) + _check_axes_range((axis,), len(arr_shape)) + + # if only one tensor in the tuple/list, return the tensor itself + if len(arrays) == 1: + return arrays[0] + + return P.Concat(axis)(arrays) diff --git a/mindspore/numpy/dtypes.py b/mindspore/numpy/dtypes.py new file mode 100644 index 0000000000..ce52cbbea7 --- /dev/null +++ b/mindspore/numpy/dtypes.py @@ -0,0 +1,96 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Dtypes and utilities""" + +from mindspore import (int8, int16, int32, int64, uint8, uint16, uint32, uint64, \ + float16, float32, float64, bool_) + +# original numpy has int->int64, float->float64, uint->uint64 mapping. we map +# them to 32 bit, since 64 bit calculation is not supported from mindspore +# backend for now. + +inf = float('inf') + +int_ = int32 +uint = uint32 +float_ = float32 + +numeric_types = [ + 'int_', + 'int8', + 'int16', + 'int32', + 'int64', + 'uint8', + 'uint16', + 'uint32', + 'uint64', + 'float_', + 'float16', + 'float32', + 'float64', + 'bool_'] + +dtype_tuple = ( + int_, + int8, + int16, + int32, + int64, + uint, + uint8, + uint16, + uint32, + uint64, + float_, + float16, + float32, + float64, + bool_) + +dtype_map = { + 'int': int_, + 'int8': int8, + 'int16': int16, + 'int32': int32, + 'int64': int64, + 'uint': uint, + 'uint8': uint8, + 'uint16': uint16, + 'uint32': uint32, + 'uint64': uint64, + 'float': float_, + 'float16': float16, + 'float32': float32, + 'float64': float64, + 'bool': bool_ +} + +all_types = [ + 'np.int', + 'np.int8', + 'np.int16', + 'np.int32', + 'np.int64', + 'np.uint', + 'np.uint8', + 'np.uint16', + 'np.uint32', + 'np.uint64', + 'np.float', + 'np.float16', + 'np.float32', + 'np.float64', + 'np.bool'] diff --git a/mindspore/numpy/math_ops.py b/mindspore/numpy/math_ops.py new file mode 100644 index 0000000000..9e681c4c84 --- /dev/null +++ b/mindspore/numpy/math_ops.py @@ -0,0 +1,160 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""math operations, the function docs are adapted from Numpy API.""" +from mindspore.ops import operations as P +from mindspore.ops import functional as F +from .array_ops import squeeze +from .utils import _infer_out_shape, _is_scalar, _check_axis_valid, _get_device_compile, \ + _check_shape_aligned + + +def mean(a, axis=None, keepdims=False): + """ + Compute the arithmetic mean along the specified axis. + + Returns the average of the array elements. The average is taken + over the flattened array by default, otherwise over the specified + axis. + + Note: + Numpy arguments dtype and out are not supported. + On GPU, the supported dtypes are np.float16, and np.float32. + On CPU, the supported dtypes are np.float16, and np.float32. + + Args: + a (Tensor): input tensor containing numbers whose mean is desired. + If a is not an array, a conversion is attempted. + axis (None or int or tuple of ints): optional. Axis or axes along + which the means are computed. The default is to compute + the mean of the flattened array. If this is a tuple of + ints, a mean is performed over multiple axes. + keepdims(bool): optional. If this is set to True, the axes which + are reduced are left in the result as dimensions with + size one. With this option, the result will broadcast + correctly against the input tensor. + + Returns: + Tensor or scalar, an array containing the mean values. + + Raises: + ValueError: if axes are out of the range of [-a.ndim, a.ndim), or + if the axes contain duplicates. + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Examples: + >>> a = np.arange(6, dtype='float32') + >>> output = np.mean(a, 0) + >>> print(output) + 2.5 + """ + axis = _check_axis_valid(axis, P.Rank()(a)) + if _is_scalar(a.shape): + if keepdims: + return a + return squeeze(a) + if keepdims: + res = P.ReduceMean(True)(a, axis) + else: + res = P.ReduceMean(False)(a, axis) + return res + + +def inner(a, b): + """ + Inner product of two tensors. + + Ordinary inner product of vectors for 1-D tensors (without complex + conjugation), in higher dimensions a sum product over the last + axes. + + Note: + Numpy argument out is not supported. + On GPU, the supported dtypes are np.float16, and np.float32. + On CPU, the supported dtype is np.float32. + + Args: + a (Tensor): input tensor. If a and b are nonscalar, their last + dimensions must match. + b (Tensor): input tensor. If a and b are nonscalar, their last + dimensions must match. + + Returns: + Tensor or scalar, out.shape = a.shape[:-1] + b.shape[:-1]. + + Raises: + ValueError: if x1.shape[-1] != x2.shape[-1]. + + Supported Platforms: + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Examples: + >>> a = np.ones((5, 3)) + >>> b = np.ones((2, 7, 3)) + >>> output = np.inner(a, b) + >>> print(output) + [[[3. 3. 3. 3. 3. 3. 3.] + [3. 3. 3. 3. 3. 3. 3.]] + + [[3. 3. 3. 3. 3. 3. 3.] + [3. 3. 3. 3. 3. 3. 3.]] + + [[3. 3. 3. 3. 3. 3. 3.] + [3. 3. 3. 3. 3. 3. 3.]] + + [[3. 3. 3. 3. 3. 3. 3.] + [3. 3. 3. 3. 3. 3. 3.]] + + [[3. 3. 3. 3. 3. 3. 3.] + [3. 3. 3. 3. 3. 3. 3.]]] + """ + if P.Rank()(a) == 0 or P.Rank()(b) == 0: + if _is_scalar(a.shape): + a, b = b, a + return _apply_bin_op(P.Mul(), a, b) + + _ = _check_shape_aligned(a.shape, b.shape) + aligned_shape_a = (F.shape_mul(a.shape[:-1]), a.shape[-1]) + aligned_shape_b = (F.shape_mul(b.shape[:-1]), a.shape[-1]) + a_aligned = P.Reshape()(a, aligned_shape_a) + b_aligned = P.Reshape()(b, aligned_shape_b) + + res = P.MatMul(False, True)(a_aligned, b_aligned) + res = P.Reshape()(res, a.shape[:-1] + b.shape[:-1]) + return res + + +def _expand(x, ndim): + """Expand x to ndim""" + while P.Rank()(x) < ndim: + x = P.ExpandDims()(x, 0) + return x + + +def _apply_bin_op(fn, x1, x2): + """apply binary operations based on fn.""" + device = _get_device_compile() + out_shape = _infer_out_shape(device, x1.shape, x2.shape) + if device == 'CPU': + # built-in operations on CPU does not support operands with + # dimensions of size 1 or with shape 0, therefore squeeze + # and scalar promotion is performed + x1, x2 = squeeze(x1), squeeze(x2) + x1, x2 = _expand(x1, 1), _expand(x2, 1) + res = fn(x1, x2) + res = P.Reshape()(res, out_shape) + return res diff --git a/mindspore/numpy/utils.py b/mindspore/numpy/utils.py new file mode 100644 index 0000000000..bafc43f314 --- /dev/null +++ b/mindspore/numpy/utils.py @@ -0,0 +1,316 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""internal utility functions""" +from functools import partial + +import numpy as onp + +import mindspore +import mindspore.context as context +from mindspore import Tensor +from mindspore.ops import operations as P +from mindspore.ops import functional as F +from mindspore.ops.primitive import constexpr +from mindspore.common import dtype as mstype + +from .dtypes import dtype_tuple, all_types, dtype_map + +@constexpr +def _check_shape_compile(shape): + """check the shape param to match the numpy style inside the graph""" + if not isinstance(shape, (int, tuple, list)): + raise TypeError( + f"only int, tuple and list are allowed for shape, but got {type(shape)}") + if isinstance(shape, int): + shape = (shape,) + if isinstance(shape, list): + shape = tuple(shape) + return shape + + +@constexpr +def _check_is_int(x): + """Check the type of x is int.""" + if isinstance(x, int): + return True + raise TypeError(f"integer argument expected, but got {type(x)}.") + + +@constexpr +def _check_start_normalize(start, ndim): + """check and normalize start argument for rollaxis.""" + if start < -ndim or start > ndim: + raise ValueError( + f"For rollaxis, start {start} is out of bounds. Ranging from {-ndim} to {ndim} is allowed.") + if start < 0: + start = start + ndim + return start + + +@constexpr +def _check_axes_range(axes, ndim): + """ + Check axes are within the number of dimensions of tensor x and normalize the negative axes. + Args: + axes (Union[int, tuple(int), list(int)]): Axes of the tensor. + ndim (int): The number of dimensions of the tensor. + Return: + Axes (Union[int, tuple(int)]). If input is integer, return integer, else tuple. + """ + if not isinstance(axes, int) and not isinstance(axes, tuple) and not isinstance(axes, list): + raise TypeError( + f"int, tuple(int) or list(int) expected, but got {type(axes)}.") + low = -ndim + up = ndim - 1 + if low > up: + raise ValueError( + f"Lower bound {low} and upper bound {up} of axes are not allowed.") + if isinstance(axes, int): + if axes < low or axes > up: + raise TypeError( + f"axis {axes} is out of bounds for tensor of dimension {ndim}.") + return axes if axes >= 0 else axes + ndim + new_axes = [] + for item in axes: + if not isinstance(item, int): + raise TypeError( + f"int in tuple or list expected, but got {type(item)}.") + if item < low or item > up: + raise TypeError( + f"axis {item} in {axes} is out of bounds for tensor of dimension {ndim}.") + new_axes.append(item if item >= 0 else item + ndim) + return tuple(new_axes) + + +def _check_shape_contain_zero(shp): + """Check whether shape contains 0""" + if isinstance(shp, int): + return shp == 0 + if isinstance(shp, (list, tuple)): + for s in shp: + if s == 0: + return True + return False + + +def _check_shape(shape): + """check the shape param to match the numpy style outside the graph""" + if not isinstance(shape, (int, tuple, list)): + raise TypeError( + f"only int, tuple and list are allowed for shape, but got {type(shape)}") + if isinstance(shape, int): + shape = (shape,) + if isinstance(shape, list): + shape = tuple(shape) + return shape + + +def _check_dtype(dtype): + """check the input dtype and make conversions""" + # convert the string dtype to mindspore.dtype + if isinstance(dtype, str): + dtype = dtype.lower() + dtype = dtype_map[dtype] + elif isinstance(dtype, type): + if dtype is int: + dtype = mindspore.int32 + if dtype is float: + dtype = mindspore.float32 + if dtype is bool: + dtype = mindspore.bool_ + if dtype not in dtype_tuple: + raise TypeError( + f"only {all_types} are allowed for dtype, but got {type(dtype)}") + return dtype + + +def _check_input_for_asarray(array_like): + """check whether array_like argument is a valid type for np.asarray conversion""" + if isinstance(array_like, (Tensor, list, tuple, int, float, bool, onp.ndarray)): + return True + raise TypeError( + "input data must be `int`, `float`, `bool`, `Tensor`, `list`, `tuple`" + \ + f"or numpy.ndarray, but got {type(array_like)}") + + +def _cast_to(array, dtype): + """cast the input to specified dtype""" + cast = P.Cast() + return cast(array, dtype) + + +def _is_scalar(shape): + """check whether input shape is a scalar""" + return F.shape_mul(shape) == 1 + + +@constexpr +def _get_device_compile(): + """Get the current device (`GPU`, `CPU`, `Ascend`)""" + return context.get_context('device_target') + + +def _get_device(): + """Get the current device (`GPU`, `CPU`, `Ascend`)""" + return context.get_context('device_target') + + +def _get_mode(): + """Get the current mode (0 is Graph mode, 1 is PyNative mode)""" + return context.get_context('mode') + + +@constexpr +def _reverse_index(idx, arr): + """ + Returns 1 if shape[idx:] is broadcastable to shape_out[idx:], + 2 situations if the function returns 1: + - 1. Tensor's shape has 1 at the designated dimension. + - 2. Tensor's dimension is less than the designated idx. (The Tensor shape + has been reversed) + For both cases, 2 tensors are broadcastable. + otherwise returns the element at position of shape + """ + if len(arr) <= idx: + return 1 + return arr[-1 - idx] + + +@constexpr +def _infer_out_shape(device, *shapes): + """ + Returns shape of output after broadcasting + Raises ValueError if shape1 and shape2 cannot be broadcast + """ + shapes_unbroadcastable = False + cpu_shapes_different = False + contains_scalar = any(_is_scalar(shape) for shape in shapes) + ndim_max = max(map(len, shapes)) + shape_out = [0]*ndim_max + i = 0 + for i in range(ndim_max): + shape_out[-1 - i] = max(map(partial(_reverse_index, i), shapes)) + for shape in shapes: + if _reverse_index(i, shape) != shape_out[-1 - i]: + if _reverse_index(i, shape) != 1: + shapes_unbroadcastable = True + if device == 'CPU' and not contains_scalar: + cpu_shapes_different = True + if not shapes_unbroadcastable and not cpu_shapes_different: + return tuple(shape_out) + if shapes_unbroadcastable: + raise ValueError( + f'operands could not be broadcast together with shapes {*shapes,}') + raise ValueError('broadcasting is currently not supported on CPU. Non-scalar' + \ + f'operands must have the same shape, but got {*shapes,}') + + +@constexpr +def _check_axis_in_range(axis, ndim): + """Checks axes are with the bounds of ndim""" + if -ndim <= axis < ndim: + return True + raise ValueError( + f'axis {axis} is out of bounds for array of dimension {ndim}') + + +@constexpr +def _check_axis_valid(axes, ndim): + """ + Checks axes are valid given ndim, and returns axes that can be passed + to the built-in operator (non-negative, int or tuple) + """ + if isinstance(axes, int): + _ = _check_axis_in_range(axes, ndim) + return (axes % ndim,) + if isinstance(axes, tuple): + for axis in axes: + _ = _check_axis_in_range(axis, ndim) + axes = tuple(map(lambda x: x % ndim, axes)) + if all(axes.count(el) <= 1 for el in axes): + return axes + if axes is None: + axes = F.make_range(ndim) + return axes + raise ValueError('duplicate value in \'axis\'') + + +@constexpr +def _check_shape_aligned(shape1, shape2): + """Checks shape1 and shape2 are valid shapes to perform inner product""" + if shape1[-1] == shape2[-1]: + return True + raise ValueError( + f'shapes {shape1} {shape2} not aligned: {shape1[-1]} (dim 0) != {shape2[-1]} (dim 0)') + + +@constexpr +def _check_dim_cpu(shape, bound): + """Checks input shape is upper-bounded by parameter bound""" + ndim = len(shape) + if _is_scalar(shape): + return True + if ndim <= bound: + return True + raise ValueError( + f'dimension {ndim} larger than {bound} is not supported on CPU') + + +@constexpr +def _tile_size(shape, out_shape, ndim): + """Returns tile_size such that shape*tile_size = out_shape""" + size = [1]*ndim + for idx, (i, j) in enumerate(zip(shape, out_shape)): + if i != j: + size[idx] = j + return tuple(size) + + +@constexpr +def _check_core_match(shape1, shape2): + """Checks shape1 and shape2 are valid shapes to perform matmul""" + ndim1, ndim2 = len(shape1), len(shape2) + if ndim1 < 1 or ndim2 < 2: + return True + if shape1[-1] == shape2[-2]: + return True + raise ValueError(f'mismatch in core dimension of input operands (size {shape1[-1]} ' + + f'is different from {shape2[-2]})') + + +@constexpr +def _cpu_not_support(name): + """Checks if a function not supported on cpu is executed on cpu device""" + if _get_device() != 'CPU': + return True + raise ValueError(f'{name} is not supported on CPU') + + +@constexpr +def _check_is_tuple(obj): + """Check whether obj is a tuple""" + return isinstance(obj, mstype.tuple_type) + + +@constexpr +def _check_is_list(obj): + """Check whether obj is a list""" + return isinstance(obj, mstype.list_type) + + +@constexpr +def _check_is_tensor(obj): + """Check whether obj is a tensor""" + return isinstance(obj, mstype.tensor_type) diff --git a/tests/ut/python/numpy_native/__init__.py b/tests/ut/python/numpy_native/__init__.py new file mode 100644 index 0000000000..43327c3354 --- /dev/null +++ b/tests/ut/python/numpy_native/__init__.py @@ -0,0 +1,21 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""setup for pytest in mindspore.numpy""" +import mindspore.context as context + + +# pylint: disable=unused-argument +def setup_module(module): + context.set_context(mode=context.GRAPH_MODE) diff --git a/tests/ut/python/numpy_native/test_array_ops.py b/tests/ut/python/numpy_native/test_array_ops.py new file mode 100644 index 0000000000..643bed282d --- /dev/null +++ b/tests/ut/python/numpy_native/test_array_ops.py @@ -0,0 +1,577 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""unit tests for array operations""" + +import functools + +import pytest +import numpy as onp + +import mindspore.context as context +import mindspore.numpy as mnp +from mindspore.nn import Cell + +from ..ut_filter import non_graph_engine +from ....mindspore_test_framework.mindspore_test import mindspore_test +from ....mindspore_test_framework.pipeline.forward.compile_forward \ + import pipeline_for_compile_forward_ge_graph_for_case_by_case_config + + +class Cases(): + def __init__(self): + self.all_shapes = [ + 0, 1, 2, (), (1,), (2,), (1, 2, 3), [], [1], [2], [1, 2, 3] + ] + self.onp_dtypes = [onp.int32, 'int32', int, + onp.float32, 'float32', float, + onp.uint32, 'uint32', + onp.bool_, 'bool', bool] + + self.mnp_dtypes = [mnp.int32, 'int32', int, + mnp.float32, 'float32', float, + mnp.uint32, 'uint32', + mnp.bool_, 'bool', bool] + + self.array_sets = [1, 1.1, True, [1, 0, True], [1, 1.0, 2], (1,), + [(1, 2, 3), (4, 5, 6)], onp.random.random( + (100, 100)).astype(onp.float32), + onp.random.random((100, 100)).astype(onp.bool)] + + +def match_array(actual, expected, error=0): + if error > 0: + onp.testing.assert_almost_equal(actual.tolist(), expected.tolist(), + decimal=error) + else: + onp.testing.assert_equal(actual.tolist(), expected.tolist()) + + +def check_all_results(onp_results, mnp_results): + """Check all results from numpy and mindspore.numpy""" + for i, _ in enumerate(onp_results): + match_array(onp_results[i], mnp_results[i].asnumpy()) + + +def test_asarray(): + test_case = Cases() + for array in test_case.array_sets: + # Check for dtype matching + actual = onp.asarray(array) + expected = mnp.asarray(array).asnumpy() + # Since we set float32/int32 as the default dtype in mindspore, we need + # to make a conversion between numpy.asarray and mindspore.numpy.asarray + if actual.dtype is onp.dtype('float64'): + assert expected.dtype == onp.dtype('float32') + elif actual.dtype is onp.dtype('int64'): + assert expected.dtype == onp.dtype('int32') + else: + assert actual.dtype == expected.dtype + match_array(actual, expected, error=7) + + for i in range(len(test_case.onp_dtypes)): + actual = onp.asarray(array, test_case.onp_dtypes[i]) + expected = mnp.asarray(array, test_case.mnp_dtypes[i]).asnumpy() + match_array(actual, expected, error=7) + + +def test_array(): + # array's function is very similar to asarray, so we mainly test the + # `copy` argument. + test_case = Cases() + for array in test_case.array_sets: + arr1 = mnp.asarray(array) + arr2 = mnp.array(arr1, copy=False) + arr3 = mnp.array(arr1) + arr4 = mnp.asarray(array, dtype='int32') + arr5 = mnp.asarray(arr4, dtype=mnp.int32) + assert arr1 is arr2 + assert arr1 is not arr3 + assert arr4 is arr5 + + +def test_asfarray(): + test_case = Cases() + for array in test_case.array_sets: + # Check for dtype matching + actual = onp.asfarray(array) + expected = mnp.asfarray(array).asnumpy() + # Since we set float32/int32 as the default dtype in mindspore, we need + # to make a conversion between numpy.asarray and mindspore.numpy.asarray + if actual.dtype is onp.dtype('float64'): + assert expected.dtype == onp.dtype('float32') + else: + assert actual.dtype == expected.dtype + match_array(actual, expected, error=7) + + for i in range(len(test_case.onp_dtypes)): + actual = onp.asfarray(array, test_case.onp_dtypes[i]) + expected = mnp.asfarray(array, test_case.mnp_dtypes[i]).asnumpy() + match_array(actual, expected, error=7) + + +def test_zeros(): + test_case = Cases() + for shape in test_case.all_shapes: + for i in range(len(test_case.onp_dtypes)): + actual = onp.zeros(shape, test_case.onp_dtypes[i]) + expected = mnp.zeros(shape, test_case.mnp_dtypes[i]).asnumpy() + match_array(actual, expected) + actual = onp.zeros(shape) + expected = mnp.zeros(shape).asnumpy() + match_array(actual, expected) + + +def test_ones(): + test_case = Cases() + for shape in test_case.all_shapes: + for i in range(len(test_case.onp_dtypes)): + actual = onp.ones(shape, test_case.onp_dtypes[i]) + expected = mnp.ones(shape, test_case.mnp_dtypes[i]).asnumpy() + match_array(actual, expected) + actual = onp.ones(shape) + expected = mnp.ones(shape).asnumpy() + match_array(actual, expected) + + +def test_full(): + actual = onp.full((2, 2), [1, 2]) + expected = mnp.full((2, 2), [1, 2]).asnumpy() + match_array(actual, expected) + + actual = onp.full((2, 0), onp.inf) + expected = mnp.full((2, 0), mnp.inf).asnumpy() + match_array(actual, expected) + + actual = onp.full((2, 3), True) + expected = mnp.full((2, 3), True).asnumpy() + match_array(actual, expected) + + actual = onp.full((3, 4, 5), 7.5) + expected = mnp.full((3, 4, 5), 7.5).asnumpy() + match_array(actual, expected) + + +def test_eye(): + test_case = Cases() + for i in range(len(test_case.onp_dtypes)): + for m in range(1, 5): + actual = onp.eye(m, dtype=test_case.onp_dtypes[i]) + expected = mnp.eye(m, dtype=test_case.mnp_dtypes[i]).asnumpy() + match_array(actual, expected) + for n in range(1, 5): + for k in range(0, 5): + actual = onp.eye(m, n, k, dtype=test_case.onp_dtypes[i]) + expected = mnp.eye( + m, n, k, dtype=test_case.mnp_dtypes[i]).asnumpy() + match_array(actual, expected) + + +def test_identity(): + test_case = Cases() + for i in range(len(test_case.onp_dtypes)): + for m in range(1, 5): + actual = onp.identity(m, dtype=test_case.onp_dtypes[i]) + expected = mnp.identity(m, dtype=test_case.mnp_dtypes[i]).asnumpy() + match_array(actual, expected) + + +def test_arange(): + actual = onp.arange(10) + expected = mnp.arange(10).asnumpy() + match_array(actual, expected) + + actual = onp.arange(0, 10) + expected = mnp.arange(0, 10).asnumpy() + match_array(actual, expected) + + actual = onp.arange(start=10) + expected = mnp.arange(start=10).asnumpy() + match_array(actual, expected) + + actual = onp.arange(start=10, step=0.1) + expected = mnp.arange(start=10, step=0.1).asnumpy() + match_array(actual, expected, error=6) + + actual = onp.arange(10, step=0.1) + expected = mnp.arange(10, step=0.1).asnumpy() + match_array(actual, expected, error=6) + + actual = onp.arange(0.1, 9.9) + expected = mnp.arange(0.1, 9.9).asnumpy() + match_array(actual, expected, error=6) + + +def test_linspace(): + actual = onp.linspace(2.0, 3.0, dtype=onp.float32) + expected = mnp.linspace(2.0, 3.0).asnumpy() + match_array(actual, expected, error=7) + + actual = onp.linspace(2.0, 3.0, num=5, dtype=onp.float32) + expected = mnp.linspace(2.0, 3.0, num=5).asnumpy() + match_array(actual, expected, error=7) + + actual = onp.linspace( + 2.0, 3.0, num=5, endpoint=False, dtype=onp.float32) + expected = mnp.linspace(2.0, 3.0, num=5, endpoint=False).asnumpy() + match_array(actual, expected, error=7) + + actual = onp.linspace(2.0, 3.0, num=5, retstep=True, dtype=onp.float32) + expected = mnp.linspace(2.0, 3.0, num=5, retstep=True) + match_array(actual[0], expected[0].asnumpy()) + assert actual[1] == expected[1] + + actual = onp.linspace(2.0, [3, 4, 5], num=5, + endpoint=False, dtype=onp.float32) + expected = mnp.linspace( + 2.0, [3, 4, 5], num=5, endpoint=False).asnumpy() + match_array(actual, expected) + + +def test_logspace(): + actual = onp.logspace(2.0, 3.0, dtype=onp.float32) + expected = mnp.logspace(2.0, 3.0).asnumpy() + match_array(actual, expected) + + actual = onp.logspace(2.0, 3.0, num=5, dtype=onp.float32) + expected = mnp.logspace(2.0, 3.0, num=5).asnumpy() + match_array(actual, expected) + + actual = onp.logspace( + 2.0, 3.0, num=5, endpoint=False, dtype=onp.float32) + expected = mnp.logspace(2.0, 3.0, num=5, endpoint=False).asnumpy() + match_array(actual, expected) + + actual = onp.logspace(2.0, [3, 4, 5], num=5, + endpoint=False, dtype=onp.float32) + expected = mnp.logspace( + 2.0, [3, 4, 5], num=5, endpoint=False).asnumpy() + match_array(actual, expected) + + +# Test np.transpose and np.ndarray.transpose + + +def mnp_transpose(input_tensor): + a = mnp.transpose(input_tensor, (0, 2, 1)) + b = mnp.transpose(input_tensor, [2, 1, 0]) + c = mnp.transpose(input_tensor, (1, 0, 2)) + d = mnp.transpose(input_tensor) + return a, b, c, d + + +def onp_transpose(input_array): + a = onp.transpose(input_array, (0, 2, 1)) + b = onp.transpose(input_array, [2, 1, 0]) + c = onp.transpose(input_array, (1, 0, 2)) + d = onp.transpose(input_array) + return a, b, c, d + +# Test np.expand_dims + + +def mnp_expand_dims(input_tensor): + a = mnp.expand_dims(input_tensor, 0) + b = mnp.expand_dims(input_tensor, -1) + c = mnp.expand_dims(input_tensor, axis=2) + d = mnp.expand_dims(input_tensor, axis=-2) + return a, b, c, d + + +def onp_expand_dims(input_array): + a = onp.expand_dims(input_array, 0) + b = onp.expand_dims(input_array, -1) + c = onp.expand_dims(input_array, axis=2) + d = onp.expand_dims(input_array, axis=-2) + return a, b, c, d + +# Test np.squeeze + + +def mnp_squeeze(input_tensor): + a = mnp.squeeze(input_tensor) + b = mnp.squeeze(input_tensor, 0) + c = mnp.squeeze(input_tensor, axis=None) + d = mnp.squeeze(input_tensor, axis=-3) + e = mnp.squeeze(input_tensor, (2,)) + f = mnp.squeeze(input_tensor, (0, 2)) + return a, b, c, d, e, f + + +def onp_squeeze(input_array): + a = onp.squeeze(input_array) + b = onp.squeeze(input_array, 0) + c = onp.squeeze(input_array, axis=None) + d = onp.squeeze(input_array, axis=-3) + e = onp.squeeze(input_array, (2,)) + f = onp.squeeze(input_array, (0, 2)) + return a, b, c, d, e, f + +# Test np.rollaxis + + +def mnp_rollaxis(input_tensor): + a = mnp.rollaxis(input_tensor, 0, 1) + b = mnp.rollaxis(input_tensor, 0, 2) + c = mnp.rollaxis(input_tensor, 2, 1) + d = mnp.rollaxis(input_tensor, 2, 2) + e = mnp.rollaxis(input_tensor, 0) + f = mnp.rollaxis(input_tensor, 1) + return a, b, c, d, e, f + + +def onp_rollaxis(input_array): + a = onp.rollaxis(input_array, 0, 1) + b = onp.rollaxis(input_array, 0, 2) + c = onp.rollaxis(input_array, 2, 1) + d = onp.rollaxis(input_array, 2, 2) + e = onp.rollaxis(input_array, 0) + f = onp.rollaxis(input_array, 1) + return a, b, c, d, e, f + +# Test np.swapaxes + + +def mnp_swapaxes(input_tensor): + a = mnp.swapaxes(input_tensor, 0, 1) + b = mnp.swapaxes(input_tensor, 1, 0) + c = mnp.swapaxes(input_tensor, 1, 1) + d = mnp.swapaxes(input_tensor, 2, 1) + e = mnp.swapaxes(input_tensor, 1, 2) + f = mnp.swapaxes(input_tensor, 2, 2) + return a, b, c, d, e, f + + +def onp_swapaxes(input_array): + a = onp.swapaxes(input_array, 0, 1) + b = onp.swapaxes(input_array, 1, 0) + c = onp.swapaxes(input_array, 1, 1) + d = onp.swapaxes(input_array, 2, 1) + e = onp.swapaxes(input_array, 1, 2) + f = onp.swapaxes(input_array, 2, 2) + return a, b, c, d, e, f + +# Test np.reshape + + +def mnp_reshape(input_tensor): + a = mnp.reshape(input_tensor, (3, 8)) + b = mnp.reshape(input_tensor, [3, -1]) + c = mnp.reshape(input_tensor, (-1, 12)) + d = mnp.reshape(input_tensor, (-1,)) + e = mnp.reshape(input_tensor, 24) + f = mnp.reshape(input_tensor, [2, 4, -1]) + return a, b, c, d, e, f + + +def onp_reshape(input_array): + a = onp.reshape(input_array, (3, 8)) + b = onp.reshape(input_array, [3, -1]) + c = onp.reshape(input_array, (-1, 12)) + d = onp.reshape(input_array, (-1,)) + e = onp.reshape(input_array, 24) + f = onp.reshape(input_array, [2, 4, -1]) + return a, b, c, d, e, f + +# Test np.ravel + + +def mnp_ravel(input_tensor): + a = mnp.ravel(input_tensor) + return a + + +def onp_ravel(input_array): + a = onp.ravel(input_array) + return a + +# Test np.concatenate + + +def mnp_concatenate(input_tensor): + a = mnp.concatenate(input_tensor, None) + b = mnp.concatenate(input_tensor, 0) + c = mnp.concatenate(input_tensor, 1) + d = mnp.concatenate(input_tensor, 2) + return a, b, c, d + + +def onp_concatenate(input_array): + a = onp.concatenate(input_array, None) + b = onp.concatenate(input_array, 0) + c = onp.concatenate(input_array, 1) + d = onp.concatenate(input_array, 2) + return a, b, c, d + + +def test_transpose(): + onp_array = onp.random.random((3, 4, 5)).astype('float32') + mnp_array = mnp.asarray(onp_array) + o_transposed = onp_transpose(onp_array) + m_transposed = mnp_transpose(mnp_array) + check_all_results(o_transposed, m_transposed) + + +def test_expand_dims(): + onp_array = onp.random.random((3, 4, 5)).astype('float32') + mnp_array = mnp.asarray(onp_array) + o_expanded = onp_expand_dims(onp_array) + m_expanded = mnp_expand_dims(mnp_array) + check_all_results(o_expanded, m_expanded) + + +def test_squeeze(): + onp_array = onp.random.random((1, 3, 1, 4, 2)).astype('float32') + mnp_array = mnp.asarray(onp_array) + o_squeezed = onp_squeeze(onp_array) + m_squeezed = mnp_squeeze(mnp_array) + check_all_results(o_squeezed, m_squeezed) + + onp_array = onp.random.random((1, 1, 1, 1, 1)).astype('float32') + mnp_array = mnp.asarray(onp_array) + o_squeezed = onp_squeeze(onp_array) + m_squeezed = mnp_squeeze(mnp_array) + check_all_results(o_squeezed, m_squeezed) + + +def test_rollaxis(): + onp_array = onp.random.random((3, 4, 5)).astype('float32') + mnp_array = mnp.asarray(onp_array) + o_rolled = onp_rollaxis(onp_array) + m_rolled = mnp_rollaxis(mnp_array) + check_all_results(o_rolled, m_rolled) + + +def test_swapaxes(): + onp_array = onp.random.random((3, 4, 5)).astype('float32') + mnp_array = mnp.asarray(onp_array) + o_swaped = onp_swapaxes(onp_array) + m_swaped = mnp_swapaxes(mnp_array) + check_all_results(o_swaped, m_swaped) + + +def test_reshape(): + onp_array = onp.random.random((2, 3, 4)).astype('float32') + mnp_array = mnp.asarray(onp_array) + o_reshaped = onp_reshape(onp_array) + m_reshaped = mnp_reshape(mnp_array) + check_all_results(o_reshaped, m_reshaped) + + +def test_ravel(): + onp_array = onp.random.random((2, 3, 4)).astype('float32') + mnp_array = mnp.asarray(onp_array) + o_ravel = onp_ravel(onp_array) + m_ravel = mnp_ravel(mnp_array).asnumpy() + match_array(o_ravel, m_ravel) + + +def test_concatenate(): + onp_array = onp.random.random((5, 4, 3, 2)).astype('float32') + mnp_array = mnp.asarray(onp_array) + o_concatenate = onp_concatenate(onp_array) + m_concatenate = mnp_concatenate(mnp_array) + check_all_results(o_concatenate, m_concatenate) + + +class ReshapeExpandSqueeze(Cell): + def __init__(self): + super(ReshapeExpandSqueeze, self).__init__() + + def construct(self, x): + x = mnp.expand_dims(x, 2) + x = mnp.reshape(x, (1, 2, 3, 4, 1, 1)) + x = mnp.squeeze(x) + return x + + +class TransposeConcatRavel(Cell): + def __init__(self): + super(TransposeConcatRavel, self).__init__() + + def construct(self, x1, x2, x3): + x1 = mnp.transpose(x1, [0, 2, 1]) + x2 = x2.transpose(0, 2, 1) + x = mnp.concatenate((x1, x2, x3), -1) + x = mnp.ravel(x) + return x + + +class RollSwap(Cell): + def __init__(self): + super(RollSwap, self).__init__() + + def construct(self, x): + x = mnp.rollaxis(x, 2) + x = mnp.swapaxes(x, 0, 1) + return x + + +test_case_array_ops = [ + ('ReshapeExpandSqueeze', { + 'block': ReshapeExpandSqueeze(), + 'desc_inputs': [mnp.ones((2, 3, 4))]}), + + ('TransposeConcatRavel', { + 'block': TransposeConcatRavel(), + 'desc_inputs': [mnp.ones((2, 3, 4)), + mnp.ones((2, 3, 4)), + mnp.ones((2, 4, 1))]}), + + ('RollSwap', { + 'block': RollSwap(), + 'desc_inputs': [mnp.ones((2, 3, 4))]}) +] + +test_case_lists = [test_case_array_ops] +test_exec_case = functools.reduce(lambda x, y: x + y, test_case_lists) +# use -k to select certain testcast +# pytest tests/python/ops/test_ops.py::test_backward -k LayerNorm + + +@non_graph_engine +@mindspore_test(pipeline_for_compile_forward_ge_graph_for_case_by_case_config) +def test_exec(): + context.set_context(mode=context.GRAPH_MODE) + return test_exec_case + + +raise_set = [ + ('Expand_dims_Error', { + 'block': (lambda x: mnp.expand_dims, {'exception': ValueError}), + 'desc_inputs': [mnp.ones((2, 3, 4)), 0]}), +] + + +def expand_dims_exception(input_tensor): + return mnp.expand_dims(input_tensor, 1.2) + + +def test_expand_dims_exception(): + with pytest.raises(TypeError): + expand_dims_exception(mnp.ones((3, 3))) + + +def test_asarray_exception(): + with pytest.raises(TypeError): + mnp.asarray({1, 2, 3}) + + +def swapaxes_exception(input_tensor): + return mnp.swapaxes(input_tensor, 1, 10) + + +def test_swapaxes_exception(): + with pytest.raises(TypeError): + swapaxes_exception(mnp.ones((3, 3))) diff --git a/tests/ut/python/numpy_native/test_math_ops.py b/tests/ut/python/numpy_native/test_math_ops.py new file mode 100644 index 0000000000..2a8e02c0ed --- /dev/null +++ b/tests/ut/python/numpy_native/test_math_ops.py @@ -0,0 +1,88 @@ + +import pytest +import numpy as onp + +import mindspore.context as context +import mindspore.numpy as mnp + + +def rand_int(*shape): + """return an random integer array with parameter shape""" + res = onp.random.randint(low=1, high=5, size=shape) + if isinstance(res, onp.ndarray): + res = res.astype(onp.float32) + return res + + +class Cases(): + def __init__(self): + self.device_cpu = context.get_context('device_target') == 'CPU' + + self.arrs = [ + rand_int(2), + rand_int(2, 3), + rand_int(2, 3, 4), + rand_int(2, 3, 4, 5), + ] + + # scalars expanded across the 0th dimension + self.scalars = [ + rand_int(), + rand_int(1), + rand_int(1, 1), + rand_int(1, 1, 1), + ] + + # arrays with last dimension aligned + self.aligned_arrs = [ + rand_int(2, 3), + rand_int(1, 4, 3), + rand_int(5, 1, 2, 3), + rand_int(4, 2, 1, 1, 3), + ] + + +test_case = Cases() + + +def mnp_inner(a, b): + return mnp.inner(a, b) + + +def onp_inner(a, b): + return onp.inner(a, b) + + +def test_inner(): + for arr1 in test_case.aligned_arrs: + for arr2 in test_case.aligned_arrs: + match_res(mnp_inner, onp_inner, arr1, arr2) + + for scalar1 in test_case.scalars: + for scalar2 in test_case.scalars: + match_res(mnp_inner, onp_inner, + scalar1, scalar2) + + +# check if the output from mnp function and onp function applied on the arrays are matched + + +def match_res(mnp_fn, onp_fn, arr1, arr2): + actual = mnp_fn(mnp.asarray(arr1, dtype='float32'), + mnp.asarray(arr2, dtype='float32')).asnumpy() + expected = onp_fn(arr1, arr2) + match_array(actual, expected) + + +def match_array(actual, expected, error=5): + if error > 0: + onp.testing.assert_almost_equal(actual.tolist(), expected.tolist(), + decimal=error) + else: + onp.testing.assert_equal(actual.tolist(), expected.tolist()) + + +def test_exception_innner(): + with pytest.raises(ValueError): + mnp.inner(mnp.asarray(test_case.arrs[0]), + mnp.asarray(test_case.arrs[1]))