From: @yanglf1121 Reviewed-by: Signed-off-by:tags/v1.2.0-rc1
| @@ -1,4 +1,4 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| @@ -1,6 +1,6 @@ | |||
| # This is the Python adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). | |||
| # | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| @@ -1,4 +1,4 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| @@ -1,4 +1,4 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| @@ -19,10 +19,10 @@ Examples: | |||
| >>> import mindspore.numpy as np | |||
| Note: | |||
| - array_ops.py define all the array generation and operation interfaces. | |||
| - math_ops.py define all the math operations on tensors. | |||
| - dtypes.py define all the mindspore.numpy dtypes (mainly redirected from mindspore) | |||
| - random/ defines all the random operations. | |||
| - array_ops.py defines all the array operation interfaces. | |||
| - array_creations.py defines all the array generation interfaces. | |||
| - math_ops.py defines all the math operations on tensors. | |||
| - dtypes.py defines all the mindspore.numpy dtypes (mainly redirected from mindspore) | |||
| """ | |||
| from .array_ops import (transpose, expand_dims, squeeze, rollaxis, swapaxes, reshape, | |||
| @@ -1,4 +1,4 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| @@ -153,7 +153,7 @@ def asarray(a, dtype=None): | |||
| elif a.dtype is onp.dtype('float'): | |||
| dtype = mstype.float32 | |||
| elif a.dtype is onp.dtype('object'): | |||
| raise TypeError(f"For Tensor convertion, the input_data is {a} that contains unsupported element.") | |||
| raise TypeError(f"For Tensor conversion, the input_data is {a} that contains unsupported element.") | |||
| a = Tensor.from_numpy(a) | |||
| # If a is already a tensor and we don't need to cast dtype, return a | |||
| @@ -208,7 +208,7 @@ def asfarray(a, dtype=mstype.float32): | |||
| a = _deep_tensor_to_nparray(a) | |||
| a = onp.asarray(a) | |||
| if a.dtype is onp.dtype('object'): | |||
| raise TypeError(f"For Tensor convertion, the input_data is {a} that contains unsupported element.") | |||
| raise TypeError(f"For Tensor conversion, the input_data is {a} that contains unsupported element.") | |||
| if isinstance(a, onp.ndarray): | |||
| a = Tensor.from_numpy(a) | |||
| @@ -952,7 +952,7 @@ def tril(m, k=0): | |||
| Returns a copy of an array with elements above the k-th diagonal zeroed. | |||
| Args: | |||
| m(array_like): The shape and data-type of a define these same | |||
| m(array_like): The shape and data-type of m define these same | |||
| attributes of the returned array. | |||
| k(int, optional): Diagonal above which to zero elements. k = 0 (the default) | |||
| is the main diagonal, k < 0 is below it and k > 0 is above. | |||
| @@ -987,16 +987,16 @@ def triu(m, k=0): | |||
| """ | |||
| Returns an upper triangle of an array. | |||
| Returns a copy of an array with elements above the k-th diagonal zeroed. | |||
| Returns a copy of an array with elements below the k-th diagonal zeroed. | |||
| Args: | |||
| m(array_like): The shape and data-type of a define these same | |||
| m(array_like): The shape and data-type of m define these same | |||
| attributes of the returned array. | |||
| k(int, optional): Diagonal above which to zero elements. k = 0 (the default) | |||
| k(int, optional): Diagonal below which to zero elements. k = 0 (the default) | |||
| is the main diagonal, k < 0 is below it and k > 0 is above. | |||
| Returns: | |||
| triu(Tensor): Lower triangle of m, of same shape and data-type as m. | |||
| triu(Tensor): Upper triangle of m, of same shape and data-type as m. | |||
| Raises: | |||
| TypeError: If input arguments have types not specified above. | |||
| @@ -1175,7 +1175,7 @@ def trace(a, offset=0, axis1=0, axis2=1): | |||
| >>> print(output) | |||
| [6 8] | |||
| >>> a = np.arange(24).reshape((2,2,2,3)) | |||
| >>> output = np.trace.shape | |||
| >>> output = np.trace(a).shape | |||
| >>> print(output) | |||
| (2, 3) | |||
| """ | |||
| @@ -1,4 +1,4 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| @@ -20,12 +20,11 @@ from ..ops import functional as F | |||
| from ..ops.primitive import constexpr | |||
| from ..nn import Cell | |||
| from .utils import _covert_list_tensor_to_tuple_tensor, _expand, _broadcast_to, \ | |||
| from .utils import _convert_list_tensor_to_tuple_tensor, _expand, _broadcast_to, \ | |||
| _is_empty | |||
| from .utils_const import _check_is_int, _check_axes_range, _check_start_normalize, \ | |||
| _check_is_tensor, _check_is_tuple, _check_is_list, _raise_type_error, _raise_value_error, \ | |||
| _infer_out_shape, _get_index_for_unique, _get_counts_for_unique, _empty, _promote, \ | |||
| _min, _check_same_type, _check_input_tensor | |||
| _infer_out_shape, _empty, _promote, _check_same_type, _check_input_tensor | |||
| # According to official numpy reference, the dimension of a numpy array must be less | |||
| # than 32 | |||
| @@ -336,7 +335,7 @@ def ravel(x): | |||
| Flattened tensor, has the same data type as the original tensor x. | |||
| Raises: | |||
| If x is not tensor. | |||
| TypeError: If x is not tensor. | |||
| Supported Platforms: | |||
| ``Ascend`` ``GPU`` ``CPU`` | |||
| @@ -450,7 +449,7 @@ def concatenate(arrays, axis=0): | |||
| return P.Concat(axis)(flattened_arrays) | |||
| # convert a list of tensor to a tuple of tensor | |||
| arrays = _covert_list_tensor_to_tuple_tensor(arrays) | |||
| arrays = _convert_list_tensor_to_tuple_tensor(arrays) | |||
| arr_shape = F.shape(arrays[0]) | |||
| _check_axes_range((axis,), len(arr_shape)) | |||
| @@ -503,12 +502,11 @@ def column_stack(tup): | |||
| trans_tup = () | |||
| for tensor in tup: | |||
| shape = F.shape(tensor) | |||
| if F.tuple_len(shape) == 1: | |||
| reshape_tensor = F.reshape(tensor, shape+(1,)) | |||
| trans_tup += (reshape_tensor,) | |||
| else: | |||
| trans_tup += (tensor,) | |||
| if tensor.ndim < 1: | |||
| tensor = F.expand_dims(tensor, 0) | |||
| if tensor.ndim == 1: | |||
| tensor = F.expand_dims(tensor, 1) | |||
| trans_tup += (tensor,) | |||
| return P.Concat(axis=1)(trans_tup) | |||
| @@ -552,12 +550,9 @@ def vstack(tup): | |||
| trans_tup = () | |||
| for tensor in tup: | |||
| shape = F.shape(tensor) | |||
| if F.tuple_len(shape) == 1: | |||
| reshape_tensor = F.reshape(tensor, (1,)+shape) | |||
| trans_tup += (reshape_tensor,) | |||
| else: | |||
| trans_tup += (tensor,) | |||
| if tensor.ndim <= 1: | |||
| tensor = _expand(tensor, 2, 0) | |||
| trans_tup += (tensor,) | |||
| return P.Concat(axis=0)(trans_tup) | |||
| @@ -600,13 +595,12 @@ def hstack(tup): | |||
| _raise_value_error("Need at least one tensor to concatenate.") | |||
| tuple_of_tensor = () | |||
| if _check_is_list(tup): | |||
| for tensor in tup: | |||
| tuple_of_tensor += (tensor,) | |||
| else: | |||
| tuple_of_tensor = tup | |||
| for tensor in tup: | |||
| if tensor.ndim < 1: | |||
| tensor = F.expand_dims(tensor, 0) | |||
| tuple_of_tensor += (tensor,) | |||
| if F.tuple_len(F.shape(tup[0])) == 1: | |||
| if tuple_of_tensor[0].ndim <= 1: | |||
| return P.Concat(axis=0)(tuple_of_tensor) | |||
| return P.Concat(axis=1)(tuple_of_tensor) | |||
| @@ -652,15 +646,11 @@ def dstack(tup): | |||
| trans_tup = () | |||
| for tensor in tup: | |||
| shape = F.shape(tensor) | |||
| if F.tuple_len(shape) == 1: | |||
| reshape_tensor = F.reshape(tensor, (1,)+shape+(1,)) | |||
| trans_tup += (reshape_tensor,) | |||
| elif F.tuple_len(shape) == 2: | |||
| reshape_tensor = F.reshape(tensor, shape+(1,)) | |||
| trans_tup += (reshape_tensor,) | |||
| else: | |||
| trans_tup += (tensor,) | |||
| if tensor.ndim <= 1: | |||
| tensor = _expand(tensor, 2, 0) | |||
| if tensor.ndim == 2: | |||
| tensor = F.expand_dims(tensor, 2) | |||
| trans_tup += (tensor,) | |||
| return P.Concat(axis=2)(trans_tup) | |||
| @@ -670,10 +660,6 @@ def where(condition, x=None, y=None): | |||
| Note: | |||
| As nonzero is not supported, neither x or y can be None. | |||
| On CPU, the supported dtypes are np.float16, np.float32, np.int16, | |||
| and np.int32. | |||
| On GPU, the supported dtypes are np.float16, np.float32, np.int16, | |||
| and np.int32. | |||
| Args: | |||
| condition (Tensor): where True, yield x, otherwise yield y. | |||
| @@ -724,6 +710,9 @@ def where(condition, x=None, y=None): | |||
| shape_out = _infer_out_shape(F.shape(condition), | |||
| F.shape(x), F.shape(y)) | |||
| ndim_out = len(shape_out) | |||
| if not _check_same_type(F.dtype(condition), mstype.float32): | |||
| # tiling with bool is not supported on GPU | |||
| condition = F.cast(condition, mstype.float32) | |||
| condition = _expand(condition, ndim_out) | |||
| x = _expand(x, ndim_out) | |||
| y = _expand(y, ndim_out) | |||
| @@ -739,24 +728,16 @@ def where(condition, x=None, y=None): | |||
| return res | |||
| def _expand_atleast(arr, ndim): | |||
| """Expands arr to at least ndim.""" | |||
| arr = _expand(arr, _min(ndim, 2)) | |||
| if ndim > 2: | |||
| arr = _expand(arr, ndim, axis=-1) | |||
| return arr | |||
| def _atleast_xd(ndim, arys): | |||
| """Returns arys with at least ndim.""" | |||
| for arr in arys: | |||
| _check_input_tensor(F.typeof(arr)) | |||
| if F.tuple_len(arys) == 1: | |||
| return _expand_atleast(*arys, ndim) | |||
| res = [] | |||
| for arr in res: | |||
| res.append(_expand_atleast(arr, ndim)) | |||
| for arr in arys: | |||
| arr = _expand(arr, ndim) | |||
| res.append(arr) | |||
| if len(res) == 1: | |||
| return res[0] | |||
| return res | |||
| @@ -770,10 +751,6 @@ def atleast_1d(*arys): | |||
| Note: | |||
| In graph mode, returns a tuple of tensor instead of a list of | |||
| tensors. | |||
| On CPU, the supported dtypes are np.float16, np.float32, np.int16, | |||
| and np.int32. | |||
| On GPU, the supported dtypes are np.float16, np.float32, np.int16, | |||
| and np.int32. | |||
| Args: | |||
| arys1, arys2, … (Tensor): one or more input tensors. | |||
| @@ -810,10 +787,6 @@ def atleast_2d(*arys): | |||
| Note: | |||
| In graph mode, returns a tuple of tensor instead of a list of | |||
| tensors. | |||
| On CPU, the supported dtypes are np.float16, np.float32, np.int16, | |||
| and np.int32. | |||
| On GPU, the supported dtypes are np.float16, np.float32, np.int16, | |||
| and np.int32. | |||
| Args: | |||
| arys1, arys2, … (Tensor): one or more input tensors. | |||
| @@ -850,10 +823,7 @@ def atleast_3d(*arys): | |||
| Note: | |||
| In graph mode, returns a tuple of tensor instead of a list of | |||
| tensors. | |||
| On CPU, the supported dtypes are np.float16, np.float32, np.int16, | |||
| and np.int32. | |||
| On GPU, the supported dtypes are np.float16, np.float32, np.int16, | |||
| and np.int32. | |||
| Args: | |||
| arys1, arys2, … (Tensor): one or more input tensors. | |||
| @@ -882,7 +852,19 @@ def atleast_3d(*arys): | |||
| value= [[[1.00000000e+000], [1.00000000e+000], [1.00000000e+000], | |||
| [1.00000000e+000], [1.00000000e+000]]])) | |||
| """ | |||
| return _atleast_xd(3, arys) | |||
| res = [] | |||
| for arr in arys: | |||
| ndim = F.rank(arr) | |||
| if ndim == 0: | |||
| arr = F.reshape(arr, (1, 1, 1)) | |||
| elif ndim == 1: | |||
| arr = F.reshape(arr, (1, F.size(arr), 1)) | |||
| elif ndim == 2: | |||
| arr = F.reshape(arr, F.shape(arr) + (1,)) | |||
| res.append(arr) | |||
| if len(res) == 1: | |||
| return res[0] | |||
| return res | |||
| def stack(arrays, axis=0): | |||
| @@ -960,31 +942,24 @@ class UniqueNet(Cell): | |||
| return self.unique(x) | |||
| def unique(x, return_index=False, return_inverse=False, return_counts=False): | |||
| def unique(x, return_inverse=False): | |||
| """ | |||
| Finds the unique elements of a tensor. The input tensor will be flattened first | |||
| when it has more than one dimension. | |||
| Note: | |||
| The operation is derived from mindspore.ops.Unique. | |||
| Numpy arguments `axis` is not supported. | |||
| Numpy arguments `axis`, `return_index` and `return_counts` are not supported. | |||
| This operator must be executed in graph mode. | |||
| Args: | |||
| x (Tensor): The input tensor to be processed. | |||
| return_index (bool): If True, also return the indices of tensor x (along | |||
| the specified axis, if provided, or in the flattened tensor) that result | |||
| in the unique tensor. Default: False. | |||
| return_inverse (bool): If True, also return the indices of the unique tensor. | |||
| Default: False. | |||
| return_counts (bool): If True, also return the number of times each unique | |||
| item appears in input tensor `x`. Default: False. | |||
| Returns: | |||
| Tensor or tuple of Tensors. | |||
| - If all of the three bool arguments (`return_index`, `return_inverse`, `return_counts`) | |||
| are False, just return the unique tensor. | |||
| - If parts of the three bool arguments are True, the corresponding results (Tensor) | |||
| will be added in the tuple. | |||
| - If `return_inverse` is False, just return the unique tensor. | |||
| - If `return_inverse` is True, return tuple of tensors. | |||
| Supported Platforms: | |||
| ``Ascend`` ``GPU`` ``CPU`` | |||
| @@ -995,14 +970,12 @@ def unique(x, return_index=False, return_inverse=False, return_counts=False): | |||
| Examples: | |||
| >>> import mindspore.numpy as mnp | |||
| >>> import numpy as onp | |||
| >>> from mindspore import context | |||
| >>> context.set_context(mode=context.GRAPH_MODE) | |||
| >>> input_x = mnp.asarray(onp.array([1, 2, 2, 2, 3, 4, 5]).astype('float32')) | |||
| >>> output_x = mnp.unique(input_x) | |||
| >>> print(output_x) | |||
| [1. 2. 3. 4. 5.] | |||
| >>> output_x = mnp.unique(input_x, return_index=True) | |||
| >>> print(output_x) | |||
| (Tensor(shape=[5], dtype=Float32, value= [ 1. 2. 3. 4. 5.]), Tensor(shape=[5], dtype=Float32, | |||
| value= [ 0. 1. 4. 5. 6.])) | |||
| >>> output_x = mnp.unique(input_x, return_inverse=True) | |||
| >>> print(output_x) | |||
| (Tensor(shape=[5], dtype=Float32, value= [ 1. 2. 3. 4. 5.]), Tensor(shape=[7], dtype=Int32, | |||
| @@ -1013,16 +986,7 @@ def unique(x, return_index=False, return_inverse=False, return_counts=False): | |||
| if F.tuple_len(F.shape(x)) > 1: | |||
| x = ravel(x) | |||
| uniq = UniqueNet() | |||
| unique_x, inverse_index = uniq(x) | |||
| if not return_index and not return_inverse and not return_counts: | |||
| return unique_x | |||
| res_tup = (unique_x,) | |||
| if return_index: | |||
| res_index = _get_index_for_unique(x, unique_x) | |||
| res_tup += (res_index,) | |||
| if return_inverse: | |||
| res_tup += (inverse_index,) | |||
| if return_counts: | |||
| res_counts = _get_counts_for_unique(x, unique_x) | |||
| res_tup += (res_counts,) | |||
| return res_tup | |||
| res = uniq(x) | |||
| if not return_inverse: | |||
| return res[0] | |||
| return res | |||
| @@ -1,4 +1,4 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| @@ -146,6 +146,9 @@ promotion_rule = { | |||
| (int32, float16): float16, | |||
| (int32, float32): float32, | |||
| (int32, float64): float64, | |||
| (int64, float16): float16, | |||
| (int64, float32): float32, | |||
| (int64, float64): float64, | |||
| (float16, float32): float32, | |||
| (float16, float64): float64, | |||
| (float32, float64): float64, | |||
| @@ -1,4 +1,4 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| @@ -72,7 +72,8 @@ def absolute(x, out=None, where=True, dtype=None): | |||
| ``Ascend`` ``GPU`` ``CPU`` | |||
| Examples: | |||
| >>> x = np.asarray([1, 2, 3, -4, -5], np.float64) | |||
| >>> import mindspore.numpy as np | |||
| >>> x = np.asarray([1, 2, 3, -4, -5], np.float32) | |||
| >>> output = np.absolute(x) | |||
| >>> print(output) | |||
| [1. 2. 3. 4. 5.] | |||
| @@ -97,10 +98,6 @@ def add(x1, x2, out=None, where=True, dtype=None): | |||
| Argument out is not supported for storing the result, however it can be | |||
| used in combination with argument where to set the value at indices for | |||
| which where is set to False. | |||
| On GPU, the supported dtypes are np.float16, np.float32, np.int32, | |||
| and np.int64. | |||
| On CPU, the supported dtypes are np.float16, np.float32, np.float64, | |||
| np.int16, np.int32, and np.int64. | |||
| Args: | |||
| x1 (Tensor): input to be added. | |||
| @@ -154,10 +151,6 @@ def subtract(x1, x2, out=None, where=True, dtype=None): | |||
| Argument out is not supported for storing the result, however it can be | |||
| used in combination with argument where to set the value at indices for | |||
| which where is set to False. | |||
| On GPU, the supported dtypes are np.float16, np.float32, np.int32, | |||
| and np.int64. | |||
| On CPU, the supported dtypes are np.float16, np.float32, np.float64, | |||
| np.int16, np.int32, and np.int64. | |||
| Args: | |||
| x1 (Tensor): the input to be subtracted from. | |||
| @@ -207,10 +200,6 @@ def multiply(x1, x2, out=None, where=True, dtype=None): | |||
| Argument out is not supported for storing the result, however it can be | |||
| used in combination with argument where to set the value at indices for | |||
| which where is set to False. | |||
| On GPU, the supported dtypes are np.float16, np.float32, np.int32, | |||
| and np.int64. | |||
| On CPU, the supported dtypes are np.float16, np.float32, np.float64, | |||
| np.int16, np.int32, and np.int64. | |||
| Args: | |||
| x1 (Tensor): input tensor to be multiplied. | |||
| @@ -273,8 +262,6 @@ def divide(x1, x2, out=None, where=True, dtype=None): | |||
| used in combination with argument where to set the value at indices for | |||
| which where is set to False. | |||
| On GPU, the supported dtypes are np.float16, and np.float32. | |||
| On CPU, the supported dtypes are np.float16, np.float32, np.float64, | |||
| np.int16, np.int32, and np.int64. | |||
| Args: | |||
| x1 (Tensor): the divident. | |||
| @@ -325,12 +312,10 @@ def power(x1, x2, out=None, where=True, dtype=None): | |||
| Numpy arguments casting, order, dtype, subok, signature, and extobj are | |||
| not supported. | |||
| On GPU, the supported dtypes are np.float16, and np.float32. | |||
| On CPU, the supported dtypes are np.float16, np.float32, np.float64, | |||
| np.int16, np.int32, and np.int64. | |||
| Args: | |||
| x1 (Tensor): the bases. | |||
| x2 (Tensor): the exponenets. | |||
| x2 (Tensor): the exponents. | |||
| out (Tensor or None): optional, defaults to None. | |||
| where (Tensor or None): optional. For any non-default value of type other | |||
| than Tensor or None, the output retains its original value. | |||
| @@ -345,7 +330,7 @@ def power(x1, x2, out=None, where=True, dtype=None): | |||
| Returns: | |||
| Tensor or scalar, the bases in x1 raised to the exponents in x2. This | |||
| is a scalarif both x1 and x2 are scalars. | |||
| is a scalar if both x1 and x2 are scalars. | |||
| Raises: | |||
| TypeError: if the input is not a tensor. | |||
| @@ -354,8 +339,8 @@ def power(x1, x2, out=None, where=True, dtype=None): | |||
| ``Ascend`` ``GPU`` ``CPU`` | |||
| Examples: | |||
| >>> x1 = np.full((3, 2), [1, 2]) | |||
| >>> x2 = np.full((3, 2), [3, 4]) | |||
| >>> x1 = np.full((3, 2), [1, 2]).astype('float32') | |||
| >>> x2 = np.full((3, 2), [3, 4]).astype('float32') | |||
| >>> output = np.power(x1, x2) | |||
| >>> print(output) | |||
| [[ 1, 16], | |||
| @@ -548,8 +533,8 @@ def dot(a, b): | |||
| Examples: | |||
| >>> import mindspore.numpy as np | |||
| >>> a = np.full((1, 3), 7) | |||
| >>> b = np.full((2, 3, 4), 5) | |||
| >>> a = np.full((1, 3), 7).astype('float32') | |||
| >>> b = np.full((2, 3, 4), 5).astype('float32') | |||
| >>> output = np.dot(a, b) | |||
| >>> print(output) | |||
| [[[105, 105, 105, 105], | |||
| @@ -597,8 +582,8 @@ def outer(a, b): | |||
| Examples: | |||
| >>> import mindspore.numpy as np | |||
| >>> a = np.full(7, 2) | |||
| >>> b = np.full(4, 3) | |||
| >>> a = np.full(7, 2).astype('float32') | |||
| >>> b = np.full(4, 3).astype('float32') | |||
| >>> output = np.outer(a, b) | |||
| >>> print(output) | |||
| [[6, 6, 6, 6], | |||
| @@ -1,4 +1,4 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| @@ -77,7 +77,7 @@ def _get_device(): | |||
| return context.get_context('device_target') | |||
| def _covert_list_tensor_to_tuple_tensor(list_of_tensor): | |||
| def _convert_list_tensor_to_tuple_tensor(list_of_tensor): | |||
| """Convert a list of tensor to a tuple of tensor""" | |||
| if isinstance(list_of_tensor, list): | |||
| tuple_of_tensor = () | |||
| @@ -1,4 +1,4 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| @@ -15,10 +15,7 @@ | |||
| """internal graph-compatible utility functions""" | |||
| from functools import partial | |||
| import numpy as onp | |||
| import mindspore.context as context | |||
| from ..common import Tensor | |||
| from ..ops import functional as F | |||
| from ..ops.primitive import constexpr | |||
| from ..common import dtype as mstype | |||
| @@ -262,65 +259,6 @@ def _empty(dtype, shape): | |||
| return Tensor_(dtype, shape) | |||
| def _get_index_for_unique(input_x, unique_x): | |||
| """ | |||
| Return the indices of the first occurrences of the unique values in the original array. | |||
| Args: | |||
| input_x (Tensor): The flattened input tensor of `mindspore.numpy.unique`. | |||
| unique_x (Tensor): The tensor contains the unique elements in `input_x`, sorted in ascending order. | |||
| Returns: | |||
| Tensor. The indices of the unique values in the original array. Has the same shape as `unique_x`. | |||
| """ | |||
| o_array = input_x.asnumpy() | |||
| dic = {} | |||
| for idx in range(o_array.size): | |||
| val = o_array[idx] | |||
| if val not in dic: | |||
| dic[val] = idx | |||
| index_lst = [] | |||
| u_array = unique_x.asnumpy() | |||
| for idx in range(u_array.size): | |||
| index_lst.append(dic[u_array[idx]]) | |||
| return Tensor(onp.array(index_lst), input_x.dtype) | |||
| @constexpr | |||
| def _get_counts_for_unique(input_x, unique_x): | |||
| """ | |||
| Return the number of times each of the unique values comes up in the original tensor. | |||
| Args: | |||
| input_x (Tensor): The flattened input tensor of `mindspore.numpy.unique`. | |||
| unique_x (Tensor): The tensor contains the unique elements in `input_x`, sorted in ascending order. | |||
| Returns: | |||
| Tensor. The number of times each of the unique values comes up in the original tensor. | |||
| """ | |||
| dic = {} | |||
| o_array = input_x.asnumpy() | |||
| for idx in range(o_array.size): | |||
| val = o_array[idx] | |||
| if val not in dic: | |||
| dic[val] = 1 | |||
| else: | |||
| dic[val] += 1 | |||
| u_array = unique_x.asnumpy() | |||
| counts_lst = [dic[val] for val in u_array] | |||
| return Tensor(onp.array(counts_lst), input_x.dtype) | |||
| @constexpr | |||
| def _get_max_value(x): | |||
| """Returns the maximum value of the input tensor `x`. """ | |||
| return int(max(x.asnumpy())) | |||
| @constexpr | |||
| def _promote(dtype1, dtype2): | |||
| if dtype1 == dtype2: | |||
| @@ -355,7 +293,8 @@ def _check_same_type(dtype1, dtype2): | |||
| @constexpr | |||
| def _check_is_float(dtype): | |||
| return dtype in mstype.float_type | |||
| """Returns whether dtype is float16 or float32.""" | |||
| return dtype in (mstype.float16, mstype.float32) | |||
| @constexpr | |||
| @@ -1,4 +1,4 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| @@ -1,4 +1,4 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| @@ -194,7 +194,7 @@ def match_all_arrays(mnp_res, onp_res, error=0): | |||
| def match_meta(actual, expected): | |||
| # float64 and int64 are not supported, and the defualt type for | |||
| # float64 and int64 are not supported, and the default type for | |||
| # float and int are float32 and int32, respectively | |||
| if expected.dtype == onp.float64: | |||
| expected = expected.astype(onp.float32) | |||
| @@ -1,4 +1,4 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| @@ -168,6 +168,7 @@ def match_res(mnp_fn, onp_fn, *arrs, **kwargs): | |||
| def match_all_arrays(mnp_res, onp_res, error=0): | |||
| if isinstance(mnp_res, (tuple, list)): | |||
| assert len(mnp_res) == len(onp_res) | |||
| for actual, expected in zip(mnp_res, onp_res): | |||
| match_array(actual.asnumpy(), expected, error) | |||
| else: | |||
| @@ -175,7 +176,7 @@ def match_all_arrays(mnp_res, onp_res, error=0): | |||
| def match_meta(actual, expected): | |||
| # float64 and int64 are not supported, and the defualt type for | |||
| # float64 and int64 are not supported, and the default type for | |||
| # float and int are float32 and int32, respectively | |||
| if expected.dtype == onp.float64: | |||
| expected = expected.astype(onp.float32) | |||
| @@ -1,4 +1,4 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||