From: @jachua Reviewed-by: Signed-off-by:tags/v1.2.0-rc1
| @@ -31,8 +31,8 @@ from .array_ops import ravel, expand_dims | |||
| from .utils_const import _infer_out_shape, _check_axis_valid, _get_device, \ | |||
| _check_shape_aligned, _raise_type_error, _check_same_type, _check_is_float, \ | |||
| _raise_value_error, _check_matmul_shapes, _promote, _check_axis_type, _canonicalize_axis, \ | |||
| _max, _is_shape_empty, _check_is_int, _expanded_shape, _check_axis_in_range | |||
| _raise_value_error, _promote, _check_axis_type, _canonicalize_axis, \ | |||
| _is_shape_empty, _check_is_int, _expanded_shape, _check_axis_in_range | |||
| from .utils import _is_scalar, _expand, _broadcast_to, _broadcast_to_shape, _get_size, \ | |||
| _check_input_tensor | |||
| @@ -1285,44 +1285,7 @@ def matmul(x1, x2, dtype=None): | |||
| [ 550. 620. 690. 760. 830.] | |||
| [ 670. 756. 842. 928. 1014.]]] | |||
| """ | |||
| # performs type promotion | |||
| dtype1 = F.dtype(x1) | |||
| dtype2 = F.dtype(x2) | |||
| dtype_out = _promote(dtype1, dtype2) | |||
| if not _check_same_type(dtype1, dtype_out): | |||
| x1 = F.cast(x1, dtype_out) | |||
| if not _check_same_type(dtype2, dtype_out): | |||
| x2 = F.cast(x2, dtype_out) | |||
| ndim1_orig, ndim2_orig = F.rank(x1), F.rank(x2) | |||
| shape1_orig, shape2_orig = F.shape(x1), F.shape(x2) | |||
| _check_matmul_shapes(shape1_orig, shape2_orig) | |||
| ndim_aligned = _max(ndim1_orig, ndim2_orig) | |||
| transpose_b = ndim2_orig == 1 | |||
| shape_backbone = _infer_out_shape( | |||
| shape1_orig[:-2], shape2_orig[:-2]) | |||
| # infers the shape of the output | |||
| shape_out = shape_backbone + _infer_shape_rem(shape1_orig, shape2_orig, | |||
| ndim1_orig, ndim2_orig, transpose_b) | |||
| x1 = _expand(x1, _max(ndim_aligned, 2)) | |||
| x2 = _expand(x2, _max(ndim_aligned, 2)) | |||
| shape1_aligned, shape2_aligned = F.shape(x1), F.shape(x2) | |||
| if ndim_aligned <= 2: | |||
| res = P.MatMul(False, transpose_b)(x1, x2) | |||
| else: | |||
| # broadcasts x1.shape[:-2] with x2.shape[:-2] | |||
| shape_aligned = shape_backbone + _infer_shape_rem(shape1_aligned, shape2_aligned, | |||
| ndim_aligned, ndim_aligned, | |||
| transpose_b) | |||
| x1 = _broadcast_to(x1, shape1_aligned[:-2], shape_aligned[:-2], ndim_aligned) | |||
| x2 = _broadcast_to(x2, shape2_aligned[:-2], shape_aligned[:-2], ndim_aligned) | |||
| res = P.BatchMatMul(False, transpose_b)(x1, x2) | |||
| if dtype is not None and not _check_same_type(dtype_out, dtype): | |||
| res = F.cast(res, dtype) | |||
| return F.reshape(res, shape_out) | |||
| return C.matmul(x1, x2, dtype=dtype) | |||
| def square(x, out=None, where=True, dtype=None): | |||
| @@ -2256,20 +2219,6 @@ def _shape_reduced(shape, axes): | |||
| return tuple(shape_out) | |||
| def _infer_shape_rem(shape1, shape2, ndim1, ndim2, transpose_b): | |||
| """Infers the shape of the last two dimensions after performing matmul.""" | |||
| shape_rem = () | |||
| if ndim1 >= 2: | |||
| shape_rem += (shape1[-2],) | |||
| if transpose_b: | |||
| if ndim2 >= 2: | |||
| shape_rem += (shape2[-2],) | |||
| else: | |||
| if ndim1 >= 1: | |||
| shape_rem += (shape2[-1],) | |||
| return shape_rem | |||
| def _reduce(a, reduce_fn, cmp_fn, axis=None, keepdims=False, initial=None, where=True): | |||
| """Applies comparison based on cmp_fn and reduction based on reduce_fn""" | |||
| _check_input_tensor(a) | |||
| @@ -278,17 +278,6 @@ def _check_is_int(dtype): | |||
| return isinstance(dtype, typing.Int) | |||
| @constexpr | |||
| def _check_matmul_shapes(shape1, shape2): | |||
| """Checks shape1 and shape2 are valid shapes to perform matmul""" | |||
| ndim1, ndim2 = len(shape1), len(shape2) | |||
| if ndim1 < 1 or ndim2 < 1: | |||
| raise ValueError('input operands must have at least 1 dimension') | |||
| if ndim2 >= 2 and shape1[-1] != shape2[-2]: | |||
| raise ValueError(f'mismatch in core dimension of input operands (size ' | |||
| f'{shape1[-1]} is different from {shape2[-2]})') | |||
| @constexpr | |||
| def _check_axis_type(axis, type_int=True, type_tuple=True, type_list=True): | |||
| """Check axis argument type.""" | |||
| @@ -27,7 +27,7 @@ from .multitype_ops.add_impl import hyper_add | |||
| from .multitype_ops.ones_like_impl import ones_like | |||
| from .multitype_ops.zeros_like_impl import zeros_like | |||
| from .random_ops import normal, laplace, uniform, gamma, poisson, multinomial | |||
| from .math_ops import count_nonzero, tensor_dot, dot, batch_dot | |||
| from .math_ops import count_nonzero, tensor_dot, dot, batch_dot, matmul | |||
| from .array_ops import repeat_elements, sequence_mask | |||
| @@ -56,4 +56,5 @@ __all__ = [ | |||
| 'dot', | |||
| 'batch_dot', | |||
| 'repeat_elements', | |||
| 'sequence_mask'] | |||
| 'sequence_mask', | |||
| 'matmul'] | |||
| @@ -13,6 +13,8 @@ | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """math Operations.""" | |||
| from itertools import zip_longest | |||
| from collections import deque | |||
| import numpy as np | |||
| from mindspore.ops.composite.multitype_ops import _constexpr_utils as const_utils | |||
| from mindspore.common import dtype as mstype | |||
| @@ -486,3 +488,147 @@ def batch_dot(x1, x2, axes=None): | |||
| final_result = squeeze_minus_one_op(final_result) | |||
| return final_result | |||
| @constexpr | |||
| def _check_same_type(dtype1, dtype2): | |||
| return dtype1 == dtype2 | |||
| @constexpr | |||
| def _max(*args): | |||
| """Returns the maximum value.""" | |||
| return max(*args) | |||
| @constexpr | |||
| def _min(*args): | |||
| """Returns the minimum value.""" | |||
| return min(*args) | |||
| @constexpr | |||
| def _infer_shape_rem(shape1, shape2, ndim1, ndim2, transpose_b): | |||
| """Infers the shape of the last two dimensions after performing matmul.""" | |||
| shape_rem = [] | |||
| if ndim1 >= 2: | |||
| shape_rem.append(shape1[-2]) | |||
| if transpose_b: | |||
| if ndim2 >= 2: | |||
| shape_rem.append(shape2[-2]) | |||
| else: | |||
| if ndim1 >= 1: | |||
| shape_rem.append(shape2[-1]) | |||
| return tuple(shape_rem) | |||
| @constexpr | |||
| def _check_matmul_shapes(shape1, shape2): | |||
| """Checks shape1 and shape2 are valid to perform matmul, and returns output shape after broadcasting.""" | |||
| ndim1, ndim2 = len(shape1), len(shape2) | |||
| if ndim1 < 1 or ndim2 < 1: | |||
| raise ValueError('input operands must have at least 1 dimension') | |||
| if ndim2 >= 2 and shape1[-1] != shape2[-2]: | |||
| raise ValueError(f'mismatch in core dimension of input operands (size ' | |||
| f'{shape1[-1]} is different from {shape2[-2]})') | |||
| shape_out = deque() | |||
| for items in zip_longest(reversed(shape1[:-2]), reversed(shape2[:-2]), fillvalue=1): | |||
| max_size = max(items) | |||
| if any(item not in (1, max_size) for item in items): | |||
| raise ValueError(f'operands could not be broadcast together with shapes {shape1} {shape2}') | |||
| shape_out.appendleft(max_size) | |||
| return tuple(shape_out) | |||
| @constexpr | |||
| def _tile_size(shape, out_shape, ndim): | |||
| """Returns tile_size such that shape*tile_size = out_shape""" | |||
| size = [1]*ndim | |||
| for idx, (i, j) in enumerate(zip(shape, out_shape)): | |||
| if i != j: | |||
| size[idx] = j | |||
| return tuple(size) | |||
| @constexpr | |||
| def _check_need_broadcast(shape1, shape2): | |||
| """Returns True if broadcast is necessary for batchmatmul.""" | |||
| return shape1[:-2] != shape2[:-2] | |||
| def _expand(x, ndim): | |||
| """Expand x to ndim from axis, which can be 0 or -1.""" | |||
| while F.rank(x) < ndim: | |||
| x = F.expand_dims(x, 0) | |||
| return x | |||
| def _broadcast_to(x, shape_cur, shape_to, ndim_to): | |||
| """Broadcasts x from shape_cur to shape_to.""" | |||
| size = _tile_size(shape_cur, shape_to, ndim_to) | |||
| return F.tile(x, size) | |||
| def matmul(x1, x2, dtype=None): | |||
| """ | |||
| Returns the matrix product of two arrays. | |||
| Note: | |||
| Numpy arguments `out`, `casting`, `order`, `subok`, `signature`, and `extobj` are | |||
| not supported. | |||
| On GPU, the supported dtypes are np.float16 and np.float32. | |||
| On CPU, the supported dtypes are np.float16 and np.float32. | |||
| Args: | |||
| x1 (Tensor): Input tensor, scalar not allowed. | |||
| x2 (Tensor): Input tensor, scalar not allowed. | |||
| dtype (:class:`mindspore.dtype`, optional): defaults to None. Overrides the dtype of the | |||
| output Tensor. | |||
| Returns: | |||
| Tensor or scalar, the matrix product of the inputs. This is a scalar only | |||
| when both `x1`, `x2` are 1-d vectors. | |||
| Raises: | |||
| ValueError: If the last dimension of `x1` is not the same size as the | |||
| second-to-last dimension of `x2`, or if a scalar value is passed in. | |||
| Supported Platforms: | |||
| ``Ascend`` ``GPU`` ``CPU`` | |||
| Examples: | |||
| >>> x1 = np.arange(2*3*4).reshape(2, 3, 4).astype('float32') | |||
| >>> x2 = np.arange(4*5).reshape(4, 5).astype('float32') | |||
| >>> output = np.matmul(x1, x2) | |||
| >>> print(output) | |||
| [[[ 70. 76. 82. 88. 94.] | |||
| [ 190. 212. 234. 256. 278.] | |||
| [ 310. 348. 386. 424. 462.]] | |||
| [[ 430. 484. 538. 592. 646.] | |||
| [ 550. 620. 690. 760. 830.] | |||
| [ 670. 756. 842. 928. 1014.]]] | |||
| """ | |||
| # performs type promotion | |||
| dtype1 = F.dtype(x1) | |||
| dtype2 = F.dtype(x2) | |||
| if not _check_same_type(dtype1, dtype2): | |||
| x1 = x1.astype(mstype.float32) | |||
| x2 = x2.astype(mstype.float32) | |||
| ndim1_orig, ndim2_orig = F.rank(x1), F.rank(x2) | |||
| shape1_orig, shape2_orig = F.shape(x1), F.shape(x2) | |||
| transpose_b = ndim2_orig == 1 | |||
| shape_backbone = _check_matmul_shapes(shape1_orig, shape2_orig) | |||
| # infers the shape of the output | |||
| shape_out = shape_backbone + _infer_shape_rem(shape1_orig, shape2_orig, | |||
| ndim1_orig, ndim2_orig, transpose_b) | |||
| x1 = _expand(x1, 2) | |||
| x2 = _expand(x2, 2) | |||
| if F.rank(x2) == 2: | |||
| if F.rank(x1) > 2: | |||
| x1 = F.reshape(x1, (-1, shape1_orig[-1])) | |||
| res = P.MatMul(False, transpose_b)(x1, x2) | |||
| else: | |||
| # broadcasts x1.shape[:-2] with x2.shape[:-2] | |||
| ndim_aligned = _max(ndim1_orig, ndim2_orig) | |||
| x1 = _expand(x1, ndim_aligned) | |||
| x2 = _expand(x2, ndim_aligned) | |||
| shape1_aligned, shape2_aligned = F.shape(x1), F.shape(x2) | |||
| x1 = _broadcast_to(x1, shape1_aligned[:-2], shape_backbone, ndim_aligned) | |||
| x2 = _broadcast_to(x2, shape2_aligned[:-2], shape_backbone, ndim_aligned) | |||
| res = P.BatchMatMul(False, transpose_b)(x1, x2) | |||
| if dtype is not None: | |||
| res = res.astype(dtype) | |||
| return F.reshape(res, shape_out) | |||
| @@ -20,6 +20,7 @@ import mindspore.context as context | |||
| import mindspore.nn as nn | |||
| from mindspore import Tensor | |||
| from mindspore.ops import operations as P | |||
| from mindspore.ops import composite as C | |||
| from mindspore.ops.operations import _inner_ops as inner | |||
| class MatMulNet(nn.Cell): | |||
| @@ -43,6 +44,15 @@ class MatMul_d(nn.Cell): | |||
| return self.matmul(x, y) | |||
| class MatMulComposite(nn.Cell): | |||
| def __init__(self): | |||
| super(MatMulComposite, self).__init__() | |||
| self.matmul = C.matmul | |||
| def construct(self, x, y): | |||
| return self.matmul(x, y) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| @@ -77,3 +87,37 @@ def test_matmul_float64(): | |||
| output = net(Tensor(x), Tensor(y)) | |||
| expect = np.matmul(x, y) | |||
| np.testing.assert_array_almost_equal(output.asnumpy(), expect) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_matmul_composite(): | |||
| context.set_context(mode=context.GRAPH_MODE, device_target='GPU') | |||
| net = MatMulComposite() | |||
| scalars = [np.random.randn(1).astype(np.float32), np.random.randn(1).astype(np.float32), | |||
| np.random.randn(1, 1).astype(np.float32), | |||
| np.random.randn(1, 1, 1).astype(np.float32)] | |||
| for x in scalars: | |||
| for y in scalars: | |||
| output = net(Tensor(x), Tensor(y)) | |||
| expect = np.matmul(x, y) | |||
| np.testing.assert_array_almost_equal(output.asnumpy(), expect) | |||
| broadcastables = [ | |||
| np.random.randn(3).astype(np.float32), np.random.randn(3).astype(np.float32), | |||
| np.random.randn(6).astype(np.float32), np.random.randn(6, 4).astype(np.float32), | |||
| np.random.randn(5, 2).astype(np.float32), np.random.randn(2).astype(np.float32), | |||
| np.random.randn(2, 9).astype(np.float32), np.random.randn(9, 8).astype(np.float32), | |||
| np.random.randn(6).astype(np.float32), np.random.randn(2, 6, 5).astype(np.float32), | |||
| np.random.randn(9, 2, 7).astype(np.float32), np.random.randn(7).astype(np.float32), | |||
| np.random.randn(5, 2, 4).astype(np.float32), np.random.randn(6, 1, 4, 9).astype(np.float32), | |||
| np.random.randn(7, 1, 5, 3, 2).astype(np.float32), np.random.randn(8, 1, 6, 1, 2, 9).astype(np.float32) | |||
| ] | |||
| for i in range(8): | |||
| x = broadcastables[2*i] | |||
| y = broadcastables[2*i + 1] | |||
| output = net(Tensor(x), Tensor(y)) | |||
| expect = np.matmul(x, y) | |||
| np.testing.assert_array_almost_equal(output.asnumpy(), expect) | |||