You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

conv_bn1.py 15 kB

5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259
  1. #!/usr/bin/env python3
  2. # coding: utf-8
  3. # Copyright 2019 Huawei Technologies Co., Ltd
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. """operator dsl function: conv_bn1"""
  17. from functools import reduce
  18. import akg.topi
  19. import akg.tvm
  20. import akg
  21. import akg.lang.cce
  22. from akg.ops.math import cast
  23. from akg.ops.nn.conv import conv_core
  24. from akg.ops.nn.conv import conv_set_dim_func
  25. from akg.utils import validation_check as vc_util
  26. conv_bn1_set_dim_map = {
  27. str(((1, 1024, 14, 14), (2048, 1024, 1, 1), (0, 0, 0, 0), (2, 2), (1, 1), False)):
  28. ([14, 2048, 64, 96, 128], {"bypass": 1}),
  29. str(((1, 1024, 14, 14), (256, 1024, 1, 1), (0, 0, 0, 0), (1, 1), (1, 1), False)):
  30. ([14, 256, 208, 64, 112], {"bypass": 1}),
  31. str(((1, 1024, 14, 14), (512, 1024, 1, 1), (0, 0, 0, 0), (2, 2), (1, 1), False)):
  32. ([13, 144, 48, 48, 128, 13], {"bypass": 0}),
  33. str(((1, 128, 28, 28), (128, 128, 3, 3), (1, 1, 1, 1), (1, 1), (1, 1), False)):
  34. ([30, 128, 240, 48, 64, 30], {"bypass": 0}),
  35. str(((1, 128, 28, 28), (512, 128, 1, 1), (0, 0, 0, 0), (1, 1), (1, 1), False)):
  36. ([8, 256, 112, 16, 48, 28], {"bypass": 0}),
  37. str(((1, 2048, 7, 7), (512, 2048, 1, 1), (0, 0, 0, 0), (1, 1), (1, 1), False)):
  38. ([7, 160, 48, 48, 96, 7], {"bypass": 0}),
  39. str(((1, 256, 14, 14), (1024, 256, 1, 1), (0, 0, 0, 0), (1, 1), (1, 1), False)):
  40. ([14, 256, 48, 64, 256, 14], {"bypass": 0}),
  41. str(((1, 256, 14, 14), (256, 256, 3, 3), (1, 1, 1, 1), (1, 1), (1, 1), False)):
  42. ([14, 192, 64, 128, 160, 16], {"bypass": 0}),
  43. str(((1, 256, 56, 56), (128, 256, 1, 1), (0, 0, 0, 0), (2, 2), (1, 1), False)):
  44. ([7, 128, 112, 48, 16, 55], {"bypass": 0}),
  45. str(((1, 256, 56, 56), (64, 256, 1, 1), (0, 0, 0, 0), (1, 1), (1, 1), False)):
  46. ([6, 64, 224, 32, 16, 56], {"bypass": 0}),
  47. str(((1, 3, 224, 224), (64, 3, 7, 7), (2, 3, 2, 3), (2, 2), (1, 1), False)):
  48. ([97, 64, 128, 128, 64, 229], {"bypass": 0}),
  49. str(((1, 512, 28, 28), (128, 512, 1, 1), (0, 0, 0, 0), (1, 1), (1, 1), False)):
  50. ([28, 64, 48, 304, 32, 28], {"bypass": 0}),
  51. str(((1, 512, 28, 28), (256, 512, 1, 1), (0, 0, 0, 0), (2, 2), (1, 1), False)):
  52. ([27, 192, 64, 48, 160, 27], {"bypass": 0}),
  53. str(((1, 512, 7, 7), (2048, 512, 1, 1), (0, 0, 0, 0), (1, 1), (1, 1), False)):
  54. ([7, 128, 48, 176, 80, 7], {"bypass": 0}),
  55. str(((1, 512, 7, 7), (512, 512, 3, 3), (1, 1, 1, 1), (1, 1), (1, 1), False)):
  56. ([9, 512, 64, 128, 96, 9], {"bypass": 1}),
  57. str(((1, 64, 56, 56), (256, 64, 1, 1), (0, 0, 0, 0), (1, 1), (1, 1), False)):
  58. ([56, 256, 392, 16, 32], {"bypass": 1}),
  59. str(((1, 64, 56, 56), (64, 64, 1, 1), (0, 0, 0, 0), (1, 1), (1, 1), False)):
  60. ([32, 64, 224, 32, 64, 56], {"bypass": 0}),
  61. str(((1, 64, 56, 56), (64, 64, 3, 3), (1, 1, 1, 1), (1, 1), (1, 1), False)):
  62. ([10, 64, 224, 48, 48, 58], {"bypass": 1}),
  63. str(((1, 256, 56, 56), (512, 256, 1, 1), (0, 0, 0, 0), (2, 2), (1, 1), False)):
  64. ([7, 320, 112, 160, 48, 55], {"bypass": 0}),
  65. str(((1, 512, 28, 28), (1024, 512, 1, 1), (0, 0, 0, 0), (2, 2), (1, 1), False)):
  66. ([13, 496, 96, 176, 144, 27], {"bypass": 0}),
  67. str(((1, 256, 56, 56), (128, 256, 1, 1), (0, 0, 0, 0), (1, 1), (1, 1), False)):
  68. ([12, 128, 112, 64, 128, 56], {"bypass": 0}),
  69. str(((1, 512, 28, 28), (256, 512, 1, 1), (0, 0, 0, 0), (1, 1), (1, 1), False)):
  70. ([28, 176, 224, 112, 80, 28], {"bypass": 0}),
  71. str(((1, 1024, 14, 14), (512, 1024, 1, 1), (0, 0, 0, 0), (1, 1), (1, 1), False)):
  72. ([7, 384, 96, 48, 224, 14], {"bypass": 0}),
  73. str(((1, 128, 56, 56), (128, 128, 3, 3), (0, 1, 0, 1), (2, 2), (1, 1), False)):
  74. ([37, 128, 224, 96, 96, 57], {"bypass": 0}),
  75. str(((1, 256, 28, 28), (256, 256, 3, 3), (0, 1, 0, 1), (2, 2), (1, 1), False)):
  76. ([29, 256, 80, 224, 144, 29], {"bypass": 1}),
  77. str(((1, 512, 14, 14), (512, 512, 3, 3), (0, 1, 0, 1), (2, 2), (1, 1), False)):
  78. ([15, 512, 64, 64, 272, 15], {"bypass": 1}),
  79. str(((2, 1024, 14, 14), (2048, 1024, 1, 1), (0, 0, 0, 0), (2, 2), (1, 1), False)):
  80. ([13, 112, 48, 176, 80, 13], {"bypass": 0}),
  81. str(((2, 1024, 14, 14), (256, 1024, 1, 1), (0, 0, 0, 0), (1, 1), (1, 1), False)):
  82. ([14, 128, 48, 48, 64, 14], {"bypass": 0}),
  83. str(((2, 1024, 14, 14), (512, 1024, 1, 1), (0, 0, 0, 0), (2, 2), (1, 1), False)):
  84. ([13, 144, 48, 48, 128, 13], {"bypass": 0}),
  85. str(((2, 128, 28, 28), (128, 128, 3, 3), (1, 1, 1, 1), (1, 1), (1, 1), False)):
  86. ([30, 128, 240, 48, 64, 30], {"bypass": 0}),
  87. str(((2, 128, 28, 28), (512, 128, 1, 1), (0, 0, 0, 0), (1, 1), (1, 1), False)):
  88. ([8, 256, 112, 16, 48, 28], {"bypass": 0}),
  89. str(((2, 2048, 7, 7), (512, 2048, 1, 1), (0, 0, 0, 0), (1, 1), (1, 1), False)):
  90. ([7, 160, 48, 48, 96, 7], {"bypass": 0}),
  91. str(((2, 256, 14, 14), (1024, 256, 1, 1), (0, 0, 0, 0), (1, 1), (1, 1), False)):
  92. ([14, 256, 48, 64, 256, 14], {"bypass": 0}),
  93. str(((2, 256, 14, 14), (256, 256, 3, 3), (1, 1, 1, 1), (1, 1), (1, 1), False)):
  94. ([14, 192, 64, 128, 160, 16], {"bypass": 0}),
  95. str(((2, 256, 56, 56), (128, 256, 1, 1), (0, 0, 0, 0), (2, 2), (1, 1), False)):
  96. ([7, 128, 112, 48, 16, 55], {"bypass": 0}),
  97. str(((2, 256, 56, 56), (64, 256, 1, 1), (0, 0, 0, 0), (1, 1), (1, 1), False)):
  98. ([6, 64, 224, 32, 16, 56], {"bypass": 0}),
  99. str(((2, 3, 224, 224), (64, 3, 7, 7), (2, 3, 2, 3), (2, 2), (1, 1), False)):
  100. ([97, 64, 128, 128, 64, 229], {"bypass": 0}),
  101. str(((2, 512, 28, 28), (128, 512, 1, 1), (0, 0, 0, 0), (1, 1), (1, 1), False)):
  102. ([28, 64, 48, 304, 32, 28], {"bypass": 0}),
  103. str(((2, 512, 28, 28), (256, 512, 1, 1), (0, 0, 0, 0), (2, 2), (1, 1), False)):
  104. ([27, 192, 64, 48, 160, 27], {"bypass": 0}),
  105. str(((2, 512, 7, 7), (2048, 512, 1, 1), (0, 0, 0, 0), (1, 1), (1, 1), False)):
  106. ([7, 128, 48, 176, 80, 7], {"bypass": 0}),
  107. str(((2, 512, 7, 7), (512, 512, 3, 3), (1, 1, 1, 1), (1, 1), (1, 1), False)):
  108. ([9, 512, 64, 128, 96, 9], {"bypass": 1}),
  109. str(((2, 64, 56, 56), (256, 64, 1, 1), (0, 0, 0, 0), (1, 1), (1, 1), False)):
  110. ([56, 256, 784, 16, 32], {"bypass": 1}),
  111. str(((2, 64, 56, 56), (64, 64, 1, 1), (0, 0, 0, 0), (1, 1), (1, 1), False)):
  112. ([32, 64, 224, 32, 64, 56], {"bypass": 0}),
  113. str(((2, 64, 56, 56), (64, 64, 3, 3), (1, 1, 1, 1), (1, 1), (1, 1), False)):
  114. ([10, 64, 224, 48, 48, 58], {"bypass": 1}),
  115. str(((2, 256, 56, 56), (512, 256, 1, 1), (0, 0, 0, 0), (2, 2), (1, 1), False)):
  116. ([7, 320, 112, 160, 48, 55], {"bypass": 0}),
  117. str(((2, 512, 28, 28), (1024, 512, 1, 1), (0, 0, 0, 0), (2, 2), (1, 1), False)):
  118. ([13, 496, 96, 176, 144, 27], {"bypass": 0}),
  119. str(((2, 256, 56, 56), (128, 256, 1, 1), (0, 0, 0, 0), (1, 1), (1, 1), False)):
  120. ([12, 128, 224, 64, 128, 56], {"bypass": 0}),
  121. str(((2, 512, 28, 28), (256, 512, 1, 1), (0, 0, 0, 0), (1, 1), (1, 1), False)):
  122. ([28, 176, 224, 112, 80, 28], {"bypass": 0}),
  123. str(((2, 1024, 14, 14), (512, 1024, 1, 1), (0, 0, 0, 0), (1, 1), (1, 1), False)):
  124. ([7, 384, 96, 48, 224, 14], {"bypass": 0}),
  125. str(((2, 128, 56, 56), (128, 128, 3, 3), (0, 1, 0, 1), (2, 2), (1, 1), False)):
  126. ([37, 128, 224, 96, 96, 57], {"bypass": 0}),
  127. str(((2, 256, 28, 28), (256, 256, 3, 3), (0, 1, 0, 1), (2, 2), (1, 1), False)):
  128. ([29, 256, 80, 224, 144, 29], {"bypass": 1}),
  129. str(((2, 512, 14, 14), (512, 512, 3, 3), (0, 1, 0, 1), (2, 2), (1, 1), False)):
  130. ([15, 512, 64, 64, 272, 15], {"bypass": 1}),
  131. str(((32, 1024, 14, 14), (2048, 1024, 1, 1), (0, 0, 0, 0), (2, 2), (1, 1), False)):
  132. ([13, 64, 48, 128, 64, 13], {"bypass": 0}),
  133. str(((32, 1024, 14, 14), (256, 1024, 1, 1), (0, 0, 0, 0), (1, 1), (1, 1), False)):
  134. ([14, 128, 48, 48, 64, 14], {"bypass": 0}),
  135. str(((32, 1024, 14, 14), (512, 1024, 1, 1), (0, 0, 0, 0), (2, 2), (1, 1), False)):
  136. ([13, 144, 48, 48, 128, 13], {"bypass": 0}),
  137. str(((32, 128, 28, 28), (128, 128, 3, 3), (1, 1, 1, 1), (1, 1), (1, 1), False)):
  138. ([30, 128, 240, 48, 64, 30], {"bypass": 0}),
  139. str(((32, 128, 28, 28), (512, 128, 1, 1), (0, 0, 0, 0), (1, 1), (1, 1), False)):
  140. ([8, 256, 224, 80, 48, 28], {"bypass": 0}),
  141. str(((32, 2048, 7, 7), (512, 2048, 1, 1), (0, 0, 0, 0), (1, 1), (1, 1), False)):
  142. ([7, 128, 48, 112, 112, 7], {"bypass": 0}),
  143. str(((32, 256, 14, 14), (1024, 256, 1, 1), (0, 0, 0, 0), (1, 1), (1, 1), False)):
  144. ([14, 256, 48, 64, 256, 14], {"bypass": 0}),
  145. str(((32, 256, 14, 14), (256, 256, 3, 3), (1, 1, 1, 1), (1, 1), (1, 1), False)):
  146. ([16, 96, 80, 96, 96, 16], {"bypass": 0}),
  147. str(((32, 256, 56, 56), (128, 256, 1, 1), (0, 0, 0, 0), (2, 2), (1, 1), False)):
  148. ([23, 128, 112, 240, 48, 55], {"bypass": 1}),
  149. str(((32, 256, 56, 56), (64, 256, 1, 1), (0, 0, 0, 0), (1, 1), (1, 1), False)):
  150. ([26, 64, 448, 32, 16, 56], {"bypass": 0}),
  151. str(((32, 3, 224, 224), (64, 3, 7, 7), (2, 3, 2, 3), (2, 2), (1, 1), False)):
  152. ([61, 64, 224, 48, 64, 229], {"bypass": 0}),
  153. str(((32, 512, 28, 28), (128, 512, 1, 1), (0, 0, 0, 0), (1, 1), (1, 1), False)):
  154. ([8, 96, 224, 48, 16, 28], {"bypass": 0}),
  155. str(((32, 512, 28, 28), (256, 512, 1, 1), (0, 0, 0, 0), (2, 2), (1, 1), False)):
  156. ([27, 48, 48, 160, 48, 27], {"bypass": 0}),
  157. str(((32, 512, 7, 7), (2048, 512, 1, 1), (0, 0, 0, 0), (1, 1), (1, 1), False)):
  158. ([7, 864, 48, 288, 16, 7], {"bypass": 0}),
  159. str(((32, 512, 7, 7), (512, 512, 3, 3), (1, 1, 1, 1), (1, 1), (1, 1), False)):
  160. ([9, 512, 64, 128, 96, 9], {"bypass": 1}),
  161. str(((32, 64, 56, 56), (256, 64, 1, 1), (0, 0, 0, 0), (1, 1), (1, 1), False)):
  162. ([8, 64, 448, 64, 32, 56], {"bypass": 0}),
  163. str(((32, 64, 56, 56), (64, 64, 1, 1), (0, 0, 0, 0), (1, 1), (1, 1), False)):
  164. ([14, 64, 112, 64, 16, 56], {"bypass": 0}),
  165. str(((32, 64, 56, 56), (64, 64, 3, 3), (1, 1, 1, 1), (1, 1), (1, 1), False)):
  166. ([6, 64, 224, 64, 64, 58], {"bypass": 0}),
  167. str(((32, 256, 56, 56), (512, 256, 1, 1), (0, 0, 0, 0), (2, 2), (1, 1), False)):
  168. ([7, 320, 112, 160, 48, 55], {"bypass": 0}),
  169. str(((32, 512, 28, 28), (1024, 512, 1, 1), (0, 0, 0, 0), (2, 2), (1, 1), False)):
  170. ([13, 496, 96, 176, 144, 27], {"bypass": 0}),
  171. str(((32, 256, 56, 56), (128, 256, 1, 1), (0, 0, 0, 0), (1, 1), (1, 1), False)):
  172. ([2, 96, 112, 48, 96, 56], {"bypass": 0}),
  173. str(((32, 512, 28, 28), (256, 512, 1, 1), (0, 0, 0, 0), (1, 1), (1, 1), False)):
  174. ([20, 128, 16, 448, 64, 28], {"bypass": 0}),
  175. str(((32, 1024, 14, 14), (512, 1024, 1, 1), (0, 0, 0, 0), (1, 1), (1, 1), False)):
  176. ([7, 384, 48, 48, 224, 14], {"bypass": 0}),
  177. str(((32, 128, 56, 56), (128, 128, 3, 3), (0, 1, 0, 1), (2, 2), (1, 1), False)):
  178. ([9, 128, 64, 64, 128, 57], {"bypass": 0}),
  179. str(((32, 256, 28, 28), (256, 256, 3, 3), (0, 1, 0, 1), (2, 2), (1, 1), False)):
  180. ([29, 256, 80, 224, 144, 29], {"bypass": 1}),
  181. str(((32, 512, 14, 14), (512, 512, 3, 3), (0, 1, 0, 1), (2, 2), (1, 1), False)):
  182. ([15, 512, 64, 64, 272, 15], {"bypass": 1}),
  183. # alexnet
  184. str(((32, 3, 227, 227), (96, 3, 11, 11), (0, 0, 0, 0), (4, 4), (1, 1), False)):
  185. ([63, 96, 208, 32, 96, 227], {"bypass": 0}),
  186. str(((32, 96, 27, 27), (256, 96, 5, 5), (2, 2, 2, 2), (1, 1), (1, 1), False)):
  187. ([21, 160, 176, 32, 96, 31], {"bypass": 0})
  188. }
  189. @vc_util.check_input_type((list, tuple), (list, tuple), (list, tuple), (list, tuple), (list, tuple), (list, tuple),
  190. (bool, type(None)), (dict, type(None)))
  191. def conv_bn1(data, fmap_shape, filter_shape, pad, stride, dilation, use_bias=False, attrs=None):
  192. """
  193. Computes sums of 5-D convolutions and use convolution's fp32 result to compute first part of Fused_batch_norm.
  194. Fused_batch_norm's first part:
  195. \f[
  196. m = N \times H \times W \\
  197. \\mu_{tmp} = \\sum_{n, h, w}{\frac{x}{m}} \\
  198. \\sigma^2_{tmp} = \\sum_{n, h, w}{\frac{x^2}{m}}
  199. \f]
  200. Args:
  201. data (list[tvm.tensor.Tensor]): the size is 3 if use_bias else the size is 2;
  202. data[0] Tensor of type float16 ,shape 5D (fN, fC // C0, C0, fH, fW)
  203. data[1] Tensor of type float16 ,shape 4D (wC // C0 * wH * wW, wN // C0, C0, C0)
  204. data[2] Tensor of type float16 ,shape 5D (1, wN // C0, 1, 1, 16)
  205. fmap_shape (list[int]): [fN, fC, fH, fW]
  206. filter_shape (list[int]): [wN, wC, wH, wW]
  207. pad (list[int]): [pad_top, pad_bottom, pad_left, pad_right]
  208. stride (list[int]): [stride_h, stride_w]
  209. dilation (list[int]): [dilation_h, dilation_w]
  210. use_bias (bool): bool var.
  211. attrs (dict): dict with keys for example: conv_tile,bypass
  212. Returns:
  213. tvm.tensor.Tensor of same type as data, shape is 5D(oN, oC // C0, oH, oW, C0)
  214. """
  215. if use_bias:
  216. raise ValueError("do not support bias yet !!!")
  217. block_size = 16
  218. dim_info, conv_tile, bypass, _ = conv_set_dim_func(fmap_shape, filter_shape, pad, stride, dilation, use_bias,
  219. block_size, attrs, conv_bn1_set_dim_map)
  220. if attrs is None:
  221. attrs = {"conv_tile": conv_tile, "bypass": bypass}
  222. else:
  223. attrs['conv_tile'] = conv_tile
  224. attrs['bypass'] = bypass
  225. conv_res_32 = conv_core(data, fmap_shape, filter_shape, pad, stride, dilation, use_bias, attrs)
  226. conv_res_16 = cast.cast(conv_res_32, "float16")
  227. axes = [3, 2, 0]
  228. conv_res_32_shape = [x.value for x in conv_res_32.shape]
  229. num = reduce(lambda i, j: i * j, [conv_res_32_shape[i] for i in axes])
  230. avg_num = round(float(1) / float(num), 12)
  231. res_sum = akg.topi.sum(conv_res_32, axes, keepdims=True)
  232. mean = akg.lang.cce.vmuls(res_sum, avg_num)
  233. res_square = akg.tvm.compute(conv_res_32.shape, lambda *i: conv_res_32[i] * conv_res_32[i], name="res_square")
  234. square_sum = akg.topi.sum(res_square, axes, keepdims=True)
  235. var_part = akg.lang.cce.vmuls(square_sum, avg_num)
  236. # need pragma_force_rmselfdep to enable multicore using atomic add
  237. # because default pragma_rmselfdep=1 will disable multicore of reduce axes
  238. attrs = {"dim": dim_info, "pragma_reschedule": 1, "enable_bisect_optimize": 0,
  239. "pragma_rmselfdep": 0, "pragma_force_rmselfdep": 1}
  240. return conv_res_16, var_part, mean, attrs

AKG(Auto Kernel Generator)对深度神经网络中的算子进行优化,并提供特定模式下的算子自动融合功能。AKG与MindSpore的图算融合功能协同工作,可提升在不同硬件后端上运行网络的性能。