You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

space_generators.py 27 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626
  1. # Copyright 2019 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """space generating functions for operators"""
  15. from functools import partial
  16. from typing import NamedTuple
  17. from collections import namedtuple
  18. from tests.common.test_run import matmul_run
  19. from akg.utils import validation_check as vc_util
  20. from akg.auto_tune.type_definitions import ConvDesc, ConvBackpropDesc, MatmulCubeDesc, ConvConfig, ConvBackpropInputConfig, ConvBackpropFilterConfig, MatmulCubeConfig
  21. from akg.auto_tune.space import ListConfigSpace
  22. from akg.auto_tune.kernel_compiler import compile_kernel
  23. def _get_space_vector(op_type: str, op_desc):
  24. """get config space of vector operator"""
  25. space_res, key, expect, input_for_mod = compile_kernel(op_type, op_desc, None, None, None, 0,
  26. gen_tiling_spaces=True)
  27. if space_res is None:
  28. raise RuntimeError('no space returned')
  29. if 'index' not in space_res or 'tuning_space' not in space_res:
  30. raise RuntimeError('invalid space returned')
  31. index_table = space_res['index']
  32. tiling_spaces = space_res['tuning_space']
  33. if not tiling_spaces:
  34. raise RuntimeError('empty tiling spaces')
  35. dim_names = ['tiling_' + str(i) for i in range(len(tiling_spaces[0]))]
  36. input_type = namedtuple(op_type, dim_names)
  37. space = ListConfigSpace(input_type)
  38. for tiling_space in tiling_spaces:
  39. config = input_type(*tiling_space)
  40. space.add(config)
  41. return index_table, space, key, expect, input_for_mod
  42. def _get_space_conv(op_desc: ConvDesc, tunning_attrs):
  43. """get config space of convolution"""
  44. if not isinstance(op_desc, ConvDesc):
  45. raise TypeError('op_desc must be ConvDesc')
  46. stride_ = op_desc.stride
  47. pad_ = op_desc.pad
  48. dilation_ = op_desc.dilation
  49. vc_util.convolution_format_check(op_desc.fmap_shape, op_desc.filter_shape, pad_, stride_, dilation_)
  50. config_space = ListConfigSpace(ConvConfig)
  51. # if double buff is not enabled, set it's value to 1
  52. size_scale = 1
  53. l1_max_size = (1024 * 1024) // size_scale
  54. l0a_max_size = (64 * 1024) // size_scale
  55. l0b_max_size = (64 * 1024) // size_scale
  56. l0c_max_size = ((256 - 8) * 1024) // size_scale // 2
  57. _, in_c, in_h, in_w = op_desc.fmap_shape
  58. k_n, _, k_h, k_w = op_desc.filter_shape
  59. padding = (pad_[0], pad_[1], pad_[2], pad_[3])
  60. p_top, p_bottom, p_left, p_right = padding
  61. s_h, s_w = stride_
  62. in_c = ((in_c - 1) // 16 + 1) * 16
  63. tile_c = in_c
  64. tile_co_start = 16
  65. data_len = 2
  66. h_max = in_h + p_top + p_bottom
  67. win_h = (h_max - k_h) // s_h + 1
  68. h_max = (h_max - k_h) // s_h * s_h + k_h
  69. w_max = in_w + p_left + p_right
  70. win_w = (w_max - k_w) // s_w + 1
  71. w_max = (w_max - k_w) // s_w * s_w + k_w
  72. bypass_options = [0, 1]
  73. for bypass in bypass_options:
  74. for tile_h in range(h_max, k_h - 1, -s_h):
  75. size_h = tile_h
  76. if tile_h == h_max:
  77. w_range = range(w_max, k_w - 1, -s_w)
  78. size_h = in_h
  79. else:
  80. w_range = [w_max]
  81. win_tile_h = (tile_h - k_h) // s_h + 1
  82. h_tiles = (win_h + win_tile_h - 1) // win_tile_h
  83. if h_tiles == 2:
  84. size_h = max(tile_h - p_top, in_h + p_top - tile_h + k_h - s_h)
  85. for tile_w in w_range:
  86. size_w = tile_w
  87. if size_w == w_max:
  88. size_w = in_w
  89. else:
  90. win_tile_w = (tile_w - k_w) // s_w + 1
  91. w_tiles = (win_w + win_tile_w - 1) // win_tile_w
  92. if w_tiles == 2:
  93. size_w = max(tile_w - p_left, in_w + p_left - tile_w + k_w - s_w)
  94. k_n_ = ((k_n - 1) // 16 + 1) * 16
  95. co_range = range(k_n_, tile_co_start - 1, -16)
  96. for tile_co in co_range:
  97. if bypass == 1:
  98. if tile_co != k_n:
  99. continue
  100. l1_size = data_len * (size_h * size_w * in_c)
  101. else:
  102. l1_size = data_len * (size_h * size_w * in_c +
  103. tile_co * tile_c * k_h * k_w)
  104. if l1_size > l1_max_size:
  105. continue
  106. tile_co_ = ((tile_co - 1) // 16 + 1) * 16
  107. for tile_n in range(tile_co_, 15, -16):
  108. k_max = in_c * k_h * k_w
  109. k_max_ = ((k_max - 1) // 16 + 1) * 16
  110. k_size = l0b_max_size // data_len // tile_n
  111. k_size_ = k_size // 16 * 16
  112. for tile_k in range(min(k_max_, k_size_), 15, -16):
  113. m_max = (int(((tile_h - k_h) // (s_h)) + 1)) * (int(((tile_w - k_w) // (s_w)) + 1))
  114. m_max_ = ((m_max - 1) // 16 + 1) * 16
  115. m_size1 = l0a_max_size // data_len // tile_k
  116. m_size1_ = m_size1 // 16 * 16
  117. m_size2 = l0c_max_size // data_len // tile_n
  118. m_size2_ = m_size2 // 16 * 16
  119. for tile_m in range(min(m_max_, m_size1_, m_size2_), 15, -16):
  120. config_space.add(ConvConfig(tile_h, tile_co, tile_m, tile_k,
  121. tile_n, tile_w, bypass))
  122. return None, config_space, op_desc.__str__(), None, None
  123. def _get_space_conv_bn1(op_desc: ConvDesc, tunning_attrs):
  124. """get config space of convolution"""
  125. if not isinstance(op_desc, ConvDesc):
  126. raise TypeError('op_desc must be ConvDesc')
  127. stride_ = op_desc.stride
  128. pad_ = op_desc.pad
  129. dilation_ = op_desc.dilation
  130. vc_util.convolution_format_check(op_desc.fmap_shape, op_desc.filter_shape, pad_, stride_, dilation_)
  131. config_space = ListConfigSpace(ConvConfig)
  132. # if double buff is not enabled, set it's value to 1
  133. size_scale = 1
  134. l1_max_size = (1024 * 1024) // size_scale
  135. l0a_max_size = (64 * 1024) // size_scale
  136. l0b_max_size = (64 * 1024) // size_scale
  137. l0c_max_size = ((256 - 8) * 1024) // size_scale // 2 // 4
  138. _, in_c, in_h, in_w = op_desc.fmap_shape
  139. k_n, _, k_h, k_w = op_desc.filter_shape
  140. padding = (pad_[0], pad_[1], pad_[2], pad_[3])
  141. p_top, p_bottom, p_left, p_right = padding
  142. s_h, s_w = stride_
  143. in_c = ((in_c - 1) // 16 + 1) * 16
  144. tile_c = in_c
  145. tile_co_start = 16
  146. data_len = 2
  147. h_max = in_h + p_top + p_bottom
  148. win_h = (h_max - k_h) // s_h + 1
  149. h_max = (h_max - k_h) // s_h * s_h + k_h
  150. w_max = in_w + p_left + p_right
  151. win_w = (w_max - k_w) // s_w + 1
  152. w_max = (w_max - k_w) // s_w * s_w + k_w
  153. bypass_options = [0, 1]
  154. for bypass in bypass_options:
  155. h_range = range(h_max, k_h - 1, -s_h)
  156. for tile_h in h_range:
  157. size_h = tile_h
  158. if tile_h == h_max:
  159. w_range = range(w_max, k_w - 1, -s_w)
  160. size_h = in_h
  161. else:
  162. w_range = [w_max]
  163. win_tile_h = (tile_h - k_h) // s_h + 1
  164. h_tiles = (win_h + win_tile_h - 1) // win_tile_h
  165. if h_tiles == 2:
  166. size_h = max(tile_h - p_top, in_h + p_top - tile_h + k_h - s_h)
  167. for tile_w in w_range:
  168. size_w = tile_w
  169. if size_w == w_max:
  170. size_w = in_w
  171. else:
  172. win_tile_w = (tile_w - k_w) // s_w + 1
  173. w_tiles = (win_w + win_tile_w - 1) // win_tile_w
  174. if w_tiles == 2:
  175. size_w = max(tile_w - p_left, in_w + p_left - tile_w + k_w - s_w)
  176. k_n_ = ((k_n - 1) // 16 + 1) * 16
  177. co_range = range(k_n_, tile_co_start - 1, -16)
  178. for tile_co in co_range:
  179. if bypass == 1:
  180. if tile_co != k_n:
  181. continue
  182. l1_size = data_len * (size_h * size_w * in_c)
  183. else:
  184. l1_size = data_len * (size_h * size_w * in_c +
  185. tile_co * tile_c * k_h * k_w)
  186. if l1_size > l1_max_size:
  187. continue
  188. tile_co_ = ((tile_co - 1) // 16 + 1) * 16
  189. for tile_n in range(tile_co_, 15, -16):
  190. k_max = in_c * k_h * k_w
  191. k_max_ = ((k_max - 1) // 16 + 1) * 16
  192. k_size = l0b_max_size // data_len // tile_n
  193. k_size_ = k_size // 16 * 16
  194. for tile_k in range(min(k_max_, k_size_), 15, -16):
  195. m_max = (int(((tile_h - k_h) // (s_h)) + 1)) * (int(((tile_w - k_w) // (s_w)) + 1))
  196. m_max_ = ((m_max - 1) // 16 + 1) * 16
  197. m_size1 = l0a_max_size // data_len // tile_k
  198. m_size1_ = m_size1 // 16 * 16
  199. m_size2 = l0c_max_size // data_len // tile_n
  200. m_size2_ = m_size2 // 16 * 16
  201. for tile_m in range(min(m_max_, m_size1_, m_size2_), 15, -16):
  202. config_space.add(ConvConfig(tile_h, tile_co, tile_m, tile_k,
  203. tile_n, tile_w, bypass))
  204. return None, config_space, op_desc.__str__(), None, None
  205. def _get_space_conv_backprop_input(op_desc: ConvBackpropDesc, tunning_attrs):
  206. """get config space of convolution backprop input"""
  207. if not isinstance(op_desc, ConvBackpropDesc):
  208. raise TypeError('op_desc must be ConvDesc')
  209. stride_ = op_desc.stride
  210. pad_ = op_desc.pad
  211. dilation_ = op_desc.dilation
  212. vc_util.convolution_format_check(op_desc.fmap_shape, op_desc.filter_shape, pad_, stride_, dilation_)
  213. config_space = ListConfigSpace(ConvBackpropInputConfig)
  214. # if double buff is not enabled, set it's value to 1
  215. size_scale = 1
  216. block_size = 16
  217. l1_max_size = (1024 * 1024) // size_scale
  218. l0a_max_size = (64 * 1024) // size_scale
  219. l0b_max_size = (64 * 1024) // size_scale
  220. l0c_max_size = ((256 - 8) * 1024) // size_scale // 2
  221. ub_max_size = l0c_max_size
  222. _, in_c, in_h, in_w = op_desc.fmap_shape
  223. k_n, _, k_h, k_w = op_desc.filter_shape
  224. in_c = (in_c + block_size - 1) // block_size * block_size
  225. k_n = (k_n + block_size - 1) // block_size * block_size
  226. pad_top, pad_bottom, pad_left, pad_right = pad_
  227. stride_h, stride_w = stride_
  228. out_c = k_n
  229. out_h = (in_h + pad_top + pad_bottom - k_h) // stride_h + 1
  230. out_w = (in_w + pad_left + pad_right - k_w) // stride_w + 1
  231. out_h = out_h * stride_h
  232. out_w = out_w * stride_w
  233. p_top = k_h - pad_[0] - 1
  234. p_bottom = in_h + pad_[0] - stride_[0] * ((in_h + pad_[0] + pad_[1] - k_h) // stride_[0] + 1)
  235. p_left = k_w - pad_[2] - 1
  236. p_right = in_w + pad_[2] - stride_[1] * ((in_w + pad_[2] + pad_[3] - k_w) // stride_[1] + 1)
  237. s_h = 1
  238. s_w = 1
  239. tile_c = out_c
  240. tile_co_start = 16
  241. data_len = 2
  242. h_max = out_h + p_top + p_bottom
  243. win_h = (h_max - k_h) // s_h + 1
  244. h_max = (h_max - k_h) // s_h * s_h + k_h
  245. w_max = out_w + p_left + p_right
  246. win_w = (w_max - k_w) // s_w + 1
  247. w_max = (w_max - k_w) // s_w * s_w + k_w
  248. for tile_h in range(h_max, k_h - 1, -s_h):
  249. size_h = tile_h
  250. if tile_h == h_max:
  251. w_range = range(w_max, k_w - 1, -s_w)
  252. size_h = in_h
  253. else:
  254. w_range = [w_max]
  255. win_tile_h = (tile_h - k_h) // s_h + 1
  256. h_tiles = (win_h + win_tile_h - 1) // win_tile_h
  257. if h_tiles == 2:
  258. size_h = max(tile_h - p_top, in_h + p_top - tile_h + k_h - s_h)
  259. for tile_w in w_range:
  260. size_w = tile_w
  261. if size_w == w_max:
  262. size_w = in_w
  263. else:
  264. win_tile_w = (tile_w - k_w) // s_w + 1
  265. w_tiles = (win_w + win_tile_w - 1) // win_tile_w
  266. if w_tiles == 2:
  267. size_w = max(tile_w - p_left, in_w + p_left - tile_w + k_w - s_w)
  268. k_n_ = ((k_n - 1) // 16 + 1) * 16
  269. co_range = range(k_n_, tile_co_start - 1, -16)
  270. for tile_co in co_range:
  271. l1_size = data_len * (size_h * size_w * out_c +
  272. tile_co * tile_c * k_h * k_w)
  273. if l1_size > l1_max_size:
  274. continue
  275. ub_size = data_len * (size_h * size_w * out_c)
  276. if ub_size > ub_max_size:
  277. continue
  278. tile_co_ = ((tile_co - 1) // 16 + 1) * 16
  279. for tile_n in range(tile_co_, 15, -16):
  280. k_max = out_c * k_h * k_w
  281. k_base = 16 * k_h * k_w
  282. k_max_ = ((k_max - 1) // k_base + 1) * k_base
  283. k_size = l0b_max_size // data_len // tile_n
  284. k_size_ = k_size // k_base * k_base
  285. for tile_k in range(min(k_max_, k_size_), k_base - 1, -k_base):
  286. m_max = (int(((tile_h - k_h) // (s_h)) + 1)) * (int(((tile_w - k_w) // (s_w)) + 1))
  287. m_max_ = ((m_max - 1) // 16 + 1) * 16
  288. m_size1 = l0a_max_size // data_len // tile_k
  289. m_size1_ = m_size1 // 16 * 16
  290. m_size2 = l0c_max_size // data_len // tile_n
  291. m_size2_ = m_size2 // 16 * 16
  292. for tile_m in range(min(m_max_, m_size1_, m_size2_), 15, -16):
  293. config_space.add(ConvBackpropInputConfig(tile_h, tile_co, tile_m,
  294. tile_k, tile_n, tile_w))
  295. return None, config_space, op_desc.__str__(), None, None
  296. def _get_space_conv_backprop_filter(op_desc: ConvBackpropDesc, tunning_attrs):
  297. """get config space of convolution backwprop filter"""
  298. if not isinstance(op_desc, ConvBackpropDesc):
  299. raise TypeError('op_desc must be ConvBackpropDesc')
  300. stride_ = op_desc.stride
  301. pad_ = op_desc.pad
  302. dilation_ = op_desc.dilation
  303. vc_util.convolution_format_check(op_desc.fmap_shape, op_desc.filter_shape, pad_, stride_, dilation_)
  304. config_space = ListConfigSpace(ConvBackpropFilterConfig)
  305. # if double buff is not enabled, set it's value to 1
  306. size_scale = 1
  307. block_size = 16
  308. l1_max_size = (1024 * 1024) // size_scale
  309. l0a_max_size = (64 * 1024) // size_scale
  310. l0b_max_size = (64 * 1024) // size_scale
  311. l0c_max_size = ((256 - 8) * 1024) // size_scale // 2
  312. in_n, in_c, in_h, in_w = op_desc.fmap_shape
  313. cout, _, k_h, k_w = op_desc.filter_shape
  314. k_n = cout
  315. in_c = (in_c + block_size - 1) // block_size * block_size
  316. cout = (cout + block_size - 1) // block_size * block_size
  317. pad_top, pad_bottom, pad_left, pad_right = pad_
  318. s_h, s_w = stride_
  319. tile_co_start = 16
  320. tile_ci_start = 16
  321. data_len = 2
  322. h_max = in_h + pad_top + pad_bottom
  323. win_h = (h_max - k_h) // s_h + 1
  324. h_max = (h_max - k_h) // s_h * s_h + k_h
  325. w_max = in_w + pad_left + pad_right
  326. win_w = (w_max - k_w) // s_w + 1
  327. w_max = (w_max - k_w) // s_w * s_w + k_w
  328. for tile_h in range(h_max, k_h - 1, -s_h):
  329. size_h = tile_h
  330. win_tile_h = (tile_h - k_h) // s_h + 1
  331. # Only one head for cut H axis
  332. if win_tile_h * s_h < pad_top:
  333. continue
  334. # Only one tail for cut H axis
  335. if (((win_h + win_tile_h - 1) // win_tile_h - 1) * win_tile_h - 1) * s_h + k_h > in_h + pad_top:
  336. continue
  337. if tile_h == h_max:
  338. w_range = range(w_max, k_w - 1, -s_w)
  339. size_h = in_h
  340. else:
  341. w_range = [w_max]
  342. h_tiles = (win_h + win_tile_h - 1) // win_tile_h
  343. if h_tiles == 2:
  344. size_h = max(tile_h - pad_top, in_h + pad_top - tile_h + k_h - s_h)
  345. for tile_w in w_range:
  346. size_w = tile_w
  347. win_tile_w = (tile_w - k_w) // s_w + 1
  348. # Only one head for cut W axis
  349. if win_tile_w * s_w < pad_left:
  350. continue
  351. # Only one tail for cut W axis
  352. if (((win_w + win_tile_w - 1) // win_tile_w - 1) * win_tile_w - 1) * s_w + k_w > in_w + pad_left:
  353. continue
  354. if size_w == w_max:
  355. size_w = in_w
  356. else:
  357. w_tiles = (win_w + win_tile_w - 1) // win_tile_w
  358. if w_tiles == 2:
  359. size_w = max(tile_w - pad_left, in_w + pad_left - tile_w + k_w - s_w)
  360. for tile_kh in range(k_h, 0, -1):
  361. for tile_kw in range(k_w, 0, -1):
  362. k_n_ = ((k_n - 1) // 16 + 1) * 16
  363. co_range = range(k_n_, tile_co_start - 1, -16)
  364. for tile_co in co_range:
  365. in_c_ = ((in_c - 1) // 16 + 1) * 16
  366. ci_range = range(in_c_, tile_ci_start - 1, -16)
  367. for tile_ci in ci_range:
  368. tile_batch = 1
  369. l1_size = data_len * tile_batch * (tile_co * win_tile_h * win_tile_w +
  370. tile_ci * size_h * size_w)
  371. if l1_size > l1_max_size:
  372. continue
  373. if (tile_batch != in_n or tile_co != k_n_ or tile_ci != in_c_):
  374. tile_m = tile_co
  375. tile_n = tile_ci * tile_kh * tile_kw
  376. l0c_size = data_len * tile_n * tile_m
  377. if l0c_size > l0c_max_size:
  378. continue
  379. k_max = tile_batch * tile_h * tile_w
  380. k_max_ = ((k_max - 1) // 16 + 1) * 16
  381. k_size1 = l0a_max_size // data_len // tile_m
  382. k_size1_ = k_size1 // 16 * 16
  383. k_size2 = l0b_max_size // data_len // tile_n
  384. k_size2_ = k_size2 // 16 * 16
  385. for tile_k in range(min(k_max_, k_size1_, k_size2_), 15, -16):
  386. config_space.add(ConvBackpropFilterConfig(tile_ci, tile_kh, tile_kw, tile_co,
  387. tile_batch, tile_h, tile_w, tile_m,
  388. tile_k, tile_n))
  389. else:
  390. for tile_n in range(tile_ci * tile_kh * tile_kw, 15, -16):
  391. k_max = tile_batch * tile_h * tile_w
  392. k_max_ = ((k_max - 1) // 16 + 1) * 16
  393. k_size = l0b_max_size // data_len // tile_n
  394. k_size_ = k_size // 16 * 16
  395. for tile_k in range(min(k_max_, k_size_), 15, -16):
  396. m_max = tile_co
  397. m_max_ = ((m_max - 1) // 16 + 1) * 16
  398. m_size1 = l0a_max_size // data_len // tile_k
  399. m_size1_ = m_size1 // 16 * 16
  400. m_size2 = l0c_max_size // data_len // tile_n
  401. m_size2_ = m_size2 // 16 * 16
  402. for tile_m in range(min(m_max_, m_size1_, m_size2_), 15, -16):
  403. config_space.add(ConvBackpropFilterConfig(tile_ci, tile_kh, tile_kw,
  404. tile_co, tile_batch, tile_h,
  405. tile_w, tile_m, tile_k, tile_n))
  406. return None, config_space, op_desc.__str__(), None, None
  407. def gen_bool_list(attr_list):
  408. bool_list = []
  409. for _ in attr_list:
  410. if len(bool_list) == 0:
  411. bool_list = [[True], [False]]
  412. else:
  413. tmp_list = []
  414. for attr_option in bool_list:
  415. tmp = attr_option[:]
  416. tmp.append(True)
  417. tmp1 = tmp[:]
  418. tmp.pop()
  419. tmp.append(False)
  420. tmp2 = tmp[:]
  421. tmp_list.append(tmp1)
  422. tmp_list.append(tmp2)
  423. bool_list = tmp_list
  424. return bool_list
  425. def _get_space_matmul_cube(op_desc: MatmulCubeDesc, tuning_attrs):
  426. """get config space of matmul_cube"""
  427. if not isinstance(op_desc, MatmulCubeDesc):
  428. raise TypeError('op_desc must be MatmulCubeDesc')
  429. config_attrs = ['n_l1', 'n_l0', 'm_l1', 'm_l0', 'k_l1', 'k_l0', 'bypass']
  430. config_attrs.extend(tuning_attrs)
  431. MatmulCubeConfig = namedtuple('MatmulCubeConfig', config_attrs)
  432. config_space = ListConfigSpace(MatmulCubeConfig)
  433. batch_tuple, m, k, n = matmul_run.extract_dim(op_desc.x_shape, op_desc.y_shape, op_desc.adj_x, op_desc.adj_y)
  434. mmax = (m + 15) // 16
  435. nmax = (n + 15) // 16
  436. kmax = (k + 15) // 16
  437. double_buffer = True
  438. mad_fp32 = True
  439. l1_max_size = (1024 * 1024) # L1 MEM 1024KB
  440. l0a_max_size = (64 * 1024) # L0A MEM 64KB
  441. l0b_max_size = (64 * 1024) # L0B MEM 64KB
  442. l0c_max_size = (256 * 1024) # L0C MEM 256KB
  443. ub_max_size = ((256 - 8) * 1024) # UB MEM 248KB, 8KB reserved for compiler
  444. if double_buffer:
  445. l1_max_size = l1_max_size // 2
  446. l0a_max_size = l0a_max_size // 2
  447. l0b_max_size = l0b_max_size // 2
  448. l0c_max_size = l0c_max_size // 2
  449. ub_max_size = ub_max_size // 2
  450. if mad_fp32:
  451. l0c_max_size = l0c_max_size // 2
  452. if op_desc.out_dtype == 'float32':
  453. ub_max_size = ub_max_size // 2
  454. bypass_options = [0, 1, 2]
  455. for bypass in bypass_options:
  456. if (bypass == 2) and ((op_desc.adj_x == False and op_desc.left_format[0].lower() == 'n') or
  457. (op_desc.adj_x == True and op_desc.left_format[0].lower() == 'z')):
  458. continue
  459. if (bypass == 1) and ((op_desc.adj_y == False and op_desc.right_format[0].lower() == 'z') or
  460. (op_desc.adj_y == True and op_desc.right_format[0].lower() == 'n')):
  461. continue
  462. for k_l1 in range(1, kmax + 1):
  463. if kmax % k_l1 != 0:
  464. continue
  465. for k_l0 in range(1, k_l1 + 1):
  466. if k_l1 % k_l0 != 0:
  467. continue
  468. # no need to cut from l1 to l0 for m and n when k is cut
  469. for m_l1 in range(1, mmax + 1):
  470. if mmax % m_l1 != 0:
  471. continue
  472. m_l0_range = [m_l1] if k_l1 != kmax else range(1, m_l1 + 1)
  473. for m_l0 in m_l0_range:
  474. if m_l1 % m_l0 != 0:
  475. continue
  476. for n_l1 in range(1, nmax + 1):
  477. if nmax % n_l1 != 0:
  478. continue
  479. n_l0_range = [n_l1] if k_l1 != kmax else range(1, n_l1 + 1)
  480. for n_l0 in n_l0_range:
  481. if n_l1 % n_l0 != 0:
  482. continue
  483. if m_l0 * 16 * k_l0 * 16 > l0a_max_size:
  484. continue
  485. if n_l0 * 16 * k_l0 * 16 > l0b_max_size:
  486. continue
  487. if m_l0 * 16 * n_l0 * 16 > l0c_max_size:
  488. continue
  489. if m_l0 * 16 * n_l0 * 16 > ub_max_size:
  490. continue
  491. if bypass == 2:
  492. l1_size = n_l1 * 16 * k_l1 * 16
  493. elif bypass == 1:
  494. l1_size = m_l1 * 16 * k_l1 * 16
  495. else:
  496. l1_size = (m_l1 * 16 + n_l1 * 16) * k_l1 * 16
  497. if l1_size > l1_max_size:
  498. continue
  499. if nmax == 1:
  500. n_l1 = 0
  501. n_l0 = 0
  502. if mmax == 1:
  503. m_l1 = 0
  504. m_l0 = 0
  505. if kmax == 1:
  506. k_l1 = 16
  507. k_l0 = 16
  508. tiling_space = [n_l1, n_l0, m_l1, m_l0, k_l1, k_l0, bypass]
  509. if len(tuning_attrs) == 0:
  510. config_space.add(MatmulCubeConfig(*tiling_space))
  511. else:
  512. attr_options = gen_bool_list(tuning_attrs)
  513. for attr_option in attr_options:
  514. tmp = tiling_space[:]
  515. tmp.extend(attr_option)
  516. config = MatmulCubeConfig(*tmp)
  517. config_space.add(config)
  518. shape_xx, shape_yy, _, _, k = matmul_run.get_converted_shapes(m, n, k, batch_tuple, op_desc.adj_x, op_desc.adj_y,
  519. op_desc.bias, op_desc.left_format,
  520. op_desc.right_format, op_desc.out_format)
  521. return None, config_space, str((shape_xx, shape_yy, op_desc.bias, op_desc.left_format, op_desc.right_format,
  522. op_desc.out_format, op_desc.adj_x, op_desc.adj_y, op_desc.dtype,
  523. op_desc.out_dtype)), None, None
  524. _get_space_func = {
  525. 'conv': _get_space_conv,
  526. 'conv_bn1': _get_space_conv_bn1,
  527. 'conv_backprop_input': _get_space_conv_backprop_input,
  528. 'conv_backprop_filter': _get_space_conv_backprop_filter,
  529. 'matmul': _get_space_matmul_cube,
  530. }
  531. def get_space(op_type: str, op_desc: NamedTuple, tuning_attrs=[]):
  532. """get space of an operator"""
  533. func = _get_space_func.get(op_type, None)
  534. if func is None:
  535. func = partial(_get_space_vector, op_type=op_type)
  536. return func(op_desc=op_desc, tuning_attrs=tuning_attrs)