|
- #!/usr/bin/env python3
- # coding: utf-8
- # Copyright 2019 Huawei Technologies Co., Ltd
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
-
- """operator dsl function: batchmatmul"""
- from functools import reduce
- import akg.topi
- import akg.tvm
- from akg.tvm.hybrid import script
- from akg.ops.math import cast
- from akg.utils import custom_tiling as ct_util
- from akg.utils import validation_check as vc_util
- from akg.utils.format_transform import get_shape, get_bytes
- from akg.utils.math import greatest_common_divisor, least_common_multiple
- from akg.utils.kernel_exec import product_is_mini
- from akg.utils import dynamic_shape as ds
-
- batchmatmul_set_dim_map = {
- # 2D
- str((256, 1024, 4096, "float32", False, True)): ((16, 16), (16, 16), (16, 16)),
- str((160, 1024, 1024, "float32", False, False)): ((1, 1), (16, 16), (1024, 1024)),
- str((8192, 1024, 4096, "float32", False, False)): ((1, 1), (1024, 1024), (16, 16)),
- str((1024, 1024, 8192, "float32", True, False)): ((8, 8), (8, 8), (512, 512)),
- str((1024, 1024, 2, "float32", False, False)): ((64, 64), (64, 64), (2, 2)),
- str((1024, 1024, 4096, "float32", False, False)): ((1, 1), (16, 16), (512, 512)),
- str((2, 1024, 8192, "float32", True, False)): ((2, 2), (16, 16), (512, 512)),
- str((30522, 1024, 1280, "float32", True, False)): ((3, 3), (64, 64), (128, 128)),
- str((1024, 4096, 8192, "float32", True, False)): ((32, 32), (32, 32), (32, 32)),
- str((2, 1024, 64, "float32", True, False)): ((4, 4), (64, 64), (64, 64)),
- str((160, 30522, 1024, "float32", False, True)): ((1, 1), (3, 3), (512, 512)),
- str((1024, 1024, 64, "float32", True, False)): ((16, 16), (16, 16), (64, 64)),
- str((4096, 1024, 8192, "float32", True, False)): ((16, 16), (64, 64), (16, 16)),
- str((1280, 1024, 30522, "float32", False, False)): ((1, 1), (512, 512), (3, 3)),
- str((8192, 1024, 4096, "float32", False, True)): ((1, 1), (16, 16), (1024, 1024)),
- str((1280, 30522, 1024, "float32", False, True)): ((1, 1), (3, 3), (512, 512)),
- str((8192, 4096, 1024, "float32", False, False)): ((1, 1), (64, 64), (256, 256)),
- str((768, 768, 8192, "float16", True, False)): ((16, 16), (16, 16), (64, 64)),
- str((3072, 768, 8192, "float16", True, False)): ((16, 16), (16, 16), (64, 64)),
- str((2, 768, 64, "float32", True, False)): ((2, 2), (64, 64), (64, 64)),
- str((768, 1024, 8192, "float16", False, False)): ((16, 1), (16, 1), (64, 1)),
- str((33, 64, 16384, "float32", True, False)): ((1, 1), (16, 16), (128, 128)),
- str((768, 768, 64, "float32", True, False)): ((16, 16), (16, 16), (16, 16)),
- str((8192, 768, 21128, "float32", False, False)): ((4, 4), (128, 128), (4, 4)),
- str((2, 768, 8192, "float32", True, False)): ((2, 2), (16, 16), (512, 512)),
- str((8192, 768, 768, "float16", False, True)): ((1, 1), (16, 16), (768, 768)),
- str((21128, 768, 8192, "float32", True, False)): ((4, 4), (128, 128), (64, 64)),
- str((768, 1024, 768, "float32", True, False)): ((2, 2), (128, 128), (128, 128)),
- str((16384, 16384, 33, "float32", True, False)): ((16, 16), (64, 64), (33, 33)),
- str((21128, 768, 8192, "float32", False, False)): ((4, 4), (128, 128), (64, 64)),
- str((1280, 1280, 1024, "float32", False, True)): ((4, 4), (32, 32), (128, 128)),
- str((1280, 768, 21128, "float32", False, False)): ((1, 1), (768, 768), (8, 8)),
- str((8192, 768, 768, "float32", False, False)): ((1, 1), (8, 8), (768, 768)),
- str((20, 768, 32000, "float32", False, False)): ((20, 20), (48, 48), (32, 32)),
- str((21128, 768, 1280, "float32", True, False)): ((2, 2), (32, 32), (32, 32)),
- str((768, 3072, 1892, "float32", True, False)): ((16, 16), (16, 16), (16, 16)),
- str((33, 64, 16384, "float32", False, True)): ((2, 2), (32, 32), (32, 32)),
- str((8192, 3072, 768, "float32", False, True)): ((16, 16), (16, 16), (16, 16)),
- str((2, 8192, 768, "float32", True, False)): ((16, 16), (16, 16), (16, 16)),
- str((8192, 768, 3072, "float32", False, False)): ((32, 32), (32, 32), (32, 32)),
- str((8192, 768, 3072, "float16", False, True)): ((1, 1), (16, 1), (768, 1)),
- str((8192, 3072, 768, "float16", False, True)): ((1, 1), (16, 1), (768, 1)),
- str((8192, 768, 768, "float16", False, False)): ((1, 1), (16, 1), (768, 1)),
- str((768, 3072, 768, "float16", False, False)): ((1, 1), (16, 1), (768, 1)),
- str((8192, 3072, 768, "float16", False, False)): ((1, 1), (16, 1), (768, 1)),
- str((8192, 768, 3072, "float16", False, False)): ((1, 1), (16, 1), (768, 1)),
- str((21128, 768, 5120, "float32", True, False)): ((4, 4), (128, 128), (64, 64)),
- str((21128, 768, 2560, "float32", True, False)): ((4, 4), (128, 128), (64, 64)),
- str((1024, 2, 4, "float32", True, False)): ((512, 1), (2, 1), (4, 1)),
- str((21128, 1024, 21128, "float32", False, False)): ((4, 4), (512, 512), (4, 4)),
- str((320, 768, 21128, "float32", False, False)): ((40, 40), (128, 128), (4, 4)),
- str((5120, 1024, 21128, "float32", False, False)): ((64, 64), (128, 128), (4, 4)),
- str((16384, 4096, 1024, "float32", False, False)): ((4, 4), (64, 64), (64, 64)),
- str((1024, 4096, 16384, "float32", True, False)): ((64, 64), (64, 64), (4, 4)),
- str((768, 3072, 8192, "float16", True, False)): ((1, 1), (16, 1), (512, 1)),
- str((2048, 768, 3072, "float32", False, False)): ((8, 1), (16, 1), (128, 1)),
- str((1024, 2, 128, "float32", True, False)): ((32, 1), (2, 1), (128, 1)),
- str((1024, 2, 16, "float32", True, False)): ((256, 1), (2, 1), (16, 1)),
- str((1024, 2, 32, "float32", True, False)): ((128, 1), (2, 1), (32, 1)),
- str((1024, 2, 64, "float32", True, False)): ((64, 1), (2, 1), (64, 1)),
- str((1024, 2, 8, "float32", True, False)): ((512, 1), (2, 1), (8, 1)),
- str((768, 2, 128, "float32", True, False)): ((32, 1), (2, 1), (128, 1)),
- str((768, 2, 16, "float32", True, False)): ((256, 1), (2, 1), (16, 1)),
- str((768, 2, 32, "float32", True, False)): ((128, 1), (2, 1), (32, 1)),
- str((768, 2, 64, "float32", True, False)): ((64, 1), (2, 1), (64, 1)),
- str((768, 3072, 2048, "float32", True, False)): ((16, 1), (16, 1), (64, 1)),
- str((3072, 768, 2048, "float32", True, False)): ((16, 1), (16, 1), (64, 1)),
- str((65536, 1024, 4096, "float32", False, False)): ((1, 1), (1024, 1), (16, 1)),
- str((10240, 1024, 21128, "float32", False, False)): ((1, 1), (1024, 1), (16, 1)),
- str((10240, 21128, 1024, "float32", False, False)): ((1, 1), (1024, 1), (16, 1)),
- str((21128, 768, 21128, "float32", True, False)): ((1, 1), (768, 1), (16, 1)),
- str((2560, 21128, 1024, "float32", False, False)): ((1, 1), (1024, 1), (16, 1)),
- str((5120, 21128, 1024, "float32", False, False)): ((1, 1), (1024, 1), (16, 1)),
- str((32768, 4096, 1024, "float32", False, False)): ((1, 1), (1024, 1), (16, 1)),
- str((20480, 1024, 21128, "float32", False, False)): ((1, 1), (1024, 1), (16, 1)),
- str((128, 1024, 4096, "float32", False, True)): ((1, 1), (16, 1), (16, 1)),
- str((128, 1024, 4096, "float16", False, True)): ((1, 1), (16, 1), (32, 1)),
- str((128, 4096, 1024, "float16", False, True)): ((1, 1), (32, 1), (32, 1)),
- str((20480, 21128, 1024, "float32", False, False)): ((1, 1), (1024, 1), (16, 1)),
- str((65536, 4096, 1024, "float32", False, False)): ((1, 1), (1024, 1), (16, 1)),
- str((512, 1024, 4096, "float32", False, True)): ((1, 1), (8, 1), (1024, 1)),
- str((256, 1024, 4096, "float32", False, True)): ((1, 1), (8, 1), (1024, 1)),
- str((1024, 1024, 4096, "float32", False, True)): ((1, 1), (8, 1), (1024, 1)),
- str((2048, 1024, 4096, "float32", False, True)): ((1, 1), (8, 1), (1024, 1)),
- str((2, 2, 1024, "float32", False, True)): ((1, 1), (2, 1), (1024, 1)),
- str((65536, 768, 3072, "float32", False, True)): ((1, 1), (192, 1), (96, 1)),
- str((16384, 1024, 4096, "float32", False, True)): ((1, 1), (128, 1), (128, 1)),
- str((1024, 4096, 32768, "float32", True, False)): ((8, 1), (1024, 1), (4, 1)),
- str((4, 2, 1024, "float32", False, True)): ((1, 1), (2, 1), (1024, 1)),
- str((2, 2, 768, "float32", False, True)): ((1, 1), (2, 1), (768, 1)),
- str((8, 2, 768, "float32", False, True)): ((1, 1), (2, 1), (768, 1)),
- str((4, 2, 768, "float32", False, True)): ((1, 1), (2, 1), (768, 1)),
- # lenet5
- str((32, 10, 84, 'float16', False, True)): ((1, 1), (16, 1), (84, 1)),
- # alexnet
- str((32, 4096, 9216, 'float32', False, False)): ((1, 1), (16, 1), (1024, 1)),
- str((32, 10, 4096, 'float32', False, False)): ((1, 1), (16, 1), (1024, 1)),
-
- # 3D
- str((128, 128, 64, 1536, "float32", True, False)): ((4, 4), (8, 8), (16, 16), (32, 32)),
- str((128, 768, 128, 64, "float32", False, True)): ((4, 4), (8, 8), (16, 16), (64, 64)),
- str((128, 128, 64, 6144, "float32", True, False)): ((4, 4), (8, 8), (16, 16), (32, 32)),
- str((128, 128, 64, 16384, "float32", True, False)): ((1, 1), (128, 128), (64, 64), (4, 4)),
- str((128, 128, 64, 2048, "float32", True, False)): ((1, 1), (128, 128), (64, 64), (4, 4)),
- str((128, 128, 64, 4096, "float32", True, False)): ((1, 1), (128, 128), (64, 64), (4, 4)),
- str((128, 128, 64, 8192, "float32", True, False)): ((1, 1), (128, 128), (64, 64), (4, 4)),
- str((128, 128, 64, 4096, "float32", True, False)): ((1, 1), (8, 1), (64, 1), (64, 1)),
- str((128, 128, 64, 12288, "float32", True, False)): ((1, 1), (8, 1), (64, 1), (32, 1)),
-
- # 4D
- str((64, 12, 128, 64, 128, "float32", False, False)): ((1, 1), (12, 12), (8, 8), (8, 8), (32, 32)),
-
- str((1, 768, 2, "float32", False, False)): ((768, 1), (1, 1)),
- str((20, 768, 21128, "float32", False, False)): ((1, 1), (768, 1), (16, 1)),
- str((128, 12, 64, 128, "float32", False, False)): ((1, 1), (1, 1), (64, 1), (128, 1)),
- str((128, 3072, 768, "float32", False, False)): ((1, 1), (768, 1), (32, 1)),
- str((1, 12, 128, 128, 64, "float32", False, True)): ((1, 1), (2, 1), (128, 1), (64, 1)),
- str((128, 128, 64, 12, "float32", True, False)): ((1, 1), (4, 1), (64, 1), (12, 1)),
- str((1, 12, 128, 64, 128, "float32", False, False)): ((1, 1), (1, 1), (64, 1), (32, 1)),
- str((20, 21128, 768, "float32", False, True)): ((1, 1), (16, 1), (768, 1)),
- str((1, 12, 128, 64, 128, "float32", True, False)): ((1, 1), (8, 1), (64, 1), (32, 1)),
- str((20, 768, 768, "float32", False, False)): ((1, 1), (768, 1), (16, 1)),
- str((128, 768, 3072, "float32", False, False)): ((1, 1), (768, 1), (16, 1)),
- str((1, 768, 768, "float32", False, False)): ((768, 1), (16, 1)),
- str((21128, 768, 20, "float32", True, False)): ((8, 1), (768, 1), (8, 1)),
- str((128, 12, 128, 64, "float32", False, True)): ((1, 1), (4, 1), (128, 1), (64, 1)),
- str((128, 768, 768, "float32", False, False)): ((1, 1), (768, 1), (16, 1)),
- str((40, 768, 768, "float32", False, False)): ((2, 1), (768, 1), (16, 1)),
- str((2, 12, 128, 128, 64, "float32", False, True)): ((1, 1), (1, 1), (2, 1), (128, 1), (64, 1)),
- str((2, 768, 2, "float32", False, False)): ((2, 1), (768, 1), (2, 1)),
- str((40, 21128, 768, "float32", False, True)): ((1, 1), (16, 1), (768, 1)),
- str((256, 768, 768, "float32", False, False)): ((2, 1), (768, 1), (8, 1)),
- str((2, 12, 128, 64, 128, "float32", True, False)): ((1, 1), (1, 1), (8, 1), (64, 1), (32, 1)),
- str((2, 12, 128, 64, 128, "float32", False, False)): ((1, 1), (1, 1), (1, 1), (64, 1), (32, 1)),
- str((128, 128, 64, 24, "float32", True, False)): ((1, 1), (2, 1), (64, 1), (12, 1)),
- str((256, 3072, 768, "float32", False, False)): ((1, 1), (768, 1), (32, 1)),
- str((128, 24, 128, 64, "float32", False, True)): ((1, 1), (4, 1), (64, 1), (64, 1)),
- str((2, 768, 768, "float32", False, False)): ((1, 1), (768, 1), (16, 1)),
- str((21128, 768, 40, "float32", True, False)): ((8, 1), (768, 1), (8, 1)),
- str((40, 768, 21128, "float32", False, False)): ((2, 1), (768, 1), (16, 1)),
- str((256, 768, 3072, "float32", False, False)): ((1, 1), (768, 1), (16, 1)),
- str((128, 24, 64, 128, "float32", False, False)): ((1, 1), (1, 1), (64, 1), (128, 1)),
- str((128, 48, 128, 64, "float32", False, True)): ((1, 1), (4, 1), (64, 1), (64, 1)),
- str((512, 3072, 768, "float32", False, False)): ((1, 1), (768, 1), (32, 1)),
- str((128, 48, 64, 128, "float32", False, False)): ((1, 1), (1, 1), (64, 1), (128, 1)),
- str((512, 768, 768, "float32", False, False)): ((2, 1), (768, 1), (8, 1)),
- str((4, 12, 128, 64, 128, "float32", True, False)): ((1, 1), (1, 1), (8, 1), (64, 1), (32, 1)),
- str((512, 768, 3072, "float32", False, False)): ((2, 1), (768, 1), (8, 1)),
- str((4, 768, 2, "float32", False, False)): ((1, 1), (768, 1), (2, 1)),
- str((80, 768, 768, "float32", False, False)): ((1, 1), (768, 1), (16, 1)),
- str((4, 12, 128, 64, 128, "float32", False, False)): ((1, 1), (1, 1), (1, 1), (64, 1), (32, 1)),
- str((21128, 768, 80, "float32", True, False)): ((8, 1), (768, 1), (8, 1)),
- str((80, 768, 21128, "float32", False, False)): ((2, 1), (768, 1), (8, 1)),
- str((128, 128, 64, 48, "float32", True, False)): ((1, 1), (4, 1), (64, 1), (12, 1)),
- str((4, 768, 768, "float32", False, False)): ((1, 1), (768, 1), (16, 1)),
- str((4, 12, 128, 128, 64, "float32", False, True)): ((1, 1), (1, 1), (2, 1), (128, 1), (64, 1)),
- str((80, 21128, 768, "float32", False, True)): ((1, 1), (16, 1), (768, 1)),
- str((128, 96, 64, 128, "float32", False, False)): ((1, 1), (1, 1), (64, 1), (128, 1)),
- str((1024, 3072, 768, "float32", False, False)): ((1, 1), (768, 1), (32, 1)),
- str((160, 21128, 768, "float32", False, True)): ((1, 1), (16, 1), (768, 1)),
- str((8, 12, 128, 64, 128, "float32", False, False)): ((1, 1), (1, 1), (1, 1), (64, 1), (32, 1)),
- str((8, 768, 2, "float32", False, False)): ((8, 1), (768, 1), (2, 1)),
- str((8, 768, 768, "float32", False, False)): ((1, 1), (768, 1), (16, 1)),
- str((1024, 768, 768, "float32", False, False)): ((2, 1), (768, 1), (8, 1)),
- str((128, 96, 128, 64, "float32", False, True)): ((1, 1), (4, 1), (64, 1), (64, 1)),
- str((128, 128, 64, 96, "float32", True, False)): ((1, 1), (4, 1), (64, 1), (12, 1)),
- str((8, 12, 128, 128, 64, "float32", False, True)): ((1, 1), (1, 1), (2, 1), (128, 1), (64, 1)),
- str((1024, 768, 3072, "float32", False, False)): ((2, 1), (768, 1), (8, 1)),
- str((21128, 768, 160, "float32", True, False)): ((8, 1), (768, 1), (8, 1)),
- str((160, 768, 768, "float32", False, False)): ((1, 1), (768, 1), (16, 1)),
- str((8, 12, 128, 64, 128, "float32", True, False)): ((1, 1), (1, 1), (8, 1), (64, 1), (32, 1)),
- str((160, 768, 21128, "float32", False, False)): ((2, 1), (768, 1), (8, 1)),
- str((16, 768, 2, "float32", False, False)): ((2, 1), (768, 1), (2, 1)),
- str((320, 21128, 768, "float32", False, True)): ((1, 1), (16, 1), (768, 1)),
- str((16, 12, 128, 64, 128, "float32", True, False)): ((1, 1), (1, 1), (8, 1), (64, 1), (32, 1)),
- str((2048, 3072, 768, "float32", False, False)): ((1, 1), (768, 1), (32, 1)),
- str((16, 12, 128, 64, 128, "float32", False, False)): ((1, 1), (1, 1), (1, 1), (64, 1), (32, 1)),
- str((128, 128, 64, 192, "float32", True, False)): ((1, 1), (4, 1), (64, 1), (12, 1)),
- str((128, 192, 64, 128, "float32", False, False)): ((1, 1), (1, 1), (64, 1), (128, 1)),
- str((21128, 768, 320, "float32", True, False)): ((8, 1), (768, 1), (8, 1)),
- str((128, 192, 128, 64, "float32", False, True)): ((1, 1), (4, 1), (64, 1), (64, 1)),
- str((320, 768, 21128, "float32", False, False)): ((2, 1), (768, 1), (8, 1)),
- str((16, 12, 128, 128, 64, "float32", False, True)): ((1, 1), (1, 1), (2, 1), (128, 1), (64, 1)),
- str((320, 768, 768, "float32", False, False)): ((1, 1), (768, 1), (16, 1)),
- str((2048, 768, 3072, "float32", False, False)): ((2, 1), (768, 1), (8, 1)),
- str((2048, 768, 768, "float32", False, False)): ((2, 1), (768, 1), (8, 1)),
- str((16, 768, 768, "float32", False, False)): ((1, 1), (768, 1), (16, 1)),
- str((4096, 768, 768, "float32", False, False)): ((2, 1), (768, 1), (8, 1)),
- str((128, 128, 64, 384, "float32", True, False)): ((1, 1), (4, 1), (64, 1), (12, 1)),
- str((4096, 3072, 768, "float32", False, False)): ((1, 1), (768, 1), (32, 1)),
- str((640, 768, 768, "float32", False, False)): ((1, 1), (768, 1), (16, 1)),
- str((128, 384, 128, 64, "float32", False, True)): ((1, 1), (4, 1), (64, 1), (64, 1)),
- str((128, 384, 64, 128, "float32", False, False)): ((1, 1), (1, 1), (64, 1), (128, 1)),
- str((32, 12, 128, 128, 64, "float32", False, True)): ((1, 1), (1, 1), (2, 1), (128, 1), (64, 1)),
- str((32, 12, 128, 64, 128, "float32", True, False)): ((1, 1), (1, 1), (8, 1), (64, 1), (32, 1)),
- str((32, 12, 128, 64, 128, "float32", False, False)): ((1, 1), (1, 1), (1, 1), (64, 1), (32, 1)),
- str((32, 768, 768, "float32", False, False)): ((1, 1), (768, 1), (16, 1)),
- str((4096, 768, 3072, "float32", False, False)): ((2, 1), (768, 1), (8, 1)),
- str((640, 768, 21128, "float32", False, False)): ((2, 1), (768, 1), (8, 1)),
- str((32, 768, 2, "float32", False, False)): ((4, 1), (768, 1), (2, 1)),
- str((21128, 768, 640, "float32", True, False)): ((8, 1), (768, 1), (8, 1)),
- str((640, 21128, 768, "float32", False, True)): ((1, 1), (16, 1), (768, 1)),
- str((64, 12, 128, 64, 128, "float32", False, False)): ((1, 1), (1, 1), (1, 1), (64, 1), (32, 1)),
- str((8192, 768, 3072, "float32", False, False)): ((2, 1), (768, 1), (8, 1)),
- str((1280, 768, 768, "float32", False, False)): ((1, 1), (768, 1), (16, 1)),
- str((8192, 3072, 768, "float32", False, False)): ((1, 1), (768, 1), (32, 1)),
- str((128, 768, 128, 64, "float32", False, True)): ((1, 1), (4, 1), (64, 1), (64, 1)),
- str((128, 128, 64, 768, "float32", True, False)): ((1, 1), (4, 1), (64, 1), (12, 1)),
- str((21128, 768, 1280, "float32", True, False)): ((8, 1), (768, 1), (8, 1)),
- str((1280, 768, 21128, "float32", False, False)): ((2, 1), (768, 1), (8, 1)),
- str((64, 12, 128, 64, 128, "float32", True, False)): ((1, 1), (1, 1), (8, 1), (64, 1), (32, 1)),
- str((1280, 21128, 768, "float32", False, True)): ((1, 1), (16, 1), (768, 1)),
- str((64, 768, 768, "float32", False, False)): ((2, 1), (768, 1), (16, 1)),
- str((64, 12, 128, 128, 64, "float32", False, True)): ((1, 1), (1, 1), (2, 1), (128, 1), (64, 1)),
- str((64, 768, 2, "float32", False, False)): ((8, 1), (768, 1), (2, 1)),
- str((8192, 768, 768, "float32", False, False)): ((2, 1), (768, 1), (8, 1)),
- str((128, 768, 64, 128, "float32", False, False)): ((1, 1), (1, 1), (64, 1), (128, 1)),
- str((128, 12, 128, 128, 64, "float32", False, True)): ((1, 1), (1, 1), (2, 1), (128, 1), (64, 1)),
- str((16384, 768, 3072, "float32", False, False)): ((2, 1), (768, 1), (8, 1)),
- str((16384, 768, 768, "float32", False, False)): ((2, 1), (768, 1), (8, 1)),
- str((128, 1536, 128, 64, "float32", False, True)): ((1, 1), (4, 1), (64, 1), (64, 1)),
- str((128, 768, 2, "float32", False, False)): ((16, 1), (768, 1), (2, 1)),
- str((2560, 768, 768, "float32", False, False)): ((1, 1), (768, 1), (16, 1)),
- str((21128, 768, 2560, "float32", True, False)): ((8, 1), (768, 1), (8, 1)),
- str((128, 12, 128, 64, 128, "float32", False, False)): ((1, 1), (1, 1), (1, 1), (64, 1), (32, 1)),
- str((2560, 768, 21128, "float32", False, False)): ((2, 1), (768, 1), (8, 1)),
- str((128, 128, 64, 1536, "float32", True, False)): ((1, 1), (4, 1), (64, 1), (12, 1)),
- str((128, 1536, 64, 128, "float32", False, False)): ((1, 1), (1, 1), (64, 1), (128, 1)),
- str((2560, 21128, 768, "float32", False, True)): ((1, 1), (16, 1), (768, 1)),
- str((128, 12, 128, 64, 128, "float32", True, False)): ((1, 1), (1, 1), (8, 1), (64, 1), (32, 1)),
- str((16384, 3072, 768, "float32", False, False)): ((1, 1), (768, 1), (32, 1)),
- str((1, 16, 128, 64, 128, "float32", False, False)): ((1, 1), (4, 1), (64, 1), (8, 1)),
- str((1, 16, 128, 64, 128, "float32", True, False)): ((1, 1), (8, 1), (64, 1), (4, 1)),
- str((128, 16, 128, 64, "float32", False, True)): ((1, 1), (2, 1), (128, 1), (64, 1)),
- str((128, 1024, 1024, "float32", False, False)): ((2, 1), (1024, 1), (8, 1)),
- str((20, 1024, 21128, "float32", False, False)): ((1, 1), (1024, 1), (8, 1)),
- str((128, 4096, 1024, "float32", False, False)): ((2, 1), (1024, 1), (8, 1)),
- str((128, 1024, 4096, "float32", False, False)): ((2, 1), (1024, 1), (8, 1)),
- str((128, 16, 64, 128, "float32", False, False)): ((1, 1), (1, 1), (64, 1), (32, 1)),
- str((20, 21128, 1024, "float32", False, True)): ((1, 1), (16, 1), (1024, 1)),
- str((1, 16, 128, 128, 64, "float32", False, True)): ((1, 1), (4, 1), (64, 1), (64, 1)),
- str((1, 1024, 2, "float32", False, False)): ((1024, 1), (2, 1)),
- str((21128, 1024, 20, "float32", True, False)): ((8, 1), (1024, 1), (2, 1)),
- str((20, 1024, 1024, "float32", False, False)): ((1, 1), (1024, 1), (8, 1)),
- str((128, 128, 64, 16, "float32", True, False)): ((1, 1), (4, 1), (64, 1), (8, 1)),
- str((1, 1024, 1024, "float32", False, False)): ((1024, 1), (8, 1)),
- str((128, 1024, 1024, "float32", False, True)): ((2, 1), (8, 1), (1024, 1)),
- str((1, 1024, 1024, "float32", False, True)): ((8, 1), (1024, 1)),
- str((20, 1024, 1024, "float32", False, True)): ((1, 1), (8, 1), (1024, 1)),
- str((128, 1024, 4096, "float32", False, True)): ((1, 1), (8, 1), (2048, 1)),
- str((128, 4096, 1024, "float32", False, True)): ((2, 1), (8, 1), (1024, 1)),
- str((1, 2, 1024, "float32", False, True)): ((2, 1), (1024, 1)),
- str((2, 1024, 1, "float32", True, False)): ((2, 1), (1024, 1)),
- str((1024, 1024, 128, "float32", True, False)): ((8, 1), (1024, 1), (2, 1)),
- str((1024, 1024, 1, "float32", True, False)): ((16, 1), (1024, 1)),
- str((1024, 1024, 20, "float32", True, False)): ((8, 1), (1024, 1), (2, 1)),
- str((4096, 1024, 128, "float32", True, False)): ((8, 1), (1024, 1), (2, 1)),
- str((1024, 4096, 128, "float32", True, False)): ((4, 1), (2048, 1), (2, 1)),
- str((128, 128, 64, 32, "float32", True, False)): ((1, 1), (2, 1), (64, 1), (16, 1)),
- str((128, 32, 128, 64, "float32", False, True)): ((1, 1), (2, 1), (128, 1), (64, 1)),
- str((256, 4096, 1024, "float32", False, False)): ((4, 1), (2048, 1), (2, 1)),
- str((256, 1024, 1024, "float32", False, False)): ((4, 1), (1024, 1), (4, 1)),
- str((40, 21128, 1024, "float32", False, True)): ((1, 1), (16, 1), (1024, 1)),
- str((2, 16, 128, 64, 128, "float32", False, False)): ((1, 1), (1, 1), (2, 1), (64, 1), (16, 1)),
- str((2, 16, 128, 64, 128, "float32", True, False)): ((1, 1), (1, 1), (4, 1), (64, 1), (8, 1)),
- str((21128, 1024, 40, "float32", True, False)): ((8, 1), (1024, 1), (2, 1)),
- str((2, 1024, 2, "float32", False, False)): ((2, 1), (1024, 1), (2, 1)),
- str((256, 1024, 4096, "float32", False, False)): ((2, 1), (1024, 1), (8, 1)),
- str((40, 1024, 21128, "float32", False, False)): ((2, 1), (1024, 1), (8, 1)),
- str((2, 1024, 1024, "float32", False, False)): ((1, 1), (1024, 1), (8, 1)),
- str((40, 1024, 1024, "float32", False, False)): ((2, 1), (1024, 1), (8, 1)),
- str((128, 32, 64, 128, "float32", False, False)): ((1, 1), (1, 1), (64, 1), (32, 1)),
- str((2, 16, 128, 128, 64, "float32", False, True)): ((1, 1), (1, 1), (4, 1), (64, 1), (64, 1)),
- str((2, 1024, 1024, "float32", False, True)): ((1, 1), (8, 1), (1024, 1)),
- str((256, 1024, 1024, "float32", False, True)): ((2, 1), (8, 1), (1024, 1)),
- str((256, 4096, 1024, "float32", False, True)): ((2, 1), (8, 1), (1024, 1)),
- str((2, 2, 1024, "float32", False, True)): ((1, 1), (2, 1), (1024, 1)),
- str((256, 1024, 4096, "float32", False, True)): ((1, 1), (8, 1), (2048, 1)),
- str((40, 1024, 1024, "float32", False, True)): ((2, 1), (8, 1), (1024, 1)),
- str((1024, 1024, 256, "float32", True, False)): ((8, 1), (1024, 1), (2, 1)),
- str((1024, 1024, 40, "float32", True, False)): ((8, 1), (1024, 1), (2, 1)),
- str((4096, 1024, 256, "float32", True, False)): ((8, 1), (1024, 1), (2, 1)),
- str((2, 1024, 2, "float32", True, False)): ((2, 1), (1024, 1), (2, 1)),
- str((1024, 1024, 2, "float32", True, False)): ((16, 1), (1024, 1), (1, 1)),
- str((1024, 4096, 256, "float32", True, False)): ((4, 1), (2048, 1), (2, 1)),
- str((4, 16, 128, 128, 64, "float32", False, True)): ((1, 1), (1, 1), (4, 1), (64, 1), (64, 1)),
- str((4, 16, 128, 64, 128, "float32", True, False)): ((1, 1), (1, 1), (4, 1), (64, 1), (8, 1)),
- str((80, 21128, 1024, "float32", False, True)): ((1, 1), (16, 1), (1024, 1)),
- str((4, 1024, 1024, "float32", False, False)): ((1, 1), (1024, 1), (8, 1)),
- str((128, 64, 128, 64, "float32", False, True)): ((1, 1), (2, 1), (128, 1), (64, 1)),
- str((128, 128, 64, 64, "float32", True, False)): ((1, 1), (2, 1), (64, 1), (16, 1)),
- str((21128, 1024, 80, "float32", True, False)): ((8, 1), (1024, 1), (2, 1)),
- str((4, 1024, 2, "float32", False, False)): ((4, 1), (1024, 1), (2, 1)),
- str((128, 64, 64, 128, "float32", False, False)): ((1, 1), (1, 1), (64, 1), (32, 1)),
- str((80, 1024, 21128, "float32", False, False)): ((1, 1), (1024, 1), (8, 1)),
- str((512, 1024, 1024, "float32", False, False)): ((4, 1), (1024, 1), (4, 1)),
- str((512, 1024, 4096, "float32", False, False)): ((2, 1), (1024, 1), (8, 1)),
- str((512, 4096, 1024, "float32", False, False)): ((4, 1), (2048, 1), (2, 1)),
- str((4, 16, 128, 64, 128, "float32", False, False)): ((1, 1), (1, 1), (2, 1), (64, 1), (16, 1)),
- str((80, 1024, 1024, "float32", False, False)): ((1, 1), (1024, 1), (8, 1)),
- str((512, 4096, 1024, "float32", False, True)): ((2, 1), (8, 1), (1024, 1)),
- str((80, 1024, 1024, "float32", False, True)): ((1, 1), (8, 1), (1024, 1)),
- str((4, 1024, 1024, "float32", False, True)): ((1, 1), (8, 1), (1024, 1)),
- str((4, 2, 1024, "float32", False, True)): ((1, 1), (2, 1), (1024, 1)),
- str((512, 1024, 1024, "float32", False, True)): ((2, 1), (8, 1), (1024, 1)),
- str((512, 1024, 4096, "float32", False, True)): ((1, 1), (8, 1), (2048, 1)),
- str((1024, 1024, 512, "float32", True, False)): ((4, 1), (2048, 1), (2, 1)),
- str((2, 1024, 4, "float32", True, False)): ((2, 1), (1024, 1), (4, 1)),
- str((4096, 1024, 512, "float32", True, False)): ((8, 1), (1024, 1), (2, 1)),
- str((1024, 1024, 4, "float32", True, False)): ((16, 1), (1024, 1), (1, 1)),
- str((1024, 1024, 80, "float32", True, False)): ((8, 1), (1024, 1), (2, 1)),
- str((1024, 4096, 512, "float32", True, False)): ((4, 1), (2048, 1), (2, 1)),
- str((4096, 768, 3072, "float32", False, True)): ((1, 1), (8, 1), (3072, 1)),
- str((4096, 3072, 768, "float32", False, True)): ((1, 1), (16, 1), (768, 1)),
- str((4096, 768, 768, "float32", False, True)): ((1, 1), (16, 1), (768, 1)),
- str((8192, 768, 768, "float32", False, True)): ((1, 1), (16, 1), (768, 1)),
- str((8192, 3072, 768, "float32", False, True)): ((1, 1), (16, 1), (768, 1)),
- str((8192, 768, 3072, "float32", False, True)): ((1, 1), (8, 1), (3072, 1)),
- str((16384, 3072, 768, "float32", False, True)): ((1, 1), (16, 1), (768, 1)),
- str((16384, 768, 768, "float32", False, True)): ((1, 1), (16, 1), (768, 1)),
- str((16384, 768, 3072, "float32", False, True)): ((1, 1), (8, 1), (3072, 1)),
- str((3072, 768, 128, "float32", True, False)): ((4, 1), (768, 1), (4, 1)),
- str((768, 768, 1, "float32", True, False)): ((16, 1), (768, 1)),
- str((768, 768, 20, "float32", True, False)): ((8, 1), (768, 1), (2, 1)),
- str((768, 3072, 128, "float32", True, False)): ((4, 1), (3072, 1), (1, 1)),
- str((2, 768, 1, "float32", True, False)): ((2, 1), (768, 1)),
- str((768, 768, 128, "float32", True, False)): ((8, 1), (768, 1), (2, 1)),
- str((768, 768, 2, "float32", True, False)): ((4, 1), (768, 1), (2, 1)),
- str((768, 768, 40, "float32", True, False)): ((8, 1), (768, 1), (2, 1)),
- str((3072, 768, 256, "float32", True, False)): ((4, 1), (768, 1), (4, 1)),
- str((768, 3072, 256, "float32", True, False)): ((2, 1), (3072, 1), (2, 1)),
- str((768, 768, 256, "float32", True, False)): ((8, 1), (768, 1), (2, 1)),
- str((2, 768, 2, "float32", True, False)): ((2, 1), (768, 1), (2, 1)),
- str((768, 768, 80, "float32", True, False)): ((8, 1), (768, 1), (2, 1)),
- str((3072, 768, 512, "float32", True, False)): ((2, 1), (768, 1), (8, 1)),
- str((768, 3072, 512, "float32", True, False)): ((2, 1), (3072, 1), (2, 1)),
- str((768, 768, 4, "float32", True, False)): ((4, 1), (768, 1), (2, 1)),
- str((768, 768, 512, "float32", True, False)): ((8, 1), (768, 1), (2, 1)),
- str((2, 768, 4, "float32", True, False)): ((2, 1), (768, 1), (4, 1)),
- str((768, 768, 1024, "float32", True, False)): ((8, 1), (768, 1), (2, 1)),
- str((768, 3072, 1024, "float32", True, False)): ((2, 1), (3072, 1), (2, 1)),
- str((3072, 768, 1024, "float32", True, False)): ((2, 1), (768, 1), (8, 1)),
- str((768, 768, 8, "float32", True, False)): ((4, 1), (768, 1), (2, 1)),
- str((768, 768, 160, "float32", True, False)): ((8, 1), (768, 1), (2, 1)),
- str((2, 768, 8, "float32", True, False)): ((2, 1), (768, 1), (8, 1)),
- str((3072, 768, 2048, "float32", True, False)): ((2, 1), (768, 1), (8, 1)),
- str((768, 3072, 2048, "float32", True, False)): ((2, 1), (3072, 1), (2, 1)),
- str((768, 768, 2048, "float32", True, False)): ((8, 1), (768, 1), (2, 1)),
- str((768, 768, 320, "float32", True, False)): ((8, 1), (768, 1), (2, 1)),
- str((768, 768, 16, "float32", True, False)): ((4, 1), (768, 1), (2, 1)),
- str((2, 768, 16, "float32", True, False)): ((2, 1), (768, 1), (16, 1)),
- str((768, 768, 32, "float32", True, False)): ((4, 1), (768, 1), (2, 1)),
- str((2, 768, 32, "float32", True, False)): ((2, 1), (768, 1), (16, 1)),
- str((3072, 768, 4096, "float32", True, False)): ((2, 1), (768, 1), (8, 1)),
- str((768, 3072, 4096, "float32", True, False)): ((2, 1), (3072, 1), (2, 1)),
- str((768, 768, 640, "float32", True, False)): ((8, 1), (768, 1), (2, 1)),
- str((768, 768, 4096, "float32", True, False)): ((8, 1), (768, 1), (2, 1)),
- str((768, 768, 64, "float32", True, False)): ((4, 1), (768, 1), (2, 1)),
- str((2, 768, 64, "float32", True, False)): ((2, 1), (768, 1), (16, 1)),
- str((768, 3072, 8192, "float32", True, False)): ((2, 1), (3072, 1), (2, 1)),
- str((768, 768, 8192, "float32", True, False)): ((8, 1), (768, 1), (2, 1)),
- str((768, 768, 1280, "float32", True, False)): ((8, 1), (768, 1), (2, 1)),
- str((3072, 768, 8192, "float32", True, False)): ((2, 1), (768, 1), (8, 1)),
- str((768, 768, 2560, "float32", True, False)): ((8, 1), (768, 1), (2, 1)),
- str((3072, 768, 16384, "float32", True, False)): ((2, 1), (768, 1), (8, 1)),
- str((2, 768, 128, "float32", True, False)): ((2, 1), (768, 1), (16, 1)),
- str((768, 768, 16384, "float32", True, False)): ((8, 1), (768, 1), (2, 1)),
- str((768, 3072, 16384, "float32", True, False)): ((2, 1), (3072, 1), (2, 1)),
- str((8, 16, 128, 64, 128, "float32", False, False)): ((1, 1), (1, 1), (2, 1), (64, 1), (16, 1)),
- str((21128, 1024, 160, "float32", True, False)): ((4, 1), (1024, 1), (4, 1)),
- str((1024, 4096, 1024, "float32", False, False)): ((4, 1), (2048, 1), (2, 1)),
- str((160, 1024, 21128, "float32", False, False)): ((2, 1), (1024, 1), (4, 1)),
- str((8, 16, 128, 64, 128, "float32", True, False)): ((1, 1), (1, 1), (4, 1), (64, 1), (8, 1)),
- str((160, 1024, 1024, "float32", False, False)): ((1, 1), (1024, 1), (8, 1)),
- str((8, 1024, 1024, "float32", False, False)): ((1, 1), (1024, 1), (8, 1)),
- str((1024, 1024, 1024, "float32", False, False)): ((4, 1), (1024, 1), (4, 1)),
- str((160, 21128, 1024, "float32", False, True)): ((1, 1), (16, 1), (1024, 1)),
- str((1024, 1024, 4096, "float32", False, False)): ((2, 1), (1024, 1), (8, 1)),
- str((8, 16, 128, 128, 64, "float32", False, True)): ((1, 1), (1, 1), (4, 1), (64, 1), (64, 1)),
- str((128, 128, 64, 128, "float32", False, False)): ((1, 1), (1, 1), (64, 1), (32, 1)),
- str((128, 128, 128, 64, "float32", False, True)): ((1, 1), (2, 1), (128, 1), (64, 1)),
- str((128, 128, 64, 128, "float32", True, False)): ((1, 1), (2, 1), (64, 1), (16, 1)),
- str((8, 1024, 2, "float32", False, False)): ((1, 1), (1024, 1), (1, 1)),
- str((21128, 1024, 320, "float32", True, False)): ((4, 1), (1024, 1), (4, 1)),
- str((16, 1024, 1024, "float32", False, False)): ((1, 1), (1024, 1), (8, 1)),
- str((2048, 4096, 1024, "float32", False, False)): ((4, 1), (2048, 1), (2, 1)),
- str((16, 1024, 2, "float32", False, False)): ((1, 1), (1024, 1), (1, 1)),
- str((16, 16, 128, 64, 128, "float32", True, False)): ((1, 1), (1, 1), (4, 1), (64, 1), (8, 1)),
- str((16, 16, 128, 64, 128, "float32", False, False)): ((1, 1), (1, 1), (2, 1), (64, 1), (16, 1)),
- str((2048, 1024, 1024, "float32", False, False)): ((4, 1), (1024, 1), (4, 1)),
- str((320, 1024, 1024, "float32", False, False)): ((1, 1), (1024, 1), (8, 1)),
- str((2048, 1024, 4096, "float32", False, False)): ((2, 1), (1024, 1), (8, 1)),
- str((320, 21128, 1024, "float32", False, True)): ((1, 1), (16, 1), (1024, 1)),
- str((128, 256, 128, 64, "float32", False, True)): ((1, 1), (2, 1), (128, 1), (64, 1)),
- str((320, 1024, 21128, "float32", False, False)): ((2, 1), (1024, 1), (4, 1)),
- str((128, 256, 64, 128, "float32", False, False)): ((1, 1), (1, 1), (64, 1), (32, 1)),
- str((128, 128, 64, 256, "float32", True, False)): ((1, 1), (2, 1), (64, 1), (16, 1)),
- str((16, 16, 128, 128, 64, "float32", False, True)): ((1, 1), (1, 1), (4, 1), (64, 1), (64, 1)),
- str((1024, 4096, 1024, "float32", False, True)): ((2, 1), (8, 1), (1024, 1)),
- str((8, 2, 1024, "float32", False, True)): ((1, 1), (2, 1), (1024, 1)),
- str((160, 1024, 1024, "float32", False, True)): ((1, 1), (8, 1), (1024, 1)),
- str((1024, 1024, 4096, "float32", False, True)): ((1, 1), (8, 1), (2048, 1)),
- str((1024, 1024, 1024, "float32", False, True)): ((2, 1), (8, 1), (1024, 1)),
- str((8, 1024, 1024, "float32", False, True)): ((1, 1), (8, 1), (1024, 1)),
- str((16, 2, 1024, "float32", False, True)): ((1, 1), (2, 1), (1024, 1)),
- str((320, 1024, 1024, "float32", False, True)): ((2, 1), (8, 1), (1024, 1)),
- str((16, 1024, 1024, "float32", False, True)): ((1, 1), (8, 1), (1024, 1)),
- str((2048, 1024, 1024, "float32", False, True)): ((2, 1), (8, 1), (1024, 1)),
- str((2048, 4096, 1024, "float32", False, True)): ((2, 1), (8, 1), (1024, 1)),
- str((2048, 1024, 4096, "float32", False, True)): ((1, 1), (8, 1), (2048, 1)),
- str((1024, 4096, 1024, "float32", True, False)): ((4, 1), (2048, 1), (2, 1)),
- str((1024, 1024, 8, "float32", True, False)): ((16, 1), (1024, 1), (1, 1)),
- str((2, 1024, 8, "float32", True, False)): ((2, 1), (1024, 1), (8, 1)),
- str((1024, 1024, 1024, "float32", True, False)): ((4, 1), (2048, 1), (2, 1)),
- str((4096, 1024, 1024, "float32", True, False)): ((8, 1), (1024, 1), (2, 1)),
- str((1024, 1024, 160, "float32", True, False)): ((8, 1), (1024, 1), (2, 1)),
- str((4096, 1024, 2048, "float32", True, False)): ((8, 1), (1024, 1), (2, 1)),
- str((1024, 1024, 320, "float32", True, False)): ((8, 1), (1024, 1), (2, 1)),
- str((1024, 1024, 16, "float32", True, False)): ((16, 1), (1024, 1), (1, 1)),
- str((1024, 4096, 2048, "float32", True, False)): ((4, 1), (2048, 1), (2, 1)),
- str((2, 1024, 16, "float32", True, False)): ((2, 1), (1024, 1), (16, 1)),
- str((1024, 1024, 2048, "float32", True, False)): ((4, 1), (2048, 1), (2, 1)),
- str((32, 1001, 2048, "float16", False, True)): ((1, 1), (77, 1), (256, 1)),
- str((1001, 2048, 32, "float16", True, False)): ((1, 1), (2048, 1), (4, 1)),
- str((32, 2048, 1001, "float16", False, False)): ((1, 1), (2048, 1), (4, 1)),
- str((32, 1001, 2048, "float32", False, True)): ((1, 1), (7, 1), (2048, 1)),
- str((1001, 2048, 32, "float32", True, False)): ((1, 1), (2048, 1), (4, 1)),
- str((32, 2048, 1001, "float32", False, False)): ((1, 1), (2048, 1), (4, 1)),
- str((768, 3072, 131072, "float32", True, False)): ((2, 1), (3072, 1), (2, 1)),
- str((3072, 768, 131072, "float32", True, False)): ((2, 1), (768, 1), (8, 1)),
- str((65536, 768, 3072, "float32", False, True)): ((1, 1), (8, 1), (3072, 1)),
- str((131072, 3072, 768, "float32", False, True)): ((1, 1), (16, 1), (768, 1)),
- str((32768, 1024, 4096, "float32", False, True)): ((1, 1), (8, 1), (2048, 1)),
- str((65536, 1024, 4096, "float32", False, True)): ((1, 1), (8, 1), (2048, 1)),
- str((131072, 1024, 1024, "float32", False, True)): ((2, 1), (8, 1), (1024, 1)),
- str((131072, 768, 3072, "float32", False, True)): ((1, 1), (8, 1), (3072, 1)),
- str((65536, 3072, 768, "float32", False, True)): ((1, 1), (16, 1), (768, 1)),
- str((32768, 4096, 1024, "float32", False, True)): ((2, 1), (8, 1), (1024, 1)),
- str((65536, 4096, 1024, "float32", False, True)): ((2, 1), (8, 1), (1024, 1)),
- str((10240, 21128, 1024, "float32", False, True)): ((1, 1), (16, 1), (1024, 1)),
- str((21128, 1024, 20480, "float32", True, False)): ((4, 1), (1024, 1), (4, 1)),
- str((2048, 3072, 768, "float16", False, False)): ((1, 1), (768, 1), (16, 1)),
- str((2048, 768, 3072, "float16", False, False)): ((2, 1), (768, 1), (8, 1)),
- str((10240, 21128, 768, "float32", False, True)): ((1, 1), (16, 1), (768, 1)),
- str((20480, 768, 21128, "float32", False, False)): ((2, 1), (768, 1), (8, 1)),
- str((21128, 768, 20480, "float32", True, False)): ((8, 1), (768, 1), (8, 1)),
- str((20480, 21128, 768, "float32", False, True)): ((1, 1), (16, 1), (768, 1)),
- str((32, 10, 2048, "float32", False, True)): ((1, 1), (8, 1), (2048, 1)),
- str((32, 10, 2048, "float16", False, True)): ((1, 1), (8, 1), (2048, 1)),
- str((32, 10, 4096, "float16", False, True)): ((1, 1), (2, 1), (4096, 1)),
- str((768, 768, 16384, "float16", True, False)): ((64, 1), (32, 1), (8, 1)),
- str((3072, 768, 16384, "float16", True, False)): ((64, 1), (32, 1), (8, 1)),
- str((768, 3072, 16384, "float16", True, False)): ((64, 1), (32, 1), (8, 1)),
- str((32, 9216, 4096, "float32", False, False)): ((1, 1), (9216, 1), (1, 1)),
- str((128, 128, 64, 3072, "float32", True, False)): ((1, 1), (64, 1), (64, 1), (1, 1)),
- str((21128, 1024, 10240, "float32", True, False)): ((8, 1), (1024, 1), (1, 1)),
- str((128, 128, 64, 1536, "float16", True, False)): ((1, 1), (1, 1), (16, 1), (1536, 1)),
- str((3072, 768, 2048, "float16", True, False)): ((1, 1), (16, 1), (1024, 1)),
- str((768, 3072, 2048, "float16", True, False)): ((1, 1), (16, 1), (1024, 1)),
- str((4096, 1024, 65536, "float32", True, False)): ((4, 1), (1024, 1), (4, 1)), # auto tiling crash
- str((1024, 4096, 65536, "float32", True, False)): ((1, 1), (4096, 1), (4, 1)), # auto tiling crash
- str((16384, 3072, 768, "float16", False, True)): ((1, 1), (16, 1), (768, 1)), # auto tiling crash
- str((16384, 768, 3072, "float16", False, True)): ((1, 1), (16, 1), (1024, 1)), # auto tiling crash
- str((1024, 4096, 1024, "float16", False, True)): ((1, 1), (16, 1), (1024, 1)), # auto tiling crash
- str((131072, 1024, 1024, "float32", False, False)): ((2, 1), (1024, 1), (8, 1)),
-
- # Alexnet shape
- str((32, 10, 4096, "float32", False, True)): ((32, 1), (10, 1), (1, 1)), # auto tiling crash
- str((32, 4096, 4096, "float16", False, False)): ((1, 1), (16, 1), (512, 1)), # auto tiling crash
- str((32, 9216, 4096, "float16", False, False)): ((1, 1), (16, 1), (512, 1)), # auto tiling crash
- str((32, 4096, 9216, "float16", False, True)): ((1, 1), (16, 1), (1024, 1)), # auto tiling crash
- str((128, 128, 64, 1536, "float16", True, False)): ((1, 1), (1, 1), (16, 1), (96, 1)), # auto tiling crash
- str((32, 4096, 4096, "float32", False, False)): ((1, 1), (256, 1), (64, 1)), # performance optimization
- # Alexnet shape
-
- str((768, 3072, 16384, "float16", True, False)): ((64, 1), (32, 1), (8, 1)),
-
- str((3072, 768, 4096, "float16", True, False)): ((2, 1), (768, 1), (8, 1)),
- str((768, 3072, 4096, "float16", True, False)): ((2, 1), (3072, 1), (2, 1)),
-
- str((400, 120, 32, "float16", True, False)): ((2, 1), (32, 1), (2, 1)),
- }
-
- CORE_NUM = 2 if product_is_mini() else 32
- MINIMAL_FOR_MULTICORE = CORE_NUM * 512
-
-
- def get_best_align_elem(tensor_size, tensor_dtype):
- """Get the best tiling factor for alignment axis."""
- basic_align_elem = int(ct_util.BLOCK_SIZE / get_bytes(tensor_dtype))
- lcm = least_common_multiple(tensor_size, basic_align_elem)
- gcd = greatest_common_divisor(tensor_size, basic_align_elem)
- if gcd != 1:
- return gcd
- if lcm < tensor_size:
- return min(tensor_size, lcm)
- return -1
-
-
- def get_shape_pos_map(tensor_shape):
- """Mapping tensor shape to corresponding axis position."""
- batch_pos = [i for i in range(len(tensor_shape) - 3) if tensor_shape[i] != 1]
- mnk = dict()
- pos_map = {0: "m", 1: "n", 2: "k"}
- count = -1
- for i, _ in enumerate(batch_pos):
- count += 1
- mnk["b%s" % str(i)] = count
- for i, shp in enumerate(tensor_shape[-3:]):
- if shp != 1:
- count += 1
- mnk[pos_map[i]] = count
- return batch_pos, mnk
-
-
- def batchmatmul_tiling_strategy(shape, align_dtype, attrs):
- """This is an efficient version of tiling strategy for batchmatmul."""
- if len(shape) < 3:
- raise RuntimeError("Shape must be in the form of [(Batch_out, Batch_in,) M, N, K]. "
- "Current length of shape is {}".format(len(shape)))
- strategy = list()
- m, n, k = shape[-3:]
- batch_pos, mnk = get_shape_pos_map(shape)
- total_size = reduce(lambda x, y: int(x) * int(y), shape)
-
- # set minimal tile for n as the basic block size for alignment
- align_elem = int(ct_util.BLOCK_SIZE / get_bytes(align_dtype))
- tile_n = get_best_align_elem(n, align_dtype)
-
- # if n is smaller than block size, it is not safe to open multi-core for now
- if n < align_elem:
- n_constraint = [ct_util.TileConstraint.FACTOR]
- used_core = CORE_NUM
- attrs["enable_multicore"] = 0
- else:
- n_constraint = [ct_util.TileConstraint.MOD, ct_util.TileConstraint.MIN]
- used_core = 1
-
- if total_size >= MINIMAL_FOR_MULTICORE and used_core < CORE_NUM:
- # set maximal tile for batch according to multi-core usage
- for i, p in enumerate(batch_pos):
- use = min(shape[p], int((CORE_NUM - 1 + used_core) / used_core))
- used_core *= use
- max_b = int(shape[p] / use)
- strategy.append(ct_util.create_constraint_on_axis(values=max_b,
- constraints=ct_util.TileConstraint.MAX,
- band=0,
- axis=i)[0])
-
- tile_k = get_best_align_elem(k, align_dtype)
- k_constraint = ct_util.TileConstraint.MIN
- total_size /= max(1, int(k / tile_k))
-
- # set minimal tile for m according to multi-core usage when there is no expansion
- tile_m = -1
- m_constraint = ct_util.TileConstraint.MAX
- m_per_block = m
- max_core = min(CORE_NUM, int(total_size / MINIMAL_FOR_MULTICORE))
- if greatest_common_divisor(n, align_elem) != 1 and used_core < max_core:
- left_core = int((max_core - 1 + used_core) / used_core)
- core_limit = max(1, int(m / greatest_common_divisor(left_core, m)))
- nk_in_mem = int(n / max(1, tile_n)) * int(k / tile_k)
- balance_limit = max(1, int(m / greatest_common_divisor(nk_in_mem, m)))
- tile_m = min(core_limit, balance_limit)
- m_per_block = int(m / tile_m)
-
- # for large m case, it is more efficient to balance memory bound and calculation bound
- if m_per_block > int(n / max(1, tile_n)) * int(k / tile_k):
- tile_m = max(min(m, align_elem), tile_m)
- k_constraint = ct_util.TileConstraint.FACTOR
-
- # create constraints based on previous analysis
- if m != 1:
- strategy.append(ct_util.create_constraint_on_axis(values=tile_m,
- constraints=m_constraint,
- band=0,
- axis=mnk["m"])[0])
- if n != 1:
- for constraint in n_constraint:
- strategy.append(ct_util.create_constraint_on_axis(values=tile_n,
- constraints=constraint,
- band=0,
- axis=mnk["n"])[0])
-
- if k != 1:
- strategy.append(ct_util.create_constraint_on_axis(values=tile_k,
- constraints=k_constraint,
- band=0,
- axis=mnk["k"])[0])
- higher_priority_pos = mnk["k"] if k >= n else mnk["n"]
- strategy.append(ct_util.create_constraint_on_axis(values=0,
- constraints=ct_util.TileConstraint.SET_PRIORITY,
- band=0,
- axis=higher_priority_pos)[0])
- strategy.append(ct_util.modify_common_constraints(0.7, ct_util.TileConstraint.SET_MEM_RATIO))
- attrs["custom_tiling"] = strategy
- return attrs
-
- def batchmatmul_tiling_strategy_dynamic(shape, output, attrs):
- """This is an efficient version of tiling strategy for batchmatmul."""
- if len(shape) < 3:
- raise RuntimeError("Shape must be in the form of [(Batch_out, Batch_in,) M, N, K]. "
- "Current length of shape is {}".format(len(shape)))
- strategy = list()
- _, mnk = get_shape_pos_map(shape)
- # create constraints based on previous analysis
- strategy.append(ct_util.create_constraint_on_axis(values=1,
- constraints=ct_util.TileConstraint.FACTOR,
- band=0,
- axis=mnk["m"])[0])
- strategy.append(ct_util.create_constraint_on_axis(values="FULL",
- constraints=ct_util.TileConstraint.MAX,
- band=0,
- axis=mnk["n"])[0])
- strategy.append(ct_util.create_constraint_on_axis(values=8,
- constraints=ct_util.TileConstraint.FACTOR,
- band=0,
- axis=mnk["k"])[0])
- strategy.append(ct_util.modify_common_constraints(0.7, ct_util.TileConstraint.SET_MEM_RATIO))
-
- attrs["custom_tiling"] = strategy
- attrs["dynamic_shape"] = ds.set_dynamic_shape_limit_for_tensor(output, 2048, [1,])
- return attrs
-
- def get_mnk_from_matrix(shape_a_list, shape_b_list, trans_a, trans_b):
- """Get m, n and k value from input tensor shapes."""
- m, k = shape_a_list[-2], shape_a_list[-1]
- if trans_a:
- m, k = k, m
-
- n = shape_b_list[-2] if trans_b else shape_b_list[-1]
- return [m, n, k]
-
-
- def batchmatmul_set_dim(a_value, b_value, trans_a, trans_b):
- """This function is used to set dim info in attrs by set_dim_map."""
- shape_a_list = get_shape(a_value)
- shape_b_list = get_shape(b_value)
- m, n, k = get_mnk_from_matrix(shape_a_list, shape_b_list, trans_a, trans_b)
-
- key = ()
- if len(shape_a_list) > 2:
- key += tuple(shape_a_list[:-2])
-
- key += (m, n, k, a_value.dtype, trans_a, trans_b)
- set_dims = ct_util.set_dims_by_key(str(key), batchmatmul_set_dim_map)
-
- return set_dims, str(key)
-
-
- def batchmatmul_bias_set_dim(a_value, b_value, bias_value, trans_a, trans_b):
- """This function is used to set dim info in attrs by set_dim_map of batchmatmul with bias."""
- return batchmatmul_set_dim(a_value, b_value, trans_a, trans_b)
-
-
- def batchmatmul_no_bias_set_dim(a_value, b_value, trans_a, trans_b):
- """This function is used to set dim info in attrs by set_dim_map of batchmatmul without bias."""
- return batchmatmul_set_dim(a_value, b_value, trans_a, trans_b)
-
-
- @ct_util.reg_set_dim_func(batchmatmul_bias_set_dim)
- @vc_util.check_input_type(akg.tvm.tensor.Tensor, akg.tvm.tensor.Tensor, akg.tvm.tensor.Tensor, bool, bool)
- def batchmatmul_bias(a_value, b_value, bias_value, trans_a, trans_b):
- """
- Multiplies two tensors in batches and adds bias to the output.
-
- Args:
- a_value (tvm.tensor.Tensor): 2D or 3D or 4D tensor of type float16 or float32 with shape(..., r_A, c_A).
- b_value (tvm.tensor.Tensor): 2D or 3D or 4D tensor of same type as a_value with shape(..., r_B, c_B).
- bias_value (tvm.tensor.Tensor): The bias tensor added to the result of a_value * b_value.
- Should be of same type as a_value, broadcast is allowed.
- trans_a (bool): Specifies whether a_value is transposed or not, default value is False.
- trans_b (bool): Specifies whether b_value is transposed or not, default value is False.
-
- Returns:
- tvm.tensor.Tensor of same type as a_value with shape(..., r_C, c_C).
- r_C = c_A if trans_a else r_A
- c_C = r_B if trans_b else c_B
- """
-
- if not isinstance(trans_a, bool):
- raise TypeError("trans_a should be of type Boolean.")
- if not isinstance(trans_b, bool):
- raise TypeError("trans_b should be of type Boolean.")
- vc_util.ops_dtype_check([a_value.dtype, b_value.dtype, bias_value.dtype], vc_util.DtypeForDavinci.ALL_FLOAT)
- vc_util.elemwise_dtype_check(a_value.dtype, b_value.dtype)
- vc_util.elemwise_dtype_check(a_value.dtype, bias_value.dtype)
- vc_util.gemm_format_check(get_shape(a_value), get_shape(b_value), trans_a, trans_b)
- if len(a_value.shape) not in [2, 3, 4]:
- raise ValueError("Batch matmul only support 2D, 3D and 4D now.")
-
- c_value = batchmatmul(a_value, b_value, trans_a, trans_b)
- if isinstance(c_value, (tuple, list)):
- c_value = c_value[0]
-
- vc_util.auto_broadcast_check(get_shape(bias_value), get_shape(c_value))
-
- shape_c_list = get_shape(c_value)
- bias_value = akg.topi.broadcast_to(bias_value, shape_c_list)
- dim_info = batchmatmul_bias_set_dim(a_value, b_value, bias_value, trans_a, trans_b)
- if isinstance(dim_info, (tuple, list)):
- dim_info = dim_info[0]
- attrs = {}
- attrs["enable_compute_in_place"] = True
- if dim_info != "":
- attrs["dim"] = dim_info
- batch = get_shape(a_value)[:-2]
- mnk = get_mnk_from_matrix(get_shape(a_value), get_shape(b_value), trans_a, trans_b)
- attrs = batchmatmul_tiling_strategy(batch + mnk, c_value.dtype, attrs)
- return akg.tvm.compute(bias_value.shape,
- lambda *indice: c_value(*indice) + bias_value(*indice), name='matmul_bias_output'), attrs
-
-
- @ct_util.reg_set_dim_func(batchmatmul_no_bias_set_dim)
- @vc_util.check_input_type(akg.tvm.tensor.Tensor, akg.tvm.tensor.Tensor, bool, bool)
- def batchmatmul(a_value, b_value, trans_a=False, trans_b=False):
- """
- Multiplies two tensors in batches.
-
- Args:
- a_value (tvm.tensor.Tensor): 2D or 3D or 4D tensor of type float16 or float32 with shape(..., r_A, c_A).
- b_value (tvm.tensor.Tensor): 2D or 3D or 4D tensor of same type as a_value with shape(..., r_B, c_B).
- trans_a (bool): Specifies whether a_value is transposed or not, default value is False.
- trans_b (bool): Specifies whether b_value is transposed or not, default value is False.
-
- Returns:
- tvm.tensor.Tensor of same type as a_value with shape(..., r_C, c_C).
- r_C = c_A if trans_a else r_A
- c_C = r_B if trans_b else c_B
- """
-
- if not isinstance(trans_a, bool):
- raise TypeError("trans_a should be of type Boolean.")
- if not isinstance(trans_b, bool):
- raise TypeError("trans_b should be of type Boolean.")
- vc_util.ops_dtype_check([a_value.dtype, b_value.dtype], vc_util.DtypeForDavinci.ALL_FLOAT)
- vc_util.elemwise_dtype_check(a_value.dtype, b_value.dtype)
- vc_util.gemm_format_check(get_shape(a_value), get_shape(b_value), trans_a, trans_b)
- if len(a_value.shape) not in [2, 3, 4]:
- raise ValueError("Batch matmul only support 2D, 3D and 4D now.")
-
- dtype = a_value.dtype
- if dtype == 'float16':
- if len(a_value.shape) == 2:
- c_value = vectormatmul_2d_cast(a_value, b_value, trans_a, trans_b, "float32")
- elif len(a_value.shape) == 3:
- c_value = vectormatmul_3d_cast(a_value, b_value, trans_a, trans_b, "float32")
- else:
- c_value = vectormatmul_4d_cast(a_value, b_value, trans_a, trans_b, "float32")
-
- else:
- if len(a_value.shape) == 2:
- c_value = vectormatmul_2d(a_value, b_value, trans_a, trans_b)
- elif len(a_value.shape) == 3:
- c_value = vectormatmul_3d(a_value, b_value, trans_a, trans_b)
- else:
- c_value = vectormatmul_4d(a_value, b_value, trans_a, trans_b)
-
- dim_info = batchmatmul_no_bias_set_dim(a_value, b_value, trans_a, trans_b)
- if isinstance(dim_info, (tuple, list)):
- dim_info = dim_info[0]
- attrs = {}
- attrs["enable_compute_in_place"] = True
- if dim_info != "":
- attrs["dim"] = dim_info
-
- mnk = get_mnk_from_matrix(get_shape(a_value), get_shape(b_value), trans_a, trans_b)
- batch = get_shape(a_value)[:-2]
- is_dynamic = ds.shape_is_dynamic([a_value, b_value])
- if not is_dynamic:
- attrs = batchmatmul_tiling_strategy(batch + mnk, c_value.dtype, attrs)
- else:
- attrs = batchmatmul_tiling_strategy_dynamic(batch + mnk, c_value, attrs)
- attrs["enable_pre_storage_write_simplify"] = True
- attrs["enable_sink_allocate"] = True
- attrs["enable_double_buffer"] = False
-
- return c_value, attrs
-
-
- def vectormatmul_3d(a_value, b_value, trans_a, trans_b):
- """hybrid implementation for 3D batchmatmul."""
- if trans_a:
- bs, k, m = a_value.shape
- else:
- bs, m, k = a_value.shape
- if trans_b:
- n = b_value.shape[-2]
- else:
- n = b_value.shape[-1]
-
- dtype = a_value.dtype
- zero = akg.tvm.const(0.0, dtype=dtype)
-
- @script(capture=locals())
- def matmul_hybrid_f_f(a, b, zero):
- t_1 = allocate((bs, m, k, n), a.dtype, 'local')
- t_2 = allocate((bs, m, n), a.dtype, 'local')
- for i_bs in range(0, bs):
- for i_m in range(0, m):
- for i_k in range(0, k):
- for i_n in range(0, n):
- t_1[i_bs, i_m, i_k, i_n] = a[i_bs, i_m, i_k] * b[i_bs, i_k, i_n]
- for i1_n in range(0, n):
- t_2[i_bs, i_m, i1_n] = zero
- for i1_k in range(0, k):
- for i1_n in range(0, n):
- t_2[i_bs, i_m, i1_n] = t_2[i_bs, i_m, i1_n] + t_1[i_bs, i_m, i1_k, i1_n]
- return t_2
-
- @script(capture=locals())
- def matmul_hybrid_f_t(a, b, zero):
- t_1 = allocate((bs, m, n, k), a.dtype, 'local')
- t_2 = allocate((bs, m, n), a.dtype, 'local')
- for i_bs in range(0, bs):
- for i_m in range(0, m):
- for i_n in range(0, n):
- t_2[i_bs, i_m, i_n] = zero
- for i_k in range(0, k):
- t_1[i_bs, i_m, i_n, i_k] = a[i_bs, i_m, i_k] * b[i_bs, i_n, i_k]
- t_2[i_bs, i_m, i_n] = t_1[i_bs, i_m, i_n, i_k] + t_2[i_bs, i_m, i_n]
- return t_2
-
- @script(capture=locals())
- def matmul_hybrid_t_f(a, b, zero):
- t_1 = allocate((bs, m, k, n), a.dtype, 'local')
- t_2 = allocate((bs, m, n), a.dtype, 'local')
- for i_bs in range(0, bs):
- for i_m in range(0, m):
- for i_k in range(0, k):
- for i_n in range(0, n):
- t_1[i_bs, i_m, i_k, i_n] = a[i_bs, i_k, i_m] * b[i_bs, i_k, i_n]
- for i1_n in range(0, n):
- t_2[i_bs, i_m, i1_n] = zero
- for i1_k in range(0, k):
- for i1_n in range(0, n):
- t_2[i_bs, i_m, i1_n] = t_2[i_bs, i_m, i1_n] + t_1[i_bs, i_m, i1_k, i1_n]
- return t_2
-
- if not trans_a and not trans_b:
- c_value = matmul_hybrid_f_f(a_value, b_value, zero)
- elif not trans_a and trans_b:
- c_value = matmul_hybrid_f_t(a_value, b_value, zero)
- elif trans_a and not trans_b:
- c_value = matmul_hybrid_t_f(a_value, b_value, zero)
- else:
- raise ValueError('Not support both transpose yet')
-
- return c_value
-
-
- def vectormatmul_4d(a_value, b_value, trans_a, trans_b):
- """hybrid implementation for 4D batchmatmul."""
- if trans_a:
- bs1, bs2, k, m = a_value.shape
- else:
- bs1, bs2, m, k = a_value.shape
- if trans_b:
- n = b_value.shape[-2]
- else:
- n = b_value.shape[-1]
-
- dtype = a_value.dtype
- zero = akg.tvm.const(0.0, dtype=dtype)
-
- @script(capture=locals())
- def matmul_hybrid_f_f(a, b, zero):
- t_1 = allocate((bs1, bs2, m, k, n), a.dtype, 'local')
- t_2 = allocate((bs1, bs2, m, n), a.dtype, 'local')
- for i_bs1 in range(0, bs1):
- for i_bs2 in range(0, bs2):
- for i_m in range(0, m):
- for i_k in range(0, k):
- for i_n in range(0, n):
- t_1[i_bs1, i_bs2, i_m, i_k, i_n] = a[i_bs1, i_bs2, i_m, i_k] * b[i_bs1, i_bs2, i_k, i_n]
- for i1_n in range(0, n):
- t_2[i_bs1, i_bs2, i_m, i1_n] = zero
- for i1_k in range(0, k):
- for i1_n in range(0, n):
- t_2[i_bs1, i_bs2, i_m, i1_n] = t_2[i_bs1, i_bs2, i_m, i1_n] + \
- t_1[i_bs1, i_bs2, i_m, i1_k, i1_n]
- return t_2
-
- @script(capture=locals())
- def matmul_hybrid_f_t(a, b, zero):
- t_1 = allocate((bs1, bs2, m, n, k), a.dtype, 'local')
- t_2 = allocate((bs1, bs2, m, n), a.dtype, 'local')
- for i_bs1 in range(0, bs1):
- for i_bs2 in range(0, bs2):
- for i_m in range(0, m):
- for i_n in range(0, n):
- t_2[i_bs1, i_bs2, i_m, i_n] = zero
- for i_k in range(0, k):
- t_1[i_bs1, i_bs2, i_m, i_n, i_k] = a[i_bs1, i_bs2, i_m, i_k] * b[i_bs1, i_bs2, i_n, i_k]
- t_2[i_bs1, i_bs2, i_m, i_n] = t_1[i_bs1, i_bs2, i_m, i_n, i_k] + t_2[i_bs1, i_bs2, i_m, i_n]
- return t_2
-
- @script(capture=locals())
- def matmul_hybrid_t_f(a, b, zero):
- t_1 = allocate((bs1, bs2, m, k, n), a.dtype, 'local')
- t_2 = allocate((bs1, bs2, m, n), a.dtype, 'local')
- for i_bs1 in range(0, bs1):
- for i_bs2 in range(0, bs2):
- for i_m in range(0, m):
- for i_k in range(0, k):
- for i_n in range(0, n):
- t_1[i_bs1, i_bs2, i_m, i_k, i_n] = a[i_bs1, i_bs2, i_k, i_m] * b[i_bs1, i_bs2, i_k, i_n]
- for i1_n in range(0, n):
- t_2[i_bs1, i_bs2, i_m, i1_n] = zero
- for i1_k in range(0, k):
- for i1_n in range(0, n):
- t_2[i_bs1, i_bs2, i_m, i1_n] = t_2[i_bs1, i_bs2, i_m, i1_n] + \
- t_1[i_bs1, i_bs2, i_m, i1_k, i1_n]
- return t_2
-
- if not trans_a and not trans_b:
- c_value = matmul_hybrid_f_f(a_value, b_value, zero)
- elif not trans_a and trans_b:
- c_value = matmul_hybrid_f_t(a_value, b_value, zero)
- elif trans_a and not trans_b:
- c_value = matmul_hybrid_t_f(a_value, b_value, zero)
- else:
- raise ValueError('Not support both transpose yet')
-
- return c_value
-
-
- def vectormatmul_4d_cast(a_value, b_value, trans_a, trans_b, cast_dtype):
- """dsl implementation with data type cast for 4D batchmatmul."""
- if trans_a:
- b1, b2, k, m = a_value.shape
- else:
- b1, b2, m, k = a_value.shape
- if trans_b:
- n = b_value.shape[-2]
- else:
- n = b_value.shape[-1]
-
- dtype = a_value.dtype
-
- def matmul_4d_dsl(a_value, b_value, trans_a, trans_b):
- if not trans_a and not trans_b:
- ele_mul = akg.tvm.compute((b1, b2, m, n, k),
- lambda i_b1, i_b2, i_m, i_n, i_k:
- a_value[i_b1, i_b2, i_m, i_k].astype(cast_dtype) *
- b_value[i_b1, i_b2, i_k, i_n].astype(cast_dtype),
- name="ele_mul")
- elif not trans_a and trans_b:
- ele_mul = akg.tvm.compute((b1, b2, m, n, k),
- lambda i_b1, i_b2, i_m, i_n, i_k:
- a_value[i_b1, i_b2, i_m, i_k].astype(cast_dtype) *
- b_value[i_b1, i_b2, i_n, i_k].astype(cast_dtype),
- name="ele_mul")
- elif trans_a and not trans_b:
- ele_mul = akg.tvm.compute((b1, b2, m, n, k),
- lambda i_b1, i_b2, i_m, i_n, i_k:
- a_value[i_b1, i_b2, i_k, i_m].astype(cast_dtype) *
- b_value[i_b1, i_b2, i_k, i_n].astype(cast_dtype),
- name="ele_mul")
- elif trans_a and trans_b:
- ele_mul = akg.tvm.compute((b1, b2, m, n, k),
- lambda i_b1, i_b2, i_m, i_n, i_k:
- b_value[i_b1, i_b2, i_n, i_k].astype(cast_dtype) *
- a_value[i_b1, i_b2, i_k, i_m].astype(cast_dtype),
- name="ele_mul")
- reduce_axis = akg.tvm.reduce_axis((0, k), name='reduce_axis')
- output_shape = (b1, b2, m, n)
- m_c = akg.tvm.compute(output_shape, lambda *i: akg.tvm.sum(ele_mul[i + (reduce_axis,)],
- axis=reduce_axis), name="matmul_compute")
- return m_c
-
- c_cast = matmul_4d_dsl(a_value, b_value, trans_a, trans_b)
- c_value = cast.cast(c_cast, dtype)
- if trans_a and trans_b:
- c_res = akg.topi.transpose(c_value, (1, 0))
- return c_res
- return c_value
-
-
- def vectormatmul_3d_cast(a_value, b_value, trans_a, trans_b, cast_dtype):
- """dsl implementation with data type cast for 3D batchmatmul."""
- if trans_a:
- b, k, m = a_value.shape
- else:
- b, m, k = a_value.shape
- if trans_b:
- n = b_value.shape[-2]
- else:
- n = b_value.shape[-1]
-
- dtype = a_value.dtype
-
- def matmul_3d_dsl(a_value, b_value, trans_a, trans_b):
- if not trans_a and not trans_b:
- ele_mul = akg.tvm.compute((b, m, n, k), lambda i_b, i_m, i_n, i_k:
- a_value[i_b, i_m, i_k].astype(cast_dtype) *
- b_value[i_b, i_k, i_n].astype(cast_dtype),
- name="ele_mul")
- elif not trans_a and trans_b:
- ele_mul = akg.tvm.compute((b, m, n, k), lambda i_b, i_m, i_n, i_k:
- a_value[i_b, i_m, i_k].astype(cast_dtype) *
- b_value[i_b, i_n, i_k].astype(cast_dtype),
- name="ele_mul")
- elif trans_a and not trans_b:
- ele_mul = akg.tvm.compute((b, m, n, k), lambda i_b, i_m, i_n, i_k:
- a_value[i_b, i_k, i_m].astype(cast_dtype) *
- b_value[i_b, i_k, i_n].astype(cast_dtype),
- name="ele_mul")
- elif trans_a and trans_b:
- ele_mul = akg.tvm.compute((b, m, n, k), lambda i_b, i_m, i_n, i_k:
- b_value[i_b, i_n, i_k].astype(cast_dtype) *
- a_value[i_b, i_k, i_m].astype(cast_dtype),
- name="ele_mul")
- reduce_axis = akg.tvm.reduce_axis((0, k), name='reduce_axis')
- output_shape = (b, m, n)
- m_c = akg.tvm.compute(output_shape, lambda *i: akg.tvm.sum(ele_mul[i + (reduce_axis,)], axis=reduce_axis),
- name="matmul_compute")
- return m_c
-
- c_cast = matmul_3d_dsl(a_value, b_value, trans_a, trans_b)
- c_value = cast.cast(c_cast, dtype)
- if trans_a and trans_b:
- c_res = akg.topi.transpose(c_value, (1, 0))
- return c_res
- return c_value
-
-
- def vectormatmul_2d_cast(a_value, b_value, trans_a, trans_b, cast_dtype):
- """hybrid implementation with data type cast for 2D batchmatmul."""
- if trans_a:
- k, m = a_value.shape
- else:
- m, k = a_value.shape
- if trans_b:
- n = b_value.shape[-2]
- else:
- n = b_value.shape[-1]
-
- dtype = a_value.dtype
-
- # When the float16 cast to float32 directly, the AutoPoly pass cost a long time.
- # Therefore, the cast be done in single element.
- def matmul_2d(a_value, b_value, trans_a, trans_b):
- if not trans_a and not trans_b:
- ele_mul = akg.tvm.compute((m, n, k), lambda i_m, i_n, i_k: a_value[i_m, i_k].astype(cast_dtype)
- * b_value[i_k, i_n].astype(cast_dtype),
- name="ele_mul")
- elif not trans_a and trans_b:
- ele_mul = akg.tvm.compute((m, n, k), lambda i_m, i_n, i_k: a_value[i_m, i_k].astype(cast_dtype) *
- b_value[i_n, i_k].astype(cast_dtype),
- name="ele_mul")
- elif trans_a and not trans_b:
- ele_mul = akg.tvm.compute((m, n, k), lambda i_m, i_n, i_k: a_value[i_k, i_m].astype(cast_dtype) *
- b_value[i_k, i_n].astype(cast_dtype),
- name="ele_mul")
- elif trans_a and trans_b:
- ele_mul = akg.tvm.compute((m, n, k), lambda i_m, i_n, i_k: b_value[i_n, i_k].astype(cast_dtype) *
- a_value[i_k, i_m].astype(cast_dtype),
- name="ele_mul")
- reduce_axis = akg.tvm.reduce_axis((0, k), name='reduce_axis')
- output_shape = (m, n)
- m_c = akg.tvm.compute(output_shape, lambda *i: akg.tvm.sum(ele_mul[i + (reduce_axis,)], axis=reduce_axis),
- name="matmul_compute")
- return m_c
-
- c_cast = matmul_2d(a_value, b_value, trans_a, trans_b)
- c_value = cast.cast(c_cast, dtype)
- if trans_a and trans_b:
- c_res = akg.topi.transpose(c_value, (1, 0))
- return c_res
- return c_value
-
-
- def vectormatmul_2d(a_value, b_value, trans_a, trans_b):
- """hybrid implementation for 2D batchmatmul."""
- if trans_a:
- k, m = a_value.shape
- else:
- m, k = a_value.shape
- if trans_b:
- n = b_value.shape[-2]
- else:
- n = b_value.shape[-1]
-
- dtype = a_value.dtype
- zero = akg.tvm.const(0.0, dtype=dtype)
-
- @script(capture=locals())
- def matmul_hybrid_f_f(a, b, zero, mv, nv, kv):
- t_1 = allocate((mv, kv, nv), a.dtype, 'local')
- t_2 = output_tensor((mv, nv), a.dtype)
- for i_m in range(0, mv):
- for i_k in range(0, kv):
- for i_n in range(0, nv):
- t_1[i_m, i_k, i_n] = a[i_m, i_k] * b[i_k, i_n]
- for i1_n in range(0, nv):
- t_2[i_m, i1_n] = zero
- for i1_k in range(0, kv):
- for i1_n in range(0, nv):
- t_2[i_m, i1_n] = t_2[i_m, i1_n] + t_1[i_m, i1_k, i1_n]
- return t_2
-
- @script(capture=locals())
- def matmul_hybrid_f_t(a, b, zero, mv, nv, kv):
- t_1 = allocate((mv, nv, kv), a.dtype, 'local')
- t_2 = allocate((mv, nv), a.dtype, 'local')
- for i_m in range(0, mv):
- for i_n in range(0, nv):
- t_2[i_m, i_n] = zero
- for i_k in range(0, kv):
- t_1[i_m, i_n, i_k] = a[i_m, i_k] * b[i_n, i_k]
- t_2[i_m, i_n] = t_1[i_m, i_n, i_k] + t_2[i_m, i_n]
- return t_2
-
- @script(capture=locals())
- def matmul_hybrid_t_f(a, b, zero, mv, nv, kv):
- t_1 = allocate((mv, kv, nv), a.dtype, 'local')
- t_2 = allocate((mv, nv), a.dtype, 'local')
- for i_m in range(0, mv):
- for i_k in range(0, kv):
- for i_n in range(0, nv):
- t_1[i_m, i_k, i_n] = a[i_k, i_m] * b[i_k, i_n]
- for i1_n in range(0, nv):
- t_2[i_m, i1_n] = zero
- for i1_k in range(0, kv):
- for i1_n in range(0, nv):
- t_2[i_m, i1_n] = t_2[i_m, i1_n] + t_1[i_m, i1_k, i1_n]
- return t_2
-
- @script(capture=locals())
- def matmul_hybrid_t_t(a, b, zero, mv, nv, kv):
- t_1 = allocate((nv, kv, mv), a.dtype, 'local')
- t_2 = allocate((nv, mv), a.dtype, 'local')
- for i_n in range(0, nv):
- for i_m in range(0, mv):
- for i_k in range(0, kv):
- t_1[i_n, i_k, i_m] = b[i_n, i_k] * a[i_k, i_m]
- for i1_m in range(0, mv):
- t_2[i_n, i1_m] = zero
- for i1_k in range(0, kv):
- for i2_m in range(0, mv):
- t_2[i_n, i2_m] = t_2[i_n, i2_m] + t_1[i_n, i1_k, i2_m]
- return t_2
-
- if not trans_a and not trans_b:
- c_value = matmul_hybrid_f_f(a_value, b_value, zero, m, n, k)
- elif not trans_a and trans_b:
- c_value = matmul_hybrid_f_t(a_value, b_value, zero, m, n, k)
- elif trans_a and not trans_b:
- c_value = matmul_hybrid_t_f(a_value, b_value, zero, m, n, k)
- else:
- c1 = matmul_hybrid_t_t(a_value, b_value, zero, m, n, k)
- c_value = akg.topi.transpose(c1, (1, 0))
-
- return c_value
|