GitOrigin-RevId: 15eb08bacb
tags/v1.3.0
| @@ -13,19 +13,23 @@ import numbers | |||
| from typing import Optional, Sequence, Tuple, Union | |||
| from ..core._imperative_rt.core2 import apply | |||
| from ..core._trace_option import use_symbolic_shape | |||
| from ..core.ops import builtin | |||
| from ..core.ops.special import Const | |||
| from ..core.tensor import utils | |||
| from ..tensor import Tensor | |||
| from .debug_param import get_conv_execution_strategy | |||
| from .elemwise import clip, exp, log, log1p | |||
| from .tensor import reshape, squeeze | |||
| from .tensor import broadcast_to, concat, expand_dims, reshape, squeeze | |||
| __all__ = [ | |||
| "argmax", | |||
| "argmin", | |||
| "argsort", | |||
| "dot", | |||
| "isinf", | |||
| "isnan", | |||
| "matmul", | |||
| "max", | |||
| "mean", | |||
| "min", | |||
| @@ -36,6 +40,7 @@ __all__ = [ | |||
| "sort", | |||
| "std", | |||
| "sum", | |||
| "svd", | |||
| "topk", | |||
| "var", | |||
| ] | |||
| @@ -663,7 +668,7 @@ def topk( | |||
| no_sort: bool = False, | |||
| ) -> Tuple[Tensor, Tensor]: | |||
| r""" | |||
| Selects the ``Top-K``(by default) smallest elements of 2d matrix by row. | |||
| Selects the ``Top-K`` (by default) smallest elements of 2d matrix by row. | |||
| :param inp: input tensor. If input tensor is 2d, each row will be sorted. | |||
| :param k: number of elements needed. | |||
| @@ -722,3 +727,204 @@ def topk( | |||
| if descending: | |||
| tns = -tns | |||
| return tns, ind | |||
| def matmul( | |||
| inp1: Tensor, | |||
| inp2: Tensor, | |||
| transpose_a=False, | |||
| transpose_b=False, | |||
| compute_mode="DEFAULT", | |||
| format="DEFAULT", | |||
| ) -> Tensor: | |||
| """ | |||
| Performs a matrix multiplication of the matrices ``inp1`` and ``inp2``. | |||
| With different inputs dim, this function behaves differently: | |||
| - Both 1-D tensor, simply forward to ``dot``. | |||
| - Both 2-D tensor, normal matrix multiplication. | |||
| - If one input tensor is 1-D, matrix vector multiplication. | |||
| - If at least one tensor are 3-dimensional or >3-dimensional, the other tensor should have dim >= 2, the batched matrix-matrix is returned, and the tensor with smaller dimension will be broadcasted. For example: | |||
| - inp1: `(n, k, m)`, inp2: `(n, m, p)`, return: `(n, k, p)` | |||
| - inp1: `(n, k, m)`, inp2: `(m, p)`, return: `(n, k, p)` | |||
| - inp1: `(n, j, k, m)`, inp2: `(n, j, m, p)`, return: `(n, j, k, p)` | |||
| :param inp1: first matrix to be multiplied. | |||
| :param inp2: second matrix to be multiplied. | |||
| :return: output tensor. | |||
| Examples: | |||
| .. testcode:: | |||
| import numpy as np | |||
| from megengine import tensor | |||
| import megengine.functional as F | |||
| data1 = tensor(np.arange(0, 6, dtype=np.float32).reshape(2, 3)) | |||
| data2 = tensor(np.arange(0, 6, dtype=np.float32).reshape(3, 2)) | |||
| out = F.matmul(data1, data2) | |||
| print(out.numpy()) | |||
| Outputs: | |||
| .. testoutput:: | |||
| [[10. 13.] | |||
| [28. 40.]] | |||
| """ | |||
| remove_row, remove_col = False, False | |||
| inp1, inp2 = utils.convert_inputs(inp1, inp2) | |||
| dim1, dim2 = inp1.ndim, inp2.ndim | |||
| # handle dim=1 cases, dot and matrix-vector multiplication | |||
| if dim1 == 1 and dim2 == 1: | |||
| return dot(inp1, inp2) | |||
| # the underlying matmul op requires input dims to be at least 2 | |||
| if dim1 == 1: | |||
| inp1 = expand_dims(inp1, 0) | |||
| dim1 = 2 | |||
| remove_row = True | |||
| if dim2 == 1: | |||
| inp2 = expand_dims(inp2, 1) | |||
| dim2 = 2 | |||
| remove_col = True | |||
| batch_shape = None | |||
| shape1 = inp1.shape | |||
| shape2 = inp2.shape | |||
| maxdim = dim1 if dim1 > dim2 else dim2 | |||
| if dim1 >= 3 or dim2 >= 3: | |||
| if use_symbolic_shape(): | |||
| if dim1 > dim2: | |||
| shape2 = concat([shape1[:-2], shape2[-2:]]) | |||
| inp2 = broadcast_to(inp2, shape2) | |||
| if dim1 < dim2: | |||
| shape1 = concat([shape2[:-2], shape1[-2:]]) | |||
| inp1 = broadcast_to(inp1, shape1) | |||
| if maxdim > 3: | |||
| batch_shape = shape1[:-2] | |||
| # compress inputs to 3d | |||
| (inp1,) = apply( | |||
| builtin.Reshape(), inp1, concat([prod(shape1[:-2]), shape1[-2:]]) | |||
| ) | |||
| (inp2,) = apply( | |||
| builtin.Reshape(), inp2, concat([prod(shape2[:-2]), shape2[-2:]]) | |||
| ) | |||
| else: | |||
| if dim1 > dim2: | |||
| shape2 = shape1[:-2] + shape2[-2:] | |||
| inp2 = broadcast_to(inp2, shape2) | |||
| if dim1 < dim2: | |||
| shape1 = shape2[:-2] + shape1[-2:] | |||
| inp1 = broadcast_to(inp1, shape1) | |||
| if maxdim > 3: | |||
| batch_shape = shape1[:-2] | |||
| # compress inputs to 3d | |||
| inp1 = inp1.reshape((-1, shape1[-2], shape1[-1])) | |||
| inp2 = inp2.reshape((-1, shape2[-2], shape2[-1])) | |||
| op = builtin.BatchedMatrixMul( | |||
| transposeA=transpose_a, | |||
| transposeB=transpose_b, | |||
| compute_mode=compute_mode, | |||
| format=format, | |||
| strategy=get_conv_execution_strategy(), | |||
| ) | |||
| else: | |||
| op = builtin.MatrixMul( | |||
| transposeA=transpose_a, | |||
| transposeB=transpose_b, | |||
| compute_mode=compute_mode, | |||
| format=format, | |||
| strategy=get_conv_execution_strategy(), | |||
| ) | |||
| (result,) = apply(op, inp1, inp2) | |||
| if maxdim > 3: | |||
| if use_symbolic_shape(): | |||
| (result,) = apply( | |||
| builtin.Reshape(), result, concat([batch_shape, result.shape[-2:]]) | |||
| ) | |||
| else: | |||
| result = result.reshape(batch_shape + result.shape[-2:]) | |||
| if remove_row: | |||
| result = squeeze(result, axis=-2) | |||
| if remove_col: | |||
| result = squeeze(result, axis=-1) | |||
| return result | |||
| def dot(inp1: Tensor, inp2: Tensor) -> Tensor: | |||
| """ | |||
| Computes dot-product of two vectors ``inp1`` and ``inp2``. | |||
| inputs must be 1-dimensional or scalar. A scalar input is automatically broadcasted. | |||
| Refer to :func:`~.matmul` for more general usage. | |||
| :param inp1: first vector. | |||
| :param inp2: second vector. | |||
| :return: output value. | |||
| Examples: | |||
| .. testcode:: | |||
| import numpy as np | |||
| from megengine import tensor | |||
| import megengine.functional as F | |||
| data1 = tensor(np.arange(0, 6, dtype=np.float32)) | |||
| data2 = tensor(np.arange(0, 6, dtype=np.float32)) | |||
| out = F.dot(data1, data2) | |||
| print(out.numpy()) | |||
| Outputs: | |||
| .. testoutput:: | |||
| 55. | |||
| """ | |||
| op = builtin.Dot() | |||
| inp1, inp2 = utils.convert_inputs(inp1, inp2) | |||
| assert ( | |||
| inp1.ndim <= 1 and inp2.ndim <= 1 | |||
| ), "Input tensors for dot must be 1-dimensional or scalar" | |||
| (result,) = apply(op, inp1, inp2) | |||
| utils.setscalar(result) | |||
| return result | |||
| def svd(inp: Tensor, full_matrices=False, compute_uv=True) -> Tensor: | |||
| """ | |||
| Computes the singular value decompositions of input matrix. | |||
| :param inp: input matrix, must has shape `[..., M, N]`. | |||
| :return: output matrices, `(U, sigma, V)`. | |||
| Examples: | |||
| .. testcode:: | |||
| import numpy as np | |||
| from megengine import tensor | |||
| import megengine.functional as F | |||
| x = tensor(np.arange(0, 6, dtype=np.float32).reshape(2,3)) | |||
| _, y, _ = F.svd(x) | |||
| print(y.numpy().round(decimals=3)) | |||
| Outputs: | |||
| .. testoutput:: | |||
| [7.348 1. ] | |||
| """ | |||
| op = builtin.SVD(full_matrices=full_matrices, compute_uv=compute_uv) | |||
| U, sigma, V = apply(op, inp) | |||
| return U, sigma, V | |||
| @@ -25,7 +25,7 @@ from ..utils.tuple_function import _pair, _pair_nonzero | |||
| from .debug_param import get_conv_execution_strategy | |||
| from .distributed import all_reduce_sum | |||
| from .elemwise import exp, floor, log, log1p, maximum, minimum, relu | |||
| from .math import argsort, max, prod, sum | |||
| from .math import argsort, matmul, max, prod, sum | |||
| from .tensor import ( | |||
| broadcast_to, | |||
| concat, | |||
| @@ -46,7 +46,6 @@ __all__ = [ | |||
| "conv_transpose2d", | |||
| "deformable_conv2d", | |||
| "deformable_psroi_pooling", | |||
| "dot", | |||
| "dropout", | |||
| "indexing_one_hot", | |||
| "leaky_relu", | |||
| @@ -55,7 +54,6 @@ __all__ = [ | |||
| "logsumexp", | |||
| "logsoftmax", | |||
| "matinv", | |||
| "matmul", | |||
| "max_pool2d", | |||
| "one_hot", | |||
| "prelu", | |||
| @@ -63,7 +61,6 @@ __all__ = [ | |||
| "resize", | |||
| "softmax", | |||
| "softplus", | |||
| "svd", | |||
| "warp_affine", | |||
| "warp_perspective", | |||
| "conv1d", | |||
| @@ -1221,207 +1218,6 @@ def matinv(inp: Tensor) -> Tensor: | |||
| return result | |||
| def matmul( | |||
| inp1: Tensor, | |||
| inp2: Tensor, | |||
| transpose_a=False, | |||
| transpose_b=False, | |||
| compute_mode="DEFAULT", | |||
| format="DEFAULT", | |||
| ) -> Tensor: | |||
| """ | |||
| Performs a matrix multiplication of the matrices ``inp1`` and ``inp2``. | |||
| With different inputs dim, this function behaves differently: | |||
| - Both 1-D tensor, simply forward to ``dot``. | |||
| - Both 2-D tensor, normal matrix multiplication. | |||
| - If one input tensor is 1-D, matrix vector multiplication. | |||
| - If at least one tensor are 3-dimensional or >3-dimensional, the other tensor should have dim >= 2, the batched matrix-matrix is returned, and the tensor with smaller dimension will | |||
| be broadcasted. For example: | |||
| - inp1: `(n, k, m)`, inp2: `(n, m, p)`, return: `(n, k, p)` | |||
| - inp1: `(n, k, m)`, inp2: `(m, p)`, return: `(n, k, p)` | |||
| - inp1: `(n, j, k, m)`, inp2: `(n, j, m, p)`, return: `(n, j, k, p)` | |||
| :param inp1: first matrix to be multiplied. | |||
| :param inp2: second matrix to be multiplied. | |||
| :return: output tensor. | |||
| Examples: | |||
| .. testcode:: | |||
| import numpy as np | |||
| from megengine import tensor | |||
| import megengine.functional as F | |||
| data1 = tensor(np.arange(0, 6, dtype=np.float32).reshape(2, 3)) | |||
| data2 = tensor(np.arange(0, 6, dtype=np.float32).reshape(3, 2)) | |||
| out = F.matmul(data1, data2) | |||
| print(out.numpy()) | |||
| Outputs: | |||
| .. testoutput:: | |||
| [[10. 13.] | |||
| [28. 40.]] | |||
| """ | |||
| remove_row, remove_col = False, False | |||
| inp1, inp2 = utils.convert_inputs(inp1, inp2) | |||
| dim1, dim2 = inp1.ndim, inp2.ndim | |||
| # handle dim=1 cases, dot and matrix-vector multiplication | |||
| if dim1 == 1 and dim2 == 1: | |||
| return dot(inp1, inp2) | |||
| # the underlying matmul op requires input dims to be at least 2 | |||
| if dim1 == 1: | |||
| inp1 = expand_dims(inp1, 0) | |||
| dim1 = 2 | |||
| remove_row = True | |||
| if dim2 == 1: | |||
| inp2 = expand_dims(inp2, 1) | |||
| dim2 = 2 | |||
| remove_col = True | |||
| batch_shape = None | |||
| shape1 = inp1.shape | |||
| shape2 = inp2.shape | |||
| maxdim = dim1 if dim1 > dim2 else dim2 | |||
| if dim1 >= 3 or dim2 >= 3: | |||
| if use_symbolic_shape(): | |||
| if dim1 > dim2: | |||
| shape2 = concat([shape1[:-2], shape2[-2:]]) | |||
| inp2 = broadcast_to(inp2, shape2) | |||
| if dim1 < dim2: | |||
| shape1 = concat([shape2[:-2], shape1[-2:]]) | |||
| inp1 = broadcast_to(inp1, shape1) | |||
| if maxdim > 3: | |||
| batch_shape = shape1[:-2] | |||
| # compress inputs to 3d | |||
| (inp1,) = apply( | |||
| builtin.Reshape(), inp1, concat([prod(shape1[:-2]), shape1[-2:]]) | |||
| ) | |||
| (inp2,) = apply( | |||
| builtin.Reshape(), inp2, concat([prod(shape2[:-2]), shape2[-2:]]) | |||
| ) | |||
| else: | |||
| if dim1 > dim2: | |||
| shape2 = shape1[:-2] + shape2[-2:] | |||
| inp2 = broadcast_to(inp2, shape2) | |||
| if dim1 < dim2: | |||
| shape1 = shape2[:-2] + shape1[-2:] | |||
| inp1 = broadcast_to(inp1, shape1) | |||
| if maxdim > 3: | |||
| batch_shape = shape1[:-2] | |||
| # compress inputs to 3d | |||
| inp1 = inp1.reshape((-1, shape1[-2], shape1[-1])) | |||
| inp2 = inp2.reshape((-1, shape2[-2], shape2[-1])) | |||
| op = builtin.BatchedMatrixMul( | |||
| transposeA=transpose_a, | |||
| transposeB=transpose_b, | |||
| compute_mode=compute_mode, | |||
| format=format, | |||
| strategy=get_conv_execution_strategy(), | |||
| ) | |||
| else: | |||
| op = builtin.MatrixMul( | |||
| transposeA=transpose_a, | |||
| transposeB=transpose_b, | |||
| compute_mode=compute_mode, | |||
| format=format, | |||
| strategy=get_conv_execution_strategy(), | |||
| ) | |||
| (result,) = apply(op, inp1, inp2) | |||
| if maxdim > 3: | |||
| if use_symbolic_shape(): | |||
| (result,) = apply( | |||
| builtin.Reshape(), result, concat([batch_shape, result.shape[-2:]]) | |||
| ) | |||
| else: | |||
| result = result.reshape(batch_shape + result.shape[-2:]) | |||
| if remove_row: | |||
| result = squeeze(result, axis=-2) | |||
| if remove_col: | |||
| result = squeeze(result, axis=-1) | |||
| return result | |||
| def dot(inp1: Tensor, inp2: Tensor) -> Tensor: | |||
| """ | |||
| Computes dot-product of two vectors ``inp1`` and ``inp2``. | |||
| inputs must be 1-dimensional or scalar. A scalar input is automatically broadcasted. | |||
| Refer to :func:`~.matmul` for more general usage. | |||
| :param inp1: first vector. | |||
| :param inp2: second vector. | |||
| :return: output value. | |||
| Examples: | |||
| .. testcode:: | |||
| import numpy as np | |||
| from megengine import tensor | |||
| import megengine.functional as F | |||
| data1 = tensor(np.arange(0, 6, dtype=np.float32)) | |||
| data2 = tensor(np.arange(0, 6, dtype=np.float32)) | |||
| out = F.dot(data1, data2) | |||
| print(out.numpy()) | |||
| Outputs: | |||
| .. testoutput:: | |||
| 55. | |||
| """ | |||
| op = builtin.Dot() | |||
| inp1, inp2 = utils.convert_inputs(inp1, inp2) | |||
| assert ( | |||
| inp1.ndim <= 1 and inp2.ndim <= 1 | |||
| ), "Input tensors for dot must be 1-dimensional or scalar" | |||
| (result,) = apply(op, inp1, inp2) | |||
| setscalar(result) | |||
| return result | |||
| def svd(inp: Tensor, full_matrices=False, compute_uv=True) -> Tensor: | |||
| """ | |||
| Computes the singular value decompositions of input matrix. | |||
| :param inp: input matrix, must has shape `[..., M, N]`. | |||
| :return: output matrices, `(U, sigma, V)`. | |||
| Examples: | |||
| .. testcode:: | |||
| import numpy as np | |||
| from megengine import tensor | |||
| import megengine.functional as F | |||
| x = tensor(np.arange(0, 6, dtype=np.float32).reshape(2,3)) | |||
| _, y, _ = F.svd(x) | |||
| print(y.numpy().round(decimals=3)) | |||
| Outputs: | |||
| .. testoutput:: | |||
| [7.348 1. ] | |||
| """ | |||
| op = builtin.SVD(full_matrices=full_matrices, compute_uv=compute_uv) | |||
| U, sigma, V = apply(op, inp) | |||
| return U, sigma, V | |||
| def interpolate( | |||
| inp: Tensor, | |||
| size: Optional[Union[int, Tuple[int, int]]] = None, | |||
| @@ -707,7 +707,7 @@ def transpose(inp: Tensor, pattern: Iterable[int]) -> Tensor: | |||
| :param inp: input tensor. | |||
| :param pattern: a list of integers including 0, 1, ... , ``ndim``-1, | |||
| and any number of ``'x'`` char in dimensions where this tensor should be broadcasted. For examples: | |||
| and any number of ``'x'`` char in dimensions where this tensor should be broadcasted. For examples: | |||
| * (``'x'``) -> make a 0d (scalar) into a 1d vector | |||
| * (0, 1) -> identity for 2d vectors | |||