#!/usr/bin/env python3 # coding: utf-8 # Copyright 2019 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """operator dsl function: batchmatmul""" from functools import reduce import akg.topi import akg.tvm from akg.tvm.hybrid import script from akg.ops.math import cast from akg.utils import custom_tiling as ct_util from akg.utils import validation_check as vc_util from akg.utils.format_transform import get_shape, get_bytes from akg.utils.math import greatest_common_divisor, least_common_multiple from akg.utils.kernel_exec import product_is_mini from akg.utils import dynamic_shape as ds batchmatmul_set_dim_map = { # 2D str((256, 1024, 4096, "float32", False, True)): ((16, 16), (16, 16), (16, 16)), str((160, 1024, 1024, "float32", False, False)): ((1, 1), (16, 16), (1024, 1024)), str((8192, 1024, 4096, "float32", False, False)): ((1, 1), (1024, 1024), (16, 16)), str((1024, 1024, 8192, "float32", True, False)): ((8, 8), (8, 8), (512, 512)), str((1024, 1024, 2, "float32", False, False)): ((64, 64), (64, 64), (2, 2)), str((1024, 1024, 4096, "float32", False, False)): ((1, 1), (16, 16), (512, 512)), str((2, 1024, 8192, "float32", True, False)): ((2, 2), (16, 16), (512, 512)), str((30522, 1024, 1280, "float32", True, False)): ((3, 3), (64, 64), (128, 128)), str((1024, 4096, 8192, "float32", True, False)): ((32, 32), (32, 32), (32, 32)), str((2, 1024, 64, "float32", True, False)): ((4, 4), (64, 64), (64, 64)), str((160, 30522, 1024, "float32", False, True)): ((1, 1), (3, 3), (512, 512)), str((1024, 1024, 64, "float32", True, False)): ((16, 16), (16, 16), (64, 64)), str((4096, 1024, 8192, "float32", True, False)): ((16, 16), (64, 64), (16, 16)), str((1280, 1024, 30522, "float32", False, False)): ((1, 1), (512, 512), (3, 3)), str((8192, 1024, 4096, "float32", False, True)): ((1, 1), (16, 16), (1024, 1024)), str((1280, 30522, 1024, "float32", False, True)): ((1, 1), (3, 3), (512, 512)), str((8192, 4096, 1024, "float32", False, False)): ((1, 1), (64, 64), (256, 256)), str((768, 768, 8192, "float16", True, False)): ((16, 16), (16, 16), (64, 64)), str((3072, 768, 8192, "float16", True, False)): ((16, 16), (16, 16), (64, 64)), str((2, 768, 64, "float32", True, False)): ((2, 2), (64, 64), (64, 64)), str((768, 1024, 8192, "float16", False, False)): ((16, 1), (16, 1), (64, 1)), str((33, 64, 16384, "float32", True, False)): ((1, 1), (16, 16), (128, 128)), str((768, 768, 64, "float32", True, False)): ((16, 16), (16, 16), (16, 16)), str((8192, 768, 21128, "float32", False, False)): ((4, 4), (128, 128), (4, 4)), str((2, 768, 8192, "float32", True, False)): ((2, 2), (16, 16), (512, 512)), str((8192, 768, 768, "float16", False, True)): ((1, 1), (16, 16), (768, 768)), str((21128, 768, 8192, "float32", True, False)): ((4, 4), (128, 128), (64, 64)), str((768, 1024, 768, "float32", True, False)): ((2, 2), (128, 128), (128, 128)), str((16384, 16384, 33, "float32", True, False)): ((16, 16), (64, 64), (33, 33)), str((21128, 768, 8192, "float32", False, False)): ((4, 4), (128, 128), (64, 64)), str((1280, 1280, 1024, "float32", False, True)): ((4, 4), (32, 32), (128, 128)), str((1280, 768, 21128, "float32", False, False)): ((1, 1), (768, 768), (8, 8)), str((8192, 768, 768, "float32", False, False)): ((1, 1), (8, 8), (768, 768)), str((20, 768, 32000, "float32", False, False)): ((20, 20), (48, 48), (32, 32)), str((21128, 768, 1280, "float32", True, False)): ((2, 2), (32, 32), (32, 32)), str((768, 3072, 1892, "float32", True, False)): ((16, 16), (16, 16), (16, 16)), str((33, 64, 16384, "float32", False, True)): ((2, 2), (32, 32), (32, 32)), str((8192, 3072, 768, "float32", False, True)): ((16, 16), (16, 16), (16, 16)), str((2, 8192, 768, "float32", True, False)): ((16, 16), (16, 16), (16, 16)), str((8192, 768, 3072, "float32", False, False)): ((32, 32), (32, 32), (32, 32)), str((8192, 768, 3072, "float16", False, True)): ((1, 1), (16, 1), (768, 1)), str((8192, 3072, 768, "float16", False, True)): ((1, 1), (16, 1), (768, 1)), str((8192, 768, 768, "float16", False, False)): ((1, 1), (16, 1), (768, 1)), str((768, 3072, 768, "float16", False, False)): ((1, 1), (16, 1), (768, 1)), str((8192, 3072, 768, "float16", False, False)): ((1, 1), (16, 1), (768, 1)), str((8192, 768, 3072, "float16", False, False)): ((1, 1), (16, 1), (768, 1)), str((21128, 768, 5120, "float32", True, False)): ((4, 4), (128, 128), (64, 64)), str((21128, 768, 2560, "float32", True, False)): ((4, 4), (128, 128), (64, 64)), str((1024, 2, 4, "float32", True, False)): ((512, 1), (2, 1), (4, 1)), str((21128, 1024, 21128, "float32", False, False)): ((4, 4), (512, 512), (4, 4)), str((320, 768, 21128, "float32", False, False)): ((40, 40), (128, 128), (4, 4)), str((5120, 1024, 21128, "float32", False, False)): ((64, 64), (128, 128), (4, 4)), str((16384, 4096, 1024, "float32", False, False)): ((4, 4), (64, 64), (64, 64)), str((1024, 4096, 16384, "float32", True, False)): ((64, 64), (64, 64), (4, 4)), str((768, 3072, 8192, "float16", True, False)): ((1, 1), (16, 1), (512, 1)), str((2048, 768, 3072, "float32", False, False)): ((8, 1), (16, 1), (128, 1)), str((1024, 2, 128, "float32", True, False)): ((32, 1), (2, 1), (128, 1)), str((1024, 2, 16, "float32", True, False)): ((256, 1), (2, 1), (16, 1)), str((1024, 2, 32, "float32", True, False)): ((128, 1), (2, 1), (32, 1)), str((1024, 2, 64, "float32", True, False)): ((64, 1), (2, 1), (64, 1)), str((1024, 2, 8, "float32", True, False)): ((512, 1), (2, 1), (8, 1)), str((768, 2, 128, "float32", True, False)): ((32, 1), (2, 1), (128, 1)), str((768, 2, 16, "float32", True, False)): ((256, 1), (2, 1), (16, 1)), str((768, 2, 32, "float32", True, False)): ((128, 1), (2, 1), (32, 1)), str((768, 2, 64, "float32", True, False)): ((64, 1), (2, 1), (64, 1)), str((768, 3072, 2048, "float32", True, False)): ((16, 1), (16, 1), (64, 1)), str((3072, 768, 2048, "float32", True, False)): ((16, 1), (16, 1), (64, 1)), str((65536, 1024, 4096, "float32", False, False)): ((1, 1), (1024, 1), (16, 1)), str((10240, 1024, 21128, "float32", False, False)): ((1, 1), (1024, 1), (16, 1)), str((10240, 21128, 1024, "float32", False, False)): ((1, 1), (1024, 1), (16, 1)), str((21128, 768, 21128, "float32", True, False)): ((1, 1), (768, 1), (16, 1)), str((2560, 21128, 1024, "float32", False, False)): ((1, 1), (1024, 1), (16, 1)), str((5120, 21128, 1024, "float32", False, False)): ((1, 1), (1024, 1), (16, 1)), str((32768, 4096, 1024, "float32", False, False)): ((1, 1), (1024, 1), (16, 1)), str((20480, 1024, 21128, "float32", False, False)): ((1, 1), (1024, 1), (16, 1)), str((128, 1024, 4096, "float32", False, True)): ((1, 1), (16, 1), (16, 1)), str((128, 1024, 4096, "float16", False, True)): ((1, 1), (16, 1), (32, 1)), str((128, 4096, 1024, "float16", False, True)): ((1, 1), (32, 1), (32, 1)), str((20480, 21128, 1024, "float32", False, False)): ((1, 1), (1024, 1), (16, 1)), str((65536, 4096, 1024, "float32", False, False)): ((1, 1), (1024, 1), (16, 1)), str((512, 1024, 4096, "float32", False, True)): ((1, 1), (8, 1), (1024, 1)), str((256, 1024, 4096, "float32", False, True)): ((1, 1), (8, 1), (1024, 1)), str((1024, 1024, 4096, "float32", False, True)): ((1, 1), (8, 1), (1024, 1)), str((2048, 1024, 4096, "float32", False, True)): ((1, 1), (8, 1), (1024, 1)), str((2, 2, 1024, "float32", False, True)): ((1, 1), (2, 1), (1024, 1)), str((65536, 768, 3072, "float32", False, True)): ((1, 1), (192, 1), (96, 1)), str((16384, 1024, 4096, "float32", False, True)): ((1, 1), (128, 1), (128, 1)), str((1024, 4096, 32768, "float32", True, False)): ((8, 1), (1024, 1), (4, 1)), str((4, 2, 1024, "float32", False, True)): ((1, 1), (2, 1), (1024, 1)), str((2, 2, 768, "float32", False, True)): ((1, 1), (2, 1), (768, 1)), str((8, 2, 768, "float32", False, True)): ((1, 1), (2, 1), (768, 1)), str((4, 2, 768, "float32", False, True)): ((1, 1), (2, 1), (768, 1)), # lenet5 str((32, 10, 84, 'float16', False, True)): ((1, 1), (16, 1), (84, 1)), # alexnet str((32, 4096, 9216, 'float32', False, False)): ((1, 1), (16, 1), (1024, 1)), str((32, 10, 4096, 'float32', False, False)): ((1, 1), (16, 1), (1024, 1)), # 3D str((128, 128, 64, 1536, "float32", True, False)): ((4, 4), (8, 8), (16, 16), (32, 32)), str((128, 768, 128, 64, "float32", False, True)): ((4, 4), (8, 8), (16, 16), (64, 64)), str((128, 128, 64, 6144, "float32", True, False)): ((4, 4), (8, 8), (16, 16), (32, 32)), str((128, 128, 64, 16384, "float32", True, False)): ((1, 1), (128, 128), (64, 64), (4, 4)), str((128, 128, 64, 2048, "float32", True, False)): ((1, 1), (128, 128), (64, 64), (4, 4)), str((128, 128, 64, 4096, "float32", True, False)): ((1, 1), (128, 128), (64, 64), (4, 4)), str((128, 128, 64, 8192, "float32", True, False)): ((1, 1), (128, 128), (64, 64), (4, 4)), str((128, 128, 64, 4096, "float32", True, False)): ((1, 1), (8, 1), (64, 1), (64, 1)), str((128, 128, 64, 12288, "float32", True, False)): ((1, 1), (8, 1), (64, 1), (32, 1)), # 4D str((64, 12, 128, 64, 128, "float32", False, False)): ((1, 1), (12, 12), (8, 8), (8, 8), (32, 32)), str((1, 768, 2, "float32", False, False)): ((768, 1), (1, 1)), str((20, 768, 21128, "float32", False, False)): ((1, 1), (768, 1), (16, 1)), str((128, 12, 64, 128, "float32", False, False)): ((1, 1), (1, 1), (64, 1), (128, 1)), str((128, 3072, 768, "float32", False, False)): ((1, 1), (768, 1), (32, 1)), str((1, 12, 128, 128, 64, "float32", False, True)): ((1, 1), (2, 1), (128, 1), (64, 1)), str((128, 128, 64, 12, "float32", True, False)): ((1, 1), (4, 1), (64, 1), (12, 1)), str((1, 12, 128, 64, 128, "float32", False, False)): ((1, 1), (1, 1), (64, 1), (32, 1)), str((20, 21128, 768, "float32", False, True)): ((1, 1), (16, 1), (768, 1)), str((1, 12, 128, 64, 128, "float32", True, False)): ((1, 1), (8, 1), (64, 1), (32, 1)), str((20, 768, 768, "float32", False, False)): ((1, 1), (768, 1), (16, 1)), str((128, 768, 3072, "float32", False, False)): ((1, 1), (768, 1), (16, 1)), str((1, 768, 768, "float32", False, False)): ((768, 1), (16, 1)), str((21128, 768, 20, "float32", True, False)): ((8, 1), (768, 1), (8, 1)), str((128, 12, 128, 64, "float32", False, True)): ((1, 1), (4, 1), (128, 1), (64, 1)), str((128, 768, 768, "float32", False, False)): ((1, 1), (768, 1), (16, 1)), str((40, 768, 768, "float32", False, False)): ((2, 1), (768, 1), (16, 1)), str((2, 12, 128, 128, 64, "float32", False, True)): ((1, 1), (1, 1), (2, 1), (128, 1), (64, 1)), str((2, 768, 2, "float32", False, False)): ((2, 1), (768, 1), (2, 1)), str((40, 21128, 768, "float32", False, True)): ((1, 1), (16, 1), (768, 1)), str((256, 768, 768, "float32", False, False)): ((2, 1), (768, 1), (8, 1)), str((2, 12, 128, 64, 128, "float32", True, False)): ((1, 1), (1, 1), (8, 1), (64, 1), (32, 1)), str((2, 12, 128, 64, 128, "float32", False, False)): ((1, 1), (1, 1), (1, 1), (64, 1), (32, 1)), str((128, 128, 64, 24, "float32", True, False)): ((1, 1), (2, 1), (64, 1), (12, 1)), str((256, 3072, 768, "float32", False, False)): ((1, 1), (768, 1), (32, 1)), str((128, 24, 128, 64, "float32", False, True)): ((1, 1), (4, 1), (64, 1), (64, 1)), str((2, 768, 768, "float32", False, False)): ((1, 1), (768, 1), (16, 1)), str((21128, 768, 40, "float32", True, False)): ((8, 1), (768, 1), (8, 1)), str((40, 768, 21128, "float32", False, False)): ((2, 1), (768, 1), (16, 1)), str((256, 768, 3072, "float32", False, False)): ((1, 1), (768, 1), (16, 1)), str((128, 24, 64, 128, "float32", False, False)): ((1, 1), (1, 1), (64, 1), (128, 1)), str((128, 48, 128, 64, "float32", False, True)): ((1, 1), (4, 1), (64, 1), (64, 1)), str((512, 3072, 768, "float32", False, False)): ((1, 1), (768, 1), (32, 1)), str((128, 48, 64, 128, "float32", False, False)): ((1, 1), (1, 1), (64, 1), (128, 1)), str((512, 768, 768, "float32", False, False)): ((2, 1), (768, 1), (8, 1)), str((4, 12, 128, 64, 128, "float32", True, False)): ((1, 1), (1, 1), (8, 1), (64, 1), (32, 1)), str((512, 768, 3072, "float32", False, False)): ((2, 1), (768, 1), (8, 1)), str((4, 768, 2, "float32", False, False)): ((1, 1), (768, 1), (2, 1)), str((80, 768, 768, "float32", False, False)): ((1, 1), (768, 1), (16, 1)), str((4, 12, 128, 64, 128, "float32", False, False)): ((1, 1), (1, 1), (1, 1), (64, 1), (32, 1)), str((21128, 768, 80, "float32", True, False)): ((8, 1), (768, 1), (8, 1)), str((80, 768, 21128, "float32", False, False)): ((2, 1), (768, 1), (8, 1)), str((128, 128, 64, 48, "float32", True, False)): ((1, 1), (4, 1), (64, 1), (12, 1)), str((4, 768, 768, "float32", False, False)): ((1, 1), (768, 1), (16, 1)), str((4, 12, 128, 128, 64, "float32", False, True)): ((1, 1), (1, 1), (2, 1), (128, 1), (64, 1)), str((80, 21128, 768, "float32", False, True)): ((1, 1), (16, 1), (768, 1)), str((128, 96, 64, 128, "float32", False, False)): ((1, 1), (1, 1), (64, 1), (128, 1)), str((1024, 3072, 768, "float32", False, False)): ((1, 1), (768, 1), (32, 1)), str((160, 21128, 768, "float32", False, True)): ((1, 1), (16, 1), (768, 1)), str((8, 12, 128, 64, 128, "float32", False, False)): ((1, 1), (1, 1), (1, 1), (64, 1), (32, 1)), str((8, 768, 2, "float32", False, False)): ((8, 1), (768, 1), (2, 1)), str((8, 768, 768, "float32", False, False)): ((1, 1), (768, 1), (16, 1)), str((1024, 768, 768, "float32", False, False)): ((2, 1), (768, 1), (8, 1)), str((128, 96, 128, 64, "float32", False, True)): ((1, 1), (4, 1), (64, 1), (64, 1)), str((128, 128, 64, 96, "float32", True, False)): ((1, 1), (4, 1), (64, 1), (12, 1)), str((8, 12, 128, 128, 64, "float32", False, True)): ((1, 1), (1, 1), (2, 1), (128, 1), (64, 1)), str((1024, 768, 3072, "float32", False, False)): ((2, 1), (768, 1), (8, 1)), str((21128, 768, 160, "float32", True, False)): ((8, 1), (768, 1), (8, 1)), str((160, 768, 768, "float32", False, False)): ((1, 1), (768, 1), (16, 1)), str((8, 12, 128, 64, 128, "float32", True, False)): ((1, 1), (1, 1), (8, 1), (64, 1), (32, 1)), str((160, 768, 21128, "float32", False, False)): ((2, 1), (768, 1), (8, 1)), str((16, 768, 2, "float32", False, False)): ((2, 1), (768, 1), (2, 1)), str((320, 21128, 768, "float32", False, True)): ((1, 1), (16, 1), (768, 1)), str((16, 12, 128, 64, 128, "float32", True, False)): ((1, 1), (1, 1), (8, 1), (64, 1), (32, 1)), str((2048, 3072, 768, "float32", False, False)): ((1, 1), (768, 1), (32, 1)), str((16, 12, 128, 64, 128, "float32", False, False)): ((1, 1), (1, 1), (1, 1), (64, 1), (32, 1)), str((128, 128, 64, 192, "float32", True, False)): ((1, 1), (4, 1), (64, 1), (12, 1)), str((128, 192, 64, 128, "float32", False, False)): ((1, 1), (1, 1), (64, 1), (128, 1)), str((21128, 768, 320, "float32", True, False)): ((8, 1), (768, 1), (8, 1)), str((128, 192, 128, 64, "float32", False, True)): ((1, 1), (4, 1), (64, 1), (64, 1)), str((320, 768, 21128, "float32", False, False)): ((2, 1), (768, 1), (8, 1)), str((16, 12, 128, 128, 64, "float32", False, True)): ((1, 1), (1, 1), (2, 1), (128, 1), (64, 1)), str((320, 768, 768, "float32", False, False)): ((1, 1), (768, 1), (16, 1)), str((2048, 768, 3072, "float32", False, False)): ((2, 1), (768, 1), (8, 1)), str((2048, 768, 768, "float32", False, False)): ((2, 1), (768, 1), (8, 1)), str((16, 768, 768, "float32", False, False)): ((1, 1), (768, 1), (16, 1)), str((4096, 768, 768, "float32", False, False)): ((2, 1), (768, 1), (8, 1)), str((128, 128, 64, 384, "float32", True, False)): ((1, 1), (4, 1), (64, 1), (12, 1)), str((4096, 3072, 768, "float32", False, False)): ((1, 1), (768, 1), (32, 1)), str((640, 768, 768, "float32", False, False)): ((1, 1), (768, 1), (16, 1)), str((128, 384, 128, 64, "float32", False, True)): ((1, 1), (4, 1), (64, 1), (64, 1)), str((128, 384, 64, 128, "float32", False, False)): ((1, 1), (1, 1), (64, 1), (128, 1)), str((32, 12, 128, 128, 64, "float32", False, True)): ((1, 1), (1, 1), (2, 1), (128, 1), (64, 1)), str((32, 12, 128, 64, 128, "float32", True, False)): ((1, 1), (1, 1), (8, 1), (64, 1), (32, 1)), str((32, 12, 128, 64, 128, "float32", False, False)): ((1, 1), (1, 1), (1, 1), (64, 1), (32, 1)), str((32, 768, 768, "float32", False, False)): ((1, 1), (768, 1), (16, 1)), str((4096, 768, 3072, "float32", False, False)): ((2, 1), (768, 1), (8, 1)), str((640, 768, 21128, "float32", False, False)): ((2, 1), (768, 1), (8, 1)), str((32, 768, 2, "float32", False, False)): ((4, 1), (768, 1), (2, 1)), str((21128, 768, 640, "float32", True, False)): ((8, 1), (768, 1), (8, 1)), str((640, 21128, 768, "float32", False, True)): ((1, 1), (16, 1), (768, 1)), str((64, 12, 128, 64, 128, "float32", False, False)): ((1, 1), (1, 1), (1, 1), (64, 1), (32, 1)), str((8192, 768, 3072, "float32", False, False)): ((2, 1), (768, 1), (8, 1)), str((1280, 768, 768, "float32", False, False)): ((1, 1), (768, 1), (16, 1)), str((8192, 3072, 768, "float32", False, False)): ((1, 1), (768, 1), (32, 1)), str((128, 768, 128, 64, "float32", False, True)): ((1, 1), (4, 1), (64, 1), (64, 1)), str((128, 128, 64, 768, "float32", True, False)): ((1, 1), (4, 1), (64, 1), (12, 1)), str((21128, 768, 1280, "float32", True, False)): ((8, 1), (768, 1), (8, 1)), str((1280, 768, 21128, "float32", False, False)): ((2, 1), (768, 1), (8, 1)), str((64, 12, 128, 64, 128, "float32", True, False)): ((1, 1), (1, 1), (8, 1), (64, 1), (32, 1)), str((1280, 21128, 768, "float32", False, True)): ((1, 1), (16, 1), (768, 1)), str((64, 768, 768, "float32", False, False)): ((2, 1), (768, 1), (16, 1)), str((64, 12, 128, 128, 64, "float32", False, True)): ((1, 1), (1, 1), (2, 1), (128, 1), (64, 1)), str((64, 768, 2, "float32", False, False)): ((8, 1), (768, 1), (2, 1)), str((8192, 768, 768, "float32", False, False)): ((2, 1), (768, 1), (8, 1)), str((128, 768, 64, 128, "float32", False, False)): ((1, 1), (1, 1), (64, 1), (128, 1)), str((128, 12, 128, 128, 64, "float32", False, True)): ((1, 1), (1, 1), (2, 1), (128, 1), (64, 1)), str((16384, 768, 3072, "float32", False, False)): ((2, 1), (768, 1), (8, 1)), str((16384, 768, 768, "float32", False, False)): ((2, 1), (768, 1), (8, 1)), str((128, 1536, 128, 64, "float32", False, True)): ((1, 1), (4, 1), (64, 1), (64, 1)), str((128, 768, 2, "float32", False, False)): ((16, 1), (768, 1), (2, 1)), str((2560, 768, 768, "float32", False, False)): ((1, 1), (768, 1), (16, 1)), str((21128, 768, 2560, "float32", True, False)): ((8, 1), (768, 1), (8, 1)), str((128, 12, 128, 64, 128, "float32", False, False)): ((1, 1), (1, 1), (1, 1), (64, 1), (32, 1)), str((2560, 768, 21128, "float32", False, False)): ((2, 1), (768, 1), (8, 1)), str((128, 128, 64, 1536, "float32", True, False)): ((1, 1), (4, 1), (64, 1), (12, 1)), str((128, 1536, 64, 128, "float32", False, False)): ((1, 1), (1, 1), (64, 1), (128, 1)), str((2560, 21128, 768, "float32", False, True)): ((1, 1), (16, 1), (768, 1)), str((128, 12, 128, 64, 128, "float32", True, False)): ((1, 1), (1, 1), (8, 1), (64, 1), (32, 1)), str((16384, 3072, 768, "float32", False, False)): ((1, 1), (768, 1), (32, 1)), str((1, 16, 128, 64, 128, "float32", False, False)): ((1, 1), (4, 1), (64, 1), (8, 1)), str((1, 16, 128, 64, 128, "float32", True, False)): ((1, 1), (8, 1), (64, 1), (4, 1)), str((128, 16, 128, 64, "float32", False, True)): ((1, 1), (2, 1), (128, 1), (64, 1)), str((128, 1024, 1024, "float32", False, False)): ((2, 1), (1024, 1), (8, 1)), str((20, 1024, 21128, "float32", False, False)): ((1, 1), (1024, 1), (8, 1)), str((128, 4096, 1024, "float32", False, False)): ((2, 1), (1024, 1), (8, 1)), str((128, 1024, 4096, "float32", False, False)): ((2, 1), (1024, 1), (8, 1)), str((128, 16, 64, 128, "float32", False, False)): ((1, 1), (1, 1), (64, 1), (32, 1)), str((20, 21128, 1024, "float32", False, True)): ((1, 1), (16, 1), (1024, 1)), str((1, 16, 128, 128, 64, "float32", False, True)): ((1, 1), (4, 1), (64, 1), (64, 1)), str((1, 1024, 2, "float32", False, False)): ((1024, 1), (2, 1)), str((21128, 1024, 20, "float32", True, False)): ((8, 1), (1024, 1), (2, 1)), str((20, 1024, 1024, "float32", False, False)): ((1, 1), (1024, 1), (8, 1)), str((128, 128, 64, 16, "float32", True, False)): ((1, 1), (4, 1), (64, 1), (8, 1)), str((1, 1024, 1024, "float32", False, False)): ((1024, 1), (8, 1)), str((128, 1024, 1024, "float32", False, True)): ((2, 1), (8, 1), (1024, 1)), str((1, 1024, 1024, "float32", False, True)): ((8, 1), (1024, 1)), str((20, 1024, 1024, "float32", False, True)): ((1, 1), (8, 1), (1024, 1)), str((128, 1024, 4096, "float32", False, True)): ((1, 1), (8, 1), (2048, 1)), str((128, 4096, 1024, "float32", False, True)): ((2, 1), (8, 1), (1024, 1)), str((1, 2, 1024, "float32", False, True)): ((2, 1), (1024, 1)), str((2, 1024, 1, "float32", True, False)): ((2, 1), (1024, 1)), str((1024, 1024, 128, "float32", True, False)): ((8, 1), (1024, 1), (2, 1)), str((1024, 1024, 1, "float32", True, False)): ((16, 1), (1024, 1)), str((1024, 1024, 20, "float32", True, False)): ((8, 1), (1024, 1), (2, 1)), str((4096, 1024, 128, "float32", True, False)): ((8, 1), (1024, 1), (2, 1)), str((1024, 4096, 128, "float32", True, False)): ((4, 1), (2048, 1), (2, 1)), str((128, 128, 64, 32, "float32", True, False)): ((1, 1), (2, 1), (64, 1), (16, 1)), str((128, 32, 128, 64, "float32", False, True)): ((1, 1), (2, 1), (128, 1), (64, 1)), str((256, 4096, 1024, "float32", False, False)): ((4, 1), (2048, 1), (2, 1)), str((256, 1024, 1024, "float32", False, False)): ((4, 1), (1024, 1), (4, 1)), str((40, 21128, 1024, "float32", False, True)): ((1, 1), (16, 1), (1024, 1)), str((2, 16, 128, 64, 128, "float32", False, False)): ((1, 1), (1, 1), (2, 1), (64, 1), (16, 1)), str((2, 16, 128, 64, 128, "float32", True, False)): ((1, 1), (1, 1), (4, 1), (64, 1), (8, 1)), str((21128, 1024, 40, "float32", True, False)): ((8, 1), (1024, 1), (2, 1)), str((2, 1024, 2, "float32", False, False)): ((2, 1), (1024, 1), (2, 1)), str((256, 1024, 4096, "float32", False, False)): ((2, 1), (1024, 1), (8, 1)), str((40, 1024, 21128, "float32", False, False)): ((2, 1), (1024, 1), (8, 1)), str((2, 1024, 1024, "float32", False, False)): ((1, 1), (1024, 1), (8, 1)), str((40, 1024, 1024, "float32", False, False)): ((2, 1), (1024, 1), (8, 1)), str((128, 32, 64, 128, "float32", False, False)): ((1, 1), (1, 1), (64, 1), (32, 1)), str((2, 16, 128, 128, 64, "float32", False, True)): ((1, 1), (1, 1), (4, 1), (64, 1), (64, 1)), str((2, 1024, 1024, "float32", False, True)): ((1, 1), (8, 1), (1024, 1)), str((256, 1024, 1024, "float32", False, True)): ((2, 1), (8, 1), (1024, 1)), str((256, 4096, 1024, "float32", False, True)): ((2, 1), (8, 1), (1024, 1)), str((2, 2, 1024, "float32", False, True)): ((1, 1), (2, 1), (1024, 1)), str((256, 1024, 4096, "float32", False, True)): ((1, 1), (8, 1), (2048, 1)), str((40, 1024, 1024, "float32", False, True)): ((2, 1), (8, 1), (1024, 1)), str((1024, 1024, 256, "float32", True, False)): ((8, 1), (1024, 1), (2, 1)), str((1024, 1024, 40, "float32", True, False)): ((8, 1), (1024, 1), (2, 1)), str((4096, 1024, 256, "float32", True, False)): ((8, 1), (1024, 1), (2, 1)), str((2, 1024, 2, "float32", True, False)): ((2, 1), (1024, 1), (2, 1)), str((1024, 1024, 2, "float32", True, False)): ((16, 1), (1024, 1), (1, 1)), str((1024, 4096, 256, "float32", True, False)): ((4, 1), (2048, 1), (2, 1)), str((4, 16, 128, 128, 64, "float32", False, True)): ((1, 1), (1, 1), (4, 1), (64, 1), (64, 1)), str((4, 16, 128, 64, 128, "float32", True, False)): ((1, 1), (1, 1), (4, 1), (64, 1), (8, 1)), str((80, 21128, 1024, "float32", False, True)): ((1, 1), (16, 1), (1024, 1)), str((4, 1024, 1024, "float32", False, False)): ((1, 1), (1024, 1), (8, 1)), str((128, 64, 128, 64, "float32", False, True)): ((1, 1), (2, 1), (128, 1), (64, 1)), str((128, 128, 64, 64, "float32", True, False)): ((1, 1), (2, 1), (64, 1), (16, 1)), str((21128, 1024, 80, "float32", True, False)): ((8, 1), (1024, 1), (2, 1)), str((4, 1024, 2, "float32", False, False)): ((4, 1), (1024, 1), (2, 1)), str((128, 64, 64, 128, "float32", False, False)): ((1, 1), (1, 1), (64, 1), (32, 1)), str((80, 1024, 21128, "float32", False, False)): ((1, 1), (1024, 1), (8, 1)), str((512, 1024, 1024, "float32", False, False)): ((4, 1), (1024, 1), (4, 1)), str((512, 1024, 4096, "float32", False, False)): ((2, 1), (1024, 1), (8, 1)), str((512, 4096, 1024, "float32", False, False)): ((4, 1), (2048, 1), (2, 1)), str((4, 16, 128, 64, 128, "float32", False, False)): ((1, 1), (1, 1), (2, 1), (64, 1), (16, 1)), str((80, 1024, 1024, "float32", False, False)): ((1, 1), (1024, 1), (8, 1)), str((512, 4096, 1024, "float32", False, True)): ((2, 1), (8, 1), (1024, 1)), str((80, 1024, 1024, "float32", False, True)): ((1, 1), (8, 1), (1024, 1)), str((4, 1024, 1024, "float32", False, True)): ((1, 1), (8, 1), (1024, 1)), str((4, 2, 1024, "float32", False, True)): ((1, 1), (2, 1), (1024, 1)), str((512, 1024, 1024, "float32", False, True)): ((2, 1), (8, 1), (1024, 1)), str((512, 1024, 4096, "float32", False, True)): ((1, 1), (8, 1), (2048, 1)), str((1024, 1024, 512, "float32", True, False)): ((4, 1), (2048, 1), (2, 1)), str((2, 1024, 4, "float32", True, False)): ((2, 1), (1024, 1), (4, 1)), str((4096, 1024, 512, "float32", True, False)): ((8, 1), (1024, 1), (2, 1)), str((1024, 1024, 4, "float32", True, False)): ((16, 1), (1024, 1), (1, 1)), str((1024, 1024, 80, "float32", True, False)): ((8, 1), (1024, 1), (2, 1)), str((1024, 4096, 512, "float32", True, False)): ((4, 1), (2048, 1), (2, 1)), str((4096, 768, 3072, "float32", False, True)): ((1, 1), (8, 1), (3072, 1)), str((4096, 3072, 768, "float32", False, True)): ((1, 1), (16, 1), (768, 1)), str((4096, 768, 768, "float32", False, True)): ((1, 1), (16, 1), (768, 1)), str((8192, 768, 768, "float32", False, True)): ((1, 1), (16, 1), (768, 1)), str((8192, 3072, 768, "float32", False, True)): ((1, 1), (16, 1), (768, 1)), str((8192, 768, 3072, "float32", False, True)): ((1, 1), (8, 1), (3072, 1)), str((16384, 3072, 768, "float32", False, True)): ((1, 1), (16, 1), (768, 1)), str((16384, 768, 768, "float32", False, True)): ((1, 1), (16, 1), (768, 1)), str((16384, 768, 3072, "float32", False, True)): ((1, 1), (8, 1), (3072, 1)), str((3072, 768, 128, "float32", True, False)): ((4, 1), (768, 1), (4, 1)), str((768, 768, 1, "float32", True, False)): ((16, 1), (768, 1)), str((768, 768, 20, "float32", True, False)): ((8, 1), (768, 1), (2, 1)), str((768, 3072, 128, "float32", True, False)): ((4, 1), (3072, 1), (1, 1)), str((2, 768, 1, "float32", True, False)): ((2, 1), (768, 1)), str((768, 768, 128, "float32", True, False)): ((8, 1), (768, 1), (2, 1)), str((768, 768, 2, "float32", True, False)): ((4, 1), (768, 1), (2, 1)), str((768, 768, 40, "float32", True, False)): ((8, 1), (768, 1), (2, 1)), str((3072, 768, 256, "float32", True, False)): ((4, 1), (768, 1), (4, 1)), str((768, 3072, 256, "float32", True, False)): ((2, 1), (3072, 1), (2, 1)), str((768, 768, 256, "float32", True, False)): ((8, 1), (768, 1), (2, 1)), str((2, 768, 2, "float32", True, False)): ((2, 1), (768, 1), (2, 1)), str((768, 768, 80, "float32", True, False)): ((8, 1), (768, 1), (2, 1)), str((3072, 768, 512, "float32", True, False)): ((2, 1), (768, 1), (8, 1)), str((768, 3072, 512, "float32", True, False)): ((2, 1), (3072, 1), (2, 1)), str((768, 768, 4, "float32", True, False)): ((4, 1), (768, 1), (2, 1)), str((768, 768, 512, "float32", True, False)): ((8, 1), (768, 1), (2, 1)), str((2, 768, 4, "float32", True, False)): ((2, 1), (768, 1), (4, 1)), str((768, 768, 1024, "float32", True, False)): ((8, 1), (768, 1), (2, 1)), str((768, 3072, 1024, "float32", True, False)): ((2, 1), (3072, 1), (2, 1)), str((3072, 768, 1024, "float32", True, False)): ((2, 1), (768, 1), (8, 1)), str((768, 768, 8, "float32", True, False)): ((4, 1), (768, 1), (2, 1)), str((768, 768, 160, "float32", True, False)): ((8, 1), (768, 1), (2, 1)), str((2, 768, 8, "float32", True, False)): ((2, 1), (768, 1), (8, 1)), str((3072, 768, 2048, "float32", True, False)): ((2, 1), (768, 1), (8, 1)), str((768, 3072, 2048, "float32", True, False)): ((2, 1), (3072, 1), (2, 1)), str((768, 768, 2048, "float32", True, False)): ((8, 1), (768, 1), (2, 1)), str((768, 768, 320, "float32", True, False)): ((8, 1), (768, 1), (2, 1)), str((768, 768, 16, "float32", True, False)): ((4, 1), (768, 1), (2, 1)), str((2, 768, 16, "float32", True, False)): ((2, 1), (768, 1), (16, 1)), str((768, 768, 32, "float32", True, False)): ((4, 1), (768, 1), (2, 1)), str((2, 768, 32, "float32", True, False)): ((2, 1), (768, 1), (16, 1)), str((3072, 768, 4096, "float32", True, False)): ((2, 1), (768, 1), (8, 1)), str((768, 3072, 4096, "float32", True, False)): ((2, 1), (3072, 1), (2, 1)), str((768, 768, 640, "float32", True, False)): ((8, 1), (768, 1), (2, 1)), str((768, 768, 4096, "float32", True, False)): ((8, 1), (768, 1), (2, 1)), str((768, 768, 64, "float32", True, False)): ((4, 1), (768, 1), (2, 1)), str((2, 768, 64, "float32", True, False)): ((2, 1), (768, 1), (16, 1)), str((768, 3072, 8192, "float32", True, False)): ((2, 1), (3072, 1), (2, 1)), str((768, 768, 8192, "float32", True, False)): ((8, 1), (768, 1), (2, 1)), str((768, 768, 1280, "float32", True, False)): ((8, 1), (768, 1), (2, 1)), str((3072, 768, 8192, "float32", True, False)): ((2, 1), (768, 1), (8, 1)), str((768, 768, 2560, "float32", True, False)): ((8, 1), (768, 1), (2, 1)), str((3072, 768, 16384, "float32", True, False)): ((2, 1), (768, 1), (8, 1)), str((2, 768, 128, "float32", True, False)): ((2, 1), (768, 1), (16, 1)), str((768, 768, 16384, "float32", True, False)): ((8, 1), (768, 1), (2, 1)), str((768, 3072, 16384, "float32", True, False)): ((2, 1), (3072, 1), (2, 1)), str((8, 16, 128, 64, 128, "float32", False, False)): ((1, 1), (1, 1), (2, 1), (64, 1), (16, 1)), str((21128, 1024, 160, "float32", True, False)): ((4, 1), (1024, 1), (4, 1)), str((1024, 4096, 1024, "float32", False, False)): ((4, 1), (2048, 1), (2, 1)), str((160, 1024, 21128, "float32", False, False)): ((2, 1), (1024, 1), (4, 1)), str((8, 16, 128, 64, 128, "float32", True, False)): ((1, 1), (1, 1), (4, 1), (64, 1), (8, 1)), str((160, 1024, 1024, "float32", False, False)): ((1, 1), (1024, 1), (8, 1)), str((8, 1024, 1024, "float32", False, False)): ((1, 1), (1024, 1), (8, 1)), str((1024, 1024, 1024, "float32", False, False)): ((4, 1), (1024, 1), (4, 1)), str((160, 21128, 1024, "float32", False, True)): ((1, 1), (16, 1), (1024, 1)), str((1024, 1024, 4096, "float32", False, False)): ((2, 1), (1024, 1), (8, 1)), str((8, 16, 128, 128, 64, "float32", False, True)): ((1, 1), (1, 1), (4, 1), (64, 1), (64, 1)), str((128, 128, 64, 128, "float32", False, False)): ((1, 1), (1, 1), (64, 1), (32, 1)), str((128, 128, 128, 64, "float32", False, True)): ((1, 1), (2, 1), (128, 1), (64, 1)), str((128, 128, 64, 128, "float32", True, False)): ((1, 1), (2, 1), (64, 1), (16, 1)), str((8, 1024, 2, "float32", False, False)): ((1, 1), (1024, 1), (1, 1)), str((21128, 1024, 320, "float32", True, False)): ((4, 1), (1024, 1), (4, 1)), str((16, 1024, 1024, "float32", False, False)): ((1, 1), (1024, 1), (8, 1)), str((2048, 4096, 1024, "float32", False, False)): ((4, 1), (2048, 1), (2, 1)), str((16, 1024, 2, "float32", False, False)): ((1, 1), (1024, 1), (1, 1)), str((16, 16, 128, 64, 128, "float32", True, False)): ((1, 1), (1, 1), (4, 1), (64, 1), (8, 1)), str((16, 16, 128, 64, 128, "float32", False, False)): ((1, 1), (1, 1), (2, 1), (64, 1), (16, 1)), str((2048, 1024, 1024, "float32", False, False)): ((4, 1), (1024, 1), (4, 1)), str((320, 1024, 1024, "float32", False, False)): ((1, 1), (1024, 1), (8, 1)), str((2048, 1024, 4096, "float32", False, False)): ((2, 1), (1024, 1), (8, 1)), str((320, 21128, 1024, "float32", False, True)): ((1, 1), (16, 1), (1024, 1)), str((128, 256, 128, 64, "float32", False, True)): ((1, 1), (2, 1), (128, 1), (64, 1)), str((320, 1024, 21128, "float32", False, False)): ((2, 1), (1024, 1), (4, 1)), str((128, 256, 64, 128, "float32", False, False)): ((1, 1), (1, 1), (64, 1), (32, 1)), str((128, 128, 64, 256, "float32", True, False)): ((1, 1), (2, 1), (64, 1), (16, 1)), str((16, 16, 128, 128, 64, "float32", False, True)): ((1, 1), (1, 1), (4, 1), (64, 1), (64, 1)), str((1024, 4096, 1024, "float32", False, True)): ((2, 1), (8, 1), (1024, 1)), str((8, 2, 1024, "float32", False, True)): ((1, 1), (2, 1), (1024, 1)), str((160, 1024, 1024, "float32", False, True)): ((1, 1), (8, 1), (1024, 1)), str((1024, 1024, 4096, "float32", False, True)): ((1, 1), (8, 1), (2048, 1)), str((1024, 1024, 1024, "float32", False, True)): ((2, 1), (8, 1), (1024, 1)), str((8, 1024, 1024, "float32", False, True)): ((1, 1), (8, 1), (1024, 1)), str((16, 2, 1024, "float32", False, True)): ((1, 1), (2, 1), (1024, 1)), str((320, 1024, 1024, "float32", False, True)): ((2, 1), (8, 1), (1024, 1)), str((16, 1024, 1024, "float32", False, True)): ((1, 1), (8, 1), (1024, 1)), str((2048, 1024, 1024, "float32", False, True)): ((2, 1), (8, 1), (1024, 1)), str((2048, 4096, 1024, "float32", False, True)): ((2, 1), (8, 1), (1024, 1)), str((2048, 1024, 4096, "float32", False, True)): ((1, 1), (8, 1), (2048, 1)), str((1024, 4096, 1024, "float32", True, False)): ((4, 1), (2048, 1), (2, 1)), str((1024, 1024, 8, "float32", True, False)): ((16, 1), (1024, 1), (1, 1)), str((2, 1024, 8, "float32", True, False)): ((2, 1), (1024, 1), (8, 1)), str((1024, 1024, 1024, "float32", True, False)): ((4, 1), (2048, 1), (2, 1)), str((4096, 1024, 1024, "float32", True, False)): ((8, 1), (1024, 1), (2, 1)), str((1024, 1024, 160, "float32", True, False)): ((8, 1), (1024, 1), (2, 1)), str((4096, 1024, 2048, "float32", True, False)): ((8, 1), (1024, 1), (2, 1)), str((1024, 1024, 320, "float32", True, False)): ((8, 1), (1024, 1), (2, 1)), str((1024, 1024, 16, "float32", True, False)): ((16, 1), (1024, 1), (1, 1)), str((1024, 4096, 2048, "float32", True, False)): ((4, 1), (2048, 1), (2, 1)), str((2, 1024, 16, "float32", True, False)): ((2, 1), (1024, 1), (16, 1)), str((1024, 1024, 2048, "float32", True, False)): ((4, 1), (2048, 1), (2, 1)), str((32, 1001, 2048, "float16", False, True)): ((1, 1), (77, 1), (256, 1)), str((1001, 2048, 32, "float16", True, False)): ((1, 1), (2048, 1), (4, 1)), str((32, 2048, 1001, "float16", False, False)): ((1, 1), (2048, 1), (4, 1)), str((32, 1001, 2048, "float32", False, True)): ((1, 1), (7, 1), (2048, 1)), str((1001, 2048, 32, "float32", True, False)): ((1, 1), (2048, 1), (4, 1)), str((32, 2048, 1001, "float32", False, False)): ((1, 1), (2048, 1), (4, 1)), str((768, 3072, 131072, "float32", True, False)): ((2, 1), (3072, 1), (2, 1)), str((3072, 768, 131072, "float32", True, False)): ((2, 1), (768, 1), (8, 1)), str((65536, 768, 3072, "float32", False, True)): ((1, 1), (8, 1), (3072, 1)), str((131072, 3072, 768, "float32", False, True)): ((1, 1), (16, 1), (768, 1)), str((32768, 1024, 4096, "float32", False, True)): ((1, 1), (8, 1), (2048, 1)), str((65536, 1024, 4096, "float32", False, True)): ((1, 1), (8, 1), (2048, 1)), str((131072, 1024, 1024, "float32", False, True)): ((2, 1), (8, 1), (1024, 1)), str((131072, 768, 3072, "float32", False, True)): ((1, 1), (8, 1), (3072, 1)), str((65536, 3072, 768, "float32", False, True)): ((1, 1), (16, 1), (768, 1)), str((32768, 4096, 1024, "float32", False, True)): ((2, 1), (8, 1), (1024, 1)), str((65536, 4096, 1024, "float32", False, True)): ((2, 1), (8, 1), (1024, 1)), str((10240, 21128, 1024, "float32", False, True)): ((1, 1), (16, 1), (1024, 1)), str((21128, 1024, 20480, "float32", True, False)): ((4, 1), (1024, 1), (4, 1)), str((2048, 3072, 768, "float16", False, False)): ((1, 1), (768, 1), (16, 1)), str((2048, 768, 3072, "float16", False, False)): ((2, 1), (768, 1), (8, 1)), str((10240, 21128, 768, "float32", False, True)): ((1, 1), (16, 1), (768, 1)), str((20480, 768, 21128, "float32", False, False)): ((2, 1), (768, 1), (8, 1)), str((21128, 768, 20480, "float32", True, False)): ((8, 1), (768, 1), (8, 1)), str((20480, 21128, 768, "float32", False, True)): ((1, 1), (16, 1), (768, 1)), str((32, 10, 2048, "float32", False, True)): ((1, 1), (8, 1), (2048, 1)), str((32, 10, 2048, "float16", False, True)): ((1, 1), (8, 1), (2048, 1)), str((32, 10, 4096, "float16", False, True)): ((1, 1), (2, 1), (4096, 1)), str((768, 768, 16384, "float16", True, False)): ((64, 1), (32, 1), (8, 1)), str((3072, 768, 16384, "float16", True, False)): ((64, 1), (32, 1), (8, 1)), str((768, 3072, 16384, "float16", True, False)): ((64, 1), (32, 1), (8, 1)), str((32, 9216, 4096, "float32", False, False)): ((1, 1), (9216, 1), (1, 1)), str((128, 128, 64, 3072, "float32", True, False)): ((1, 1), (64, 1), (64, 1), (1, 1)), str((21128, 1024, 10240, "float32", True, False)): ((8, 1), (1024, 1), (1, 1)), str((128, 128, 64, 1536, "float16", True, False)): ((1, 1), (1, 1), (16, 1), (1536, 1)), str((3072, 768, 2048, "float16", True, False)): ((1, 1), (16, 1), (1024, 1)), str((768, 3072, 2048, "float16", True, False)): ((1, 1), (16, 1), (1024, 1)), str((4096, 1024, 65536, "float32", True, False)): ((4, 1), (1024, 1), (4, 1)), # auto tiling crash str((1024, 4096, 65536, "float32", True, False)): ((1, 1), (4096, 1), (4, 1)), # auto tiling crash str((16384, 3072, 768, "float16", False, True)): ((1, 1), (16, 1), (768, 1)), # auto tiling crash str((16384, 768, 3072, "float16", False, True)): ((1, 1), (16, 1), (1024, 1)), # auto tiling crash str((1024, 4096, 1024, "float16", False, True)): ((1, 1), (16, 1), (1024, 1)), # auto tiling crash str((131072, 1024, 1024, "float32", False, False)): ((2, 1), (1024, 1), (8, 1)), # Alexnet shape str((32, 10, 4096, "float32", False, True)): ((32, 1), (10, 1), (1, 1)), # auto tiling crash str((32, 4096, 4096, "float16", False, False)): ((1, 1), (16, 1), (512, 1)), # auto tiling crash str((32, 9216, 4096, "float16", False, False)): ((1, 1), (16, 1), (512, 1)), # auto tiling crash str((32, 4096, 9216, "float16", False, True)): ((1, 1), (16, 1), (1024, 1)), # auto tiling crash str((128, 128, 64, 1536, "float16", True, False)): ((1, 1), (1, 1), (16, 1), (96, 1)), # auto tiling crash str((32, 4096, 4096, "float32", False, False)): ((1, 1), (256, 1), (64, 1)), # performance optimization # Alexnet shape str((768, 3072, 16384, "float16", True, False)): ((64, 1), (32, 1), (8, 1)), str((3072, 768, 4096, "float16", True, False)): ((2, 1), (768, 1), (8, 1)), str((768, 3072, 4096, "float16", True, False)): ((2, 1), (3072, 1), (2, 1)), str((400, 120, 32, "float16", True, False)): ((2, 1), (32, 1), (2, 1)), } CORE_NUM = 2 if product_is_mini() else 32 MINIMAL_FOR_MULTICORE = CORE_NUM * 512 def get_best_align_elem(tensor_size, tensor_dtype): """Get the best tiling factor for alignment axis.""" basic_align_elem = int(ct_util.BLOCK_SIZE / get_bytes(tensor_dtype)) lcm = least_common_multiple(tensor_size, basic_align_elem) gcd = greatest_common_divisor(tensor_size, basic_align_elem) if gcd != 1: return gcd if lcm < tensor_size: return min(tensor_size, lcm) return -1 def get_shape_pos_map(tensor_shape): """Mapping tensor shape to corresponding axis position.""" batch_pos = [i for i in range(len(tensor_shape) - 3) if tensor_shape[i] != 1] mnk = dict() pos_map = {0: "m", 1: "n", 2: "k"} count = -1 for i, _ in enumerate(batch_pos): count += 1 mnk["b%s" % str(i)] = count for i, shp in enumerate(tensor_shape[-3:]): if shp != 1: count += 1 mnk[pos_map[i]] = count return batch_pos, mnk def batchmatmul_tiling_strategy(shape, align_dtype, attrs): """This is an efficient version of tiling strategy for batchmatmul.""" if len(shape) < 3: raise RuntimeError("Shape must be in the form of [(Batch_out, Batch_in,) M, N, K]. " "Current length of shape is {}".format(len(shape))) strategy = list() m, n, k = shape[-3:] batch_pos, mnk = get_shape_pos_map(shape) total_size = reduce(lambda x, y: int(x) * int(y), shape) # set minimal tile for n as the basic block size for alignment align_elem = int(ct_util.BLOCK_SIZE / get_bytes(align_dtype)) tile_n = get_best_align_elem(n, align_dtype) # if n is smaller than block size, it is not safe to open multi-core for now if n < align_elem: n_constraint = [ct_util.TileConstraint.FACTOR] used_core = CORE_NUM attrs["enable_multicore"] = 0 else: n_constraint = [ct_util.TileConstraint.MOD, ct_util.TileConstraint.MIN] used_core = 1 if total_size >= MINIMAL_FOR_MULTICORE and used_core < CORE_NUM: # set maximal tile for batch according to multi-core usage for i, p in enumerate(batch_pos): use = min(shape[p], int((CORE_NUM - 1 + used_core) / used_core)) used_core *= use max_b = int(shape[p] / use) strategy.append(ct_util.create_constraint_on_axis(values=max_b, constraints=ct_util.TileConstraint.MAX, band=0, axis=i)[0]) tile_k = get_best_align_elem(k, align_dtype) k_constraint = ct_util.TileConstraint.MIN total_size /= max(1, int(k / tile_k)) # set minimal tile for m according to multi-core usage when there is no expansion tile_m = -1 m_constraint = ct_util.TileConstraint.MAX m_per_block = m max_core = min(CORE_NUM, int(total_size / MINIMAL_FOR_MULTICORE)) if greatest_common_divisor(n, align_elem) != 1 and used_core < max_core: left_core = int((max_core - 1 + used_core) / used_core) core_limit = max(1, int(m / greatest_common_divisor(left_core, m))) nk_in_mem = int(n / max(1, tile_n)) * int(k / tile_k) balance_limit = max(1, int(m / greatest_common_divisor(nk_in_mem, m))) tile_m = min(core_limit, balance_limit) m_per_block = int(m / tile_m) # for large m case, it is more efficient to balance memory bound and calculation bound if m_per_block > int(n / max(1, tile_n)) * int(k / tile_k): tile_m = max(min(m, align_elem), tile_m) k_constraint = ct_util.TileConstraint.FACTOR # create constraints based on previous analysis if m != 1: strategy.append(ct_util.create_constraint_on_axis(values=tile_m, constraints=m_constraint, band=0, axis=mnk["m"])[0]) if n != 1: for constraint in n_constraint: strategy.append(ct_util.create_constraint_on_axis(values=tile_n, constraints=constraint, band=0, axis=mnk["n"])[0]) if k != 1: strategy.append(ct_util.create_constraint_on_axis(values=tile_k, constraints=k_constraint, band=0, axis=mnk["k"])[0]) higher_priority_pos = mnk["k"] if k >= n else mnk["n"] strategy.append(ct_util.create_constraint_on_axis(values=0, constraints=ct_util.TileConstraint.SET_PRIORITY, band=0, axis=higher_priority_pos)[0]) strategy.append(ct_util.modify_common_constraints(0.7, ct_util.TileConstraint.SET_MEM_RATIO)) attrs["custom_tiling"] = strategy return attrs def batchmatmul_tiling_strategy_dynamic(shape, output, attrs): """This is an efficient version of tiling strategy for batchmatmul.""" if len(shape) < 3: raise RuntimeError("Shape must be in the form of [(Batch_out, Batch_in,) M, N, K]. " "Current length of shape is {}".format(len(shape))) strategy = list() _, mnk = get_shape_pos_map(shape) # create constraints based on previous analysis strategy.append(ct_util.create_constraint_on_axis(values=1, constraints=ct_util.TileConstraint.FACTOR, band=0, axis=mnk["m"])[0]) strategy.append(ct_util.create_constraint_on_axis(values="FULL", constraints=ct_util.TileConstraint.MAX, band=0, axis=mnk["n"])[0]) strategy.append(ct_util.create_constraint_on_axis(values=8, constraints=ct_util.TileConstraint.FACTOR, band=0, axis=mnk["k"])[0]) strategy.append(ct_util.modify_common_constraints(0.7, ct_util.TileConstraint.SET_MEM_RATIO)) attrs["custom_tiling"] = strategy attrs["dynamic_shape"] = ds.set_dynamic_shape_limit_for_tensor(output, 2048, [1,]) return attrs def get_mnk_from_matrix(shape_a_list, shape_b_list, trans_a, trans_b): """Get m, n and k value from input tensor shapes.""" m, k = shape_a_list[-2], shape_a_list[-1] if trans_a: m, k = k, m n = shape_b_list[-2] if trans_b else shape_b_list[-1] return [m, n, k] def batchmatmul_set_dim(a_value, b_value, trans_a, trans_b): """This function is used to set dim info in attrs by set_dim_map.""" shape_a_list = get_shape(a_value) shape_b_list = get_shape(b_value) m, n, k = get_mnk_from_matrix(shape_a_list, shape_b_list, trans_a, trans_b) key = () if len(shape_a_list) > 2: key += tuple(shape_a_list[:-2]) key += (m, n, k, a_value.dtype, trans_a, trans_b) set_dims = ct_util.set_dims_by_key(str(key), batchmatmul_set_dim_map) return set_dims, str(key) def batchmatmul_bias_set_dim(a_value, b_value, bias_value, trans_a, trans_b): """This function is used to set dim info in attrs by set_dim_map of batchmatmul with bias.""" return batchmatmul_set_dim(a_value, b_value, trans_a, trans_b) def batchmatmul_no_bias_set_dim(a_value, b_value, trans_a, trans_b): """This function is used to set dim info in attrs by set_dim_map of batchmatmul without bias.""" return batchmatmul_set_dim(a_value, b_value, trans_a, trans_b) @ct_util.reg_set_dim_func(batchmatmul_bias_set_dim) @vc_util.check_input_type(akg.tvm.tensor.Tensor, akg.tvm.tensor.Tensor, akg.tvm.tensor.Tensor, bool, bool) def batchmatmul_bias(a_value, b_value, bias_value, trans_a, trans_b): """ Multiplies two tensors in batches and adds bias to the output. Args: a_value (tvm.tensor.Tensor): 2D or 3D or 4D tensor of type float16 or float32 with shape(..., r_A, c_A). b_value (tvm.tensor.Tensor): 2D or 3D or 4D tensor of same type as a_value with shape(..., r_B, c_B). bias_value (tvm.tensor.Tensor): The bias tensor added to the result of a_value * b_value. Should be of same type as a_value, broadcast is allowed. trans_a (bool): Specifies whether a_value is transposed or not, default value is False. trans_b (bool): Specifies whether b_value is transposed or not, default value is False. Returns: tvm.tensor.Tensor of same type as a_value with shape(..., r_C, c_C). r_C = c_A if trans_a else r_A c_C = r_B if trans_b else c_B """ if not isinstance(trans_a, bool): raise TypeError("trans_a should be of type Boolean.") if not isinstance(trans_b, bool): raise TypeError("trans_b should be of type Boolean.") vc_util.ops_dtype_check([a_value.dtype, b_value.dtype, bias_value.dtype], vc_util.DtypeForDavinci.ALL_FLOAT) vc_util.elemwise_dtype_check(a_value.dtype, b_value.dtype) vc_util.elemwise_dtype_check(a_value.dtype, bias_value.dtype) vc_util.gemm_format_check(get_shape(a_value), get_shape(b_value), trans_a, trans_b) if len(a_value.shape) not in [2, 3, 4]: raise ValueError("Batch matmul only support 2D, 3D and 4D now.") c_value = batchmatmul(a_value, b_value, trans_a, trans_b) if isinstance(c_value, (tuple, list)): c_value = c_value[0] vc_util.auto_broadcast_check(get_shape(bias_value), get_shape(c_value)) shape_c_list = get_shape(c_value) bias_value = akg.topi.broadcast_to(bias_value, shape_c_list) dim_info = batchmatmul_bias_set_dim(a_value, b_value, bias_value, trans_a, trans_b) if isinstance(dim_info, (tuple, list)): dim_info = dim_info[0] attrs = {} attrs["enable_compute_in_place"] = True if dim_info != "": attrs["dim"] = dim_info batch = get_shape(a_value)[:-2] mnk = get_mnk_from_matrix(get_shape(a_value), get_shape(b_value), trans_a, trans_b) attrs = batchmatmul_tiling_strategy(batch + mnk, c_value.dtype, attrs) return akg.tvm.compute(bias_value.shape, lambda *indice: c_value(*indice) + bias_value(*indice), name='matmul_bias_output'), attrs @ct_util.reg_set_dim_func(batchmatmul_no_bias_set_dim) @vc_util.check_input_type(akg.tvm.tensor.Tensor, akg.tvm.tensor.Tensor, bool, bool) def batchmatmul(a_value, b_value, trans_a=False, trans_b=False): """ Multiplies two tensors in batches. Args: a_value (tvm.tensor.Tensor): 2D or 3D or 4D tensor of type float16 or float32 with shape(..., r_A, c_A). b_value (tvm.tensor.Tensor): 2D or 3D or 4D tensor of same type as a_value with shape(..., r_B, c_B). trans_a (bool): Specifies whether a_value is transposed or not, default value is False. trans_b (bool): Specifies whether b_value is transposed or not, default value is False. Returns: tvm.tensor.Tensor of same type as a_value with shape(..., r_C, c_C). r_C = c_A if trans_a else r_A c_C = r_B if trans_b else c_B """ if not isinstance(trans_a, bool): raise TypeError("trans_a should be of type Boolean.") if not isinstance(trans_b, bool): raise TypeError("trans_b should be of type Boolean.") vc_util.ops_dtype_check([a_value.dtype, b_value.dtype], vc_util.DtypeForDavinci.ALL_FLOAT) vc_util.elemwise_dtype_check(a_value.dtype, b_value.dtype) vc_util.gemm_format_check(get_shape(a_value), get_shape(b_value), trans_a, trans_b) if len(a_value.shape) not in [2, 3, 4]: raise ValueError("Batch matmul only support 2D, 3D and 4D now.") dtype = a_value.dtype if dtype == 'float16': if len(a_value.shape) == 2: c_value = vectormatmul_2d_cast(a_value, b_value, trans_a, trans_b, "float32") elif len(a_value.shape) == 3: c_value = vectormatmul_3d_cast(a_value, b_value, trans_a, trans_b, "float32") else: c_value = vectormatmul_4d_cast(a_value, b_value, trans_a, trans_b, "float32") else: if len(a_value.shape) == 2: c_value = vectormatmul_2d(a_value, b_value, trans_a, trans_b) elif len(a_value.shape) == 3: c_value = vectormatmul_3d(a_value, b_value, trans_a, trans_b) else: c_value = vectormatmul_4d(a_value, b_value, trans_a, trans_b) dim_info = batchmatmul_no_bias_set_dim(a_value, b_value, trans_a, trans_b) if isinstance(dim_info, (tuple, list)): dim_info = dim_info[0] attrs = {} attrs["enable_compute_in_place"] = True if dim_info != "": attrs["dim"] = dim_info mnk = get_mnk_from_matrix(get_shape(a_value), get_shape(b_value), trans_a, trans_b) batch = get_shape(a_value)[:-2] is_dynamic = ds.shape_is_dynamic([a_value, b_value]) if not is_dynamic: attrs = batchmatmul_tiling_strategy(batch + mnk, c_value.dtype, attrs) else: attrs = batchmatmul_tiling_strategy_dynamic(batch + mnk, c_value, attrs) attrs["enable_pre_storage_write_simplify"] = True attrs["enable_sink_allocate"] = True attrs["enable_double_buffer"] = False return c_value, attrs def vectormatmul_3d(a_value, b_value, trans_a, trans_b): """hybrid implementation for 3D batchmatmul.""" if trans_a: bs, k, m = a_value.shape else: bs, m, k = a_value.shape if trans_b: n = b_value.shape[-2] else: n = b_value.shape[-1] dtype = a_value.dtype zero = akg.tvm.const(0.0, dtype=dtype) @script(capture=locals()) def matmul_hybrid_f_f(a, b, zero): t_1 = allocate((bs, m, k, n), a.dtype, 'local') t_2 = allocate((bs, m, n), a.dtype, 'local') for i_bs in range(0, bs): for i_m in range(0, m): for i_k in range(0, k): for i_n in range(0, n): t_1[i_bs, i_m, i_k, i_n] = a[i_bs, i_m, i_k] * b[i_bs, i_k, i_n] for i1_n in range(0, n): t_2[i_bs, i_m, i1_n] = zero for i1_k in range(0, k): for i1_n in range(0, n): t_2[i_bs, i_m, i1_n] = t_2[i_bs, i_m, i1_n] + t_1[i_bs, i_m, i1_k, i1_n] return t_2 @script(capture=locals()) def matmul_hybrid_f_t(a, b, zero): t_1 = allocate((bs, m, n, k), a.dtype, 'local') t_2 = allocate((bs, m, n), a.dtype, 'local') for i_bs in range(0, bs): for i_m in range(0, m): for i_n in range(0, n): t_2[i_bs, i_m, i_n] = zero for i_k in range(0, k): t_1[i_bs, i_m, i_n, i_k] = a[i_bs, i_m, i_k] * b[i_bs, i_n, i_k] t_2[i_bs, i_m, i_n] = t_1[i_bs, i_m, i_n, i_k] + t_2[i_bs, i_m, i_n] return t_2 @script(capture=locals()) def matmul_hybrid_t_f(a, b, zero): t_1 = allocate((bs, m, k, n), a.dtype, 'local') t_2 = allocate((bs, m, n), a.dtype, 'local') for i_bs in range(0, bs): for i_m in range(0, m): for i_k in range(0, k): for i_n in range(0, n): t_1[i_bs, i_m, i_k, i_n] = a[i_bs, i_k, i_m] * b[i_bs, i_k, i_n] for i1_n in range(0, n): t_2[i_bs, i_m, i1_n] = zero for i1_k in range(0, k): for i1_n in range(0, n): t_2[i_bs, i_m, i1_n] = t_2[i_bs, i_m, i1_n] + t_1[i_bs, i_m, i1_k, i1_n] return t_2 if not trans_a and not trans_b: c_value = matmul_hybrid_f_f(a_value, b_value, zero) elif not trans_a and trans_b: c_value = matmul_hybrid_f_t(a_value, b_value, zero) elif trans_a and not trans_b: c_value = matmul_hybrid_t_f(a_value, b_value, zero) else: raise ValueError('Not support both transpose yet') return c_value def vectormatmul_4d(a_value, b_value, trans_a, trans_b): """hybrid implementation for 4D batchmatmul.""" if trans_a: bs1, bs2, k, m = a_value.shape else: bs1, bs2, m, k = a_value.shape if trans_b: n = b_value.shape[-2] else: n = b_value.shape[-1] dtype = a_value.dtype zero = akg.tvm.const(0.0, dtype=dtype) @script(capture=locals()) def matmul_hybrid_f_f(a, b, zero): t_1 = allocate((bs1, bs2, m, k, n), a.dtype, 'local') t_2 = allocate((bs1, bs2, m, n), a.dtype, 'local') for i_bs1 in range(0, bs1): for i_bs2 in range(0, bs2): for i_m in range(0, m): for i_k in range(0, k): for i_n in range(0, n): t_1[i_bs1, i_bs2, i_m, i_k, i_n] = a[i_bs1, i_bs2, i_m, i_k] * b[i_bs1, i_bs2, i_k, i_n] for i1_n in range(0, n): t_2[i_bs1, i_bs2, i_m, i1_n] = zero for i1_k in range(0, k): for i1_n in range(0, n): t_2[i_bs1, i_bs2, i_m, i1_n] = t_2[i_bs1, i_bs2, i_m, i1_n] + \ t_1[i_bs1, i_bs2, i_m, i1_k, i1_n] return t_2 @script(capture=locals()) def matmul_hybrid_f_t(a, b, zero): t_1 = allocate((bs1, bs2, m, n, k), a.dtype, 'local') t_2 = allocate((bs1, bs2, m, n), a.dtype, 'local') for i_bs1 in range(0, bs1): for i_bs2 in range(0, bs2): for i_m in range(0, m): for i_n in range(0, n): t_2[i_bs1, i_bs2, i_m, i_n] = zero for i_k in range(0, k): t_1[i_bs1, i_bs2, i_m, i_n, i_k] = a[i_bs1, i_bs2, i_m, i_k] * b[i_bs1, i_bs2, i_n, i_k] t_2[i_bs1, i_bs2, i_m, i_n] = t_1[i_bs1, i_bs2, i_m, i_n, i_k] + t_2[i_bs1, i_bs2, i_m, i_n] return t_2 @script(capture=locals()) def matmul_hybrid_t_f(a, b, zero): t_1 = allocate((bs1, bs2, m, k, n), a.dtype, 'local') t_2 = allocate((bs1, bs2, m, n), a.dtype, 'local') for i_bs1 in range(0, bs1): for i_bs2 in range(0, bs2): for i_m in range(0, m): for i_k in range(0, k): for i_n in range(0, n): t_1[i_bs1, i_bs2, i_m, i_k, i_n] = a[i_bs1, i_bs2, i_k, i_m] * b[i_bs1, i_bs2, i_k, i_n] for i1_n in range(0, n): t_2[i_bs1, i_bs2, i_m, i1_n] = zero for i1_k in range(0, k): for i1_n in range(0, n): t_2[i_bs1, i_bs2, i_m, i1_n] = t_2[i_bs1, i_bs2, i_m, i1_n] + \ t_1[i_bs1, i_bs2, i_m, i1_k, i1_n] return t_2 if not trans_a and not trans_b: c_value = matmul_hybrid_f_f(a_value, b_value, zero) elif not trans_a and trans_b: c_value = matmul_hybrid_f_t(a_value, b_value, zero) elif trans_a and not trans_b: c_value = matmul_hybrid_t_f(a_value, b_value, zero) else: raise ValueError('Not support both transpose yet') return c_value def vectormatmul_4d_cast(a_value, b_value, trans_a, trans_b, cast_dtype): """dsl implementation with data type cast for 4D batchmatmul.""" if trans_a: b1, b2, k, m = a_value.shape else: b1, b2, m, k = a_value.shape if trans_b: n = b_value.shape[-2] else: n = b_value.shape[-1] dtype = a_value.dtype def matmul_4d_dsl(a_value, b_value, trans_a, trans_b): if not trans_a and not trans_b: ele_mul = akg.tvm.compute((b1, b2, m, n, k), lambda i_b1, i_b2, i_m, i_n, i_k: a_value[i_b1, i_b2, i_m, i_k].astype(cast_dtype) * b_value[i_b1, i_b2, i_k, i_n].astype(cast_dtype), name="ele_mul") elif not trans_a and trans_b: ele_mul = akg.tvm.compute((b1, b2, m, n, k), lambda i_b1, i_b2, i_m, i_n, i_k: a_value[i_b1, i_b2, i_m, i_k].astype(cast_dtype) * b_value[i_b1, i_b2, i_n, i_k].astype(cast_dtype), name="ele_mul") elif trans_a and not trans_b: ele_mul = akg.tvm.compute((b1, b2, m, n, k), lambda i_b1, i_b2, i_m, i_n, i_k: a_value[i_b1, i_b2, i_k, i_m].astype(cast_dtype) * b_value[i_b1, i_b2, i_k, i_n].astype(cast_dtype), name="ele_mul") elif trans_a and trans_b: ele_mul = akg.tvm.compute((b1, b2, m, n, k), lambda i_b1, i_b2, i_m, i_n, i_k: b_value[i_b1, i_b2, i_n, i_k].astype(cast_dtype) * a_value[i_b1, i_b2, i_k, i_m].astype(cast_dtype), name="ele_mul") reduce_axis = akg.tvm.reduce_axis((0, k), name='reduce_axis') output_shape = (b1, b2, m, n) m_c = akg.tvm.compute(output_shape, lambda *i: akg.tvm.sum(ele_mul[i + (reduce_axis,)], axis=reduce_axis), name="matmul_compute") return m_c c_cast = matmul_4d_dsl(a_value, b_value, trans_a, trans_b) c_value = cast.cast(c_cast, dtype) if trans_a and trans_b: c_res = akg.topi.transpose(c_value, (1, 0)) return c_res return c_value def vectormatmul_3d_cast(a_value, b_value, trans_a, trans_b, cast_dtype): """dsl implementation with data type cast for 3D batchmatmul.""" if trans_a: b, k, m = a_value.shape else: b, m, k = a_value.shape if trans_b: n = b_value.shape[-2] else: n = b_value.shape[-1] dtype = a_value.dtype def matmul_3d_dsl(a_value, b_value, trans_a, trans_b): if not trans_a and not trans_b: ele_mul = akg.tvm.compute((b, m, n, k), lambda i_b, i_m, i_n, i_k: a_value[i_b, i_m, i_k].astype(cast_dtype) * b_value[i_b, i_k, i_n].astype(cast_dtype), name="ele_mul") elif not trans_a and trans_b: ele_mul = akg.tvm.compute((b, m, n, k), lambda i_b, i_m, i_n, i_k: a_value[i_b, i_m, i_k].astype(cast_dtype) * b_value[i_b, i_n, i_k].astype(cast_dtype), name="ele_mul") elif trans_a and not trans_b: ele_mul = akg.tvm.compute((b, m, n, k), lambda i_b, i_m, i_n, i_k: a_value[i_b, i_k, i_m].astype(cast_dtype) * b_value[i_b, i_k, i_n].astype(cast_dtype), name="ele_mul") elif trans_a and trans_b: ele_mul = akg.tvm.compute((b, m, n, k), lambda i_b, i_m, i_n, i_k: b_value[i_b, i_n, i_k].astype(cast_dtype) * a_value[i_b, i_k, i_m].astype(cast_dtype), name="ele_mul") reduce_axis = akg.tvm.reduce_axis((0, k), name='reduce_axis') output_shape = (b, m, n) m_c = akg.tvm.compute(output_shape, lambda *i: akg.tvm.sum(ele_mul[i + (reduce_axis,)], axis=reduce_axis), name="matmul_compute") return m_c c_cast = matmul_3d_dsl(a_value, b_value, trans_a, trans_b) c_value = cast.cast(c_cast, dtype) if trans_a and trans_b: c_res = akg.topi.transpose(c_value, (1, 0)) return c_res return c_value def vectormatmul_2d_cast(a_value, b_value, trans_a, trans_b, cast_dtype): """hybrid implementation with data type cast for 2D batchmatmul.""" if trans_a: k, m = a_value.shape else: m, k = a_value.shape if trans_b: n = b_value.shape[-2] else: n = b_value.shape[-1] dtype = a_value.dtype # When the float16 cast to float32 directly, the AutoPoly pass cost a long time. # Therefore, the cast be done in single element. def matmul_2d(a_value, b_value, trans_a, trans_b): if not trans_a and not trans_b: ele_mul = akg.tvm.compute((m, n, k), lambda i_m, i_n, i_k: a_value[i_m, i_k].astype(cast_dtype) * b_value[i_k, i_n].astype(cast_dtype), name="ele_mul") elif not trans_a and trans_b: ele_mul = akg.tvm.compute((m, n, k), lambda i_m, i_n, i_k: a_value[i_m, i_k].astype(cast_dtype) * b_value[i_n, i_k].astype(cast_dtype), name="ele_mul") elif trans_a and not trans_b: ele_mul = akg.tvm.compute((m, n, k), lambda i_m, i_n, i_k: a_value[i_k, i_m].astype(cast_dtype) * b_value[i_k, i_n].astype(cast_dtype), name="ele_mul") elif trans_a and trans_b: ele_mul = akg.tvm.compute((m, n, k), lambda i_m, i_n, i_k: b_value[i_n, i_k].astype(cast_dtype) * a_value[i_k, i_m].astype(cast_dtype), name="ele_mul") reduce_axis = akg.tvm.reduce_axis((0, k), name='reduce_axis') output_shape = (m, n) m_c = akg.tvm.compute(output_shape, lambda *i: akg.tvm.sum(ele_mul[i + (reduce_axis,)], axis=reduce_axis), name="matmul_compute") return m_c c_cast = matmul_2d(a_value, b_value, trans_a, trans_b) c_value = cast.cast(c_cast, dtype) if trans_a and trans_b: c_res = akg.topi.transpose(c_value, (1, 0)) return c_res return c_value def vectormatmul_2d(a_value, b_value, trans_a, trans_b): """hybrid implementation for 2D batchmatmul.""" if trans_a: k, m = a_value.shape else: m, k = a_value.shape if trans_b: n = b_value.shape[-2] else: n = b_value.shape[-1] dtype = a_value.dtype zero = akg.tvm.const(0.0, dtype=dtype) @script(capture=locals()) def matmul_hybrid_f_f(a, b, zero, mv, nv, kv): t_1 = allocate((mv, kv, nv), a.dtype, 'local') t_2 = output_tensor((mv, nv), a.dtype) for i_m in range(0, mv): for i_k in range(0, kv): for i_n in range(0, nv): t_1[i_m, i_k, i_n] = a[i_m, i_k] * b[i_k, i_n] for i1_n in range(0, nv): t_2[i_m, i1_n] = zero for i1_k in range(0, kv): for i1_n in range(0, nv): t_2[i_m, i1_n] = t_2[i_m, i1_n] + t_1[i_m, i1_k, i1_n] return t_2 @script(capture=locals()) def matmul_hybrid_f_t(a, b, zero, mv, nv, kv): t_1 = allocate((mv, nv, kv), a.dtype, 'local') t_2 = allocate((mv, nv), a.dtype, 'local') for i_m in range(0, mv): for i_n in range(0, nv): t_2[i_m, i_n] = zero for i_k in range(0, kv): t_1[i_m, i_n, i_k] = a[i_m, i_k] * b[i_n, i_k] t_2[i_m, i_n] = t_1[i_m, i_n, i_k] + t_2[i_m, i_n] return t_2 @script(capture=locals()) def matmul_hybrid_t_f(a, b, zero, mv, nv, kv): t_1 = allocate((mv, kv, nv), a.dtype, 'local') t_2 = allocate((mv, nv), a.dtype, 'local') for i_m in range(0, mv): for i_k in range(0, kv): for i_n in range(0, nv): t_1[i_m, i_k, i_n] = a[i_k, i_m] * b[i_k, i_n] for i1_n in range(0, nv): t_2[i_m, i1_n] = zero for i1_k in range(0, kv): for i1_n in range(0, nv): t_2[i_m, i1_n] = t_2[i_m, i1_n] + t_1[i_m, i1_k, i1_n] return t_2 @script(capture=locals()) def matmul_hybrid_t_t(a, b, zero, mv, nv, kv): t_1 = allocate((nv, kv, mv), a.dtype, 'local') t_2 = allocate((nv, mv), a.dtype, 'local') for i_n in range(0, nv): for i_m in range(0, mv): for i_k in range(0, kv): t_1[i_n, i_k, i_m] = b[i_n, i_k] * a[i_k, i_m] for i1_m in range(0, mv): t_2[i_n, i1_m] = zero for i1_k in range(0, kv): for i2_m in range(0, mv): t_2[i_n, i2_m] = t_2[i_n, i2_m] + t_1[i_n, i1_k, i2_m] return t_2 if not trans_a and not trans_b: c_value = matmul_hybrid_f_f(a_value, b_value, zero, m, n, k) elif not trans_a and trans_b: c_value = matmul_hybrid_f_t(a_value, b_value, zero, m, n, k) elif trans_a and not trans_b: c_value = matmul_hybrid_t_f(a_value, b_value, zero, m, n, k) else: c1 = matmul_hybrid_t_t(a_value, b_value, zero, m, n, k) c_value = akg.topi.transpose(c1, (1, 0)) return c_value