|
|
|
@@ -354,7 +354,6 @@ def check_col_row_equal(x1_shape, x2_shape, transpose_x1, transpose_x2): |
|
|
|
+ f'the row of matrix dimensions of x2, but got {x1_col} and {x2_row}.') |
|
|
|
|
|
|
|
|
|
|
|
@constexpr |
|
|
|
def matmul_op_select(x1_shape, x2_shape, transpose_x1, transpose_x2): |
|
|
|
"""select matmul op""" |
|
|
|
x1_dim, x2_dim = len(x1_shape), len(x2_shape) |
|
|
|
|