|
|
|
@@ -99,7 +99,7 @@ def cus_cube_matmul_right_mul(tik_instance, input_x1, input_x2, input_x3, |
|
|
|
"""cus_cube_matmul_right_mul""" |
|
|
|
diag_size = 128 |
|
|
|
ko, mo, _, _ = input_x1.shape |
|
|
|
no, ko, ki, _ = input_x2.shape |
|
|
|
no, ko, _, _ = input_x2.shape |
|
|
|
c0 = input_x1.shape[-1] |
|
|
|
diag_outer = diag_size // c0 |
|
|
|
if [input_x1.shape[-1], input_x1.shape[-2], input_x2.shape[-1], input_x2.shape[-2]] != [c0, c0, c0, c0]: |
|
|
|
|