Merge pull request !17626 from lingyunli63/matmul_to_multags/v1.3.0
| @@ -53,3 +53,4 @@ from .softmax_grad_ext import SoftmaxGradExt | |||
| from .square_sum_v1 import SquareSumV1 | |||
| from .fused_mul_add import FusedMulAdd | |||
| from .conv2d import Conv2D | |||
| from .matmul import MatMul, BatchMatMul | |||
| @@ -0,0 +1,74 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # =========================================================================== | |||
| """generate json desc for BatchMatMul and MatMul""" | |||
| from mindspore._extends.graph_kernel.model.model import DataFormat as DF | |||
| from mindspore._extends.graph_kernel.model.model import GraphKernelUnsupportedException as GKException | |||
| from ._utils import Expander, ExpanderInfoValidator as VLD | |||
| @VLD.check_attrs('transpose_a', 'transpose_b', 'left_format', 'right_format') | |||
| class MatMul(Expander): | |||
| """ | |||
| MatMul expander | |||
| """ | |||
| def __init__(self, expand_info): | |||
| super().__init__(expand_info) | |||
| self.transpose_a = self.attrs['transpose_a'] | |||
| self.transpose_b = self.attrs['transpose_b'] | |||
| self.left_format = self.attrs['left_format'] | |||
| self.right_format = self.attrs['right_format'] | |||
| self.shape_a = self.inputs[0]['shape'] | |||
| self.shape_b = self.inputs[1]['shape'] | |||
| def _optimize_to_mul(self): | |||
| """check if matmul can be replace by mul""" | |||
| if self.left_format != DF.DEFAULT or self.right_format != DF.DEFAULT: | |||
| return False | |||
| k_a = self.shape_a[-2] if self.transpose_a else self.shape_a[-1] | |||
| k_b = self.shape_b[-1] if self.transpose_b else self.shape_b[-2] | |||
| if k_a != 1 or k_b != 1: | |||
| return False | |||
| return True | |||
| def _check(self): | |||
| input_num = len(self.inputs) | |||
| if input_num < 2: | |||
| raise GKException("matul inputs number should bigger than 1, but got {}.".format(input_num)) | |||
| def _trans_shape(self, shape): | |||
| trans_shape = list(shape) | |||
| trans_shape[-2] = shape[-1] | |||
| trans_shape[-1] = shape[-2] | |||
| return trans_shape | |||
| def _expand(self, graph_builder): | |||
| if not self._optimize_to_mul(): | |||
| raise GKException("MatMul/BatchMatMul do not need to be replaced by Mul") | |||
| #Matmul is replaced by Mul([b m k], [b k n]) when k==1 | |||
| input_a = self.inputs[0] | |||
| input_b = self.inputs[1] | |||
| if self.transpose_a: | |||
| shape_a_trans = self._trans_shape(self.shape_a) | |||
| input_a = graph_builder.emit('Reshape', [input_a], attrs={'shape': shape_a_trans}) | |||
| if self.transpose_b: | |||
| shape_b_trans = self._trans_shape(self.shape_b) | |||
| input_b = graph_builder.emit('Reshape', [input_b], attrs={'shape': shape_b_trans}) | |||
| result = graph_builder.emit('Mul', [input_a, input_b]) | |||
| if 'dst_type' in self.attrs and self.inputs[0].dtype != self.attrs['dst_type']: | |||
| result = graph_builder.emit('Cast', [result], attrs={'dst_type': self.attrs['dst_type']}) | |||
| return result | |||
| class BatchMatMul(MatMul): | |||
| """BatchMatMul expander""" | |||
| @@ -16,7 +16,6 @@ | |||
| import copy | |||
| import sys | |||
| from functools import reduce | |||
| from .model import GraphKernelUnsupportedException as GKException | |||
| from .model import PrimLib, DataFormat as DF | |||
| @@ -102,19 +101,60 @@ class OpInfer: | |||
| class _Elemwise(OpInfer): | |||
| """Common infer for elementwise operators""" | |||
| def _infer_shape(self): | |||
| """returns the input shape with largest flatten size""" | |||
| shape = (1,) | |||
| max_flatten_size = 1 | |||
| for t in self.inputs: | |||
| if t.data_format != DF.DEFAULT: | |||
| return t.shape | |||
| flatten_size = reduce(lambda x, y: x * y, t.shape) | |||
| if flatten_size > max_flatten_size or (flatten_size == max_flatten_size and len(t.shape) > len(shape)): | |||
| max_flatten_size = flatten_size | |||
| shape = t.shape | |||
| def _broadcast_shape(self, shapes): | |||
| """deduce broadcast shape using same rules as numpy""" | |||
| dim_size = max([len(shape) for shape in shapes]) | |||
| align_shapes = [[1] * (dim_size - len(shape)) + shape for shape in shapes] | |||
| out_shape = [1] * dim_size | |||
| for i in range(dim_size): | |||
| for align_shape in align_shapes: | |||
| if align_shape[i] > 1: | |||
| if out_shape[i] == 1: | |||
| out_shape[i] = align_shape[i] | |||
| if out_shape[i] != align_shape[i]: | |||
| raise GKException("shape broadcast failed!") | |||
| return out_shape | |||
| def _to_nz(self, default_shape): | |||
| """default format shape to fractal_Nz format shape""" | |||
| if len(default_shape) not in (1, 2): | |||
| raise GKException("shape is too long!") | |||
| # (32) or (1, 32) -> (2, 1, 1, 16) | |||
| if len(default_shape) == 1 or (len(default_shape) == 2 and default_shape[0] == 1): | |||
| shape = [default_shape[-1] // 16, 1, 1, 16] | |||
| if default_shape[-1] % 16 != 0: | |||
| raise GKException("should be multiplies of 16") | |||
| return shape | |||
| #(32, 1) -> (1, 2, 16, 1) | |||
| if len(default_shape) == 2 and default_shape[1] == 1: | |||
| shape = [1, default_shape[0] // 16, 16, 1] | |||
| if default_shape[0] % 16 != 0: | |||
| raise GKException("should be multiples of 16") | |||
| return shape | |||
| #(32, 48) -> (3, 2, 16, 16) | |||
| shape = [default_shape[1] // 16, default_shape[0] // 16, 16, 16] | |||
| if default_shape[0] % 16 != 0 or defautl_shape[1] % 16 != 0: | |||
| raise GKException("should be multiples of 16") | |||
| return shape | |||
| def _infer_shape(self): | |||
| """returns the output shape with broadcast""" | |||
| # in case all inputs are default format/NHWC/NCHW | |||
| is_default = [input.data_format in (DF.DEFAULT, DF.NHWC, DF.NCHW) for input in self.inputs] | |||
| if all(is_default): | |||
| return self._broadcast_shape([input.shape for input in self.inputs]) | |||
| # in case formats are fractal_nz, default_fromat/NHWC/HCHW(optional) | |||
| is_default_frac_nz = [input.data_format in (DF.DEFAULT, DF.NHWC, DF.NCHW, DF.FRAC_NZ) \ | |||
| for input in self.inputs] | |||
| if all(is_default_frac_nz): | |||
| nz_shapes = [self._to_nz(input.shape) if input.data_format != DF.FRAC_NZ else input.shape \ | |||
| for input in self.inputs] | |||
| return self._broadcast_shape(nz_shapes) | |||
| raise GKException("Only support default and fractal_nz") | |||
| def _infer_format(self): | |||
| for tensor in self.inputs: | |||
| if tensor.data_format != DF.DEFAULT: | |||
| @@ -56,6 +56,8 @@ std::vector<PrimitivePtr> GetExpandOps() { | |||
| prim::kPrimLogSoftmax, | |||
| prim::kPrimLogSoftmaxGrad, | |||
| prim::kPrimTile, | |||
| prim::kPrimMatMul, | |||
| prim::kPrimBatchMatMul, | |||
| #if ENABLE_D | |||
| prim::kPrimSqrtGrad, | |||
| prim::kPrimClipByNormNoDivSum, | |||
| @@ -0,0 +1,91 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| import numpy as np | |||
| import pytest | |||
| import mindspore.context as context | |||
| from mindspore import Tensor | |||
| from mindspore.nn import Cell | |||
| import mindspore.ops.operations as P | |||
| class Net(Cell): | |||
| def __init__(self): | |||
| super(Net, self).__init__() | |||
| self.matmul = P.MatMul(transpose_a=False, transpose_b=False) | |||
| def construct(self, x, y): | |||
| return self.matmul(x, y) | |||
| class Net1(Cell): | |||
| def __init__(self): | |||
| super(Net1, self).__init__() | |||
| self.bmm = P.BatchMatMul(transpose_a=False, transpose_b=False) | |||
| def construct(self, x, y): | |||
| return self.bmm(x, y) | |||
| def get_output(i0, i1, net_cls, enable_graph_kernel=False): | |||
| context.set_context(enable_graph_kernel=enable_graph_kernel) | |||
| net = net_cls() | |||
| output = net(i0, i1) | |||
| return output | |||
| def test_matmul(): | |||
| i0 = Tensor(np.random.normal(1, 0.01, [96, 1]).astype(np.float32)) | |||
| i1 = Tensor(np.random.normal(1, 0.01, [1, 128]).astype(np.float32)) | |||
| expect = get_output(i0, i1, Net, False) | |||
| output = get_output(i0, i1, Net, True) | |||
| expect_np = expect.asnumpy().copy() | |||
| output_np = output.asnumpy().copy() | |||
| assert np.allclose(expect_np, output_np, 1.e-4, 1.e-7) | |||
| def test_batchmatmul(): | |||
| i0 = Tensor(np.random.normal(1, 0.01, [16, 96, 1]).astype(np.float32)) | |||
| i1 = Tensor(np.random.normal(1, 0.01, [16, 1, 128]).astype(np.float32)) | |||
| expect = get_output(i0, i1, Net1, False) | |||
| output = get_output(i0, i1, Net1, True) | |||
| expect_np = expect.asnumpy().copy() | |||
| output_np = output.asnumpy().copy() | |||
| assert np.allclose(expect_np, output_np, 6.e-4, 6.e-4) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_arm_ascend_training | |||
| @pytest.mark.platform_x86_ascend_training | |||
| @pytest.mark.env_onecard | |||
| def test_matmul_ascend(): | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | |||
| test_matmul() | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_arm_ascend_training | |||
| @pytest.mark.platform_x86_ascend_training | |||
| @pytest.mark.env_onecard | |||
| def test_batchmatmul_ascend(): | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | |||
| test_batchmatmul() | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_matmul_gpu(): | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | |||
| test_matmul() | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_batchmatmul_gpu(): | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | |||
| test_batchmatmul() | |||