|
- #!/usr/bin/env python3
- # coding: utf-8
- # Copyright 2019 Huawei Technologies Co., Ltd
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
-
- """operator dsl function: conv_input_ad"""
- import akg.tvm
- import akg.topi
- import akg
- from akg.ops.nn import conv_backprop_input
- from akg.ops.nn import conv as conv_forward
- from akg.utils.format_transform import tvm_array_to_list
- from akg.utils import validation_check as vc_util
-
-
- def expr_to_int(in_expr):
- """Converte expr to int type value."""
- result = [a.value for a in in_expr]
- return result
-
-
- @akg.tvm.register_func("akg.autodiff.conv_input_ad_tensor")
- def conv_input_ad_tensor(data, fmap_shape, filter_shape, pad_, stride_, dilation_, attrs=None):
- """wraper of convolution filter backprop func."""
- data_list = tvm_array_to_list(data)
- fmap_shape = expr_to_int(fmap_shape)
- filter_shape = expr_to_int(filter_shape)
- pad_ = expr_to_int(pad_)
- stride_ = expr_to_int(stride_)
- dilation_ = expr_to_int(dilation_)
- c, _ = conv_backprop_input.conv_backprop_input(data_list, fmap_shape, filter_shape,
- pad_, stride_, dilation_, attrs=attrs)
- return c
-
-
- def conv_input_ad_config(data, fmap_shape, filter_shape, pad_, stride_, dilation_, attrs=None):
- """Configuration of convolution filter gradient."""
- _, configs = conv_backprop_input.conv_backprop_input(data, fmap_shape, filter_shape,
- pad_, stride_, dilation_, attrs=attrs)
- return configs
-
-
- @vc_util.check_input_type((list, tuple), (list, tuple), (list, tuple),
- (list, tuple), (list, tuple), (list, tuple), (dict, type(None)))
- def conv_input_ad(input_ad_inputs, fmap_shape, filter_shape, pad_, stride_, dilation_, attrs=None):
- """
- Compute dx according to "conv forward".
-
- Args:
- input_ad_inputs (list[tvm.tensor.Tensor]): a list with length 2.
- input_ad_inputs[0](consider as dy) Tensor of type float16 ,shape 5D(out_n, out_c//C0, out_h, out_w,C0)
- input_ad_inputs[1](consider as w) Tensor of type float16 ,shape 4D(wC//C0*wH*wW, wN//C0, C0,C0)
- fmap_shape (list): [fN, fC, fH, fW]
- filter_shape (list): [wN, wC, wH, wW]
- pad_ (list): [pad_left, pad_right, pad_top, pad_bottom]
- stride_ (list): [stride_h, stride_w]
- dilation_ (list): [dilation_h, dilation_w]
- attrs (dict): a dict with keys like conv_tile, bypass and etc.
-
- Returns:
- tvm.tensor.Tensor, configs.
- """
-
- backward_dy, forward_w = input_ad_inputs
-
- in_n, in_c, in_h, in_w = fmap_shape
- block_size = 16
- in_c = (in_c + block_size - 1) // block_size * block_size
- x_5d_shape = (in_n, in_c // block_size, in_h, in_w, block_size)
-
- forward_x = akg.tvm.placeholder(x_5d_shape, forward_w.dtype, "input_X")
- original_filter_shape = akg.tvm.placeholder(filter_shape, forward_w.dtype, "input_filter")
- forward_output, _ = conv_forward.conv([forward_x, forward_w], fmap_shape, filter_shape,
- pad_, stride_, dilation_, use_bias=False, attrs=attrs)
-
- ad_attrs = {"ad_conv_enable": 1, "ad_conv_reuse_conv": 0}
- jacs = list(akg.differentiate(forward_output, [forward_x], backward_dy, ad_attrs,
- [backward_dy, forward_w, original_filter_shape]))
- configs = conv_input_ad_config([backward_dy, forward_w], fmap_shape, filter_shape,
- pad_, stride_, dilation_, attrs=attrs)
-
- return jacs[0], configs
|