|
|
|
@@ -92,57 +92,8 @@ def _inner_matmul_new_1_64_32_64(tik_instance, dtype, input1, input1_index, inpu |
|
|
|
matmul_hybrid_f_t_local_UB, 0, 1, 4, 0, 0) |
|
|
|
|
|
|
|
|
|
|
|
@op_info_register(cus_batchmatmul_op_info) |
|
|
|
def CusBatchMatMul(input_x1, input_x2, output, transpose_a=False, transpose_b=True, kernel_name="batchmatmul"): |
|
|
|
"""CusBatchMatMul""" |
|
|
|
if util.get_product_version() == util.VERSION_MINI: |
|
|
|
tik_instance = tik.Tik(tik.Dprofile("v100", "mini")) |
|
|
|
else: |
|
|
|
tik_instance = tik.Tik(tik.Dprofile("v100", "cloud")) |
|
|
|
x1_shape = input_x1.get("shape") |
|
|
|
dtype = input_x1.get("dtype").lower() |
|
|
|
x2_shape = input_x2.get("shape") |
|
|
|
if dtype != input_x2.get("dtype").lower(): |
|
|
|
raise RuntimeError("dtype of input_x1 and input_x2 must be same, but got %s vs %s" % ( |
|
|
|
dtype, input_x2.get("dtype").lower())) |
|
|
|
input_shape = (tuple(x1_shape), tuple(x2_shape), dtype, transpose_a, transpose_b) |
|
|
|
support_shape = [((8, 128, 128), (8, 128, 128), "float32", False, True), |
|
|
|
((36, 128, 128), (36, 128, 128), "float32", False, True), |
|
|
|
((5, 128, 128), (5, 128, 128), "float32", False, True), |
|
|
|
((18, 128, 128), (18, 128, 128), "float32", False, True), |
|
|
|
((16, 128, 128), (16, 128, 128), "float32", False, True), |
|
|
|
((9, 128, 128), (9, 128, 128), "float32", False, True), |
|
|
|
((1, 64, 64), (1, 64, 64), "float32", False, True), |
|
|
|
((1, 128, 128), (1, 128, 128), "float32", False, True), |
|
|
|
((4, 128, 128), (4, 128, 128), "float32", False, True), |
|
|
|
((2, 128, 128), (2, 128, 128), "float32", False, True), |
|
|
|
((32, 128, 128), (32, 128, 128), 'float32', False, True)] |
|
|
|
if input_shape not in support_shape: |
|
|
|
raise RuntimeError("input_shape %s is not supported" % str(input_shape)) |
|
|
|
|
|
|
|
# if not transpose_a and transpose_b: |
|
|
|
batch, m, k = x1_shape |
|
|
|
|
|
|
|
input1_shape = _get_flattern_shape(x1_shape) |
|
|
|
input1 = tik_instance.Tensor(dtype, input1_shape, name="input1", scope=tik.scope_gm) |
|
|
|
input2_shape = _get_flattern_shape(x2_shape) |
|
|
|
input2 = tik_instance.Tensor(dtype, input2_shape, name="input2", scope=tik.scope_gm) |
|
|
|
|
|
|
|
output_shape = x1_shape |
|
|
|
res_shape = _get_flattern_shape(output_shape) |
|
|
|
res = tik_instance.Tensor(dtype, res_shape, name="res", scope=tik.scope_gm) |
|
|
|
|
|
|
|
if input_shape == ((36, 128, 128), (36, 128, 128), "float32", False, True): |
|
|
|
with tik_instance.for_range(0, 18, block_num=18) as block_idx: |
|
|
|
with tik_instance.for_range(0, 2) as cc0: |
|
|
|
with tik_instance.for_range(0, 128, thread_num=2) as cc1: |
|
|
|
input1_index = block_idx * 32768 + cc0 * 16384 + cc1 * 128 |
|
|
|
input2_index = block_idx * 32768 + cc0 * 16384 |
|
|
|
res_index = block_idx * 32768 + cc0 * 16384 + cc1 * 128 |
|
|
|
_inner_matmul_new(tik_instance, dtype, |
|
|
|
input1, input1_index, |
|
|
|
input2, input2_index, |
|
|
|
res, res_index) |
|
|
|
def process_input_shape_640(input_shape, tik_instance, dtype, input1, input2, res): |
|
|
|
"""process input shape of 640""" |
|
|
|
if input_shape == ((5, 128, 128), (5, 128, 128), "float32", False, True): |
|
|
|
with tik_instance.for_range(0, 30, block_num=30) as block_idx: |
|
|
|
with tik_instance.for_range(0, 11) as cc1_db: |
|
|
|
@@ -189,17 +140,9 @@ def CusBatchMatMul(input_x1, input_x2, output, transpose_a=False, transpose_b=Tr |
|
|
|
thread_idx * 128 + thread_idx2 * 64], |
|
|
|
matmul_hybrid_f_t_local_UB, 0, 1, 8, 0, 0) |
|
|
|
|
|
|
|
if input_shape == ((18, 128, 128), (18, 128, 128), "float32", False, True): |
|
|
|
with tik_instance.for_range(0, 18, block_num=18) as block_idx: |
|
|
|
with tik_instance.for_range(0, 128, thread_num=2) as cc0: |
|
|
|
input1_index = block_idx * 16384 + cc0 * 128 |
|
|
|
input2_index = block_idx * 16384 |
|
|
|
res_index = block_idx * 16384 + cc0 * 128 |
|
|
|
_inner_matmul_new(tik_instance, dtype, |
|
|
|
input1, input1_index, |
|
|
|
input2, input2_index, |
|
|
|
res, res_index) |
|
|
|
|
|
|
|
def process_input_shape_1152(input_shape, tik_instance, dtype, input1, input2, res): |
|
|
|
"""process input shape of 1152""" |
|
|
|
if input_shape == ((9, 128, 128), (9, 128, 128), "float32", False, True): |
|
|
|
with tik_instance.for_range(0, 27, block_num=27) as block_idx: |
|
|
|
with tik_instance.for_range(0, 42, thread_num=2) as cc0: |
|
|
|
@@ -219,6 +162,76 @@ def CusBatchMatMul(input_x1, input_x2, output, transpose_a=False, transpose_b=Tr |
|
|
|
input2, input2_index, |
|
|
|
res, res_index) |
|
|
|
|
|
|
|
|
|
|
|
@op_info_register(cus_batchmatmul_op_info) |
|
|
|
def CusBatchMatMul(input_x1, input_x2, output, transpose_a=False, transpose_b=True, kernel_name="batchmatmul"): |
|
|
|
"""CusBatchMatMul""" |
|
|
|
if util.get_product_version() == util.VERSION_MINI: |
|
|
|
tik_instance = tik.Tik(tik.Dprofile("v100", "mini")) |
|
|
|
else: |
|
|
|
tik_instance = tik.Tik(tik.Dprofile("v100", "cloud")) |
|
|
|
x1_shape = input_x1.get("shape") |
|
|
|
dtype = input_x1.get("dtype").lower() |
|
|
|
x2_shape = input_x2.get("shape") |
|
|
|
if dtype != input_x2.get("dtype").lower(): |
|
|
|
raise RuntimeError("dtype of input_x1 and input_x2 must be same, but got %s vs %s" % ( |
|
|
|
dtype, input_x2.get("dtype").lower())) |
|
|
|
input_shape = (tuple(x1_shape), tuple(x2_shape), dtype, transpose_a, transpose_b) |
|
|
|
support_shape = [((8, 128, 128), (8, 128, 128), "float32", False, True), |
|
|
|
((36, 128, 128), (36, 128, 128), "float32", False, True), |
|
|
|
((5, 128, 128), (5, 128, 128), "float32", False, True), |
|
|
|
((18, 128, 128), (18, 128, 128), "float32", False, True), |
|
|
|
((16, 128, 128), (16, 128, 128), "float32", False, True), |
|
|
|
((9, 128, 128), (9, 128, 128), "float32", False, True), |
|
|
|
((1, 64, 64), (1, 64, 64), "float32", False, True), |
|
|
|
((1, 128, 128), (1, 128, 128), "float32", False, True), |
|
|
|
((4, 128, 128), (4, 128, 128), "float32", False, True), |
|
|
|
((2, 128, 128), (2, 128, 128), "float32", False, True), |
|
|
|
((6, 128, 128), (6, 128, 128), "float32", False, True), |
|
|
|
((24, 128, 128), (24, 128, 128), "float32", False, True), |
|
|
|
((32, 128, 128), (32, 128, 128), 'float32', False, True)] |
|
|
|
if input_shape not in support_shape: |
|
|
|
raise RuntimeError("input_shape %s is not supported" % str(input_shape)) |
|
|
|
|
|
|
|
# if not transpose_a and transpose_b: |
|
|
|
batch, m, k = x1_shape |
|
|
|
|
|
|
|
input1_shape = _get_flattern_shape(x1_shape) |
|
|
|
input1 = tik_instance.Tensor(dtype, input1_shape, name="input1", scope=tik.scope_gm) |
|
|
|
input2_shape = _get_flattern_shape(x2_shape) |
|
|
|
input2 = tik_instance.Tensor(dtype, input2_shape, name="input2", scope=tik.scope_gm) |
|
|
|
|
|
|
|
output_shape = x1_shape |
|
|
|
res_shape = _get_flattern_shape(output_shape) |
|
|
|
res = tik_instance.Tensor(dtype, res_shape, name="res", scope=tik.scope_gm) |
|
|
|
|
|
|
|
if input_shape == ((36, 128, 128), (36, 128, 128), "float32", False, True): |
|
|
|
with tik_instance.for_range(0, 18, block_num=18) as block_idx: |
|
|
|
with tik_instance.for_range(0, 2) as cc0: |
|
|
|
with tik_instance.for_range(0, 128, thread_num=2) as cc1: |
|
|
|
input1_index = block_idx * 32768 + cc0 * 16384 + cc1 * 128 |
|
|
|
input2_index = block_idx * 32768 + cc0 * 16384 |
|
|
|
res_index = block_idx * 32768 + cc0 * 16384 + cc1 * 128 |
|
|
|
_inner_matmul_new(tik_instance, dtype, |
|
|
|
input1, input1_index, |
|
|
|
input2, input2_index, |
|
|
|
res, res_index) |
|
|
|
|
|
|
|
process_input_shape_640(input_shape, tik_instance, dtype, input1, input2, res) |
|
|
|
|
|
|
|
if input_shape == ((18, 128, 128), (18, 128, 128), "float32", False, True): |
|
|
|
with tik_instance.for_range(0, 18, block_num=18) as block_idx: |
|
|
|
with tik_instance.for_range(0, 128, thread_num=2) as cc0: |
|
|
|
input1_index = block_idx * 16384 + cc0 * 128 |
|
|
|
input2_index = block_idx * 16384 |
|
|
|
res_index = block_idx * 16384 + cc0 * 128 |
|
|
|
_inner_matmul_new(tik_instance, dtype, |
|
|
|
input1, input1_index, |
|
|
|
input2, input2_index, |
|
|
|
res, res_index) |
|
|
|
|
|
|
|
process_input_shape_1152(input_shape, tik_instance, dtype, input1, input2, res) |
|
|
|
|
|
|
|
if input_shape == ((1, 64, 64), (1, 64, 64), "float32", False, True): |
|
|
|
with tik_instance.for_range(0, 32, block_num=32) as block_idx: |
|
|
|
with tik_instance.for_range(0, 2, thread_num=2) as cc0: |
|
|
|
@@ -233,8 +246,10 @@ def CusBatchMatMul(input_x1, input_x2, output, transpose_a=False, transpose_b=Tr |
|
|
|
input_shape_list = [((1, 128, 128), (1, 128, 128), "float32", False, True), |
|
|
|
((2, 128, 128), (2, 128, 128), "float32", False, True), |
|
|
|
((4, 128, 128), (4, 128, 128), "float32", False, True), |
|
|
|
((6, 128, 128), (6, 128, 128), "float32", False, True), |
|
|
|
((8, 128, 128), (8, 128, 128), "float32", False, True), |
|
|
|
((16, 128, 128), (16, 128, 128), "float32", False, True), |
|
|
|
((24, 128, 128), (24, 128, 128), "float32", False, True), |
|
|
|
((32, 128, 128), (32, 128, 128), 'float32', False, True) |
|
|
|
] |
|
|
|
if input_shape in input_shape_list: |
|
|
|
|