| @@ -92,57 +92,8 @@ def _inner_matmul_new_1_64_32_64(tik_instance, dtype, input1, input1_index, inpu | |||
| matmul_hybrid_f_t_local_UB, 0, 1, 4, 0, 0) | |||
| @op_info_register(cus_batchmatmul_op_info) | |||
| def CusBatchMatMul(input_x1, input_x2, output, transpose_a=False, transpose_b=True, kernel_name="batchmatmul"): | |||
| """CusBatchMatMul""" | |||
| if util.get_product_version() == util.VERSION_MINI: | |||
| tik_instance = tik.Tik(tik.Dprofile("v100", "mini")) | |||
| else: | |||
| tik_instance = tik.Tik(tik.Dprofile("v100", "cloud")) | |||
| x1_shape = input_x1.get("shape") | |||
| dtype = input_x1.get("dtype").lower() | |||
| x2_shape = input_x2.get("shape") | |||
| if dtype != input_x2.get("dtype").lower(): | |||
| raise RuntimeError("dtype of input_x1 and input_x2 must be same, but got %s vs %s" % ( | |||
| dtype, input_x2.get("dtype").lower())) | |||
| input_shape = (tuple(x1_shape), tuple(x2_shape), dtype, transpose_a, transpose_b) | |||
| support_shape = [((8, 128, 128), (8, 128, 128), "float32", False, True), | |||
| ((36, 128, 128), (36, 128, 128), "float32", False, True), | |||
| ((5, 128, 128), (5, 128, 128), "float32", False, True), | |||
| ((18, 128, 128), (18, 128, 128), "float32", False, True), | |||
| ((16, 128, 128), (16, 128, 128), "float32", False, True), | |||
| ((9, 128, 128), (9, 128, 128), "float32", False, True), | |||
| ((1, 64, 64), (1, 64, 64), "float32", False, True), | |||
| ((1, 128, 128), (1, 128, 128), "float32", False, True), | |||
| ((4, 128, 128), (4, 128, 128), "float32", False, True), | |||
| ((2, 128, 128), (2, 128, 128), "float32", False, True), | |||
| ((32, 128, 128), (32, 128, 128), 'float32', False, True)] | |||
| if input_shape not in support_shape: | |||
| raise RuntimeError("input_shape %s is not supported" % str(input_shape)) | |||
| # if not transpose_a and transpose_b: | |||
| batch, m, k = x1_shape | |||
| input1_shape = _get_flattern_shape(x1_shape) | |||
| input1 = tik_instance.Tensor(dtype, input1_shape, name="input1", scope=tik.scope_gm) | |||
| input2_shape = _get_flattern_shape(x2_shape) | |||
| input2 = tik_instance.Tensor(dtype, input2_shape, name="input2", scope=tik.scope_gm) | |||
| output_shape = x1_shape | |||
| res_shape = _get_flattern_shape(output_shape) | |||
| res = tik_instance.Tensor(dtype, res_shape, name="res", scope=tik.scope_gm) | |||
| if input_shape == ((36, 128, 128), (36, 128, 128), "float32", False, True): | |||
| with tik_instance.for_range(0, 18, block_num=18) as block_idx: | |||
| with tik_instance.for_range(0, 2) as cc0: | |||
| with tik_instance.for_range(0, 128, thread_num=2) as cc1: | |||
| input1_index = block_idx * 32768 + cc0 * 16384 + cc1 * 128 | |||
| input2_index = block_idx * 32768 + cc0 * 16384 | |||
| res_index = block_idx * 32768 + cc0 * 16384 + cc1 * 128 | |||
| _inner_matmul_new(tik_instance, dtype, | |||
| input1, input1_index, | |||
| input2, input2_index, | |||
| res, res_index) | |||
| def process_input_shape_640(input_shape, tik_instance, dtype, input1, input2, res): | |||
| """process input shape of 640""" | |||
| if input_shape == ((5, 128, 128), (5, 128, 128), "float32", False, True): | |||
| with tik_instance.for_range(0, 30, block_num=30) as block_idx: | |||
| with tik_instance.for_range(0, 11) as cc1_db: | |||
| @@ -189,17 +140,9 @@ def CusBatchMatMul(input_x1, input_x2, output, transpose_a=False, transpose_b=Tr | |||
| thread_idx * 128 + thread_idx2 * 64], | |||
| matmul_hybrid_f_t_local_UB, 0, 1, 8, 0, 0) | |||
| if input_shape == ((18, 128, 128), (18, 128, 128), "float32", False, True): | |||
| with tik_instance.for_range(0, 18, block_num=18) as block_idx: | |||
| with tik_instance.for_range(0, 128, thread_num=2) as cc0: | |||
| input1_index = block_idx * 16384 + cc0 * 128 | |||
| input2_index = block_idx * 16384 | |||
| res_index = block_idx * 16384 + cc0 * 128 | |||
| _inner_matmul_new(tik_instance, dtype, | |||
| input1, input1_index, | |||
| input2, input2_index, | |||
| res, res_index) | |||
| def process_input_shape_1152(input_shape, tik_instance, dtype, input1, input2, res): | |||
| """process input shape of 1152""" | |||
| if input_shape == ((9, 128, 128), (9, 128, 128), "float32", False, True): | |||
| with tik_instance.for_range(0, 27, block_num=27) as block_idx: | |||
| with tik_instance.for_range(0, 42, thread_num=2) as cc0: | |||
| @@ -219,6 +162,76 @@ def CusBatchMatMul(input_x1, input_x2, output, transpose_a=False, transpose_b=Tr | |||
| input2, input2_index, | |||
| res, res_index) | |||
| @op_info_register(cus_batchmatmul_op_info) | |||
| def CusBatchMatMul(input_x1, input_x2, output, transpose_a=False, transpose_b=True, kernel_name="batchmatmul"): | |||
| """CusBatchMatMul""" | |||
| if util.get_product_version() == util.VERSION_MINI: | |||
| tik_instance = tik.Tik(tik.Dprofile("v100", "mini")) | |||
| else: | |||
| tik_instance = tik.Tik(tik.Dprofile("v100", "cloud")) | |||
| x1_shape = input_x1.get("shape") | |||
| dtype = input_x1.get("dtype").lower() | |||
| x2_shape = input_x2.get("shape") | |||
| if dtype != input_x2.get("dtype").lower(): | |||
| raise RuntimeError("dtype of input_x1 and input_x2 must be same, but got %s vs %s" % ( | |||
| dtype, input_x2.get("dtype").lower())) | |||
| input_shape = (tuple(x1_shape), tuple(x2_shape), dtype, transpose_a, transpose_b) | |||
| support_shape = [((8, 128, 128), (8, 128, 128), "float32", False, True), | |||
| ((36, 128, 128), (36, 128, 128), "float32", False, True), | |||
| ((5, 128, 128), (5, 128, 128), "float32", False, True), | |||
| ((18, 128, 128), (18, 128, 128), "float32", False, True), | |||
| ((16, 128, 128), (16, 128, 128), "float32", False, True), | |||
| ((9, 128, 128), (9, 128, 128), "float32", False, True), | |||
| ((1, 64, 64), (1, 64, 64), "float32", False, True), | |||
| ((1, 128, 128), (1, 128, 128), "float32", False, True), | |||
| ((4, 128, 128), (4, 128, 128), "float32", False, True), | |||
| ((2, 128, 128), (2, 128, 128), "float32", False, True), | |||
| ((6, 128, 128), (6, 128, 128), "float32", False, True), | |||
| ((24, 128, 128), (24, 128, 128), "float32", False, True), | |||
| ((32, 128, 128), (32, 128, 128), 'float32', False, True)] | |||
| if input_shape not in support_shape: | |||
| raise RuntimeError("input_shape %s is not supported" % str(input_shape)) | |||
| # if not transpose_a and transpose_b: | |||
| batch, m, k = x1_shape | |||
| input1_shape = _get_flattern_shape(x1_shape) | |||
| input1 = tik_instance.Tensor(dtype, input1_shape, name="input1", scope=tik.scope_gm) | |||
| input2_shape = _get_flattern_shape(x2_shape) | |||
| input2 = tik_instance.Tensor(dtype, input2_shape, name="input2", scope=tik.scope_gm) | |||
| output_shape = x1_shape | |||
| res_shape = _get_flattern_shape(output_shape) | |||
| res = tik_instance.Tensor(dtype, res_shape, name="res", scope=tik.scope_gm) | |||
| if input_shape == ((36, 128, 128), (36, 128, 128), "float32", False, True): | |||
| with tik_instance.for_range(0, 18, block_num=18) as block_idx: | |||
| with tik_instance.for_range(0, 2) as cc0: | |||
| with tik_instance.for_range(0, 128, thread_num=2) as cc1: | |||
| input1_index = block_idx * 32768 + cc0 * 16384 + cc1 * 128 | |||
| input2_index = block_idx * 32768 + cc0 * 16384 | |||
| res_index = block_idx * 32768 + cc0 * 16384 + cc1 * 128 | |||
| _inner_matmul_new(tik_instance, dtype, | |||
| input1, input1_index, | |||
| input2, input2_index, | |||
| res, res_index) | |||
| process_input_shape_640(input_shape, tik_instance, dtype, input1, input2, res) | |||
| if input_shape == ((18, 128, 128), (18, 128, 128), "float32", False, True): | |||
| with tik_instance.for_range(0, 18, block_num=18) as block_idx: | |||
| with tik_instance.for_range(0, 128, thread_num=2) as cc0: | |||
| input1_index = block_idx * 16384 + cc0 * 128 | |||
| input2_index = block_idx * 16384 | |||
| res_index = block_idx * 16384 + cc0 * 128 | |||
| _inner_matmul_new(tik_instance, dtype, | |||
| input1, input1_index, | |||
| input2, input2_index, | |||
| res, res_index) | |||
| process_input_shape_1152(input_shape, tik_instance, dtype, input1, input2, res) | |||
| if input_shape == ((1, 64, 64), (1, 64, 64), "float32", False, True): | |||
| with tik_instance.for_range(0, 32, block_num=32) as block_idx: | |||
| with tik_instance.for_range(0, 2, thread_num=2) as cc0: | |||
| @@ -233,8 +246,10 @@ def CusBatchMatMul(input_x1, input_x2, output, transpose_a=False, transpose_b=Tr | |||
| input_shape_list = [((1, 128, 128), (1, 128, 128), "float32", False, True), | |||
| ((2, 128, 128), (2, 128, 128), "float32", False, True), | |||
| ((4, 128, 128), (4, 128, 128), "float32", False, True), | |||
| ((6, 128, 128), (6, 128, 128), "float32", False, True), | |||
| ((8, 128, 128), (8, 128, 128), "float32", False, True), | |||
| ((16, 128, 128), (16, 128, 128), "float32", False, True), | |||
| ((24, 128, 128), (24, 128, 128), "float32", False, True), | |||
| ((32, 128, 128), (32, 128, 128), 'float32', False, True) | |||
| ] | |||
| if input_shape in input_shape_list: | |||