You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

conv_input_ad.py 4.1 kB

5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293
  1. #!/usr/bin/env python3
  2. # coding: utf-8
  3. # Copyright 2019 Huawei Technologies Co., Ltd
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. """operator dsl function: conv_input_ad"""
  17. import akg.tvm
  18. import akg.topi
  19. import akg
  20. from akg.ops.nn import conv_backprop_input
  21. from akg.ops.nn import conv as conv_forward
  22. from akg.utils.format_transform import tvm_array_to_list
  23. from akg.utils import validation_check as vc_util
  24. def expr_to_int(in_expr):
  25. """Converte expr to int type value."""
  26. result = [a.value for a in in_expr]
  27. return result
  28. @akg.tvm.register_func("akg.autodiff.conv_input_ad_tensor")
  29. def conv_input_ad_tensor(data, fmap_shape, filter_shape, pad_, stride_, dilation_, attrs=None):
  30. """wraper of convolution filter backprop func."""
  31. data_list = tvm_array_to_list(data)
  32. fmap_shape = expr_to_int(fmap_shape)
  33. filter_shape = expr_to_int(filter_shape)
  34. pad_ = expr_to_int(pad_)
  35. stride_ = expr_to_int(stride_)
  36. dilation_ = expr_to_int(dilation_)
  37. c, _ = conv_backprop_input.conv_backprop_input(data_list, fmap_shape, filter_shape,
  38. pad_, stride_, dilation_, attrs=attrs)
  39. return c
  40. def conv_input_ad_config(data, fmap_shape, filter_shape, pad_, stride_, dilation_, attrs=None):
  41. """Configuration of convolution filter gradient."""
  42. _, configs = conv_backprop_input.conv_backprop_input(data, fmap_shape, filter_shape,
  43. pad_, stride_, dilation_, attrs=attrs)
  44. return configs
  45. @vc_util.check_input_type((list, tuple), (list, tuple), (list, tuple),
  46. (list, tuple), (list, tuple), (list, tuple), (dict, type(None)))
  47. def conv_input_ad(input_ad_inputs, fmap_shape, filter_shape, pad_, stride_, dilation_, attrs=None):
  48. """
  49. Compute dx according to "conv forward".
  50. Args:
  51. input_ad_inputs (list[tvm.tensor.Tensor]): a list with length 2.
  52. input_ad_inputs[0](consider as dy) Tensor of type float16 ,shape 5D(out_n, out_c//C0, out_h, out_w,C0)
  53. input_ad_inputs[1](consider as w) Tensor of type float16 ,shape 4D(wC//C0*wH*wW, wN//C0, C0,C0)
  54. fmap_shape (list): [fN, fC, fH, fW]
  55. filter_shape (list): [wN, wC, wH, wW]
  56. pad_ (list): [pad_left, pad_right, pad_top, pad_bottom]
  57. stride_ (list): [stride_h, stride_w]
  58. dilation_ (list): [dilation_h, dilation_w]
  59. attrs (dict): a dict with keys like conv_tile, bypass and etc.
  60. Returns:
  61. tvm.tensor.Tensor, configs.
  62. """
  63. backward_dy, forward_w = input_ad_inputs
  64. in_n, in_c, in_h, in_w = fmap_shape
  65. block_size = 16
  66. in_c = (in_c + block_size - 1) // block_size * block_size
  67. x_5d_shape = (in_n, in_c // block_size, in_h, in_w, block_size)
  68. forward_x = akg.tvm.placeholder(x_5d_shape, forward_w.dtype, "input_X")
  69. original_filter_shape = akg.tvm.placeholder(filter_shape, forward_w.dtype, "input_filter")
  70. forward_output, _ = conv_forward.conv([forward_x, forward_w], fmap_shape, filter_shape,
  71. pad_, stride_, dilation_, use_bias=False, attrs=attrs)
  72. ad_attrs = {"ad_conv_enable": 1, "ad_conv_reuse_conv": 0}
  73. jacs = list(akg.differentiate(forward_output, [forward_x], backward_dy, ad_attrs,
  74. [backward_dy, forward_w, original_filter_shape]))
  75. configs = conv_input_ad_config([backward_dy, forward_w], fmap_shape, filter_shape,
  76. pad_, stride_, dilation_, attrs=attrs)
  77. return jacs[0], configs