You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

kernel_compiler.py 13 kB

5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324
  1. # Copyright 2019 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """Compile kernel module for operator"""
  15. import os
  16. from typing import NamedTuple
  17. from tests.common.base import TestBase
  18. from akg import build_module
  19. from akg.utils import kernel_exec as utils
  20. from akg.utils import custom_tiling as ct_util
  21. from akg.ops.nn import conv_bn1
  22. from akg.ops.nn import conv, conv_backprop_input, conv_backprop_filter, batchmatmul
  23. from akg.ops.nn import matmul
  24. from tests.common.test_run import batchmatmul_run, matmul_run
  25. from akg.auto_tune.type_definitions import ConvDesc, ConvBackpropDesc, MatmulCubeDesc, ConvConfig, ConvBackpropInputConfig, ConvBackpropFilterConfig, MatmulCubeConfig
  26. def gen_kernel_conv(op_desc: ConvDesc, input_shape, index_table,
  27. config: ConvConfig = None, idx=None, gen_tiling_spaces=False):
  28. """Compile kernel module for conv"""
  29. if index_table is not None:
  30. raise RuntimeError('index_table should be none')
  31. kernel_name = "conv_poly"
  32. if idx is not None:
  33. kernel_name += str(idx)
  34. if config is None:
  35. attrs = {'dim': ""}
  36. else:
  37. tile_hh = config.tile_h
  38. tile_coco = config.tile_co
  39. tile_mm = config.tile_m
  40. tile_kk = config.tile_k
  41. tile_nn = config.tile_n
  42. tile_ww = config.tile_w
  43. tiling_param = [tile_hh, tile_coco, tile_mm, tile_kk, tile_nn, tile_ww]
  44. attrs = {'conv_tile': tiling_param, 'bypass': config.bypass}
  45. if op_desc.use_bias:
  46. shape = [input_shape[0], input_shape[1], input_shape[2]]
  47. else:
  48. shape = [input_shape[0], input_shape[1]]
  49. conv_dtype = 'float16'
  50. return utils.op_build(conv.conv, [shape], [conv_dtype],
  51. op_attrs=[op_desc.fmap_shape, op_desc.filter_shape, op_desc.pad, op_desc.stride,
  52. op_desc.dilation, op_desc.use_bias, attrs],
  53. kernel_name=kernel_name, attrs=attrs, polyhedral=True, tuning=gen_tiling_spaces)
  54. def gen_kernel_conv_bn1(op_desc: ConvDesc, input_shape, index_table, config: ConvConfig = None,
  55. idx=None, gen_tiling_spaces=False):
  56. """Compile kernel module for conv_bn1"""
  57. if index_table is not None:
  58. raise RuntimeError('index_table should be none')
  59. kernel_name = "conv_bn1_poly"
  60. if idx is not None:
  61. kernel_name += str(idx)
  62. if config is None:
  63. attrs = {'dim': ""}
  64. else:
  65. tile_hh = config.tile_h
  66. tile_coco = config.tile_co
  67. tile_mm = config.tile_m
  68. tile_kk = config.tile_k
  69. tile_nn = config.tile_n
  70. tile_ww = config.tile_w
  71. tiling_param = [tile_hh, tile_coco, tile_mm, tile_kk, tile_nn, tile_ww]
  72. attrs = {'conv_tile': tiling_param, 'bypass': config.bypass}
  73. if op_desc.use_bias:
  74. shape = [input_shape[0], input_shape[1], input_shape[2]]
  75. else:
  76. shape = [input_shape[0], input_shape[1]]
  77. conv_dtype = 'float16'
  78. return utils.op_build(conv_bn1.conv_bn1, [shape], [conv_dtype],
  79. op_attrs=[op_desc.fmap_shape, op_desc.filter_shape, op_desc.pad, op_desc.stride,
  80. op_desc.dilation, op_desc.use_bias, attrs],
  81. kernel_name=kernel_name, attrs=attrs, polyhedral=True, tuning=gen_tiling_spaces)
  82. def get_matmul_cube_attrs(op_desc, config):
  83. tiling_param = []
  84. for _ in range(len(op_desc.x_shape) - 2):
  85. tiling_param.append((1, 1))
  86. if config.n_l1 > 0:
  87. tiling_param.append((config.n_l1, config.n_l0))
  88. if config.m_l1 > 0:
  89. tiling_param.append((config.m_l1, config.m_l0))
  90. tiling_param.extend([(16, 16), (16, 16), (config.k_l1, config.k_l0)])
  91. dim_info = ct_util.set_dims(tuple(tiling_param))
  92. attrs = {'dim': dim_info, 'bypass': config.bypass}
  93. return attrs
  94. def gen_kernel_matmul_cube(op_desc: MatmulCubeDesc, _, index_table,
  95. config: MatmulCubeConfig = None, idx=None, gen_tiling_spaces=False):
  96. """Compile kernel module for matmul_cube"""
  97. if index_table is not None:
  98. raise RuntimeError('index_table should be none')
  99. kernel_name = "matmul_cube_poly"
  100. if idx is not None:
  101. kernel_name += str(idx)
  102. if config is None:
  103. attrs = {'dim': ""}
  104. else:
  105. attrs = get_matmul_cube_attrs(op_desc, config)
  106. return matmul_run.matmul_compile(op_desc.x_shape, op_desc.y_shape, op_desc.bias, op_desc.left_format,
  107. op_desc.right_format, op_desc.out_format, op_desc.adj_x, op_desc.adj_y,
  108. op_desc.dtype, op_desc.bias_dtype, op_desc.out_dtype, kernel_name,
  109. attrs, tuning=gen_tiling_spaces)
  110. def gen_kernel_conv_backprop_input(op_desc: ConvBackpropDesc, _, index_table, config: ConvBackpropInputConfig = None,
  111. idx=None, gen_tiling_spaces=False):
  112. """Compile kernel module for conv_backprop_input"""
  113. if index_table is not None:
  114. raise RuntimeError('index_table should be none')
  115. kernel_name = "conv_backprop_input_poly"
  116. if idx is not None:
  117. kernel_name += str(idx)
  118. if config is None:
  119. attrs = {'dim': ""}
  120. else:
  121. tile_hh = config.tile_h
  122. tile_coco = config.tile_co
  123. tile_mm = config.tile_m
  124. tile_kk = config.tile_k
  125. tile_nn = config.tile_n
  126. tile_ww = config.tile_w
  127. tiling_param = [tile_hh, tile_coco, tile_mm, tile_kk, tile_nn, tile_ww]
  128. attrs = {'conv_tile': tiling_param}
  129. conv_dtype = 'float16'
  130. block_size = 16
  131. in_n, in_c, in_h, in_w = op_desc.fmap_shape
  132. cout, _, w_h, w_w = op_desc.filter_shape
  133. in_c = (in_c + block_size - 1) // block_size * block_size
  134. cout = (cout + block_size - 1) // block_size * block_size
  135. pad_top, pad_bottom, pad_left, pad_right = op_desc.pad
  136. stride_h, stride_w = op_desc.stride
  137. out_n = in_n
  138. out_c = cout
  139. out_h = (in_h + pad_top + pad_bottom - w_h) // stride_h + 1
  140. out_w = (in_w + pad_left + pad_right - w_w) // stride_w + 1
  141. x_shape = (out_n, out_c, out_h, out_w)
  142. w_shape = (cout, in_c, w_h, w_w)
  143. in_nn, in_cc, in_hh, in_ww = x_shape
  144. input_shape_nc1hwc0 = (in_nn, in_cc // block_size, in_hh, in_ww, block_size)
  145. k_n, k_c, k_h, k_w = w_shape
  146. kernel_shape_nc1hwc0 = (k_n, k_c // block_size, k_h, k_w, block_size)
  147. k_n, _, k_h, k_w, _ = kernel_shape_nc1hwc0
  148. kernel_shape_fractal = (k_c // block_size * k_h * k_w, k_n // block_size, block_size, block_size)
  149. shape = [input_shape_nc1hwc0, kernel_shape_fractal]
  150. return utils.op_build(conv_backprop_input.conv_backprop_input, [shape], [conv_dtype],
  151. op_attrs=[op_desc.fmap_shape, op_desc.filter_shape, op_desc.pad,
  152. op_desc.stride, op_desc.dilation, attrs],
  153. kernel_name=kernel_name, attrs=attrs, polyhedral=True, tuning=gen_tiling_spaces)
  154. def gen_kernel_conv_backprop_filter(op_desc: ConvBackpropDesc, _, index_table, config: ConvBackpropFilterConfig = None,
  155. idx=None, gen_tiling_spaces=False):
  156. """Compile kernel module for conv_backprop_filter"""
  157. if index_table is not None:
  158. raise RuntimeError('index_table should be none')
  159. kernel_name = "conv_backprop_filter_poly"
  160. if idx is not None:
  161. kernel_name += str(idx)
  162. if config is None:
  163. attrs = {'dim': ""}
  164. else:
  165. tile_cici = config.tile_ci
  166. tile_khkh = config.tile_kh
  167. tile_kwkw = config.tile_kw
  168. tile_coco = config.tile_co
  169. tile_bb = config.tile_batch
  170. tile_hh = config.tile_h
  171. tile_ww = config.tile_w
  172. tile_mm = config.tile_m
  173. tile_kk = config.tile_k
  174. tile_nn = config.tile_n
  175. tiling_param = [tile_cici, tile_khkh, tile_kwkw, tile_coco, tile_bb, tile_hh, tile_ww,
  176. tile_mm, tile_kk, tile_nn]
  177. attrs = {'conv_tile': tiling_param}
  178. conv_dtype = 'float16'
  179. block_size = 16
  180. in_n, in_c, in_h, in_w = op_desc.fmap_shape
  181. cout, _, w_h, w_w = op_desc.filter_shape
  182. in_c = (in_c + block_size - 1) // block_size * block_size
  183. cout = (cout + block_size - 1) // block_size * block_size
  184. pad_top, pad_bottom, pad_left, pad_right = op_desc.pad
  185. stride_h, stride_w = op_desc.stride
  186. out_n = in_n
  187. out_c = cout
  188. out_h = (in_h + pad_top + pad_bottom - w_h) // stride_h + 1
  189. out_w = (in_w + pad_left + pad_right - w_w) // stride_w + 1
  190. x_shape = (in_n, in_c, in_h, in_w)
  191. y_shape = (out_n, out_c, out_h, out_w)
  192. in_n, in_c, in_h, in_w = x_shape
  193. input_shape_nc1hwc0 = (in_n, in_c // block_size, in_h, in_w, block_size)
  194. o_n, o_c, o_h, o_w = y_shape
  195. kernel_shape_nc1hwc0 = (o_n, o_c // block_size, o_h, o_w, block_size)
  196. o_n, o_c1, o_h, o_w, o_c0 = kernel_shape_nc1hwc0
  197. mo = (o_h * o_w + block_size - 1) // block_size
  198. mi = block_size
  199. kernel_shape_fractal = (o_n, o_c1, mo, mi, o_c0)
  200. input_shape = [kernel_shape_fractal, input_shape_nc1hwc0]
  201. return utils.op_build(conv_backprop_filter.conv_backprop_filter, [input_shape], [conv_dtype],
  202. op_attrs=[op_desc.fmap_shape, op_desc.filter_shape, op_desc.pad,
  203. op_desc.stride, op_desc.dilation, attrs],
  204. kernel_name=kernel_name, attrs=attrs, polyhedral=True, tuning=gen_tiling_spaces)
  205. def gen_kernel_for_vector(op_desc, _, index_table=None, config: NamedTuple = None, idx=None, gen_tiling_spaces=False):
  206. """Compile kernel module for vector"""
  207. test_base = TestBase()
  208. test_base.params_init(op_desc[0][0:4] + str(idx), os.getcwd())
  209. kernel_name = "poly_"
  210. if idx is not None:
  211. kernel_name += str(idx)
  212. if config is None:
  213. attrs = {'dim': ""}
  214. else:
  215. tiling = [[getattr(config, name), 1] for name in getattr(config, '_fields') if name.startswith('tiling')]
  216. tiling_param = []
  217. for i, element in enumerate(tiling):
  218. tiling_param.append(index_table[i] + element)
  219. dim_info = ct_util.set_dims(tuple(tiling_param))
  220. attrs = {'dim': dim_info}
  221. _, func, args, kwargs = test_base.ana_args(op_desc)
  222. if 'attrs' in kwargs.keys():
  223. kwargs['attrs']['dim'] = attrs['dim']
  224. kwargs['attrs']['tuning'] = gen_tiling_spaces
  225. kwargs['attrs']['kernel_name'] = kernel_name
  226. else:
  227. for _, arg_ in enumerate(args):
  228. if isinstance(arg_, dict):
  229. arg_['dim'] = attrs['dim']
  230. arg_['tuning'] = gen_tiling_spaces
  231. arg_['kernel_name'] = kernel_name
  232. break
  233. try:
  234. if gen_tiling_spaces:
  235. mod, expect, param_for_mod = func(*args, **kwargs)
  236. mod = list(mod)
  237. mod.append(expect)
  238. mod.append(param_for_mod)
  239. else:
  240. mod = func(*args, **kwargs)
  241. except BaseException as e:
  242. print("Compile ERROR message:", e)
  243. print(func)
  244. print("Compile ERROR")
  245. raise Exception("Compile ERROR")
  246. return mod
  247. _compile_kernel_func = {
  248. 'conv': gen_kernel_conv,
  249. 'conv_bn1': gen_kernel_conv_bn1,
  250. 'conv_backprop_input': gen_kernel_conv_backprop_input,
  251. 'conv_backprop_filter': gen_kernel_conv_backprop_filter,
  252. 'matmul': gen_kernel_matmul_cube,
  253. }
  254. def compile_kernel(op_type: str, op_desc: NamedTuple, input_shape=None, index_table=None,
  255. config_param: NamedTuple = None, idx: int = None, gen_tiling_spaces: bool = False):
  256. """Generate kernel module for operator
  257. Parameters
  258. op_type: str
  259. operator name
  260. op_desc: NamedTuple
  261. operator definition parameters
  262. config_param: NameTuple
  263. operator config parameters
  264. idx: int
  265. operator idx(th) kernel
  266. gen_tiling_spaces: bool
  267. parameter passed to utils.op_build, whether to get spaces instead of stmt
  268. ----------
  269. Returns:
  270. kernel if gen_tiling_spaces == False else np.ndarray
  271. """
  272. gen_func = _compile_kernel_func.get(op_type, None)
  273. if gen_func is None:
  274. gen_func = gen_kernel_for_vector
  275. if gen_tiling_spaces:
  276. mod, key, expect, input_for_mod = gen_func(op_desc, input_shape, index_table, config_param,
  277. idx, gen_tiling_spaces)
  278. else:
  279. mod = gen_func(op_desc, input_shape, index_table, config_param, idx, gen_tiling_spaces)
  280. return [build_module.tuning_spaces, key, expect, input_for_mod] if gen_tiling_spaces else mod