#!/usr/bin/env python3 # coding: utf-8 # Copyright 2019 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """operator dsl function: conv_backprop_input""" import logging import akg.tvm import akg import akg.lang.cce from akg import dim from akg.utils import validation_check as vc_util from akg.ops.math import cast conv_backprop_input_tiling_args = { str(((1, 1024, 14, 14), (2048, 1024, 1, 1), (0, 0, 0, 0), (2, 2), (1, 1))): [2, 32, 64, 96, 128], str(((1, 1024, 14, 14), (256, 1024, 1, 1), (0, 0, 0, 0), (1, 1), (1, 1))): [14, 256, 208, 64, 112], str(((1, 1024, 14, 14), (512, 1024, 1, 1), (0, 0, 0, 0), (1, 1), (1, 1))): [4, 128, 48, 352, 16, 14], str(((1, 1024, 14, 14), (512, 1024, 1, 1), (0, 0, 0, 0), (2, 2), (1, 1))): [14, 512, 49, 32, 512], str(((1, 128, 28, 28), (128, 128, 3, 3), (1, 1, 1, 1), (1, 1), (1, 1))): [28, 128, 128, 144, 128], str(((1, 128, 28, 28), (512, 128, 1, 1), (0, 0, 0, 0), (1, 1), (1, 1))): [28, 128, 784, 16, 32], str(((1, 128, 56, 56), (128, 128, 3, 3), (0, 1, 0, 1), (2, 2), (1, 1))): [6, 32, 112, 160, 32, 58], str(((1, 16, 224, 224), (64, 16, 7, 7), (3, 3, 3, 3), (2, 2), (1, 1))): [10, 16, 16, 49 * 16, 16, 10], str(((1, 2048, 7, 7), (512, 2048, 1, 1), (0, 0, 0, 0), (1, 1), (1, 1))): [7, 512, 49, 32, 512], str(((1, 256, 13, 13), (384, 256, 3, 3), (1, 1, 1, 1), (1, 1), (1, 1))): [13, 64, 80, 48, 16, 15], str(((1, 256, 14, 14), (1024, 256, 1, 1), (0, 0, 0, 0), (1, 1), (1, 1))): [14, 256, 112, 32, 240], str(((1, 256, 14, 14), (256, 256, 3, 3), (1, 1, 1, 1), (1, 1), (1, 1))): [14, 128, 196, 144, 128], str(((1, 256, 28, 28), (256, 256, 3, 3), (0, 1, 0, 1), (2, 2), (1, 1))): [9, 16, 48, 448, 16, 30], str(((1, 256, 28, 28), (256, 256, 3, 3), (1, 1, 1, 1), (1, 1), (1, 1))): [14, 128, 196, 144, 128], str(((1, 256, 56, 56), (128, 256, 1, 1), (0, 0, 0, 0), (1, 1), (1, 1))): [8, 128, 240, 128, 128, 56], str(((1, 256, 56, 56), (128, 256, 1, 1), (0, 0, 0, 0), (2, 2), (1, 1))): [7, 128, 252, 64, 128], str(((1, 256, 56, 56), (512, 256, 1, 1), (0, 0, 0, 0), (2, 2), (1, 1))): [3, 32, 32, 32, 32], str(((1, 256, 56, 56), (64, 256, 1, 1), (0, 0, 0, 0), (1, 1), (1, 1))): [16, 64, 280, 16, 64], str(((1, 3, 224, 224), (64, 3, 7, 7), (3, 3, 3, 3), (2, 2), (1, 1))): [10, 16, 16, 49 * 16, 16, 10], str(((1, 384, 13, 13), (256, 384, 3, 3), (1, 1, 1, 1), (1, 1), (1, 1))): [12, 192, 16, 240, 96, 15], str(((1, 384, 13, 13), (384, 384, 3, 3), (1, 1, 1, 1), (1, 1), (1, 1))): [9, 128, 96, 176, 80, 15], str(((1, 512, 14, 14), (512, 512, 3, 3), (0, 1, 0, 1), (2, 2), (1, 1))): [10, 16, 80, 64, 16, 16], str(((1, 512, 28, 28), (1024, 512, 1, 1), (0, 0, 0, 0), (2, 2), (1, 1))): [2, 64, 112, 32, 512], str(((1, 512, 28, 28), (128, 512, 1, 1), (0, 0, 0, 0), (1, 1), (1, 1))): [14, 128, 448, 16, 64], str(((1, 512, 28, 28), (256, 512, 1, 1), (0, 0, 0, 0), (1, 1), (1, 1))): [10, 256, 128, 32, 256, 28], str(((1, 512, 28, 28), (256, 512, 1, 1), (0, 0, 0, 0), (2, 2), (1, 1))): [7, 256, 98, 64, 256], str(((1, 512, 7, 7), (2048, 512, 1, 1), (0, 0, 0, 0), (1, 1), (1, 1))): [7, 128, 49, 256, 128], str(((1, 512, 7, 7), (512, 512, 3, 3), (1, 1, 1, 1), (1, 1), (1, 1))): [7, 64, 49, 144, 64], str(((1, 6, 14, 14), (16, 6, 5, 5), (0, 0, 0, 0), (1, 1), (1, 1))): [18, 16, 64, 240, 16, 18], str(((1, 64, 56, 56), (256, 64, 1, 1), (0, 0, 0, 0), (1, 1), (1, 1))): [14, 256, 784, 16, 32], str(((1, 64, 56, 56), (64, 64, 1, 1), (0, 0, 0, 0), (1, 1), (1, 1))): [56, 64, 784, 16, 32], str(((1, 64, 56, 56), (64, 64, 3, 3), (1, 1, 1, 1), (1, 1), (1, 1))): [56, 64, 128, 144, 64], str(((1, 96, 28, 28), (256, 96, 5, 5), (2, 2, 2, 2), (1, 1), (1, 1))): [14, 48, 32, 384, 48, 32], str(((32, 1024, 14, 14), (2048, 1024, 1, 1), (0, 0, 0, 0), (2, 2), (1, 1))): [2, 224, 32, 32, 144, 14], str(((32, 1024, 14, 14), (256, 1024, 1, 1), (0, 0, 0, 0), (1, 1), (1, 1))): [14, 224, 192, 64, 48, 14], str(((32, 1024, 14, 14), (512, 1024, 1, 1), (0, 0, 0, 0), (1, 1), (1, 1))): [7, 352, 96, 80, 176, 14], str(((32, 1024, 14, 14), (512, 1024, 1, 1), (0, 0, 0, 0), (2, 2), (1, 1))): [14, 512, 49, 32, 512], str(((32, 128, 28, 28), (128, 128, 3, 3), (1, 1, 1, 1), (1, 1), (1, 1))): [18, 64, 208, 144, 64, 30], str(((32, 128, 28, 28), (512, 128, 1, 1), (0, 0, 0, 0), (1, 1), (1, 1))): [4, 384, 112, 16, 336, 28], str(((32, 128, 56, 56), (128, 128, 3, 3), (0, 1, 0, 1), (2, 2), (1, 1))): [6, 112, 112, 144, 112, 58], str(((32, 16, 224, 224), (64, 16, 7, 7), (3, 3, 3, 3), (2, 2), (1, 1))): [10, 16, 16, 49 * 16, 16, 10], str(((32, 2048, 7, 7), (512, 2048, 1, 1), (0, 0, 0, 0), (1, 1), (1, 1))): [7, 32, 48, 272, 32, 7], str(((32, 256, 13, 13), (384, 256, 3, 3), (1, 1, 1, 1), (1, 1), (1, 1))): [13, 64, 80, 48, 16, 15], str(((32, 256, 14, 14), (1024, 256, 1, 1), (0, 0, 0, 0), (1, 1), (1, 1))): [2, 416, 32, 752, 16, 14], str(((32, 256, 14, 14), (256, 256, 3, 3), (1, 1, 1, 1), (1, 1), (1, 1))): [16, 112, 112, 144, 112, 14], str(((32, 256, 28, 28), (256, 256, 3, 3), (0, 1, 0, 1), (2, 2), (1, 1))): [6, 144, 112, 144, 112, 30], str(((32, 256, 28, 28), (256, 256, 3, 3), (1, 1, 1, 1), (1, 1), (1, 1))): [14, 128, 196, 144, 128], str(((32, 256, 56, 56), (128, 256, 1, 1), (0, 0, 0, 0), (1, 1), (1, 1))): [4, 128, 224, 64, 112, 56], str(((32, 256, 56, 56), (128, 256, 1, 1), (0, 0, 0, 0), (2, 2), (1, 1))): [8, 128, 224, 96, 48, 56], str(((32, 256, 56, 56), (512, 256, 1, 1), (0, 0, 0, 0), (2, 2), (1, 1))): [2, 288, 112, 144, 32, 56], str(((32, 256, 56, 56), (64, 256, 1, 1), (0, 0, 0, 0), (1, 1), (1, 1))): [8, 64, 448, 64, 32, 56], str(((32, 3, 224, 224), (64, 3, 7, 7), (2, 3, 2, 3), (2, 2), (1, 1))): [13, 16, 16, 49 * 16, 16, 13], str(((32, 3, 224, 224), (64, 3, 7, 7), (3, 3, 3, 3), (2, 2), (1, 1))): [10, 16, 16, 49 * 16, 16, 10], str(((32, 3, 32, 32), (6, 3, 5, 5), (0, 0, 0, 0), (1, 1), (1, 1))): [16, 16, 16, 16, 16, 16], str(((32, 384, 13, 13), (256, 384, 3, 3), (1, 1, 1, 1), (1, 1), (1, 1))): [12, 192, 16, 240, 96, 15], str(((32, 384, 13, 13), (384, 384, 3, 3), (1, 1, 1, 1), (1, 1), (1, 1))): [9, 128, 96, 176, 80, 15], str(((32, 512, 14, 14), (512, 512, 3, 3), (0, 1, 0, 1), (2, 2), (1, 1))): [10, 96, 112, 144, 48, 16], str(((32, 512, 28, 28), (1024, 512, 1, 1), (0, 0, 0, 0), (2, 2), (1, 1))): [2, 336, 64, 80, 208, 28], str(((32, 512, 28, 28), (128, 512, 1, 1), (0, 0, 0, 0), (1, 1), (1, 1))): [4, 64, 112, 80, 64, 28], str(((32, 512, 28, 28), (256, 512, 1, 1), (0, 0, 0, 0), (1, 1), (1, 1))): [8, 224, 112, 64, 96, 28], str(((32, 512, 28, 28), (256, 512, 1, 1), (0, 0, 0, 0), (2, 2), (1, 1))): [7, 192, 96, 48, 192, 28], str(((32, 512, 7, 7), (2048, 512, 1, 1), (0, 0, 0, 0), (1, 1), (1, 1))): [7, 128, 49, 256, 128], str(((32, 512, 7, 7), (512, 512, 3, 3), (1, 1, 1, 1), (1, 1), (1, 1))): [9, 80, 64, 144, 80, 9], str(((32, 6, 14, 14), (16, 6, 5, 5), (0, 0, 0, 0), (1, 1), (1, 1))): [18, 16, 64, 240, 16, 18], str(((32, 64, 56, 56), (256, 64, 1, 1), (0, 0, 0, 0), (1, 1), (1, 1))): [4, 224, 112, 224, 80, 56], str(((32, 64, 56, 56), (64, 64, 1, 1), (0, 0, 0, 0), (1, 1), (1, 1))): [10, 64, 336, 16, 16, 56], str(((32, 64, 56, 56), (64, 64, 3, 3), (1, 1, 1, 1), (1, 1), (1, 1))): [15, 64, 112, 144, 64, 58], str(((32, 96, 27, 27), (256, 96, 5, 5), (2, 2, 2, 2), (1, 1), (1, 1))): [7, 32, 80, 48, 32, 31], str(((32, 96, 28, 28), (256, 96, 5, 5), (2, 2, 2, 2), (1, 1), (1, 1))): [14, 48, 32, 384, 48, 32], } cast_tiling_args = { str(((1, 1024, 14, 14), (2048, 1024, 1, 1), (0, 0, 0, 0), (2, 2), (1, 1))): [2, 16, 64, 96, 128], str(((1, 1024, 14, 14), (256, 1024, 1, 1), (0, 0, 0, 0), (1, 1), (1, 1))): [14, 256 // 2, 208, 64, 112], str(((1, 1024, 14, 14), (512, 1024, 1, 1), (0, 0, 0, 0), (2, 2), (1, 1))): [14, 16, 50, 32, 512], str(((1, 128, 28, 28), (128, 128, 3, 3), (1, 1, 1, 1), (1, 1), (1, 1))): [28, 128 // 4, 128, 144, 128], str(((1, 128, 28, 28), (512, 128, 1, 1), (0, 0, 0, 0), (1, 1), (1, 1))): [28, 128 // 2, 784, 16, 32], str(((1, 2048, 7, 7), (512, 2048, 1, 1), (0, 0, 0, 0), (1, 1), (1, 1))): [7, 512 // 8, 49, 32, 512], str(((1, 256, 14, 14), (1024, 256, 1, 1), (0, 0, 0, 0), (1, 1), (1, 1))): [14, 256 // 8, 112, 32, 240], str(((1, 256, 14, 14), (256, 256, 3, 3), (1, 1, 1, 1), (1, 1), (1, 1))): [14, 128 // 8, 196, 144, 128], str(((1, 256, 56, 56), (128, 256, 1, 1), (0, 0, 0, 0), (2, 2), (1, 1))): [7, 128, 252, 64, 128], str(((1, 256, 56, 56), (64, 256, 1, 1), (0, 0, 0, 0), (1, 1), (1, 1))): [16, 64, 280, 16, 64], str(((1, 3, 224, 224), (64, 3, 7, 7), (3, 3, 3, 3), (2, 2), (1, 1))): [10, 16, 16, 49 * 16, 16, 10], str(((1, 16, 224, 224), (64, 16, 7, 7), (3, 3, 3, 3), (2, 2), (1, 1))): [10, 16, 16, 49 * 16, 16, 10], str(((1, 512, 28, 28), (128, 512, 1, 1), (0, 0, 0, 0), (1, 1), (1, 1))): [14, 128, 448, 16, 64], str(((1, 512, 28, 28), (256, 512, 1, 1), (0, 0, 0, 0), (2, 2), (1, 1))): [7, 256 // 4, 98, 64, 256], str(((1, 512, 7, 7), (2048, 512, 1, 1), (0, 0, 0, 0), (1, 1), (1, 1))): [7, 128 // 8, 49, 256, 128], str(((1, 512, 7, 7), (512, 512, 3, 3), (1, 1, 1, 1), (1, 1), (1, 1))): [7, 64 // 4, 49, 144, 64], str(((1, 64, 56, 56), (256, 64, 1, 1), (0, 0, 0, 0), (1, 1), (1, 1))): [14, 256 // 8, 784, 16, 32], str(((1, 64, 56, 56), (64, 64, 1, 1), (0, 0, 0, 0), (1, 1), (1, 1))): [56, 64, 784, 16, 32], str(((1, 64, 56, 56), (64, 64, 3, 3), (1, 1, 1, 1), (1, 1), (1, 1))): [56, 64, 128, 144, 64], str(((1, 256, 56, 56), (512, 256, 1, 1), (0, 0, 0, 0), (2, 2), (1, 1))): [3, 16, 32, 32, 32], str(((1, 512, 28, 28), (1024, 512, 1, 1), (0, 0, 0, 0), (2, 2), (1, 1))): [2, 16, 112, 32, 512], } def gen_key(fmap_shape, filter_shape, pad_, stride_, dilation_): """generate key.""" key = str((tuple(fmap_shape), tuple(filter_shape), tuple(pad_), tuple(stride_), tuple(dilation_))) return key def conv_backprop_input_compute(data, output_shape, filter_shape, input_shape, pad_, stride_, block_size=16, attrs=None, key=None): """core computation of conv_backprop_input.""" _, in_c, w_h, w_w = filter_shape # stride (stride_h, stride_w) stride_h, stride_w = stride_ if stride_h != stride_w: raise ValueError("stride_h must be equal to stride_w.") # output shape (NCHW -> NC1HWC0) in_nn, in_cc, in_hh, in_ww = output_shape if in_c % block_size != 0: raise ValueError("in_c must be divided by block_size.") input_shape_nc1hwc0 = (in_nn, in_cc // block_size, in_hh, in_ww, block_size) in_nn, _, in_hh, in_ww, _ = input_shape_nc1hwc0 input_trans_shape_nc1hwc0 = (in_nn, in_cc // block_size, in_hh * stride_h, in_ww * stride_w, block_size) in_n, in_c1, in_h, in_w, _ = input_trans_shape_nc1hwc0 # kernel shape (NCHW -> NC1HWC0 -> Fractal) k_n, k_c, k_h, k_w = filter_shape if k_c % block_size != 0: raise ValueError("k_c must be divided by block_size.") kernel_shape_nc1hwc0 = (k_n, k_c // block_size, k_h, k_w, block_size) k_n, k_c1, k_h, k_w, k_c0 = kernel_shape_nc1hwc0 kernel_shape_trans = (k_n // block_size * k_h * k_w, k_c // block_size, block_size, block_size) k_c1 = k_n // block_size k_n = k_c _, _, input_h, input_w = input_shape # padding ((padding_h, padding_w) -> (padding_top, padding_bottom, padding_left, padding_right)) padding = (pad_[0], pad_[1], pad_[2], pad_[3]) pad_t, pad_b, pad_l, pad_r = padding # padHT -> padHT' p_top = k_h - pad_t - 1 # padHB -> padHB' p_bottom = input_h + pad_t - stride_h * ((input_h + pad_t + pad_b - k_h) // stride_h + 1) # padWL -> padWL' p_left = k_w - pad_l - 1 # padWR -> padWR' p_right = input_w + pad_l - stride_w * ((input_w + pad_l + pad_r - k_w) // stride_w + 1) s_h = 1 s_w = 1 # NC1HWCO a_value = data[0] if data[1].dtype == 'float32': b_value = cast.cast(data[1], 'float16') tiling_args = cast_tiling_args else: b_value = data[1] tiling_args = conv_backprop_input_tiling_args # Create reduction variables kc1 = akg.tvm.reduce_axis((0, k_c1), name='kc1') kh = akg.tvm.reduce_axis((0, k_h), name='kh') kw = akg.tvm.reduce_axis((0, k_w), name='kw') kc0 = akg.tvm.reduce_axis((0, k_c0), name='kc0') use_auto_tiling = False if attrs is not None and 'conv_tile' in attrs and len(attrs['conv_tile']) >= 5: tile_value = attrs['conv_tile'] elif key in tiling_args: tile_value = tiling_args[key] else: use_auto_tiling = True out_h = (in_h + p_top + p_bottom - k_h) // (s_h) + 1 out_w = (in_w + p_left + p_right - k_w) // (s_w) + 1 out_shape_nc1hwc0 = (in_n, k_n // block_size, out_h, out_w, block_size) out_n, out_c1, out_h, out_w, out_c0 = out_shape_nc1hwc0 # set dim info = dim.Dim() index_ = 0 if not use_auto_tiling: tile_hh = tile_value[0] if tile_hh == input_h: tile_hh += pad_t + pad_b tile_coco = tile_value[1] tile_coco = (tile_coco + block_size - 1) // block_size * block_size tile_mm = tile_value[2] tile_mm = (tile_mm + block_size - 1) // block_size * block_size tile_kk = tile_value[3] if not tile_kk % (block_size * w_h * w_w) == 0: logging.warning("Warning: tile_k must be a multiple of (block_size * w_h * w_w)") tile_kk = (tile_kk + block_size * w_h * w_w - 1) // (block_size * w_h * w_w) * (block_size * w_h * w_w) tile_nn = tile_value[4] tile_nn = (tile_nn + block_size - 1) // block_size * block_size tile_ww = input_w if len(tile_value) >= 6 and tile_value[5] > 0: tile_ww = tile_value[5] if tile_ww == input_w: tile_ww += pad_l + pad_r if tile_hh == in_h: tile_hh += p_top + p_bottom tile_out_h = (tile_hh - k_h) // s_h + 1 if tile_ww == in_w: tile_ww += p_left + p_right tile_out_w = (tile_ww - k_w) // s_w + 1 if tile_coco > 0: c1_cut = tile_coco // block_size else: c1_cut = out_c1 if out_n > 1: info.setdim(index=index_, axis=0, tilel1=1, tilel0=0) # n if out_c1 > 1: info.setdim(index=index_, axis=1, tilel1=c1_cut, tilel0=0) # c1 if out_h > 1: info.setdim(index=index_, axis="H", tilel1=tile_out_h, tilel0=0) # h if out_w > 1: info.setdim(index=index_, axis="W", tilel1=tile_out_w, tilel0=0) # w if out_c0 > 1: info.setdim(index=index_, axis=4, tilel1=out_c0, tilel0=0) # c0 if in_c1 > 1: info.setdim(index=index_, axis=5, tilel1=in_c1, tilel0=0) # kc1 if k_h > 1: info.setdim(index=index_, axis=5, tilel1=k_h, tilel0=0) # kh if k_w > 1: info.setdim(index=index_, axis=5, tilel1=k_w, tilel0=0) # kw info = str(info) else: info = "" # Compute the convolution below output_name = "output0" # weight_trans [ ko, no, ni, ki ] # weight_trans [ co_1, kh, kw, ci_1, ci_0, co_0 ] # kw = ko % k_w # kh = ko // k_w % k_h # co_1 = ko // k_w // k_h # ci_1 = no # --> # weight [ ci_1, kh', kw', co_1, co_0, ci_0 ] # weight [ no, k_h - ko // k_w % k_h - 1, k_w - ko % k_w - 1, ko // k_w // k_h, co_0, ci_0 ] b_trans = akg.tvm.compute(kernel_shape_trans, lambda ko, no, ni, ki: b_value[((no * k_h + k_h - 1 - ko // k_w % k_h) * k_w + k_w - 1 - ko % k_w), ko // (k_h * k_w), ki, ni], name='B_trans') if ((stride_h > 1) or (stride_w > 1)): @akg.tvm.hybrid.script def data_trans_hybrid(output, inputs, const_zero): """Implements data_trans ( B[n, c1, h * strideH, w * strideW, c0] = A[n, c1, h, w, c0] ).""" stride_h = output.shape[2] // inputs.shape[2] stride_w = output.shape[3] // inputs.shape[3] b = allocate(output.shape, output.dtype, 'local') for n in range(output.shape[0]): for c1 in range(output.shape[1]): for h in range(output.shape[2]): for w in range(output.shape[3]): for c0 in range(output.shape[4]): b[n, c1, h, w, c0] = const_zero if h % stride_h == 0 and w % stride_w == 0: b[n, c1, h, w, c0] = inputs[n, c1, h // stride_h, w // stride_w, c0] return b a_trans_init = akg.tvm.placeholder(input_trans_shape_nc1hwc0, dtype="float16", name='a_trans') const_zero = akg.tvm.const(0, 'float16') a_trans = data_trans_hybrid(a_trans_init, a_value, const_zero) else: a_trans = a_value conv_attrs = { "pragma_conv_kernel_n": k_n, "pragma_conv_kernel_h": k_h, "pragma_conv_kernel_w": k_w, "pragma_conv_padding_top": p_top, "pragma_conv_padding_bottom": p_bottom, "pragma_conv_padding_left": p_left, "pragma_conv_padding_right": p_right, "pragma_conv_bypass_l1": 0, "pragma_conv_backprop_input": 1, "pragma_conv_stride_h": s_h, "pragma_conv_stride_w": s_w, "pragma_conv_dilation_h": 1, "pragma_conv_dilation_w": 1, "pragma_conv_fm_n": in_n, "pragma_conv_fm_c": in_c, "pragma_conv_fm_h": in_h, "pragma_conv_fm_w": in_w, "feature": a_trans.op.name, "filter": b_value.op.name, "bias": 'None', "res": output_name} if not use_auto_tiling: conv_attrs["pragma_conv_h_cut"] = (tile_out_h - 1) * s_h + k_h conv_attrs["pragma_conv_w_cut"] = (tile_out_w - 1) * s_w + k_w conv_attrs["pragma_conv_co_cut"] = c1_cut * k_c0 conv_attrs["pragma_conv_m_cut"] = tile_mm conv_attrs["pragma_conv_k_cut"] = tile_kk conv_attrs["pragma_conv_n_cut"] = tile_nn res_c = akg.tvm.compute(out_shape_nc1hwc0, lambda n, c1, h, w, c0: akg.lang.cce.mmad( (akg.tvm.if_then_else(akg.tvm.any((h * s_h + kh) < p_top, (h * s_h + kh) > (in_h + p_top - 1), (w * s_w + kw) < p_left, (w * s_w + kw) > (in_w + p_left - 1)), akg.tvm.const(0.0, 'float16'), a_trans[n, kc1, (h * s_h + kh - p_top), (w * s_w + kw - p_left), kc0]) * b_trans[(kc1 * k_h + kh) * k_w + kw, c1, c0, kc0]).astype("float32"), axis=[kc1, kh, kw, kc0]), name=output_name, attrs=conv_attrs) res_c = cast.cast(res_c, "float16") return res_c, {"dim": info, "pragma_reschedule": 1, "pragma_rmselfdep": 0} @vc_util.check_input_type((list, tuple), (list, tuple), (list, tuple), (list, tuple), (list, tuple), (list, tuple), (dict, type(None))) def conv_backprop_input(data, fmap_shape, filter_shape, pad_, stride_, dilation_, attrs=None): """ Computes dx according "conv forward". Args: data (list[tvm.tensor.Tensor]): a list with length 2. data[0](consider as dy) Tensor of type float16 ,shape 5D(out_n, out_c//C0, out_h, out_w,C0) data[1](consider as w) Tensor of type float16 ,shape 4D(wC//C0*wH*wW, wN//C0, C0,C0) fmap_shape (list[int]): [fN, fC, fH, fW] filter_shape (list[int]): [wN, wC, wH, wW] pad_ (list[int]): [pad_left, pad_right, pad_top, pad_bottom] stride_ (list[int]): [stride_h, stride_w] dilation_ (list[int]): [dilation_h, dilation_w] attrs (dict): a dict with keys like conv_tile,bypass and etc. Returns: tvm.tensor.Tensor. configs. """ if len(data) != 2: raise IndexError("data contains output tensor and filter tensor") vc_util.convolution_format_check(fmap_shape, filter_shape, pad_, stride_, dilation_) block_size = 16 in_n, in_c, in_h, in_w = fmap_shape cout, _, w_h, w_w = filter_shape in_c = (in_c + block_size - 1) // block_size * block_size cout = (cout + block_size - 1) // block_size * block_size pad_top, pad_bottom, pad_left, pad_right = pad_ stride_h, stride_w = stride_ dilation_h, dilation_w = dilation_ if dilation_h != 1 or dilation_w != 1: raise ValueError("The value od elements in dilation_ must be 1.") out_n = in_n out_c = cout out_h = (in_h + pad_top + pad_bottom - w_h) // stride_h + 1 out_w = (in_w + pad_left + pad_right - w_w) // stride_w + 1 x_shape = (out_n, out_c, out_h, out_w) w_shape = (cout, in_c, w_h, w_w) key = gen_key(fmap_shape, filter_shape, pad_, stride_, dilation_) res_c, configs = conv_backprop_input_compute(data, x_shape, w_shape, fmap_shape, pad_, stride_, block_size=block_size, attrs=attrs, key=key) return res_c, configs