|
|
|
@@ -339,6 +339,12 @@ def get_broadcast_matmul_shape(x_shape, y_shape): |
|
|
|
@constexpr |
|
|
|
def check_col_row_equal(x1_shape, x2_shape, transpose_x1, transpose_x2): |
|
|
|
"""check col and row equal""" |
|
|
|
if len(x1_shape) == 1: |
|
|
|
transpose_x1 = False |
|
|
|
x1_shape = (1,) + x1_shape |
|
|
|
if len(x2_shape) == 1: |
|
|
|
transpose_x2 = False |
|
|
|
x2_shape = x2_shape + (1,) |
|
|
|
x1_last = x1_shape[-2:] |
|
|
|
x2_last = x2_shape[-2:] |
|
|
|
x1_col = x1_last[not transpose_x1] # x1_col = x1_last[1] if (not transpose_a) else x1_last[0] |
|
|
|
@@ -348,27 +354,48 @@ def check_col_row_equal(x1_shape, x2_shape, transpose_x1, transpose_x2): |
|
|
|
+ f'the row of matrix dimensions of x2, but got {x1_col} and {x2_row}.') |
|
|
|
|
|
|
|
|
|
|
|
@constexpr |
|
|
|
def matmul_op_select(x1_shape, x2_shape, transpose_x1, transpose_x2): |
|
|
|
"""select matmul op""" |
|
|
|
x1_dim, x2_dim = len(x1_shape), len(x2_shape) |
|
|
|
if x1_dim == 1 and x2_dim == 1: |
|
|
|
matmul_op = P.Mul() |
|
|
|
elif x1_dim <= 2 and x2_dim <= 2: |
|
|
|
transpose_x1 = False if x1_dim == 1 else transpose_x1 |
|
|
|
transpose_x2 = False if x2_dim == 1 else transpose_x2 |
|
|
|
matmul_op = P.MatMul(transpose_x1, transpose_x2) |
|
|
|
elif x1_dim == 1 and x2_dim > 2: |
|
|
|
matmul_op = P.BatchMatMul(False, transpose_x2) |
|
|
|
elif x1_dim > 2 and x2_dim == 1: |
|
|
|
matmul_op = P.BatchMatMul(transpose_x1, False) |
|
|
|
else: |
|
|
|
matmul_op = P.BatchMatMul(transpose_x1, transpose_x2) |
|
|
|
return matmul_op |
|
|
|
|
|
|
|
|
|
|
|
class MatMul(Cell): |
|
|
|
""" |
|
|
|
Multiplies matrix `x1` by matrix `x2`. |
|
|
|
|
|
|
|
The rank of input tensors must be not less than `2`. The none-matrix dimensions(batch) of inputs |
|
|
|
will be broadcasted and must be broadcastable. |
|
|
|
- If both x1 and x2 are 1-dimensional, the dot product is returned. |
|
|
|
- If the dimensions of x1 and x2 are all not greater than 2, the matrix-matrix product will be returned. Note if |
|
|
|
one of 'x1' and 'x2' is 1-dimensional, the argument will first be expanded to 2 dimension. After the matrix |
|
|
|
multiply, the expanded dimension will be removed. |
|
|
|
- If at least one of x1 and x2 is N-dimensional (N>2), the none-matrix dimensions(batch) of inputs will be |
|
|
|
broadcasted and must be broadcastable. Note if one of 'x1' and 'x2' is 1-dimensional, the argument will first be |
|
|
|
expanded to 2 dimension and then the none-matrix dimensions will be broadcasted. After the matrix multiply, the |
|
|
|
expanded dimension will be removed. |
|
|
|
|
|
|
|
Args: |
|
|
|
transpose_x1 (bool): If true, `a` is transposed before multiplication. Default: False. |
|
|
|
transpose_x2 (bool): If true, `b` is transposed before multiplication. Default: False. |
|
|
|
|
|
|
|
Inputs: |
|
|
|
- **input_x1** (Tensor) - The first tensor to be multiplied. The shape of the tensor is :math:`(*A, N, C)`, |
|
|
|
where :math:`*A` represents the batch size of `x1` which can be multidimensional. |
|
|
|
If `transpose_a` is True, its shape must be :math:`(*A, N, C)` after transposing. |
|
|
|
- **input_x2** (Tensor) - The second tensor to be multiplied. The shape of the tensor is :math:`(*B, C, M)`, |
|
|
|
where :math:`*B` represents the batch size of `x2` which can be multidimensional. |
|
|
|
If `transpose_b` is True, its shape must be :math:`(*B, C, M)` after transposing. |
|
|
|
- **input_x1** (Tensor) - The first tensor to be multiplied. |
|
|
|
- **input_x2** (Tensor) - The second tensor to be multiplied. |
|
|
|
|
|
|
|
Outputs: |
|
|
|
Tensor, the shape of the output tensor is :math:`(*L, N, M)`. :math:`*L` is the batch size after broadcasting. |
|
|
|
Tensor, the shape of the output tensor depends on the dimension of input tensors. |
|
|
|
|
|
|
|
Examples: |
|
|
|
>>> net = nn.MatMul() |
|
|
|
@@ -387,13 +414,26 @@ class MatMul(Cell): |
|
|
|
self.transpose_x1 = transpose_x1 |
|
|
|
self.transpose_x2 = transpose_x2 |
|
|
|
self.shape_op = P.Shape() |
|
|
|
self.matmul_op = P.MatMul(self.transpose_x1, self.transpose_x2) |
|
|
|
self.batch_matmul_op = P.BatchMatMul(self.transpose_x1, self.transpose_x2) |
|
|
|
self.expand_op = P.ExpandDims() |
|
|
|
self.squeeze_left_op = P.Squeeze(-2) |
|
|
|
self.squeeze_right_op = P.Squeeze(-1) |
|
|
|
self.reduce_sum_op = P.ReduceSum(keep_dims=False) |
|
|
|
|
|
|
|
def construct(self, x1, x2): |
|
|
|
x1_shape = self.shape_op(x1) |
|
|
|
x2_shape = self.shape_op(x2) |
|
|
|
check_col_row_equal(x1_shape, x2_shape, self.transpose_x1, self.transpose_x2) |
|
|
|
matmul_op = matmul_op_select(x1_shape, x2_shape, self.transpose_x1, self.transpose_x2) |
|
|
|
|
|
|
|
x1_dim, x2_dim = len(x1_shape), len(x2_shape) |
|
|
|
if x1_dim == x2_dim and x2_dim == 1: |
|
|
|
return self.reduce_sum_op(matmul_op(x1, x2), -1) |
|
|
|
if x1_dim == 1: |
|
|
|
x1 = self.expand_op(x1, 0) |
|
|
|
x1_shape = self.shape_op(x1) |
|
|
|
if x2_dim == 1: |
|
|
|
x2 = self.expand_op(x2, 1) |
|
|
|
x2_shape = self.shape_op(x2) |
|
|
|
|
|
|
|
x1_broadcast_shape, x2_broadcast_shape = get_broadcast_matmul_shape(x1_shape, x2_shape) |
|
|
|
x1_broadcast_to = P.BroadcastTo(x1_broadcast_shape) |
|
|
|
@@ -402,8 +442,12 @@ class MatMul(Cell): |
|
|
|
x1 = x1_broadcast_to(x1) |
|
|
|
if x2_broadcast_shape != x2_shape: |
|
|
|
x2 = x2_broadcast_to(x2) |
|
|
|
if len(x1_broadcast_shape) == 2: |
|
|
|
matmul_broadcast = self.matmul_op(x1, x2) |
|
|
|
else: |
|
|
|
matmul_broadcast = self.batch_matmul_op(x1, x2) |
|
|
|
|
|
|
|
matmul_broadcast = matmul_op(x1, x2) |
|
|
|
|
|
|
|
if x1_dim == 1: |
|
|
|
matmul_broadcast = self.squeeze_left_op(matmul_broadcast) |
|
|
|
if x2_dim == 1: |
|
|
|
matmul_broadcast = self.squeeze_right_op(matmul_broadcast) |
|
|
|
|
|
|
|
return matmul_broadcast |