You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

custom_tiling.py 18 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394
  1. #!/usr/bin/env python3
  2. # coding: utf-8
  3. # Copyright 2019 Huawei Technologies Co., Ltd
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. """
  17. custom tiling function
  18. """
  19. from enum import Enum, unique
  20. from functools import wraps
  21. from numpy.core import double
  22. import akg
  23. from akg import dim
  24. from akg.utils.validation_check import check_input_type
  25. set_dim_func_map = {}
  26. gen_key_func_map = {}
  27. NODE_TYPE = "CustomTilingNode"
  28. DEFAULT_VALUE = -1
  29. DEFAULT_STRING = ""
  30. BLOCK_SIZE = 32
  31. CUBE_UNIT = 16
  32. class TileTemplate(Enum):
  33. """class TileTemplate."""
  34. NC1HWC0 = "NC1HWC0"
  35. NCHW = "NCHW"
  36. DEFAULT_FORMAT = "NCHW"
  37. NHWC = "NHWC"
  38. @unique
  39. class TileLevel(Enum):
  40. """class TileLevel."""
  41. C1 = "C1"
  42. C0 = "C0"
  43. @unique
  44. class TileMode(Enum):
  45. """class TileMode."""
  46. AXIS = "AXIS"
  47. TENSOR = "TENSOR"
  48. COMMON = "COMMON"
  49. @unique
  50. class TileConstraint(Enum):
  51. """class TileConstraint."""
  52. MIN = "MIN"
  53. MOD = "MOD"
  54. MAX = "MAX"
  55. FACTOR = "FACTOR"
  56. CANDIDATE = "CANDIDATE"
  57. FORBID_ISOLATE = "FORBID_ISOLATE"
  58. SET_PRIORITY = "SET_PRIORITY"
  59. SET_EXPANSION = "SET_EXPANSION"
  60. SET_MEM_RATIO = "SET_MEM_RATIO"
  61. SET_AXIS_INFO = "SET_AXIS_INFO"
  62. THREAD_MIN = "THREAD_MIN"
  63. THREAD_MAX = "THREAD_MAX"
  64. THREAD_MOD = "THREAD_MOD"
  65. BLOCK_MIN = "BLOCK_MIN"
  66. BLOCK_MAX = "BLOCK_MAX"
  67. BLOCK_MOD = "BLOCK_MOD"
  68. @check_input_type((double, float, int, list), TileConstraint, TileLevel)
  69. def modify_common_constraints(value, constraint, level=TileLevel.C1):
  70. """api for dsl to modify some default constraint used in auto tiling."""
  71. if constraint not in TileConstraint:
  72. raise ValueError("Tile constraints must be chosen from {0}".format(TileConstraint))
  73. if constraint == TileConstraint.SET_MEM_RATIO:
  74. return create_custom_tiling_node(TileMode.COMMON, tile_level=level, mem_ratio=double(value))
  75. if constraint == TileConstraint.THREAD_MIN:
  76. return create_custom_tiling_node(TileMode.COMMON, thread_min=value)
  77. if constraint == TileConstraint.THREAD_MAX:
  78. return create_custom_tiling_node(TileMode.COMMON, thread_max=value)
  79. if constraint == TileConstraint.THREAD_MOD:
  80. return create_custom_tiling_node(TileMode.COMMON, thread_mod=value)
  81. if constraint == TileConstraint.BLOCK_MIN:
  82. return create_custom_tiling_node(TileMode.COMMON, block_min=value)
  83. if constraint == TileConstraint.BLOCK_MAX:
  84. return create_custom_tiling_node(TileMode.COMMON, block_max=value)
  85. if constraint == TileConstraint.BLOCK_MOD:
  86. return create_custom_tiling_node(TileMode.COMMON, block_mod=value)
  87. raise TypeError("Constraint {} is not supported in this api, please use other api"
  88. .format(constraint.value))
  89. @check_input_type((str, int), TileConstraint, int, (int, list, tuple, type(None)), TileLevel)
  90. def create_constraint_on_axis(values, constraints, band=0, axis=None, level=TileLevel.C1):
  91. """api for dsl to create tiling constraints on certain axis."""
  92. if constraints not in TileConstraint:
  93. raise ValueError("Tile constraints must be chosen from {0}".format(TileConstraint))
  94. res = []
  95. if axis is None:
  96. axis = [i for i in range(len(values))]
  97. elif not isinstance(axis, (int, list, tuple)):
  98. raise TypeError("Axis should be int, list or tuple")
  99. if isinstance(axis, int):
  100. axis = [axis]
  101. if isinstance(values, (str, int)):
  102. values = [values]
  103. else:
  104. raise TypeError("Tiling factor must be string or int, while receives {}".format(type(values)))
  105. if len(axis) != len(values):
  106. raise ValueError("Length of axis must equal to length of values")
  107. for a, v in zip(axis, values):
  108. if constraints == TileConstraint.MIN:
  109. res.append(create_custom_tiling_node(TileMode.AXIS, tile_level=level,
  110. tile_band=band, tile_axis=a, tile_min=v))
  111. elif constraints == TileConstraint.MOD:
  112. res.append(create_custom_tiling_node(TileMode.AXIS, tile_level=level,
  113. tile_band=band, tile_axis=a, tile_mod=v))
  114. elif constraints == TileConstraint.FACTOR:
  115. res.append(create_custom_tiling_node(TileMode.AXIS, tile_level=level,
  116. tile_band=band, tile_axis=a, tile_factor=v))
  117. elif constraints == TileConstraint.CANDIDATE:
  118. res.append(create_custom_tiling_node(TileMode.AXIS, tile_level=level,
  119. tile_band=band, tile_axis=a, tile_candidate=v))
  120. elif constraints == TileConstraint.MAX:
  121. res.append(create_custom_tiling_node(TileMode.AXIS, tile_level=level,
  122. tile_band=band, tile_axis=a, tile_max=v))
  123. elif constraints == TileConstraint.FORBID_ISOLATE:
  124. res.append(create_custom_tiling_node(TileMode.AXIS, tile_level=level,
  125. tile_band=band, tile_axis=a, forbid_isolate=v))
  126. elif constraints == TileConstraint.SET_AXIS_INFO:
  127. res.append(create_custom_tiling_node(TileMode.AXIS, tile_level=level,
  128. tile_band=band, tile_axis=a, axis_info=v))
  129. elif constraints == TileConstraint.SET_PRIORITY:
  130. res.append(create_custom_tiling_node(TileMode.AXIS, tile_level=level,
  131. tile_band=band, tile_axis=a, priority=v))
  132. else:
  133. raise TypeError("Constraint {} is not supported in this api, please use other api"
  134. .format(constraints.value))
  135. return res
  136. @check_input_type((akg.tvm.tensor.Tensor, list, tuple), (str, int, list, tuple), TileConstraint,
  137. (int, list, tuple, type(None)), TileLevel)
  138. def create_constraint_on_tensor(tensor, values, constraints, tensor_pos=None, level=TileLevel.C1):
  139. """api for dsl to create tiling constraints on certain tensor."""
  140. if constraints not in TileConstraint:
  141. raise ValueError("Tile constraint must be chosen from {0}".format(TileConstraint))
  142. if isinstance(tensor, (list, tuple)):
  143. for t in tensor:
  144. if not isinstance(t, akg.tvm.tensor.Tensor):
  145. raise TypeError("Tensor should be tvm.tensor.Tensor or a list/tuple of tvm.tensor.Tensor.")
  146. tensor_name = [tensor.op.name] if isinstance(tensor, akg.tvm.tensor.Tensor) else [t.op.name for t in tensor]
  147. values = [values] if isinstance(values, (str, int)) else values
  148. if tensor_pos is None:
  149. tensor_pos = [i for i in range(len(values))]
  150. else:
  151. tensor_pos = [tensor_pos] if isinstance(tensor_pos, int) else tensor_pos
  152. if len(tensor_pos) != len(values):
  153. raise ValueError("Length of tensor position is not compatible with length of constraint values")
  154. strategy = list()
  155. for t in tensor_name:
  156. for p, v in zip(tensor_pos, values):
  157. if constraints == TileConstraint.MIN:
  158. strategy.append(create_custom_tiling_node(TileMode.TENSOR, tile_level=level,
  159. tensor_name=t, tile_pos=p, tile_min=v))
  160. elif constraints == TileConstraint.MOD:
  161. strategy.append(create_custom_tiling_node(TileMode.TENSOR, tile_level=level,
  162. tensor_name=t, tile_pos=p, tile_mod=v))
  163. elif constraints == TileConstraint.FACTOR:
  164. strategy.append(create_custom_tiling_node(TileMode.TENSOR, tile_level=level,
  165. tensor_name=t, tile_pos=p, tile_factor=v))
  166. elif constraints == TileConstraint.CANDIDATE:
  167. strategy.append(create_custom_tiling_node(TileMode.TENSOR, tile_level=level,
  168. tensor_name=t, tile_pos=p, tile_candidate=v))
  169. elif constraints == TileConstraint.MAX:
  170. strategy.append(create_custom_tiling_node(TileMode.TENSOR, tile_level=level,
  171. tensor_name=t, tile_pos=p, tile_max=v))
  172. elif constraints == TileConstraint.FORBID_ISOLATE:
  173. strategy.append(create_custom_tiling_node(TileMode.TENSOR, tile_level=level,
  174. tensor_name=t, tile_pos=p, forbid_isolate=v))
  175. elif constraints == TileConstraint.SET_PRIORITY:
  176. strategy.append(create_custom_tiling_node(TileMode.TENSOR, tile_level=level,
  177. tensor_name=t, tile_pos=p, priority=v))
  178. elif constraints == TileConstraint.SET_EXPANSION:
  179. strategy.append(create_custom_tiling_node(TileMode.TENSOR, tile_level=level,
  180. tensor_name=t, expansion=v))
  181. else:
  182. raise TypeError("Constraint {} is not supported in this api, please use other api"
  183. .format(constraints.value))
  184. return strategy
  185. @check_input_type(akg.tvm.tensor.Tensor, TileTemplate, TileLevel)
  186. def create_template(tensor, template, level=TileLevel.C1):
  187. """create template according to given template arg."""
  188. tensor_name = tensor.op.name
  189. if template not in TileTemplate:
  190. raise ValueError("Invalid template name {0}, must chosen from {1}".
  191. format(template, TileTemplate))
  192. if template in [TileTemplate.NCHW, TileTemplate.DEFAULT_FORMAT]:
  193. return template_nchw(tensor_name, level)
  194. if template == TileTemplate.NC1HWC0:
  195. return template_nc1hwc0(tensor_name, level)
  196. if template == TileTemplate.NHWC:
  197. return template_nhwc(tensor_name, level)
  198. return []
  199. def to_tvm_type(value, t_type):
  200. """transform integer and string to corresponding type in tvm."""
  201. if isinstance(value, int):
  202. return akg.tvm.expr.IntImm("int32", value)
  203. if isinstance(value, str):
  204. return akg.tvm.expr.StringImm(value)
  205. if isinstance(value, (akg.tvm.expr.IntImm, akg.tvm.expr.StringImm)):
  206. return value
  207. raise TypeError("{} only support integer or string, found {}".format(t_type, type(value)))
  208. def create_custom_tiling_node(tile_mode,
  209. tile_level=TileLevel.C1,
  210. tensor_name=DEFAULT_STRING,
  211. tile_pos=DEFAULT_VALUE,
  212. tile_band=DEFAULT_VALUE,
  213. tile_axis=DEFAULT_VALUE,
  214. tile_min=DEFAULT_VALUE,
  215. tile_max=DEFAULT_VALUE,
  216. tile_mod=DEFAULT_VALUE,
  217. tile_factor=DEFAULT_VALUE,
  218. tile_candidate=DEFAULT_VALUE,
  219. forbid_isolate=DEFAULT_VALUE,
  220. axis_info=DEFAULT_STRING,
  221. priority=DEFAULT_VALUE,
  222. expansion=DEFAULT_VALUE,
  223. mem_ratio=double(DEFAULT_VALUE),
  224. thread_min=[],
  225. thread_max=[],
  226. thread_mod=[],
  227. block_min=[],
  228. block_max=[],
  229. block_mod=[]):
  230. """default method to create custom tiling node, all values are default except tile mode."""
  231. tile_min = to_tvm_type(tile_min, "tile_min")
  232. tile_max = to_tvm_type(tile_max, "tile_max")
  233. tile_mod = to_tvm_type(tile_mod, "tile_mod")
  234. tile_factor = to_tvm_type(tile_factor, "tile_factor")
  235. tile_candidate = to_tvm_type(tile_candidate, "tile_candidate")
  236. return akg.tvm.make.node(NODE_TYPE,
  237. tile_level=akg.tvm.expr.StringImm(tile_level.value),
  238. tile_mode=akg.tvm.expr.StringImm(tile_mode.value),
  239. tensor_name=akg.tvm.expr.StringImm(tensor_name),
  240. tile_pos=tile_pos,
  241. tile_band=tile_band,
  242. tile_axis=tile_axis,
  243. tile_min=tile_min,
  244. tile_max=tile_max,
  245. tile_mod=tile_mod,
  246. tile_factor=tile_factor,
  247. tile_candidate=tile_candidate,
  248. forbid_isolate=forbid_isolate,
  249. axis_info=akg.tvm.expr.StringImm(axis_info),
  250. priority=priority,
  251. expansion=expansion,
  252. mem_ratio=mem_ratio,
  253. thread_min=thread_min,
  254. thread_max=thread_max,
  255. thread_mod=thread_mod,
  256. block_min=block_min,
  257. block_max=block_max,
  258. block_mod=block_mod)
  259. def template_nc1hwc0(tensor_name, level):
  260. """create default tiling strategy for nc1hwc0 template."""
  261. node_n = create_custom_tiling_node(TileMode.TENSOR,
  262. tile_level=level,
  263. tensor_name=tensor_name,
  264. tile_pos=0,
  265. tile_factor=to_tvm_type(1, "tile_factor"))
  266. node_c0 = create_custom_tiling_node(TileMode.TENSOR,
  267. tile_level=level,
  268. tensor_name=tensor_name,
  269. tile_pos=4,
  270. tile_max="FULL")
  271. return [node_n, node_c0]
  272. def template_nchw(tensor_name, level):
  273. """create default tiling strategy for nchw template."""
  274. node_n = create_custom_tiling_node(TileMode.TENSOR,
  275. tile_level=level,
  276. tensor_name=tensor_name,
  277. tile_pos=0,
  278. tile_factor=to_tvm_type(1, "tile_factor"))
  279. node_c = create_custom_tiling_node(TileMode.TENSOR,
  280. tile_level=level,
  281. tensor_name=tensor_name,
  282. tile_pos=1,
  283. tile_mod=to_tvm_type(CUBE_UNIT, "tile_factor"))
  284. return [node_n, node_c]
  285. def template_nhwc(tensor_name, level):
  286. """create default tiling strategy for nhwc template."""
  287. node_n = create_custom_tiling_node(TileMode.TENSOR,
  288. tile_level=level,
  289. tensor_name=tensor_name,
  290. tile_pos=0,
  291. tile_factor=to_tvm_type(1, "tile_factor"))
  292. node_c = create_custom_tiling_node(TileMode.TENSOR,
  293. tile_level=level,
  294. tensor_name=tensor_name,
  295. tile_pos=3,
  296. tile_mod=to_tvm_type(CUBE_UNIT, "tile_factor"))
  297. return [node_n, node_c]
  298. def set_dims(tiling):
  299. """Set dim for tiling."""
  300. info = dim.Dim()
  301. for d, tile_d in enumerate(tiling):
  302. if len(tile_d) == 2: # only c1 and c0 tile
  303. index = 0
  304. axis = d
  305. c1 = tile_d[0]
  306. c0 = tile_d[1]
  307. elif len(tile_d) == 4: # index, axis, c1, c0
  308. index = tile_d[0]
  309. axis = tile_d[1]
  310. c1 = tile_d[2]
  311. c0 = tile_d[3]
  312. else:
  313. raise RuntimeError("Each element in tiling should be length-2 (c1_tile, c0_tile) "
  314. "or length-4 (band_index, axis_index, c1_tile, c0_tile)")
  315. info.setdim(index=index, axis=axis, tilel1=c1, tilel0=c0)
  316. return str(info)
  317. def set_dims_by_key(key, map_):
  318. """Set dim for tiling by key."""
  319. if key in map_.keys():
  320. return set_dims(map_[key])
  321. return ""
  322. def reg_set_dim_func(set_dim_func):
  323. """register setdim function."""
  324. def decorate(func_):
  325. @wraps(func_)
  326. def wrapper(*args, **kwargs):
  327. set_dim_func_map[func_.__name__] = set_dim_func
  328. return func_(*args, **kwargs)
  329. return wrapper
  330. return decorate
  331. def reg_set_dim_func_by_func(func_, set_dim_func):
  332. """register setdim function by function."""
  333. set_dim_func_map[func_.__name__] = set_dim_func
  334. def reg_gen_key_func(gen_key_func):
  335. """register generated key by function."""
  336. def decorate(func_):
  337. @wraps(func_)
  338. def wrapper(*args, **kwargs):
  339. gen_key_func_map[func_.__name__] = gen_key_func
  340. return func_(*args, **kwargs)
  341. return wrapper
  342. return decorate