|
- #!/usr/bin/env python3
- # coding: utf-8
- # Copyright 2019 Huawei Technologies Co., Ltd
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
-
- """operator dsl function: conv"""
- import akg.tvm
- import akg
- import akg.lang.cce
- from akg import dim
-
- from akg.ops.math import cast
- from akg.utils import validation_check as vc_util
- from akg.utils.format_transform import get_shape
- from akg.utils.dynamic_shape import set_poly_upper_bound_for_tensor
-
- k_h_fake = 11
- k_w_fake = 31
- p_top_fake = 9
- p_bottom_fake = 8
- p_left_fake = 23
- p_right_fake = 21
- s_h_fake = 7
- s_w_fake = 17
- c1_cut_fake = 67
- tile_out_h_fake = 47
- tile_out_w_fake = 37
- m_cut_fake = 53 * 16
- k_cut_fake = 59 * 16
- n_cut_fake = 61 * 16
-
- conv_set_dim_map = {
- str(((1, 1024, 14, 14), (2048, 1024, 1, 1), (0, 0, 0, 0), (2, 2), (1, 1), True)):
- ([14, 2048, 64, 96, 128], {"bypass": 1}),
- str(((1, 1024, 14, 14), (256, 1024, 1, 1), (0, 0, 0, 0), (1, 1), (1, 1), True)):
- ([14, 256, 208, 64, 112], {"bypass": 1}),
- str(((1, 1024, 14, 14), (512, 1024, 1, 1), (0, 0, 0, 0), (2, 2), (1, 1), True)):
- ([14, 512, 49, 32, 512], {"bypass": 1}),
- str(((1, 128, 28, 28), (128, 128, 3, 3), (1, 1, 1, 1), (1, 1), (1, 1), False)):
- ([28, 128, 400, 32, 128], {"bypass": 1}),
- str(((1, 128, 28, 28), (128, 128, 3, 3), (1, 1, 1, 1), (1, 1), (2, 2), True)):
- ([28, 128, 400, 32, 128], {"bypass": 1}),
- str(((1, 128, 28, 28), (512, 128, 1, 1), (0, 0, 0, 0), (1, 1), (1, 1), True)):
- ([28, 512, 784, 16, 32], {"bypass": 1}),
- str(((1, 128, 28, 32), (128, 128, 3, 3), (1, 1, 1, 1), (1, 1), (1, 1), True)):
- ([28, 128, 16, 72 * 16, 16], {"bypass": 1}),
- str(((1, 128, 34, 34), (128, 128, 3, 3), (1, 1, 1, 1), (1, 1), (2, 2), True)):
- ([36, 128, 64, 32, 64], {"bypass": 1}),
- str(((1, 128, 36, 36), (128, 128, 3, 3), (0, 0, 0, 0), (1, 1), (2, 2), False)):
- ([36, 128, 64, 32, 64], {"bypass": 1}),
- str(((1, 16, 16, 16), (64, 16, 3, 3), (1, 1, 1, 1), (1, 1), (1, 1), True)):
- [16, 4 * 16, 16 * 16, 3 * 16, 4 * 16, 16 + 2],
- str(((1, 2048, 7, 7), (512, 2048, 1, 1), (0, 0, 0, 0), (1, 1), (1, 1), True)):
- ([7, 512, 49, 32, 512], {"bypass": 1}),
- str(((1, 256, 14, 14), (1024, 256, 1, 1), (0, 0, 0, 0), (1, 1), (1, 1), True)):
- ([14, 944, 112, 32, 240], {"bypass": 1}),
- str(((1, 256, 14, 14), (256, 256, 3, 3), (1, 1, 1, 1), (1, 1), (1, 1), False)):
- ([14, 256, 196, 64, 256], {"bypass": 1}),
- str(((1, 256, 14, 14), (256, 256, 3, 3), (1, 1, 1, 1), (1, 1), (2, 2), True)):
- ([14, 256, 196, 64, 256], {"bypass": 1}),
- str(((1, 256, 56, 56), (128, 256, 1, 1), (0, 0, 0, 0), (2, 2), (1, 1), True)):
- ([7, 128, 252, 64, 128], {"bypass": 1}),
- str(((1, 256, 56, 56), (512, 256, 1, 1), (0, 0, 0, 0), (2, 2), (1, 1), True)):
- ([7, 512, 196, 64, 256], {"bypass": 1}),
- str(((1, 256, 56, 56), (64, 256, 1, 1), (0, 0, 0, 0), (1, 1), (1, 1), True)):
- ([16, 64, 280, 16, 64], {"bypass": 1}),
- str(((1, 3, 224, 224), (64, 3, 7, 7), (2, 3, 2, 3), (2, 2), (1, 1), True)):
- ([117, 64, 448, 32, 64], {"bypass": 1}),
- str(((1, 3, 224, 224), (64, 3, 7, 7), (3, 3, 3, 3), (2, 2), (1, 1), True)):
- ([65, 64, 448, 32, 64], {"bypass": 1}),
- str(((1, 512, 14, 14), (512, 512, 3, 3), (1, 1, 1, 1), (2, 2), (1, 1), False)):
- ([14, 512, 64, 48, 128, 16], {"bypass": 1}),
- str(((1, 512, 28, 28), (1024, 512, 1, 1), (0, 0, 0, 0), (2, 2), (1, 1), True)):
- ([13, 1024, 112, 32, 512], {"bypass": 1}),
- str(((1, 512, 28, 28), (128, 512, 1, 1), (0, 0, 0, 0), (1, 1), (1, 1), True)):
- ([14, 128, 448, 16, 64], {"bypass": 1}),
- str(((1, 512, 28, 28), (256, 512, 1, 1), (0, 0, 0, 0), (2, 2), (1, 1), True)):
- ([11, 256, 98, 64, 256], {"bypass": 1}),
- str(((1, 512, 7, 7), (2048, 512, 1, 1), (0, 0, 0, 0), (1, 1), (1, 1), True)):
- ([7, 2048, 49, 16, 512], {"bypass": 1}),
- str(((1, 512, 7, 7), (512, 512, 3, 3), (1, 1, 1, 1), (1, 1), (1, 1), False)):
- ([7, 512, 49, 32, 512], {"bypass": 1}),
- str(((1, 512, 7, 7), (512, 512, 3, 3), (1, 1, 1, 1), (1, 1), (2, 2), True)):
- ([7, 512, 49, 32, 512], {"bypass": 1}),
- str(((1, 64, 112, 112), (64, 64, 1, 1), (0, 0, 0, 0), (2, 2), (1, 1), False)):
- ([56, 64, 784, 16, 32], {"bypass": 1}),
- str(((1, 64, 56, 56), (256, 64, 1, 1), (0, 0, 0, 0), (1, 1), (1, 1), True)):
- ([56, 256, 784, 16, 32], {"bypass": 1}),
- str(((1, 64, 56, 56), (64, 64, 1, 1), (0, 0, 0, 0), (1, 1), (1, 1), True)):
- ([56, 64, 784, 16, 32], {"bypass": 1}),
- str(((1, 64, 56, 56), (64, 64, 3, 3), (1, 1, 1, 1), (1, 1), (1, 1), False)):
- ([56, 64, 336, 16, 64], {"bypass": 1}),
- str(((2, 1024, 48, 72), (2048, 1024, 1, 1), (0, 0, 0, 0), (1, 1), (1, 1), False)):
- ([5, 128, 288, 16, 112, 72], {"bypass": 0}),
- str(((2, 1024, 48, 72), (256, 1024, 1, 1), (0, 0, 0, 0), (1, 1), (1, 1), False)):
- ([4, 128, 256, 96, 64, 72], {"bypass": 0}),
- str(((2, 1024, 48, 72), (512, 1024, 1, 1), (0, 0, 0, 0), (1, 1), (1, 1), False)):
- ([3, 512, 160, 160, 96, 72], {"bypass": 1}),
- str(((2, 128, 192, 288), (256, 128, 1, 1), (0, 0, 0, 0), (1, 1), (1, 1), False)):
- ([3, 128, 288, 48, 48, 288], {"bypass": 0}),
- str(((2, 128, 192, 288), (64, 128, 1, 1), (0, 0, 0, 0), (1, 1), (1, 1), False)):
- ([3, 64, 176, 80, 64, 288], {"bypass": 1}),
- str(((2, 128, 48, 72), (512, 128, 1, 1), (0, 0, 0, 0), (1, 1), (1, 1), False)):
- ([8, 256, 256, 16, 48, 72], {"bypass": 0}),
- str(((2, 128, 96, 144), (128, 128, 3, 3), (1, 1, 1, 1), (2, 2), (1, 1), False)):
- ([17, 16, 256, 96, 16, 145], {"bypass": 0}),
- str(((2, 128, 96, 144), (128, 128, 3, 3), (2, 2, 2, 2), (1, 1), (1, 1), False)):
- ([17, 32, 320, 48, 32, 148], {"bypass": 0}),
- str(((2, 128, 96, 144), (512, 128, 1, 1), (0, 0, 0, 0), (1, 1), (1, 1), False)):
- ([6, 256, 160, 16, 192, 144], {"bypass": 0}),
- str(((2, 1280, 48, 72), (256, 1280, 1, 1), (0, 0, 0, 0), (1, 1), (1, 1), False)):
- ([4, 64, 48, 160, 64, 72], {"bypass": 0}),
- str(((2, 16, 768, 1152), (64, 16, 3, 3), (1, 1, 1, 1), (2, 2), (1, 1), False)):
- ([25, 16, 256, 48, 16, 1153], {"bypass": 0}),
- str(((2, 2048, 1, 1), (256, 2048, 1, 1), (0, 0, 0, 0), (1, 1), (1, 1), False)):
- ([1, 64, 16, 176, 64, 1], {"bypass": 0}),
- str(((2, 2048, 48, 72), (256, 2048, 1, 1), (0, 0, 0, 0), (1, 1), (1, 1), False)):
- ([3, 256, 48, 432, 64, 72], {"bypass": 1}),
- str(((2, 2048, 48, 72), (512, 2048, 1, 1), (0, 0, 0, 0), (1, 1), (1, 1), False)):
- ([2, 512, 144, 64, 80, 72], {"bypass": 1}),
- str(((2, 256, 192, 288), (64, 256, 1, 1), (0, 0, 0, 0), (1, 1), (1, 1), False)):
- ([3, 64, 256, 128, 64, 288], {"bypass": 0}),
- str(((2, 256, 48, 72), (1024, 256, 1, 1), (0, 0, 0, 0), (1, 1), (1, 1), False)):
- ([6, 64, 144, 176, 16, 72], {"bypass": 0}),
- str(((2, 256, 48, 72), (256, 256, 3, 3), (2, 2, 2, 2), (1, 1), (1, 1), False)):
- ([19, 16, 240, 96, 16, 76], {"bypass": 0}),
- str(((2, 256, 48, 72), (3, 256, 1, 1), (0, 0, 0, 0), (1, 1), (1, 1), False)):
- ([6, 16, 16, 208, 16, 72], {"bypass": 0}),
- str(((2, 256, 96, 144), (128, 256, 1, 1), (0, 0, 0, 0), (1, 1), (1, 1), False)):
- ([96, 128, 128, 64, 64, 16], {"bypass": 0}),
- str(((2, 256, 96, 144), (512, 256, 1, 1), (0, 0, 0, 0), (1, 1), (1, 1), False)):
- ([8, 512, 336, 48, 112, 144], {"bypass": 0}),
- str(((2, 512, 48, 72), (1024, 512, 1, 1), (0, 0, 0, 0), (1, 1), (1, 1), False)):
- ([4, 256, 112, 32, 176, 72], {"bypass": 0}),
- str(((2, 512, 48, 72), (2048, 512, 1, 1), (0, 0, 0, 0), (1, 1), (1, 1), False)):
- ([4, 64, 288, 48, 32, 72], {"bypass": 0}),
- str(((2, 512, 48, 72), (256, 512, 1, 1), (0, 0, 0, 0), (1, 1), (1, 1), False)):
- ([4, 64, 96, 16, 32, 72], {"bypass": 0}),
- str(((2, 512, 96, 144), (128, 512, 1, 1), (0, 0, 0, 0), (1, 1), (1, 1), False)):
- ([4, 128, 272, 32, 96, 144], {"bypass": 0}),
- str(((2, 64, 192, 288), (256, 64, 1, 1), (0, 0, 0, 0), (1, 1), (1, 1), False)):
- ([6, 64, 192, 16, 64, 288], {"bypass": 0}),
- str(((2, 64, 192, 288), (64, 64, 3, 3), (1, 1, 1, 1), (2, 2), (1, 1), False)):
- ([193, 16, 352, 80, 16, 33], {"bypass": 0}),
- str(((2, 64, 192, 288), (64, 64, 3, 3), (2, 2, 2, 2), (1, 1), (1, 1), False)):
- ([9, 16, 1968, 16, 16, 292], {"bypass": 0}),
- str(((2, 64, 384, 576), (128, 64, 3, 3), (2, 2, 2, 2), (1, 1), (1, 1), False)):
- ([5, 16, 1600, 16, 16, 580], {"bypass": 0}),
- str(((2, 64, 384, 576), (64, 64, 3, 3), (2, 2, 2, 2), (1, 1), (1, 1), False)):
- ([388, 32, 448, 32, 32, 18], {"bypass": 0}),
- str(((2, 64, 96, 144), (256, 64, 1, 1), (0, 0, 0, 0), (1, 1), (1, 1), False)):
- ([96, 256, 224, 32, 96, 16], {"bypass": 1}),
- str(((32, 1024, 14, 14), (2048, 1024, 1, 1), (0, 0, 0, 0), (2, 2), (1, 1), False)):
- ([13, 2048, 64, 48, 336, 13], {"bypass": 1}),
- str(((32, 1024, 14, 14), (256, 1024, 1, 1), (0, 0, 0, 0), (1, 1), (1, 1), False)):
- ([14, 256, 112, 112, 128, 14], {"bypass": 1}),
- str(((32, 1024, 14, 14), (512, 1024, 1, 1), (0, 0, 0, 0), (1, 1), (1, 1), False)):
- ([14, 272, 160, 32, 176, 14], {"bypass": 0}),
- str(((32, 1024, 14, 14), (512, 1024, 1, 1), (0, 0, 0, 0), (2, 2), (1, 1), False)):
- ([13, 512, 64, 80, 112, 13], {"bypass": 1}),
- str(((32, 128, 28, 28), (128, 128, 3, 3), (1, 1, 1, 1), (1, 1), (1, 1), False)):
- ([18, 128, 336, 32, 80, 30], {"bypass": 0}),
- str(((32, 128, 28, 28), (512, 128, 1, 1), (0, 0, 0, 0), (1, 1), (1, 1), False)):
- ([4, 288, 112, 96, 272, 28], {"bypass": 0}),
- str(((32, 128, 56, 56), (128, 128, 3, 3), (0, 1, 0, 1), (2, 2), (1, 1), False)):
- ([57, 128, 288, 48, 128, 41], {"bypass": 0}),
- str(((32, 16, 33, 33), (64, 16, 3, 3), (0, 0, 0, 0), (2, 2), (1, 1), False)):
- [128, 128, 64, 128, 64],
- str(((32, 16, 34, 34), (64, 16, 3, 3), (0, 0, 0, 0), (1, 1), (1, 1), False)):
- [128, 128, 64, 128, 64],
- str(((32, 2048, 7, 7), (512, 2048, 1, 1), (0, 0, 0, 0), (1, 1), (1, 1), False)):
- ([7, 512, 64, 80, 176, 7], {"bypass": 1}),
- str(((32, 256, 14, 14), (1024, 256, 1, 1), (0, 0, 0, 0), (1, 1), (1, 1), False)):
- ([14, 352, 176, 128, 192, 14], {"bypass": 0}),
- str(((32, 256, 14, 14), (256, 256, 3, 3), (1, 1, 1, 1), (1, 1), (1, 1), False)):
- ([14, 192, 160, 96, 160, 16], {"bypass": 0}),
- str(((32, 256, 28, 28), (256, 256, 3, 3), (0, 1, 0, 1), (2, 2), (1, 1), False)):
- ([17, 256, 112, 112, 160, 29], {"bypass": 1}),
- str(((32, 256, 56, 56), (128, 256, 1, 1), (0, 0, 0, 0), (1, 1), (1, 1), False)):
- ([10, 128, 336, 32, 128, 56], {"bypass": 0}),
- str(((32, 256, 56, 56), (128, 256, 1, 1), (0, 0, 0, 0), (2, 2), (1, 1), False)):
- ([15, 128, 16, 256, 80, 55], {"bypass": 1}),
- str(((32, 256, 56, 56), (512, 256, 1, 1), (0, 0, 0, 0), (2, 2), (1, 1), False)):
- ([13, 384, 32, 32, 384, 55], {"bypass": 0}),
- str(((32, 256, 56, 56), (64, 256, 1, 1), (0, 0, 0, 0), (1, 1), (1, 1), False)):
- ([12, 64, 336, 16, 64, 56], {"bypass": 0}),
- str(((32, 3, 224, 224), (64, 3, 7, 7), (2, 3, 2, 3), (2, 2), (1, 1), False)):
- ([17, 64, 336, 48, 64, 229], {"bypass": 0}),
- str(((32, 3, 224, 224), (64, 3, 7, 7), (3, 3, 3, 3), (2, 2), (1, 1), False)):
- ([149, 64, 96, 35, 64, 117], {"bypass": 0}),
- str(((32, 512, 14, 14), (512, 512, 3, 3), (0, 1, 0, 1), (2, 2), (1, 1), False)):
- ([15, 512, 64, 32, 368, 15], {"bypass": 1}),
- str(((32, 512, 28, 28), (1024, 512, 1, 1), (0, 0, 0, 0), (2, 2), (1, 1), False)):
- ([27, 1024, 160, 112, 96, 27], {"bypass": 1}),
- str(((32, 512, 28, 28), (128, 512, 1, 1), (0, 0, 0, 0), (1, 1), (1, 1), False)):
- ([6, 128, 160, 16, 128, 28], {"bypass": 0}),
- str(((32, 512, 28, 28), (256, 512, 1, 1), (0, 0, 0, 0), (1, 1), (1, 1), False)):
- ([28, 128, 144, 160, 128, 28], {"bypass": 0}),
- str(((32, 512, 28, 28), (256, 512, 1, 1), (0, 0, 0, 0), (2, 2), (1, 1), False)):
- ([27, 256, 160, 144, 32, 27], {"bypass": 1}),
- str(((32, 512, 7, 7), (2048, 512, 1, 1), (0, 0, 0, 0), (1, 1), (1, 1), False)):
- ([7, 2048, 64, 48, 336, 7], {"bypass": 1}),
- str(((32, 512, 7, 7), (512, 512, 3, 3), (1, 1, 1, 1), (1, 1), (1, 1), False)):
- ([9, 512, 64, 256, 64, 9], {"bypass": 1}),
- str(((32, 64, 56, 56), (256, 64, 1, 1), (0, 0, 0, 0), (1, 1), (1, 1), False)):
- ([42, 176, 224, 48, 96, 56], {"bypass": 0}),
- str(((32, 64, 56, 56), (64, 64, 1, 1), (0, 0, 0, 0), (1, 1), (1, 1), False)):
- ([12, 64, 32, 64, 64, 56], {"bypass": 1}),
- str(((32, 64, 56, 56), (64, 64, 3, 3), (1, 1, 1, 1), (1, 1), (1, 1), False)):
- ([40, 64, 1120, 16, 48, 58], {"bypass": 0}),
- str(((32, 96, 28, 28), (256, 96, 5, 5), (2, 2, 2, 2), (1, 1), (1, 1), False)):
- ([28, 64, 112, 160, 32, 32], {"bypass": 0}),
- str(((8, 512, 26, 38), (512, 512, 3, 3), (0, 0, 0, 0), (1, 1), (1, 1), False)):
- ([26, 512, 128, 96, 80, 11], {"bypass": 1}),
- str(((32, 3, 227, 227), (96, 3, 11, 11), (0, 0, 0, 0), (4, 4), (1, 1), False)):
- ([63, 96, 208, 32, 96, 227], {"bypass": 0}),
- str(((32, 96, 27, 27), (256, 96, 5, 5), (2, 2, 2, 2), (1, 1), (1, 1), False)):
- ([21, 160, 352, 32, 96, 31], {"bypass": 0})
- }
-
-
- def conv_set_dim_func(fmap_shape, filter_shape, pad, stride, dilation,
- use_bias=False, block_size=16, attrs=None, setdim_map=None):
- """set dim info in attrs by conv_set_dim_map."""
- if isinstance(stride, int):
- stride = [stride] * 2
- elif isinstance(stride, (list, tuple)) and len(stride) == 1:
- stride = list(stride) * 2
- elif isinstance(stride, (list, tuple)) and len(stride) == 2:
- pass
- else:
- raise IndexError("stride para illegal !!!")
-
- if isinstance(pad, int):
- pad = [pad] * 4
- elif isinstance(pad, (list, tuple)) and len(pad) == 1:
- pad = list(pad) * 4
- elif isinstance(pad, (list, tuple)) and len(pad) == 4:
- pass
- else:
- raise IndexError("pad para illegal !!!")
-
- if isinstance(dilation, int):
- dilation = [dilation] * 2
- elif isinstance(dilation, (list, tuple)) and len(dilation) == 1:
- dilation = list(dilation) * 2
- elif isinstance(dilation, (list, tuple)) and len(dilation) == 2:
- pass
- else:
- raise IndexError("dilation para illegal !!!")
-
- key = []
-
- key.append(tuple(fmap_shape))
- key.append(tuple(filter_shape))
- key.append(tuple(pad))
- key.append(tuple(stride))
- key.append(tuple(dilation))
- key.append(use_bias)
-
- hash_key = str(tuple(key))
-
- # input shape (NCHW -> NC1HWC0)
- in_n, in_c, in_h, in_w = fmap_shape
- in_c = (in_c + block_size - 1) // block_size * block_size
-
- # kernel shape (NCHW -> NC1HWC0 -> Fractal)
- k_n, k_c, k_h, k_w = filter_shape
- k_c = (k_c + block_size - 1) // block_size * block_size
- k_n = (k_n + block_size - 1) // block_size * block_size
-
- # padding(padding_top, padding_bottom, padding_left, padding_right)
- padding = (pad[0], pad[1], pad[2], pad[3])
- p_top, p_bottom, p_left, p_right = padding
-
- # stride (stride_h, stride_w)
- s_h, s_w = stride
-
- # dilation (dilation_h, dilation_w)
- d_h, d_w = dilation
-
- k_w_d = (k_w - 1) * d_w + 1
- out_w = (in_w + p_left + p_right - k_w_d) // (s_w) + 1
-
- all_dynamic = 0 # kh kw pad stride
- partial_dynamic = 0 # fn fc1 fh fw wN wC
- if attrs is None:
- attrs = {}
- if attrs.get("dynamic"):
- all_dynamic = 1
- if attrs.get("partial_dynamic"):
- partial_dynamic = 1
- dynamic = partial_dynamic or all_dynamic
-
- bypass_list = [0, 1]
- bypass = 0 if dynamic else 1
- dynamic_tiling = 1 if attrs.get("dynamic") else 0 # tile size is a parameter
- use_autotiling = 1 if dynamic and not dynamic_tiling else 0
-
- dynamic_ci_c1 = 128
-
- if attrs is not None and "conv_tile" in attrs and len(attrs["conv_tile"]) >= 5:
- use_autotiling = 0
- tiles = attrs["conv_tile"]
- tile_hh = attrs["conv_tile"][0]
- tile_coco = attrs["conv_tile"][1]
- tile_mm = attrs["conv_tile"][2]
- tile_kk = attrs["conv_tile"][3]
- tile_nn = attrs["conv_tile"][4]
- if len(attrs["conv_tile"]) > 5:
- tile_ww = attrs["conv_tile"][5]
- if dynamic and not use_autotiling and len(attrs["conv_tile"]) == 7:
- dynamic_ci_c1 = attrs["conv_tile"][6]
- else:
- tile_ww = (out_w - 1) * s_w + k_w_d
- if "bypass" in attrs:
- bypass = attrs["bypass"]
- elif hash_key in setdim_map:
- configs = setdim_map[hash_key]
- if isinstance(configs, tuple):
- tiles = configs[0]
- if "bypass" in configs[1]:
- bypass = configs[1]["bypass"]
- else:
- tiles = configs
- if len(tiles) > 5:
- tile_hh, tile_coco, tile_mm, tile_kk, tile_nn, tile_ww = tiles
- else:
- tile_hh, tile_coco, tile_mm, tile_kk, tile_nn = tiles
- tile_ww = (out_w - 1) * s_w + k_w_d
- else:
- win_cut_h = 1
- k_h_d = (k_h - 1) * d_h + 1
- win_h = (in_h + p_top + p_bottom - k_h_d) // (s_h) + 1
- if not dynamic:
- while win_cut_h <= win_h:
- if (((win_h + win_cut_h - 1) // win_cut_h - 1) * win_cut_h - 1) * s_h + k_h_d <= in_h + p_top:
- break
- win_cut_h += 1
- tile_hh = (win_cut_h - 1) * s_h + k_h_d
- tile_ww = (out_w - 1) * s_w + k_w_d
- tile_coco = 16
- tile_mm = 16
- tile_kk = 16
- tile_nn = 16
- tiles = [tile_hh, tile_coco, tile_mm, tile_kk, tile_nn, tile_ww]
- if bypass not in bypass_list:
- raise ValueError("conv_cce ony supports %s while bypass is %d" % (",".join(str(bypass_list)), bypass))
-
- if tile_hh == in_h:
- tile_hh += p_top + p_bottom
- tile_coco = (tile_coco + block_size - 1) // block_size * block_size
- tile_mm = (tile_mm + block_size - 1) // block_size * block_size
- tile_kk = (tile_kk + block_size - 1) // block_size * block_size
- tile_nn = (tile_nn + block_size - 1) // block_size * block_size
-
- input_shape_nc1hwc0 = (in_n, in_c // block_size, in_h, in_w, block_size)
- in_n, in_c1, in_h, in_w, _ = input_shape_nc1hwc0
-
- kernel_shape_nc1hwc0 = (k_n, k_c // block_size, k_h, k_w, block_size)
- k_n, _, k_h, k_w, _ = kernel_shape_nc1hwc0
-
- k_h_d = (k_h - 1) * d_h + 1
- k_w_d = (k_w - 1) * d_w + 1
- out_h = (in_h + p_top + p_bottom - k_h_d) // (s_h) + 1
- tile_out_h = (tile_hh - k_h_d) // s_h + 1
- out_w = (in_w + p_left + p_right - k_w_d) // (s_w) + 1
- tile_out_w = (tile_ww - k_w_d) // s_w + 1
-
- out_shape_nc1hwc0 = (in_n, k_n // block_size, out_h, out_w, block_size)
- out_n, out_c1, out_h, out_w, out_c0 = out_shape_nc1hwc0
-
- if tile_coco > 0:
- c1_cut = tile_coco // block_size
- else:
- c1_cut = out_c1
-
- # set dim
- def gen_static_dim():
- info = dim.Dim()
- if out_n > 1:
- info.setdim(index=0, axis=0, tilel1=1, tilel0=0) # n
- if out_c1 > 1:
- info.setdim(index=0, axis=0, tilel1=c1_cut, tilel0=0) # c1
- if out_h > 1:
- info.setdim(index=0, axis="H", tilel1=tile_out_h, tilel0=0) # h
- if out_w > 1:
- info.setdim(index=0, axis="W", tilel1=tile_out_w, tilel0=0) # w
- if out_c0 > 1:
- info.setdim(index=0, axis=4, tilel1=out_c0, tilel0=0) # c0
-
- if in_c1 > 1:
- info.setdim(index=0, axis=5, tilel1=in_c1, tilel0=0) # kc1
- if k_h > 1:
- info.setdim(index=0, axis=5, tilel1=k_h, tilel0=0) # kh
- if k_w > 1:
- info.setdim(index=0, axis=5, tilel1=k_w, tilel0=0) # kw
- info.setdim(index=0, axis="KC0", tilel1=block_size, tilel0=0) # kc0
- return info
-
- def gen_dynamic_dim():
- info = dim.Dim()
- if dynamic:
- info.setdim(index=0, axis=0, tilel1=1, tilel0=0) # n
- elif out_n > 1:
- info.setdim(index=0, axis=0, tilel1=1, tilel0=0) # n
-
- if dynamic_tiling:
- info.setdim(index=0, axis=0, tilel1=c1_cut_fake, tilel0=0) # c1
- elif dynamic or out_c1 > 1:
- info.setdim(index=0, axis=0, tilel1=c1_cut, tilel0=0) # c1
-
- if dynamic_tiling:
- info.setdim(index=0, axis="H", tilel1=tile_out_h_fake, tilel0=0) # h
- elif dynamic or out_h > 1:
- info.setdim(index=0, axis="H", tilel1=tile_out_h, tilel0=0) # h
-
- if dynamic_tiling:
- info.setdim(index=0, axis="W", tilel1=tile_out_w_fake, tilel0=0) # w
- elif dynamic or out_w > 1:
- info.setdim(index=0, axis="W", tilel1=tile_out_w, tilel0=0) # w
-
- if dynamic or out_c0 > 1:
- info.setdim(index=0, axis=4, tilel1=out_c0, tilel0=0) # c0
-
- if dynamic and not use_autotiling:
- info.setdim(index=0, axis=5, tilel1=dynamic_ci_c1, tilel0=0) # kc1
- elif dynamic or in_c1 > 1:
- info.setdim(index=0, axis=5, tilel1=in_c1, tilel0=0) # kc1
-
- if dynamic or k_h > 1:
- info.setdim(index=0, axis=5, tilel1=k_h, tilel0=0) # kh
-
- if dynamic or k_w > 1:
- info.setdim(index=0, axis=5, tilel1=k_w, tilel0=0) # kw
-
- info.setdim(index=0, axis="KC0", tilel1=block_size, tilel0=0) # kc0
- return info
-
- if dynamic:
- info = gen_dynamic_dim()
- else:
- info = gen_static_dim()
- tiling = str(info)
- if use_autotiling:
- tiling = ""
- dynamic_ci_c1 = 215
- return tiling, tiles, bypass, dynamic_ci_c1
-
-
- def conv_core(data, fmap_shape, filter_shape, pad, stride, dilation, use_bias=False, attrs=None):
- """core computation for op conv."""
- if use_bias:
- if len(data) != 3:
- raise IndexError("data should contain 3 tensors, i.e. feature map, filter and bias")
- if data[2].dtype != "float16":
- raise TypeError("data type of bias should be float16")
- else:
- if len(data) != 2:
- raise IndexError("data should contain 2 tensors, i.e. feature map and filter")
- if data[0].dtype != "float16":
- raise TypeError("data type of feature map should be float16")
- if data[1].dtype != "float16":
- raise TypeError("data type of filter should be float16")
- if not isinstance(use_bias, bool):
- raise TypeError("use_bias should be set as False or True")
-
- all_dynamic = 0 # kh kw pad stride
- partial_dynamic = 0 # fn fc1 fh fw wN wC
- if attrs is None:
- attrs = {}
- if attrs.get("dynamic"):
- all_dynamic = 1
- if attrs.get("partial_dynamic"):
- partial_dynamic = 1
- dynamic = partial_dynamic or all_dynamic
- dynamic_tiling = 1 if attrs.get("dynamic") else 0
- use_autotiling = 1 if dynamic and not dynamic_tiling else 0
- block_size = 16
-
- if not dynamic:
- vc_util.convolution_format_check(fmap_shape, filter_shape, pad, stride, dilation)
- for tmp_data in data:
- shape = [x.value for x in tmp_data.shape]
- vc_util.check_shape(shape)
- vc_util.check_shape(fmap_shape)
- vc_util.check_shape(filter_shape)
-
- stride_len = 2
- pad_len = 4
- dilation_len = 2
- zero = 0
- max_s = 63
- max_d = 255
-
- if isinstance(stride, int):
- stride = [stride] * stride_len
- elif isinstance(stride, (list, tuple)) and len(stride) == 1: # only has one element
- stride = list(stride) * stride_len
- elif isinstance(stride, (list, tuple)) and len(stride) == stride_len:
- pass
- else:
- raise IndexError("stride para illegal !!!")
-
- if not dynamic:
- for val in stride:
- if val <= zero:
- raise ValueError("elements in stride should be greater than Zero !!!")
- if val > max_s:
- raise ValueError("elements in stride should be less than 64 !!!")
-
- if isinstance(pad, int):
- pad = [pad] * pad_len
- elif isinstance(pad, (list, tuple)) and len(pad) == 1: # only has one element
- pad = list(pad) * pad_len
- elif isinstance(pad, (list, tuple)) and len(pad) == pad_len:
- pass
- else:
- raise IndexError("pad para illegal !!!")
-
- if not dynamic:
- for val in pad:
- if val < zero:
- raise ValueError("elements in pad should not be less than Zero !!!")
- if val > max_d:
- raise ValueError("elements in pad should be less than 256 !!!")
-
- if isinstance(dilation, int):
- dilation = [dilation] * dilation_len
- elif isinstance(dilation, (list, tuple)) and len(dilation) == 1: # only has one element
- dilation = list(dilation) * dilation_len
- elif isinstance(dilation, (list, tuple)) and len(dilation) == dilation_len:
- pass
- else:
- raise IndexError("dilation para illegal !!!")
-
- for val in dilation:
- if val <= zero:
- raise ValueError("elements in dilation should be greater than Zero !!!")
- if val > max_d:
- raise ValueError("elements in dilation should be less than 256 !!!")
-
- if len(stride) != stride_len or len(pad) != pad_len or len(dilation) != dilation_len:
- raise IndexError(" shape of parameters must be as expected")
-
- block_size_sub_one = block_size - 1
- # input shape (NCHW -> NC1HWC0)
- in_n, in_c, in_h, in_w = fmap_shape
- in_c = (in_c + block_size_sub_one) // block_size * block_size
-
- # kernel shape (NCHW -> NC1HWC0 -> Fractal)
- k_n, k_c, k_h, k_w = filter_shape
- k_c = (k_c + block_size_sub_one) // block_size * block_size
- k_n = (k_n + block_size_sub_one) // block_size * block_size
-
- # padding(padding_top, padding_bottom, padding_left, padding_right)
- p_top, p_bottom, p_left, p_right = pad
-
- # stride (stride_h, stride_w)
- s_h, s_w = stride
-
- k_h_real = k_h
- k_w_real = k_w
- p_top_real = p_top
- p_bottom_real = p_bottom
- p_left_real = p_left
- p_right_real = p_right
- s_h_real = s_h
- s_w_real = s_w
-
- if dynamic_tiling:
- k_h = k_h_fake
- k_w = k_w_fake
- p_top = p_top_fake
- p_bottom = p_bottom_fake
- p_left = p_left_fake
- p_right = p_right_fake
- s_h = s_h_fake
- s_w = s_w_fake
-
- # dilation (dilation_h, dilation_w)
- d_h, d_w = dilation
-
- # tiling
- key = []
- key.append(tuple(fmap_shape))
- key.append(tuple(filter_shape))
- key.append(tuple(pad))
- key.append(tuple(stride))
- key.append(tuple(dilation))
- key.append(use_bias)
-
- hash_key = str(tuple(key))
-
- k_w_d = (k_w - 1) * d_w + 1
- out_w = (in_w + p_left + p_right - k_w_d) // (s_w) + 1
-
- bypass_list = [0, 1]
- bypass = 0 if dynamic else 1
-
- # (NC1HWCO)
- a_value = data[0]
-
- # (fractal)
- b_value = data[1]
- setdim_map = conv_set_dim_map
-
- conv_tile_num = 5
- if attrs is not None and "conv_tile" in attrs and len(attrs["conv_tile"]) >= conv_tile_num:
- use_autotiling = 0
- tile_hh = attrs["conv_tile"][0]
- tile_coco = attrs["conv_tile"][1]
- tile_mm = attrs["conv_tile"][2]
- tile_kk = attrs["conv_tile"][3]
- tile_nn = attrs["conv_tile"][4]
- if len(attrs["conv_tile"]) > conv_tile_num:
- tile_ww = attrs["conv_tile"][conv_tile_num]
- else:
- tile_ww = (out_w - 1) * s_w + k_w_d
- if "bypass" in attrs:
- bypass = attrs["bypass"]
- elif hash_key in setdim_map:
- configs = setdim_map[hash_key]
- if isinstance(configs, tuple):
- tiles = configs[0]
- if "bypass" in configs[1]:
- bypass = configs[1]["bypass"]
- else:
- tiles = configs
- if len(tiles) > conv_tile_num:
- tile_hh, tile_coco, tile_mm, tile_kk, tile_nn, tile_ww = tiles
- else:
- tile_hh, tile_coco, tile_mm, tile_kk, tile_nn = tiles
- tile_ww = (out_w - 1) * s_w + k_w_d
- else:
- win_cut_h = 1
- k_h_d = (k_h - 1) * d_h + 1
- win_h = (in_h + p_top + p_bottom - k_h_d) // (s_h) + 1
- if not dynamic:
- while win_cut_h <= win_h:
- if (((win_h + win_cut_h - 1) // win_cut_h - 1) * win_cut_h - 1) * s_h + k_h_d <= in_h + p_top:
- break
- win_cut_h += 1
- tile_hh = (win_cut_h - 1) * s_h + k_h_d
- tile_ww = (out_w - 1) * s_w + k_w_d
- tile_coco = block_size
- tile_mm = block_size
- tile_kk = block_size
- tile_nn = block_size
- if bypass not in bypass_list:
- raise ValueError("bypass of conv only supports %s" % (",".join(str(bypass_list))))
-
- if tile_hh == in_h:
- tile_hh += p_top + p_bottom
-
- if tile_ww == in_w:
- tile_ww += p_left + p_right
-
- tile_coco = (tile_coco + block_size_sub_one) // block_size * block_size
- tile_mm = (tile_mm + block_size_sub_one) // block_size * block_size
- tile_kk = (tile_kk + block_size_sub_one) // block_size * block_size
- tile_nn = (tile_nn + block_size_sub_one) // block_size * block_size
-
- input_shape_nc1hwc0 = get_shape(data[0])
- if not dynamic and input_shape_nc1hwc0 != [in_n, in_c // block_size, in_h, in_w, block_size]:
- raise ValueError("feature map tensor data[0] shape illegal !!!")
- in_n, c1_in, in_h, in_w, _ = input_shape_nc1hwc0
-
- if not dynamic:
- kernel_shape_nc1hwc0 = (k_n, k_c // block_size, k_h, k_w, block_size)
- else:
- kernel_shape_nc1hwc0 = (k_n, c1_in, k_h, k_w, block_size) # simplify for dynamic case
- k_n, k_c1, k_h, k_w, k_c0 = kernel_shape_nc1hwc0
- kernel_shape_fractal = get_shape(data[1])
- if not dynamic and kernel_shape_fractal != [k_c1 * k_h * k_w, k_n // block_size, block_size, k_c0]:
- raise ValueError("filter tensor data[1] shape illegal !!!")
-
- if use_bias:
- bias_value = data[2]
- bias_name = bias_value.op.name
- bias_shape = [x.value for x in data[2].shape]
- if bias_shape != [1, k_n // block_size, 1, 1, block_size]:
- raise ValueError("bias tensor data[2] shape illegal !!!")
- else:
- bias_name = "None"
- bias_value = None
-
- # Create reduction variables
- kc1 = akg.tvm.reduce_axis((0, k_c1), name="kc1")
- kh = akg.tvm.reduce_axis((0, k_h), name="kh")
- kw = akg.tvm.reduce_axis((0, k_w), name="kw")
- kc0 = akg.tvm.reduce_axis((0, k_c0), name="kc0")
-
- k_h_d = (k_h - 1) * d_h + 1
- k_h_d_real = (k_h_real - 1) * d_h + 1
- k_w_d = (k_w - 1) * d_w + 1
- k_w_d_real = (k_w_real - 1) * d_w + 1
- out_h = (in_h + p_top + p_bottom - k_h_d) // (s_h) + 1
- tile_out_h = (tile_hh - k_h_d) // s_h + 1
- tile_out_h_real = (tile_hh - k_h_d_real) // s_h_real + 1
- out_w = (in_w + p_left + p_right - k_w_d) // (s_w) + 1
- tile_out_w = (tile_ww - k_w_d) // s_w + 1
- tile_out_w_real = (tile_ww - k_w_d_real) // s_w_real + 1
-
- if not dynamic:
- out_shape_nc1hwc0 = (in_n, k_n // block_size, out_h, out_w, block_size)
- else:
- _, c1_out, _, _ = data[1].shape
- out_shape_nc1hwc0 = (in_n, c1_out, out_h, out_w, block_size)
- _, out_c1, out_h, out_w, _ = out_shape_nc1hwc0
-
- if tile_coco > 0:
- c1_cut = tile_coco // block_size
- else:
- c1_cut = out_c1
-
- # Compute the convolution
- output_name = "output0"
- conv_attr = {
- "pragma_conv_kernel_n": k_n,
- "pragma_conv_kernel_h": k_h,
- "pragma_conv_kernel_w": k_w,
- "pragma_conv_padding_top": p_top,
- "pragma_conv_padding_bottom": p_bottom,
- "pragma_conv_padding_left": p_left,
- "pragma_conv_padding_right": p_right,
- "pragma_conv_bypass_l1": bypass,
- "pragma_conv_stride_h": s_h,
- "pragma_conv_stride_w": s_w,
- "pragma_conv_dilation_h": d_h,
- "pragma_conv_dilation_w": d_w,
- "pragma_conv_fm_n": in_n,
- "pragma_conv_fm_c": in_c,
- "pragma_conv_fm_h": in_h,
- "pragma_conv_fm_w": in_w,
- "feature": a_value.op.name,
- "filter": b_value.op.name,
- "bias": bias_name,
- "res": output_name}
-
- if dynamic_tiling:
- conv_attr["pragma_conv_h_cut"] = (tile_out_h_fake - 1) * s_h + k_h_d
- conv_attr["pragma_conv_w_cut"] = (tile_out_w_fake - 1) * s_w + k_w_d
- conv_attr["pragma_conv_co_cut"] = c1_cut_fake * 16
- conv_attr["pragma_conv_m_cut"] = m_cut_fake
- conv_attr["pragma_conv_k_cut"] = k_cut_fake
- conv_attr["pragma_conv_n_cut"] = n_cut_fake
- conv_attr["pragma_conv_tile_co"] = c1_cut
- conv_attr["pragma_conv_tile_ho"] = tile_out_h_real
- conv_attr["pragma_conv_tile_wo"] = tile_out_w_real
- conv_attr["pragma_conv_tile_mo"] = tile_mm // 16
- conv_attr["pragma_conv_tile_ko"] = tile_kk // 16
- conv_attr["pragma_conv_tile_no"] = tile_nn // 16
- conv_attr["pragma_conv_real_kh"] = k_h_real
- conv_attr["pragma_conv_real_kw"] = k_w_real
- conv_attr["pragma_conv_real_sh"] = s_h_real
- conv_attr["pragma_conv_real_sw"] = s_w_real
- conv_attr["pragma_conv_real_pt"] = p_top_real
- conv_attr["pragma_conv_real_pb"] = p_bottom_real
- conv_attr["pragma_conv_real_pl"] = p_left_real
- conv_attr["pragma_conv_real_pr"] = p_right_real
- elif not use_autotiling:
- conv_attr["pragma_conv_h_cut"] = (tile_out_h - 1) * s_h + k_h_d
- conv_attr["pragma_conv_w_cut"] = (tile_out_w - 1) * s_w + k_w_d
- conv_attr["pragma_conv_co_cut"] = c1_cut * k_c0
- conv_attr["pragma_conv_m_cut"] = tile_mm
- conv_attr["pragma_conv_k_cut"] = tile_kk
- conv_attr["pragma_conv_n_cut"] = tile_nn
- c_value = akg.tvm.compute(out_shape_nc1hwc0,
- lambda n, c1, h, w, c0: akg.lang.cce.mmad(
- (akg.tvm.if_then_else(akg.tvm.any((h * s_h + kh) < p_top,
- (h * s_h + kh) > (in_h + p_top - 1),
- (w * s_w + kw) < p_left,
- (w * s_w + kw) > (in_w + p_left - 1)),
- akg.tvm.const(0.0, "float16"),
- a_value[n, kc1, (h * s_h + (kh * d_h) - p_top),
- (w * s_w + (kw * d_w) - p_left), kc0])
- * b_value[(kc1 * k_h + kh) * k_w + kw, c1, c0, kc0]).astype("float32"),
- axis=[kc1, kh, kw, kc0]), name=output_name,
- attrs=conv_attr)
- return c_value
-
- @vc_util.check_input_type((list, tuple), (list, tuple), (list, tuple), (list, tuple), (list, tuple), (list, tuple),
- (bool, type(None)), (dict, type(None)), (list, tuple, type(None)))
- def conv(data, fmap_shape, filter_shape, pad, stride, dilation, use_bias=False, attrs=None, params=None):
- """
- Computes sums of 5-D convolutionis.
-
- Args:
- data (list[tvm.tensor.Tensor]): the size is 3 if use_bias else the size is 2;
- data[0] Tensor of type float16 ,shape 5D (fN, fC // C0, C0, fH, fW)
- data[1] Tensor of type float16 ,shape 4D (wC // C0 * wH * wW, wN // C0, C0, C0)
- data[2] Tensor of type float16 ,shape 5D (1, wN // C0, 1, 1, 16)
- fmap_shape (list[int]): [fN, fC, fH, fW]
- filter_shape (list[int]): [wN, wC, wH, wW]
- pad (list[int]): [pad_top, pad_bottom, pad_left, pad_right]
- stride (list[int]): [stride_h, stride_w]
- dilation (list[int]): [dilation_h, dilation_w]
- use_bias (bool): bool var.
- attrs (dict): dict with keys for example: conv_tile,bypass
-
- Returns:
- tvm.tensor.Tensor of same type as data, shape is 5D(oN, oC // C0, oH, oW, C0)
- """
- c_value = conv_core(data, fmap_shape, filter_shape, pad, stride, dilation, use_bias, attrs)
- c_value = cast.cast(c_value, "float16")
-
- if use_bias:
- bias_value = data[2]
- output_bias_name = "output1"
- cube = akg.tvm.compute(c_value.shape, lambda n, c1, h, w, c0: c_value[n, c1, h, w, c0] +
- bias_value[0, c1, 0, 0, c0],
- name=output_bias_name)
- else:
- cube = c_value
-
- block_size = 16
- dim_info, _, _, dynamic_ci_c1 = conv_set_dim_func(fmap_shape, filter_shape, pad, stride, dilation,
- use_bias, block_size, attrs, conv_set_dim_map)
-
- all_dynamic = 0 # kh kw pad stride
- partial_dynamic = 0 # fn fc1 fh fw wN wC
- dynamic_tiling_full_dynamic = 1 # kh, kw, pad, stride are parameters if dynamic_tiling is enabled
-
- if attrs is None:
- attrs = {}
- if attrs.get("dynamic"):
- all_dynamic = 1
- if attrs.get("partial_dynamic"):
- partial_dynamic = 1
- dynamic = partial_dynamic or all_dynamic
- dynamic_tiling = 1 if attrs.get("dynamic") else 0
-
- if not dynamic:
- attrs = {"dim": dim_info, "pragma_reschedule": 1, "pragma_rmselfdep": 0}
- else:
- attrs = {"dim": dim_info,
- "pragma_reschedule": 1,
- "pragma_rmselfdep": 0,
- "enable_fix_loop_extent": 0,
- "enable_post_poly_loop_partition": 0,
- "enable_isolate_loop": 0,
- "enable_isolate_min_max": 1,
- "enable_conv_analyze_align": 0,
- "enable_double_buffer": 1,
- "enable_multicore": 1,
- "enable_invariant_hoist": 1,
- "pragma_keep_outer_band_order": 1,
- "enable_algebra_simplify": 1,
- "dynamic_shape_conv_full_parametric": dynamic_tiling and dynamic_tiling_full_dynamic,
- }
- attrs["pragma_outerband_need_split"] = 1
- attrs["pragma_is_conv"] = 1
- if dynamic_tiling:
- attrs["dynamic_shape"] = set_poly_upper_bound_for_tensor(data[0], 129, 1) # pos 1 of data[0] is CI1 axis
- else:
- attrs["dynamic_shape"] = set_poly_upper_bound_for_tensor(
- data[0], dynamic_ci_c1 + 1, 1) # pos 1 of data[0] is CI1 axis
- if dynamic_tiling:
- attrs["pragma_tilesize_is_var"] = 1
- attrs["enable_stride_kernel_op"] = 0
-
- return cube, attrs
|