You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

kernel_compiler.py 13 kB

5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328
  1. # Copyright 2019 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """Compile kernel module for operator"""
  15. import os
  16. from typing import NamedTuple
  17. from tests.common.base import TestBase
  18. from akg import build_module
  19. from akg.utils import kernel_exec as utils
  20. from akg.utils import custom_tiling as ct_util
  21. from akg.ops.nn import conv_bn1
  22. from akg.ops.nn import conv, conv_backprop_input, conv_backprop_filter, batchmatmul
  23. from akg.ops.nn import matmul
  24. from tests.common.test_run import batchmatmul_run, matmul_run
  25. from akg.auto_tune.type_definitions import ConvDesc, ConvBackpropDesc, MatmulCubeDesc, ConvConfig, ConvBackpropInputConfig, ConvBackpropFilterConfig, MatmulCubeConfig
  26. def gen_kernel_conv(op_desc: ConvDesc, input_shape, index_table,
  27. config: ConvConfig = None, idx=None, gen_tiling_spaces=False):
  28. """Compile kernel module for conv"""
  29. if index_table is not None:
  30. raise RuntimeError('index_table should be none')
  31. kernel_name = "conv_poly"
  32. if idx is not None:
  33. kernel_name += str(idx)
  34. if config is None:
  35. attrs = {'dim': ""}
  36. else:
  37. tile_hh = config.tile_h
  38. tile_coco = config.tile_co
  39. tile_mm = config.tile_m
  40. tile_kk = config.tile_k
  41. tile_nn = config.tile_n
  42. tile_ww = config.tile_w
  43. tiling_param = [tile_hh, tile_coco, tile_mm, tile_kk, tile_nn, tile_ww]
  44. attrs = {'conv_tile': tiling_param, 'bypass': config.bypass}
  45. if op_desc.use_bias:
  46. shape = [input_shape[0], input_shape[1], input_shape[2]]
  47. else:
  48. shape = [input_shape[0], input_shape[1]]
  49. conv_dtype = 'float16'
  50. return utils.op_build(conv.conv, [shape], [conv_dtype],
  51. op_attrs=[op_desc.fmap_shape, op_desc.filter_shape, op_desc.pad, op_desc.stride,
  52. op_desc.dilation, op_desc.use_bias, attrs],
  53. kernel_name=kernel_name, attrs=attrs, polyhedral=True, tuning=gen_tiling_spaces)
  54. def gen_kernel_conv_bn1(op_desc: ConvDesc, input_shape, index_table, config: ConvConfig = None,
  55. idx=None, gen_tiling_spaces=False):
  56. """Compile kernel module for conv_bn1"""
  57. if index_table is not None:
  58. raise RuntimeError('index_table should be none')
  59. kernel_name = "conv_bn1_poly"
  60. if idx is not None:
  61. kernel_name += str(idx)
  62. if config is None:
  63. attrs = {'dim': ""}
  64. else:
  65. tile_hh = config.tile_h
  66. tile_coco = config.tile_co
  67. tile_mm = config.tile_m
  68. tile_kk = config.tile_k
  69. tile_nn = config.tile_n
  70. tile_ww = config.tile_w
  71. tiling_param = [tile_hh, tile_coco, tile_mm, tile_kk, tile_nn, tile_ww]
  72. attrs = {'conv_tile': tiling_param, 'bypass': config.bypass}
  73. if op_desc.use_bias:
  74. shape = [input_shape[0], input_shape[1], input_shape[2]]
  75. else:
  76. shape = [input_shape[0], input_shape[1]]
  77. conv_dtype = 'float16'
  78. return utils.op_build(conv_bn1.conv_bn1, [shape], [conv_dtype],
  79. op_attrs=[op_desc.fmap_shape, op_desc.filter_shape, op_desc.pad, op_desc.stride,
  80. op_desc.dilation, op_desc.use_bias, attrs],
  81. kernel_name=kernel_name, attrs=attrs, polyhedral=True, tuning=gen_tiling_spaces)
  82. def get_matmul_cube_attrs(op_desc, config):
  83. tiling_param = []
  84. for _ in range(len(op_desc.x_shape) - 2):
  85. tiling_param.append((1, 1))
  86. if config.n_l1 > 0:
  87. tiling_param.append((config.n_l1, config.n_l0))
  88. if config.m_l1 > 0:
  89. tiling_param.append((config.m_l1, config.m_l0))
  90. tiling_param.extend([(16, 16), (16, 16), (config.k_l1, config.k_l0)])
  91. dim_info = ct_util.set_dims(tuple(tiling_param))
  92. attrs = {'dim': dim_info}
  93. tuning_dict = config._asdict()
  94. for key, value in tuning_dict.items():
  95. if key not in ['n_l1', 'n_l0', 'm_l1', 'm_l0', 'k_l1', 'k_l0']:
  96. attrs[key] = value
  97. return attrs
  98. def gen_kernel_matmul_cube(op_desc: MatmulCubeDesc, _, index_table,
  99. config: MatmulCubeConfig = None, idx=None, gen_tiling_spaces=False):
  100. """Compile kernel module for matmul_cube"""
  101. if index_table is not None:
  102. raise RuntimeError('index_table should be none')
  103. kernel_name = "matmul_cube_poly"
  104. if idx is not None:
  105. kernel_name += str(idx)
  106. if config is None:
  107. attrs = {'dim': ""}
  108. else:
  109. attrs = get_matmul_cube_attrs(op_desc, config)
  110. return matmul_run.matmul_compile(op_desc.x_shape, op_desc.y_shape, op_desc.bias, op_desc.left_format,
  111. op_desc.right_format, op_desc.out_format, op_desc.adj_x, op_desc.adj_y,
  112. op_desc.dtype, op_desc.bias_dtype, op_desc.out_dtype, kernel_name,
  113. attrs, tuning=gen_tiling_spaces)
  114. def gen_kernel_conv_backprop_input(op_desc: ConvBackpropDesc, _, index_table, config: ConvBackpropInputConfig = None,
  115. idx=None, gen_tiling_spaces=False):
  116. """Compile kernel module for conv_backprop_input"""
  117. if index_table is not None:
  118. raise RuntimeError('index_table should be none')
  119. kernel_name = "conv_backprop_input_poly"
  120. if idx is not None:
  121. kernel_name += str(idx)
  122. if config is None:
  123. attrs = {'dim': ""}
  124. else:
  125. tile_hh = config.tile_h
  126. tile_coco = config.tile_co
  127. tile_mm = config.tile_m
  128. tile_kk = config.tile_k
  129. tile_nn = config.tile_n
  130. tile_ww = config.tile_w
  131. tiling_param = [tile_hh, tile_coco, tile_mm, tile_kk, tile_nn, tile_ww]
  132. attrs = {'conv_tile': tiling_param}
  133. conv_dtype = 'float16'
  134. block_size = 16
  135. in_n, in_c, in_h, in_w = op_desc.fmap_shape
  136. cout, _, w_h, w_w = op_desc.filter_shape
  137. in_c = (in_c + block_size - 1) // block_size * block_size
  138. cout = (cout + block_size - 1) // block_size * block_size
  139. pad_top, pad_bottom, pad_left, pad_right = op_desc.pad
  140. stride_h, stride_w = op_desc.stride
  141. out_n = in_n
  142. out_c = cout
  143. out_h = (in_h + pad_top + pad_bottom - w_h) // stride_h + 1
  144. out_w = (in_w + pad_left + pad_right - w_w) // stride_w + 1
  145. x_shape = (out_n, out_c, out_h, out_w)
  146. w_shape = (cout, in_c, w_h, w_w)
  147. in_nn, in_cc, in_hh, in_ww = x_shape
  148. input_shape_nc1hwc0 = (in_nn, in_cc // block_size, in_hh, in_ww, block_size)
  149. k_n, k_c, k_h, k_w = w_shape
  150. kernel_shape_nc1hwc0 = (k_n, k_c // block_size, k_h, k_w, block_size)
  151. k_n, _, k_h, k_w, _ = kernel_shape_nc1hwc0
  152. kernel_shape_fractal = (k_c // block_size * k_h * k_w, k_n // block_size, block_size, block_size)
  153. shape = [input_shape_nc1hwc0, kernel_shape_fractal]
  154. return utils.op_build(conv_backprop_input.conv_backprop_input, [shape], [conv_dtype],
  155. op_attrs=[op_desc.fmap_shape, op_desc.filter_shape, op_desc.pad,
  156. op_desc.stride, op_desc.dilation, attrs],
  157. kernel_name=kernel_name, attrs=attrs, polyhedral=True, tuning=gen_tiling_spaces)
  158. def gen_kernel_conv_backprop_filter(op_desc: ConvBackpropDesc, _, index_table, config: ConvBackpropFilterConfig = None,
  159. idx=None, gen_tiling_spaces=False):
  160. """Compile kernel module for conv_backprop_filter"""
  161. if index_table is not None:
  162. raise RuntimeError('index_table should be none')
  163. kernel_name = "conv_backprop_filter_poly"
  164. if idx is not None:
  165. kernel_name += str(idx)
  166. if config is None:
  167. attrs = {'dim': ""}
  168. else:
  169. tile_cici = config.tile_ci
  170. tile_khkh = config.tile_kh
  171. tile_kwkw = config.tile_kw
  172. tile_coco = config.tile_co
  173. tile_bb = config.tile_batch
  174. tile_hh = config.tile_h
  175. tile_ww = config.tile_w
  176. tile_mm = config.tile_m
  177. tile_kk = config.tile_k
  178. tile_nn = config.tile_n
  179. tiling_param = [tile_cici, tile_khkh, tile_kwkw, tile_coco, tile_bb, tile_hh, tile_ww,
  180. tile_mm, tile_kk, tile_nn]
  181. attrs = {'conv_tile': tiling_param}
  182. conv_dtype = 'float16'
  183. block_size = 16
  184. in_n, in_c, in_h, in_w = op_desc.fmap_shape
  185. cout, _, w_h, w_w = op_desc.filter_shape
  186. in_c = (in_c + block_size - 1) // block_size * block_size
  187. cout = (cout + block_size - 1) // block_size * block_size
  188. pad_top, pad_bottom, pad_left, pad_right = op_desc.pad
  189. stride_h, stride_w = op_desc.stride
  190. out_n = in_n
  191. out_c = cout
  192. out_h = (in_h + pad_top + pad_bottom - w_h) // stride_h + 1
  193. out_w = (in_w + pad_left + pad_right - w_w) // stride_w + 1
  194. x_shape = (in_n, in_c, in_h, in_w)
  195. y_shape = (out_n, out_c, out_h, out_w)
  196. in_n, in_c, in_h, in_w = x_shape
  197. input_shape_nc1hwc0 = (in_n, in_c // block_size, in_h, in_w, block_size)
  198. o_n, o_c, o_h, o_w = y_shape
  199. kernel_shape_nc1hwc0 = (o_n, o_c // block_size, o_h, o_w, block_size)
  200. o_n, o_c1, o_h, o_w, o_c0 = kernel_shape_nc1hwc0
  201. mo = (o_h * o_w + block_size - 1) // block_size
  202. mi = block_size
  203. kernel_shape_fractal = (o_n, o_c1, mo, mi, o_c0)
  204. input_shape = [kernel_shape_fractal, input_shape_nc1hwc0]
  205. return utils.op_build(conv_backprop_filter.conv_backprop_filter, [input_shape], [conv_dtype],
  206. op_attrs=[op_desc.fmap_shape, op_desc.filter_shape, op_desc.pad,
  207. op_desc.stride, op_desc.dilation, attrs],
  208. kernel_name=kernel_name, attrs=attrs, polyhedral=True, tuning=gen_tiling_spaces)
  209. def gen_kernel_for_vector(op_desc, _, index_table=None, config: NamedTuple = None, idx=None, gen_tiling_spaces=False):
  210. """Compile kernel module for vector"""
  211. test_base = TestBase()
  212. test_base.params_init(op_desc[0][0:4] + str(idx), os.getcwd())
  213. kernel_name = "poly_"
  214. if idx is not None:
  215. kernel_name += str(idx)
  216. if config is None:
  217. attrs = {'dim': ""}
  218. else:
  219. tiling = [[getattr(config, name), 1] for name in getattr(config, '_fields') if name.startswith('tiling')]
  220. tiling_param = []
  221. for i, element in enumerate(tiling):
  222. tiling_param.append(index_table[i] + element)
  223. dim_info = ct_util.set_dims(tuple(tiling_param))
  224. attrs = {'dim': dim_info}
  225. _, func, args, kwargs = test_base.ana_args(op_desc)
  226. if 'attrs' in kwargs.keys():
  227. kwargs['attrs']['dim'] = attrs['dim']
  228. kwargs['attrs']['tuning'] = gen_tiling_spaces
  229. kwargs['attrs']['kernel_name'] = kernel_name
  230. else:
  231. for _, arg_ in enumerate(args):
  232. if isinstance(arg_, dict):
  233. arg_['dim'] = attrs['dim']
  234. arg_['tuning'] = gen_tiling_spaces
  235. arg_['kernel_name'] = kernel_name
  236. break
  237. try:
  238. if gen_tiling_spaces:
  239. mod, expect, param_for_mod = func(*args, **kwargs)
  240. mod = list(mod)
  241. mod.append(expect)
  242. mod.append(param_for_mod)
  243. else:
  244. mod = func(*args, **kwargs)
  245. except BaseException as e:
  246. print("Compile ERROR message:", e)
  247. print(func)
  248. print("Compile ERROR")
  249. raise Exception("Compile ERROR")
  250. return mod
  251. _compile_kernel_func = {
  252. 'conv': gen_kernel_conv,
  253. 'conv_bn1': gen_kernel_conv_bn1,
  254. 'conv_backprop_input': gen_kernel_conv_backprop_input,
  255. 'conv_backprop_filter': gen_kernel_conv_backprop_filter,
  256. 'matmul': gen_kernel_matmul_cube,
  257. }
  258. def compile_kernel(op_type: str, op_desc: NamedTuple, input_shape=None, index_table=None,
  259. config_param: NamedTuple = None, idx: int = None, gen_tiling_spaces: bool = False):
  260. """Generate kernel module for operator
  261. Parameters
  262. op_type: str
  263. operator name
  264. op_desc: NamedTuple
  265. operator definition parameters
  266. config_param: NameTuple
  267. operator config parameters
  268. idx: int
  269. operator idx(th) kernel
  270. gen_tiling_spaces: bool
  271. parameter passed to utils.op_build, whether to get spaces instead of stmt
  272. ----------
  273. Returns:
  274. kernel if gen_tiling_spaces == False else np.ndarray
  275. """
  276. gen_func = _compile_kernel_func.get(op_type, None)
  277. if gen_func is None:
  278. gen_func = gen_kernel_for_vector
  279. if gen_tiling_spaces:
  280. mod, key, expect, input_for_mod = gen_func(op_desc, input_shape, index_table, config_param,
  281. idx, gen_tiling_spaces)
  282. else:
  283. mod = gen_func(op_desc, input_shape, index_table, config_param, idx, gen_tiling_spaces)
  284. return [build_module.tuning_spaces, key, expect, input_for_mod] if gen_tiling_spaces else mod