You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

conv_backprop_input.py 22 kB

5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407
  1. #!/usr/bin/env python3
  2. # coding: utf-8
  3. # Copyright 2019 Huawei Technologies Co., Ltd
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. """operator dsl function: conv_backprop_input"""
  17. import logging
  18. import akg.tvm
  19. import akg
  20. import akg.lang.cce
  21. from akg import dim
  22. from akg.utils import validation_check as vc_util
  23. from akg.ops.math import cast
  24. conv_backprop_input_tiling_args = {
  25. str(((1, 1024, 14, 14), (2048, 1024, 1, 1), (0, 0, 0, 0), (2, 2), (1, 1))): [2, 32, 64, 96, 128],
  26. str(((1, 1024, 14, 14), (256, 1024, 1, 1), (0, 0, 0, 0), (1, 1), (1, 1))): [14, 256, 208, 64, 112],
  27. str(((1, 1024, 14, 14), (512, 1024, 1, 1), (0, 0, 0, 0), (1, 1), (1, 1))): [4, 128, 48, 352, 16, 14],
  28. str(((1, 1024, 14, 14), (512, 1024, 1, 1), (0, 0, 0, 0), (2, 2), (1, 1))): [14, 512, 49, 32, 512],
  29. str(((1, 128, 28, 28), (128, 128, 3, 3), (1, 1, 1, 1), (1, 1), (1, 1))): [28, 128, 128, 144, 128],
  30. str(((1, 128, 28, 28), (512, 128, 1, 1), (0, 0, 0, 0), (1, 1), (1, 1))): [28, 128, 784, 16, 32],
  31. str(((1, 128, 56, 56), (128, 128, 3, 3), (0, 1, 0, 1), (2, 2), (1, 1))): [6, 32, 112, 160, 32, 58],
  32. str(((1, 16, 224, 224), (64, 16, 7, 7), (3, 3, 3, 3), (2, 2), (1, 1))): [10, 16, 16, 49 * 16, 16, 10],
  33. str(((1, 2048, 7, 7), (512, 2048, 1, 1), (0, 0, 0, 0), (1, 1), (1, 1))): [7, 512, 49, 32, 512],
  34. str(((1, 256, 13, 13), (384, 256, 3, 3), (1, 1, 1, 1), (1, 1), (1, 1))): [13, 64, 80, 48, 16, 15],
  35. str(((1, 256, 14, 14), (1024, 256, 1, 1), (0, 0, 0, 0), (1, 1), (1, 1))): [14, 256, 112, 32, 240],
  36. str(((1, 256, 14, 14), (256, 256, 3, 3), (1, 1, 1, 1), (1, 1), (1, 1))): [14, 128, 196, 144, 128],
  37. str(((1, 256, 28, 28), (256, 256, 3, 3), (0, 1, 0, 1), (2, 2), (1, 1))): [9, 16, 48, 448, 16, 30],
  38. str(((1, 256, 28, 28), (256, 256, 3, 3), (1, 1, 1, 1), (1, 1), (1, 1))): [14, 128, 196, 144, 128],
  39. str(((1, 256, 56, 56), (128, 256, 1, 1), (0, 0, 0, 0), (1, 1), (1, 1))): [8, 128, 240, 128, 128, 56],
  40. str(((1, 256, 56, 56), (128, 256, 1, 1), (0, 0, 0, 0), (2, 2), (1, 1))): [7, 128, 252, 64, 128],
  41. str(((1, 256, 56, 56), (512, 256, 1, 1), (0, 0, 0, 0), (2, 2), (1, 1))): [3, 32, 32, 32, 32],
  42. str(((1, 256, 56, 56), (64, 256, 1, 1), (0, 0, 0, 0), (1, 1), (1, 1))): [16, 64, 280, 16, 64],
  43. str(((1, 3, 224, 224), (64, 3, 7, 7), (3, 3, 3, 3), (2, 2), (1, 1))): [10, 16, 16, 49 * 16, 16, 10],
  44. str(((1, 384, 13, 13), (256, 384, 3, 3), (1, 1, 1, 1), (1, 1), (1, 1))): [12, 192, 16, 240, 96, 15],
  45. str(((1, 384, 13, 13), (384, 384, 3, 3), (1, 1, 1, 1), (1, 1), (1, 1))): [9, 128, 96, 176, 80, 15],
  46. str(((1, 512, 14, 14), (512, 512, 3, 3), (0, 1, 0, 1), (2, 2), (1, 1))): [10, 16, 80, 64, 16, 16],
  47. str(((1, 512, 28, 28), (1024, 512, 1, 1), (0, 0, 0, 0), (2, 2), (1, 1))): [2, 64, 112, 32, 512],
  48. str(((1, 512, 28, 28), (128, 512, 1, 1), (0, 0, 0, 0), (1, 1), (1, 1))): [14, 128, 448, 16, 64],
  49. str(((1, 512, 28, 28), (256, 512, 1, 1), (0, 0, 0, 0), (1, 1), (1, 1))): [10, 256, 128, 32, 256, 28],
  50. str(((1, 512, 28, 28), (256, 512, 1, 1), (0, 0, 0, 0), (2, 2), (1, 1))): [7, 256, 98, 64, 256],
  51. str(((1, 512, 7, 7), (2048, 512, 1, 1), (0, 0, 0, 0), (1, 1), (1, 1))): [7, 128, 49, 256, 128],
  52. str(((1, 512, 7, 7), (512, 512, 3, 3), (1, 1, 1, 1), (1, 1), (1, 1))): [7, 64, 49, 144, 64],
  53. str(((1, 6, 14, 14), (16, 6, 5, 5), (0, 0, 0, 0), (1, 1), (1, 1))): [18, 16, 64, 240, 16, 18],
  54. str(((1, 64, 56, 56), (256, 64, 1, 1), (0, 0, 0, 0), (1, 1), (1, 1))): [14, 256, 784, 16, 32],
  55. str(((1, 64, 56, 56), (64, 64, 1, 1), (0, 0, 0, 0), (1, 1), (1, 1))): [56, 64, 784, 16, 32],
  56. str(((1, 64, 56, 56), (64, 64, 3, 3), (1, 1, 1, 1), (1, 1), (1, 1))): [56, 64, 128, 144, 64],
  57. str(((1, 96, 28, 28), (256, 96, 5, 5), (2, 2, 2, 2), (1, 1), (1, 1))): [14, 48, 32, 384, 48, 32],
  58. str(((32, 1024, 14, 14), (2048, 1024, 1, 1), (0, 0, 0, 0), (2, 2), (1, 1))): [2, 224, 32, 32, 144, 14],
  59. str(((32, 1024, 14, 14), (256, 1024, 1, 1), (0, 0, 0, 0), (1, 1), (1, 1))): [14, 224, 192, 64, 48, 14],
  60. str(((32, 1024, 14, 14), (512, 1024, 1, 1), (0, 0, 0, 0), (1, 1), (1, 1))): [7, 352, 96, 80, 176, 14],
  61. str(((32, 1024, 14, 14), (512, 1024, 1, 1), (0, 0, 0, 0), (2, 2), (1, 1))): [14, 512, 49, 32, 512],
  62. str(((32, 128, 28, 28), (128, 128, 3, 3), (1, 1, 1, 1), (1, 1), (1, 1))): [18, 64, 208, 144, 64, 30],
  63. str(((32, 128, 28, 28), (512, 128, 1, 1), (0, 0, 0, 0), (1, 1), (1, 1))): [4, 384, 112, 16, 336, 28],
  64. str(((32, 128, 56, 56), (128, 128, 3, 3), (0, 1, 0, 1), (2, 2), (1, 1))): [6, 112, 112, 144, 112, 58],
  65. str(((32, 16, 224, 224), (64, 16, 7, 7), (3, 3, 3, 3), (2, 2), (1, 1))): [10, 16, 16, 49 * 16, 16, 10],
  66. str(((32, 2048, 7, 7), (512, 2048, 1, 1), (0, 0, 0, 0), (1, 1), (1, 1))): [7, 32, 48, 272, 32, 7],
  67. str(((32, 256, 13, 13), (384, 256, 3, 3), (1, 1, 1, 1), (1, 1), (1, 1))): [13, 64, 80, 48, 16, 15],
  68. str(((32, 256, 14, 14), (1024, 256, 1, 1), (0, 0, 0, 0), (1, 1), (1, 1))): [2, 416, 32, 752, 16, 14],
  69. str(((32, 256, 14, 14), (256, 256, 3, 3), (1, 1, 1, 1), (1, 1), (1, 1))): [16, 112, 112, 144, 112, 14],
  70. str(((32, 256, 28, 28), (256, 256, 3, 3), (0, 1, 0, 1), (2, 2), (1, 1))): [6, 144, 112, 144, 112, 30],
  71. str(((32, 256, 28, 28), (256, 256, 3, 3), (1, 1, 1, 1), (1, 1), (1, 1))): [14, 128, 196, 144, 128],
  72. str(((32, 256, 56, 56), (128, 256, 1, 1), (0, 0, 0, 0), (1, 1), (1, 1))): [4, 128, 224, 64, 112, 56],
  73. str(((32, 256, 56, 56), (128, 256, 1, 1), (0, 0, 0, 0), (2, 2), (1, 1))): [8, 128, 224, 96, 48, 56],
  74. str(((32, 256, 56, 56), (512, 256, 1, 1), (0, 0, 0, 0), (2, 2), (1, 1))): [2, 288, 112, 144, 32, 56],
  75. str(((32, 256, 56, 56), (64, 256, 1, 1), (0, 0, 0, 0), (1, 1), (1, 1))): [8, 64, 448, 64, 32, 56],
  76. str(((32, 3, 224, 224), (64, 3, 7, 7), (2, 3, 2, 3), (2, 2), (1, 1))): [13, 16, 16, 49 * 16, 16, 13],
  77. str(((32, 3, 224, 224), (64, 3, 7, 7), (3, 3, 3, 3), (2, 2), (1, 1))): [10, 16, 16, 49 * 16, 16, 10],
  78. str(((32, 3, 32, 32), (6, 3, 5, 5), (0, 0, 0, 0), (1, 1), (1, 1))): [16, 16, 16, 16, 16, 16],
  79. str(((32, 384, 13, 13), (256, 384, 3, 3), (1, 1, 1, 1), (1, 1), (1, 1))): [12, 192, 16, 240, 96, 15],
  80. str(((32, 384, 13, 13), (384, 384, 3, 3), (1, 1, 1, 1), (1, 1), (1, 1))): [9, 128, 96, 176, 80, 15],
  81. str(((32, 512, 14, 14), (512, 512, 3, 3), (0, 1, 0, 1), (2, 2), (1, 1))): [10, 96, 112, 144, 48, 16],
  82. str(((32, 512, 28, 28), (1024, 512, 1, 1), (0, 0, 0, 0), (2, 2), (1, 1))): [2, 336, 64, 80, 208, 28],
  83. str(((32, 512, 28, 28), (128, 512, 1, 1), (0, 0, 0, 0), (1, 1), (1, 1))): [4, 64, 112, 80, 64, 28],
  84. str(((32, 512, 28, 28), (256, 512, 1, 1), (0, 0, 0, 0), (1, 1), (1, 1))): [8, 224, 112, 64, 96, 28],
  85. str(((32, 512, 28, 28), (256, 512, 1, 1), (0, 0, 0, 0), (2, 2), (1, 1))): [7, 192, 96, 48, 192, 28],
  86. str(((32, 512, 7, 7), (2048, 512, 1, 1), (0, 0, 0, 0), (1, 1), (1, 1))): [7, 128, 49, 256, 128],
  87. str(((32, 512, 7, 7), (512, 512, 3, 3), (1, 1, 1, 1), (1, 1), (1, 1))): [9, 80, 64, 144, 80, 9],
  88. str(((32, 6, 14, 14), (16, 6, 5, 5), (0, 0, 0, 0), (1, 1), (1, 1))): [18, 16, 64, 240, 16, 18],
  89. str(((32, 64, 56, 56), (256, 64, 1, 1), (0, 0, 0, 0), (1, 1), (1, 1))): [4, 224, 112, 224, 80, 56],
  90. str(((32, 64, 56, 56), (64, 64, 1, 1), (0, 0, 0, 0), (1, 1), (1, 1))): [10, 64, 336, 16, 16, 56],
  91. str(((32, 64, 56, 56), (64, 64, 3, 3), (1, 1, 1, 1), (1, 1), (1, 1))): [15, 64, 112, 144, 64, 58],
  92. str(((32, 96, 27, 27), (256, 96, 5, 5), (2, 2, 2, 2), (1, 1), (1, 1))): [7, 32, 80, 48, 32, 31],
  93. str(((32, 96, 28, 28), (256, 96, 5, 5), (2, 2, 2, 2), (1, 1), (1, 1))): [14, 48, 32, 384, 48, 32],
  94. }
  95. cast_tiling_args = {
  96. str(((1, 1024, 14, 14), (2048, 1024, 1, 1), (0, 0, 0, 0), (2, 2), (1, 1))): [2, 16, 64, 96, 128],
  97. str(((1, 1024, 14, 14), (256, 1024, 1, 1), (0, 0, 0, 0), (1, 1), (1, 1))): [14, 256 // 2, 208, 64, 112],
  98. str(((1, 1024, 14, 14), (512, 1024, 1, 1), (0, 0, 0, 0), (2, 2), (1, 1))): [14, 16, 50, 32, 512],
  99. str(((1, 128, 28, 28), (128, 128, 3, 3), (1, 1, 1, 1), (1, 1), (1, 1))): [28, 128 // 4, 128, 144, 128],
  100. str(((1, 128, 28, 28), (512, 128, 1, 1), (0, 0, 0, 0), (1, 1), (1, 1))): [28, 128 // 2, 784, 16, 32],
  101. str(((1, 2048, 7, 7), (512, 2048, 1, 1), (0, 0, 0, 0), (1, 1), (1, 1))): [7, 512 // 8, 49, 32, 512],
  102. str(((1, 256, 14, 14), (1024, 256, 1, 1), (0, 0, 0, 0), (1, 1), (1, 1))): [14, 256 // 8, 112, 32, 240],
  103. str(((1, 256, 14, 14), (256, 256, 3, 3), (1, 1, 1, 1), (1, 1), (1, 1))): [14, 128 // 8, 196, 144, 128],
  104. str(((1, 256, 56, 56), (128, 256, 1, 1), (0, 0, 0, 0), (2, 2), (1, 1))): [7, 128, 252, 64, 128],
  105. str(((1, 256, 56, 56), (64, 256, 1, 1), (0, 0, 0, 0), (1, 1), (1, 1))): [16, 64, 280, 16, 64],
  106. str(((1, 3, 224, 224), (64, 3, 7, 7), (3, 3, 3, 3), (2, 2), (1, 1))): [10, 16, 16, 49 * 16, 16, 10],
  107. str(((1, 16, 224, 224), (64, 16, 7, 7), (3, 3, 3, 3), (2, 2), (1, 1))): [10, 16, 16, 49 * 16, 16, 10],
  108. str(((1, 512, 28, 28), (128, 512, 1, 1), (0, 0, 0, 0), (1, 1), (1, 1))): [14, 128, 448, 16, 64],
  109. str(((1, 512, 28, 28), (256, 512, 1, 1), (0, 0, 0, 0), (2, 2), (1, 1))): [7, 256 // 4, 98, 64, 256],
  110. str(((1, 512, 7, 7), (2048, 512, 1, 1), (0, 0, 0, 0), (1, 1), (1, 1))): [7, 128 // 8, 49, 256, 128],
  111. str(((1, 512, 7, 7), (512, 512, 3, 3), (1, 1, 1, 1), (1, 1), (1, 1))): [7, 64 // 4, 49, 144, 64],
  112. str(((1, 64, 56, 56), (256, 64, 1, 1), (0, 0, 0, 0), (1, 1), (1, 1))): [14, 256 // 8, 784, 16, 32],
  113. str(((1, 64, 56, 56), (64, 64, 1, 1), (0, 0, 0, 0), (1, 1), (1, 1))): [56, 64, 784, 16, 32],
  114. str(((1, 64, 56, 56), (64, 64, 3, 3), (1, 1, 1, 1), (1, 1), (1, 1))): [56, 64, 128, 144, 64],
  115. str(((1, 256, 56, 56), (512, 256, 1, 1), (0, 0, 0, 0), (2, 2), (1, 1))): [3, 16, 32, 32, 32],
  116. str(((1, 512, 28, 28), (1024, 512, 1, 1), (0, 0, 0, 0), (2, 2), (1, 1))): [2, 16, 112, 32, 512],
  117. }
  118. def gen_key(fmap_shape, filter_shape, pad_, stride_, dilation_):
  119. """generate key."""
  120. key = str((tuple(fmap_shape), tuple(filter_shape), tuple(pad_), tuple(stride_), tuple(dilation_)))
  121. return key
  122. def conv_backprop_input_compute(data, output_shape, filter_shape, input_shape, pad_, stride_,
  123. block_size=16, attrs=None, key=None):
  124. """core computation of conv_backprop_input."""
  125. _, in_c, w_h, w_w = filter_shape
  126. # stride (stride_h, stride_w)
  127. stride_h, stride_w = stride_
  128. if stride_h != stride_w:
  129. raise ValueError("stride_h must be equal to stride_w.")
  130. # output shape (NCHW -> NC1HWC0)
  131. in_nn, in_cc, in_hh, in_ww = output_shape
  132. if in_c % block_size != 0:
  133. raise ValueError("in_c must be divided by block_size.")
  134. input_shape_nc1hwc0 = (in_nn, in_cc // block_size, in_hh, in_ww, block_size)
  135. in_nn, _, in_hh, in_ww, _ = input_shape_nc1hwc0
  136. input_trans_shape_nc1hwc0 = (in_nn, in_cc // block_size, in_hh * stride_h, in_ww * stride_w, block_size)
  137. in_n, in_c1, in_h, in_w, _ = input_trans_shape_nc1hwc0
  138. # kernel shape (NCHW -> NC1HWC0 -> Fractal)
  139. k_n, k_c, k_h, k_w = filter_shape
  140. if k_c % block_size != 0:
  141. raise ValueError("k_c must be divided by block_size.")
  142. kernel_shape_nc1hwc0 = (k_n, k_c // block_size, k_h, k_w, block_size)
  143. k_n, k_c1, k_h, k_w, k_c0 = kernel_shape_nc1hwc0
  144. kernel_shape_trans = (k_n // block_size * k_h * k_w, k_c // block_size, block_size, block_size)
  145. k_c1 = k_n // block_size
  146. k_n = k_c
  147. _, _, input_h, input_w = input_shape
  148. # padding ((padding_h, padding_w) -> (padding_top, padding_bottom, padding_left, padding_right))
  149. padding = (pad_[0], pad_[1], pad_[2], pad_[3])
  150. pad_t, pad_b, pad_l, pad_r = padding
  151. # padHT -> padHT'
  152. p_top = k_h - pad_t - 1
  153. # padHB -> padHB'
  154. p_bottom = input_h + pad_t - stride_h * ((input_h + pad_t + pad_b - k_h) // stride_h + 1)
  155. # padWL -> padWL'
  156. p_left = k_w - pad_l - 1
  157. # padWR -> padWR'
  158. p_right = input_w + pad_l - stride_w * ((input_w + pad_l + pad_r - k_w) // stride_w + 1)
  159. s_h = 1
  160. s_w = 1
  161. # NC1HWCO
  162. a_value = data[0]
  163. if data[1].dtype == 'float32':
  164. b_value = cast.cast(data[1], 'float16')
  165. tiling_args = cast_tiling_args
  166. else:
  167. b_value = data[1]
  168. tiling_args = conv_backprop_input_tiling_args
  169. # Create reduction variables
  170. kc1 = akg.tvm.reduce_axis((0, k_c1), name='kc1')
  171. kh = akg.tvm.reduce_axis((0, k_h), name='kh')
  172. kw = akg.tvm.reduce_axis((0, k_w), name='kw')
  173. kc0 = akg.tvm.reduce_axis((0, k_c0), name='kc0')
  174. use_auto_tiling = False
  175. if attrs is not None and 'conv_tile' in attrs and len(attrs['conv_tile']) >= 5:
  176. tile_value = attrs['conv_tile']
  177. elif key in tiling_args:
  178. tile_value = tiling_args[key]
  179. else:
  180. use_auto_tiling = True
  181. out_h = (in_h + p_top + p_bottom - k_h) // (s_h) + 1
  182. out_w = (in_w + p_left + p_right - k_w) // (s_w) + 1
  183. out_shape_nc1hwc0 = (in_n, k_n // block_size, out_h, out_w, block_size)
  184. out_n, out_c1, out_h, out_w, out_c0 = out_shape_nc1hwc0
  185. # set dim
  186. info = dim.Dim()
  187. index_ = 0
  188. if not use_auto_tiling:
  189. tile_hh = tile_value[0]
  190. if tile_hh == input_h:
  191. tile_hh += pad_t + pad_b
  192. tile_coco = tile_value[1]
  193. tile_coco = (tile_coco + block_size - 1) // block_size * block_size
  194. tile_mm = tile_value[2]
  195. tile_mm = (tile_mm + block_size - 1) // block_size * block_size
  196. tile_kk = tile_value[3]
  197. if not tile_kk % (block_size * w_h * w_w) == 0:
  198. logging.warning("Warning: tile_k must be a multiple of (block_size * w_h * w_w)")
  199. tile_kk = (tile_kk + block_size * w_h * w_w - 1) // (block_size * w_h * w_w) * (block_size * w_h * w_w)
  200. tile_nn = tile_value[4]
  201. tile_nn = (tile_nn + block_size - 1) // block_size * block_size
  202. tile_ww = input_w
  203. if len(tile_value) >= 6 and tile_value[5] > 0:
  204. tile_ww = tile_value[5]
  205. if tile_ww == input_w:
  206. tile_ww += pad_l + pad_r
  207. if tile_hh == in_h:
  208. tile_hh += p_top + p_bottom
  209. tile_out_h = (tile_hh - k_h) // s_h + 1
  210. if tile_ww == in_w:
  211. tile_ww += p_left + p_right
  212. tile_out_w = (tile_ww - k_w) // s_w + 1
  213. if tile_coco > 0:
  214. c1_cut = tile_coco // block_size
  215. else:
  216. c1_cut = out_c1
  217. if out_n > 1:
  218. info.setdim(index=index_, axis=0, tilel1=1, tilel0=0) # n
  219. if out_c1 > 1:
  220. info.setdim(index=index_, axis=1, tilel1=c1_cut, tilel0=0) # c1
  221. if out_h > 1:
  222. info.setdim(index=index_, axis="H", tilel1=tile_out_h, tilel0=0) # h
  223. if out_w > 1:
  224. info.setdim(index=index_, axis="W", tilel1=tile_out_w, tilel0=0) # w
  225. if out_c0 > 1:
  226. info.setdim(index=index_, axis=4, tilel1=out_c0, tilel0=0) # c0
  227. if in_c1 > 1:
  228. info.setdim(index=index_, axis=5, tilel1=in_c1, tilel0=0) # kc1
  229. if k_h > 1:
  230. info.setdim(index=index_, axis=5, tilel1=k_h, tilel0=0) # kh
  231. if k_w > 1:
  232. info.setdim(index=index_, axis=5, tilel1=k_w, tilel0=0) # kw
  233. info = str(info)
  234. else:
  235. info = ""
  236. # Compute the convolution below
  237. output_name = "output0"
  238. # weight_trans [ ko, no, ni, ki ]
  239. # weight_trans [ co_1, kh, kw, ci_1, ci_0, co_0 ]
  240. # kw = ko % k_w
  241. # kh = ko // k_w % k_h
  242. # co_1 = ko // k_w // k_h
  243. # ci_1 = no
  244. # -->
  245. # weight [ ci_1, kh', kw', co_1, co_0, ci_0 ]
  246. # weight [ no, k_h - ko // k_w % k_h - 1, k_w - ko % k_w - 1, ko // k_w // k_h, co_0, ci_0 ]
  247. b_trans = akg.tvm.compute(kernel_shape_trans,
  248. lambda ko, no, ni, ki: b_value[((no * k_h + k_h - 1 - ko // k_w % k_h)
  249. * k_w + k_w - 1 - ko % k_w), ko // (k_h * k_w), ki, ni],
  250. name='B_trans')
  251. if ((stride_h > 1) or (stride_w > 1)):
  252. @akg.tvm.hybrid.script
  253. def data_trans_hybrid(output, inputs, const_zero):
  254. """Implements data_trans ( B[n, c1, h * strideH, w * strideW, c0] = A[n, c1, h, w, c0] )."""
  255. stride_h = output.shape[2] // inputs.shape[2]
  256. stride_w = output.shape[3] // inputs.shape[3]
  257. b = allocate(output.shape, output.dtype, 'local')
  258. for n in range(output.shape[0]):
  259. for c1 in range(output.shape[1]):
  260. for h in range(output.shape[2]):
  261. for w in range(output.shape[3]):
  262. for c0 in range(output.shape[4]):
  263. b[n, c1, h, w, c0] = const_zero
  264. if h % stride_h == 0 and w % stride_w == 0:
  265. b[n, c1, h, w, c0] = inputs[n, c1, h // stride_h, w // stride_w, c0]
  266. return b
  267. a_trans_init = akg.tvm.placeholder(input_trans_shape_nc1hwc0, dtype="float16", name='a_trans')
  268. const_zero = akg.tvm.const(0, 'float16')
  269. a_trans = data_trans_hybrid(a_trans_init, a_value, const_zero)
  270. else:
  271. a_trans = a_value
  272. conv_attrs = {
  273. "pragma_conv_kernel_n": k_n,
  274. "pragma_conv_kernel_h": k_h,
  275. "pragma_conv_kernel_w": k_w,
  276. "pragma_conv_padding_top": p_top,
  277. "pragma_conv_padding_bottom": p_bottom,
  278. "pragma_conv_padding_left": p_left,
  279. "pragma_conv_padding_right": p_right,
  280. "pragma_conv_bypass_l1": 0,
  281. "pragma_conv_backprop_input": 1,
  282. "pragma_conv_stride_h": s_h,
  283. "pragma_conv_stride_w": s_w,
  284. "pragma_conv_dilation_h": 1,
  285. "pragma_conv_dilation_w": 1,
  286. "pragma_conv_fm_n": in_n,
  287. "pragma_conv_fm_c": in_c,
  288. "pragma_conv_fm_h": in_h,
  289. "pragma_conv_fm_w": in_w,
  290. "feature": a_trans.op.name,
  291. "filter": b_value.op.name,
  292. "bias": 'None',
  293. "res": output_name}
  294. if not use_auto_tiling:
  295. conv_attrs["pragma_conv_h_cut"] = (tile_out_h - 1) * s_h + k_h
  296. conv_attrs["pragma_conv_w_cut"] = (tile_out_w - 1) * s_w + k_w
  297. conv_attrs["pragma_conv_co_cut"] = c1_cut * k_c0
  298. conv_attrs["pragma_conv_m_cut"] = tile_mm
  299. conv_attrs["pragma_conv_k_cut"] = tile_kk
  300. conv_attrs["pragma_conv_n_cut"] = tile_nn
  301. res_c = akg.tvm.compute(out_shape_nc1hwc0,
  302. lambda n, c1, h, w, c0: akg.lang.cce.mmad(
  303. (akg.tvm.if_then_else(akg.tvm.any((h * s_h + kh) < p_top,
  304. (h * s_h + kh) > (in_h + p_top - 1),
  305. (w * s_w + kw) < p_left,
  306. (w * s_w + kw) > (in_w + p_left - 1)),
  307. akg.tvm.const(0.0, 'float16'),
  308. a_trans[n, kc1, (h * s_h + kh - p_top),
  309. (w * s_w + kw - p_left), kc0])
  310. * b_trans[(kc1 * k_h + kh) * k_w + kw, c1, c0, kc0]).astype("float32"),
  311. axis=[kc1, kh, kw, kc0]), name=output_name,
  312. attrs=conv_attrs)
  313. res_c = cast.cast(res_c, "float16")
  314. return res_c, {"dim": info, "pragma_reschedule": 1, "pragma_rmselfdep": 0}
  315. @vc_util.check_input_type((list, tuple), (list, tuple), (list, tuple), (list, tuple), (list, tuple), (list, tuple),
  316. (dict, type(None)))
  317. def conv_backprop_input(data, fmap_shape, filter_shape, pad_, stride_, dilation_, attrs=None):
  318. """
  319. Computes dx according "conv forward".
  320. Args:
  321. data (list[tvm.tensor.Tensor]): a list with length 2.
  322. data[0](consider as dy) Tensor of type float16 ,shape 5D(out_n, out_c//C0, out_h, out_w,C0)
  323. data[1](consider as w) Tensor of type float16 ,shape 4D(wC//C0*wH*wW, wN//C0, C0,C0)
  324. fmap_shape (list[int]): [fN, fC, fH, fW]
  325. filter_shape (list[int]): [wN, wC, wH, wW]
  326. pad_ (list[int]): [pad_left, pad_right, pad_top, pad_bottom]
  327. stride_ (list[int]): [stride_h, stride_w]
  328. dilation_ (list[int]): [dilation_h, dilation_w]
  329. attrs (dict): a dict with keys like conv_tile,bypass and etc.
  330. Returns:
  331. tvm.tensor.Tensor.
  332. configs.
  333. """
  334. if len(data) != 2:
  335. raise IndexError("data contains output tensor and filter tensor")
  336. vc_util.convolution_format_check(fmap_shape, filter_shape, pad_, stride_, dilation_)
  337. block_size = 16
  338. in_n, in_c, in_h, in_w = fmap_shape
  339. cout, _, w_h, w_w = filter_shape
  340. in_c = (in_c + block_size - 1) // block_size * block_size
  341. cout = (cout + block_size - 1) // block_size * block_size
  342. pad_top, pad_bottom, pad_left, pad_right = pad_
  343. stride_h, stride_w = stride_
  344. dilation_h, dilation_w = dilation_
  345. if dilation_h != 1 or dilation_w != 1:
  346. raise ValueError("The value od elements in dilation_ must be 1.")
  347. out_n = in_n
  348. out_c = cout
  349. out_h = (in_h + pad_top + pad_bottom - w_h) // stride_h + 1
  350. out_w = (in_w + pad_left + pad_right - w_w) // stride_w + 1
  351. x_shape = (out_n, out_c, out_h, out_w)
  352. w_shape = (cout, in_c, w_h, w_w)
  353. key = gen_key(fmap_shape, filter_shape, pad_, stride_, dilation_)
  354. res_c, configs = conv_backprop_input_compute(data, x_shape, w_shape, fmap_shape, pad_, stride_,
  355. block_size=block_size, attrs=attrs, key=key)
  356. return res_c, configs