diff --git a/mindspore/ops/_op_impl/_custom_op/batch_matmul_impl.py b/mindspore/ops/_op_impl/_custom_op/batch_matmul_impl.py index 0421de2dab..ff4a0ee521 100644 --- a/mindspore/ops/_op_impl/_custom_op/batch_matmul_impl.py +++ b/mindspore/ops/_op_impl/_custom_op/batch_matmul_impl.py @@ -92,57 +92,8 @@ def _inner_matmul_new_1_64_32_64(tik_instance, dtype, input1, input1_index, inpu matmul_hybrid_f_t_local_UB, 0, 1, 4, 0, 0) -@op_info_register(cus_batchmatmul_op_info) -def CusBatchMatMul(input_x1, input_x2, output, transpose_a=False, transpose_b=True, kernel_name="batchmatmul"): - """CusBatchMatMul""" - if util.get_product_version() == util.VERSION_MINI: - tik_instance = tik.Tik(tik.Dprofile("v100", "mini")) - else: - tik_instance = tik.Tik(tik.Dprofile("v100", "cloud")) - x1_shape = input_x1.get("shape") - dtype = input_x1.get("dtype").lower() - x2_shape = input_x2.get("shape") - if dtype != input_x2.get("dtype").lower(): - raise RuntimeError("dtype of input_x1 and input_x2 must be same, but got %s vs %s" % ( - dtype, input_x2.get("dtype").lower())) - input_shape = (tuple(x1_shape), tuple(x2_shape), dtype, transpose_a, transpose_b) - support_shape = [((8, 128, 128), (8, 128, 128), "float32", False, True), - ((36, 128, 128), (36, 128, 128), "float32", False, True), - ((5, 128, 128), (5, 128, 128), "float32", False, True), - ((18, 128, 128), (18, 128, 128), "float32", False, True), - ((16, 128, 128), (16, 128, 128), "float32", False, True), - ((9, 128, 128), (9, 128, 128), "float32", False, True), - ((1, 64, 64), (1, 64, 64), "float32", False, True), - ((1, 128, 128), (1, 128, 128), "float32", False, True), - ((4, 128, 128), (4, 128, 128), "float32", False, True), - ((2, 128, 128), (2, 128, 128), "float32", False, True), - ((32, 128, 128), (32, 128, 128), 'float32', False, True)] - if input_shape not in support_shape: - raise RuntimeError("input_shape %s is not supported" % str(input_shape)) - - # if not transpose_a and transpose_b: - batch, m, k = x1_shape - - input1_shape = _get_flattern_shape(x1_shape) - input1 = tik_instance.Tensor(dtype, input1_shape, name="input1", scope=tik.scope_gm) - input2_shape = _get_flattern_shape(x2_shape) - input2 = tik_instance.Tensor(dtype, input2_shape, name="input2", scope=tik.scope_gm) - - output_shape = x1_shape - res_shape = _get_flattern_shape(output_shape) - res = tik_instance.Tensor(dtype, res_shape, name="res", scope=tik.scope_gm) - - if input_shape == ((36, 128, 128), (36, 128, 128), "float32", False, True): - with tik_instance.for_range(0, 18, block_num=18) as block_idx: - with tik_instance.for_range(0, 2) as cc0: - with tik_instance.for_range(0, 128, thread_num=2) as cc1: - input1_index = block_idx * 32768 + cc0 * 16384 + cc1 * 128 - input2_index = block_idx * 32768 + cc0 * 16384 - res_index = block_idx * 32768 + cc0 * 16384 + cc1 * 128 - _inner_matmul_new(tik_instance, dtype, - input1, input1_index, - input2, input2_index, - res, res_index) +def process_input_shape_640(input_shape, tik_instance, dtype, input1, input2, res): + """process input shape of 640""" if input_shape == ((5, 128, 128), (5, 128, 128), "float32", False, True): with tik_instance.for_range(0, 30, block_num=30) as block_idx: with tik_instance.for_range(0, 11) as cc1_db: @@ -189,17 +140,9 @@ def CusBatchMatMul(input_x1, input_x2, output, transpose_a=False, transpose_b=Tr thread_idx * 128 + thread_idx2 * 64], matmul_hybrid_f_t_local_UB, 0, 1, 8, 0, 0) - if input_shape == ((18, 128, 128), (18, 128, 128), "float32", False, True): - with tik_instance.for_range(0, 18, block_num=18) as block_idx: - with tik_instance.for_range(0, 128, thread_num=2) as cc0: - input1_index = block_idx * 16384 + cc0 * 128 - input2_index = block_idx * 16384 - res_index = block_idx * 16384 + cc0 * 128 - _inner_matmul_new(tik_instance, dtype, - input1, input1_index, - input2, input2_index, - res, res_index) +def process_input_shape_1152(input_shape, tik_instance, dtype, input1, input2, res): + """process input shape of 1152""" if input_shape == ((9, 128, 128), (9, 128, 128), "float32", False, True): with tik_instance.for_range(0, 27, block_num=27) as block_idx: with tik_instance.for_range(0, 42, thread_num=2) as cc0: @@ -219,6 +162,76 @@ def CusBatchMatMul(input_x1, input_x2, output, transpose_a=False, transpose_b=Tr input2, input2_index, res, res_index) + +@op_info_register(cus_batchmatmul_op_info) +def CusBatchMatMul(input_x1, input_x2, output, transpose_a=False, transpose_b=True, kernel_name="batchmatmul"): + """CusBatchMatMul""" + if util.get_product_version() == util.VERSION_MINI: + tik_instance = tik.Tik(tik.Dprofile("v100", "mini")) + else: + tik_instance = tik.Tik(tik.Dprofile("v100", "cloud")) + x1_shape = input_x1.get("shape") + dtype = input_x1.get("dtype").lower() + x2_shape = input_x2.get("shape") + if dtype != input_x2.get("dtype").lower(): + raise RuntimeError("dtype of input_x1 and input_x2 must be same, but got %s vs %s" % ( + dtype, input_x2.get("dtype").lower())) + input_shape = (tuple(x1_shape), tuple(x2_shape), dtype, transpose_a, transpose_b) + support_shape = [((8, 128, 128), (8, 128, 128), "float32", False, True), + ((36, 128, 128), (36, 128, 128), "float32", False, True), + ((5, 128, 128), (5, 128, 128), "float32", False, True), + ((18, 128, 128), (18, 128, 128), "float32", False, True), + ((16, 128, 128), (16, 128, 128), "float32", False, True), + ((9, 128, 128), (9, 128, 128), "float32", False, True), + ((1, 64, 64), (1, 64, 64), "float32", False, True), + ((1, 128, 128), (1, 128, 128), "float32", False, True), + ((4, 128, 128), (4, 128, 128), "float32", False, True), + ((2, 128, 128), (2, 128, 128), "float32", False, True), + ((6, 128, 128), (6, 128, 128), "float32", False, True), + ((24, 128, 128), (24, 128, 128), "float32", False, True), + ((32, 128, 128), (32, 128, 128), 'float32', False, True)] + if input_shape not in support_shape: + raise RuntimeError("input_shape %s is not supported" % str(input_shape)) + + # if not transpose_a and transpose_b: + batch, m, k = x1_shape + + input1_shape = _get_flattern_shape(x1_shape) + input1 = tik_instance.Tensor(dtype, input1_shape, name="input1", scope=tik.scope_gm) + input2_shape = _get_flattern_shape(x2_shape) + input2 = tik_instance.Tensor(dtype, input2_shape, name="input2", scope=tik.scope_gm) + + output_shape = x1_shape + res_shape = _get_flattern_shape(output_shape) + res = tik_instance.Tensor(dtype, res_shape, name="res", scope=tik.scope_gm) + + if input_shape == ((36, 128, 128), (36, 128, 128), "float32", False, True): + with tik_instance.for_range(0, 18, block_num=18) as block_idx: + with tik_instance.for_range(0, 2) as cc0: + with tik_instance.for_range(0, 128, thread_num=2) as cc1: + input1_index = block_idx * 32768 + cc0 * 16384 + cc1 * 128 + input2_index = block_idx * 32768 + cc0 * 16384 + res_index = block_idx * 32768 + cc0 * 16384 + cc1 * 128 + _inner_matmul_new(tik_instance, dtype, + input1, input1_index, + input2, input2_index, + res, res_index) + + process_input_shape_640(input_shape, tik_instance, dtype, input1, input2, res) + + if input_shape == ((18, 128, 128), (18, 128, 128), "float32", False, True): + with tik_instance.for_range(0, 18, block_num=18) as block_idx: + with tik_instance.for_range(0, 128, thread_num=2) as cc0: + input1_index = block_idx * 16384 + cc0 * 128 + input2_index = block_idx * 16384 + res_index = block_idx * 16384 + cc0 * 128 + _inner_matmul_new(tik_instance, dtype, + input1, input1_index, + input2, input2_index, + res, res_index) + + process_input_shape_1152(input_shape, tik_instance, dtype, input1, input2, res) + if input_shape == ((1, 64, 64), (1, 64, 64), "float32", False, True): with tik_instance.for_range(0, 32, block_num=32) as block_idx: with tik_instance.for_range(0, 2, thread_num=2) as cc0: @@ -233,8 +246,10 @@ def CusBatchMatMul(input_x1, input_x2, output, transpose_a=False, transpose_b=Tr input_shape_list = [((1, 128, 128), (1, 128, 128), "float32", False, True), ((2, 128, 128), (2, 128, 128), "float32", False, True), ((4, 128, 128), (4, 128, 128), "float32", False, True), + ((6, 128, 128), (6, 128, 128), "float32", False, True), ((8, 128, 128), (8, 128, 128), "float32", False, True), ((16, 128, 128), (16, 128, 128), "float32", False, True), + ((24, 128, 128), (24, 128, 128), "float32", False, True), ((32, 128, 128), (32, 128, 128), 'float32', False, True) ] if input_shape in input_shape_list: