You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

four2five.py 13 kB

5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304
  1. #!/usr/bin/env python3
  2. # coding: utf-8
  3. # Copyright 2019 Huawei Technologies Co., Ltd
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. """operator dsl function: four2five"""
  17. import akg
  18. import akg.tvm
  19. from akg.tvm.hybrid import script
  20. from akg.topi.nn import pad as tvm_pad
  21. from akg.utils.format_transform import get_shape, get_bytes, to_tvm_const
  22. from akg.utils import validation_check as vc_util
  23. from akg.utils import custom_tiling as ct_util
  24. from akg.utils import dynamic_shape as ds
  25. C_LIMIT_FOR_CAST = 3600
  26. def get_attrs():
  27. """get attrs."""
  28. attrs = {
  29. "help_tiling": 0,
  30. "pragma_sink_last_axis": False,
  31. "enable_pre_poly_loop_partition": False
  32. }
  33. return attrs
  34. def get_dynamic_attrs():
  35. """get dynamic attrs."""
  36. attrs = {
  37. "help_tiling": 0,
  38. "pragma_sink_last_axis": False,
  39. "enable_pre_poly_loop_partition": True,
  40. "dynamic_shape_bound": 65535,
  41. "enable_post_poly_loop_partition": False,
  42. "enable_double_buffer:": False,
  43. # "enable_scalar_align": True,
  44. }
  45. return attrs
  46. four2five_set_dim_map = {
  47. "((1, 1, 7, 7), 'NCHW', 'float32', 'float16')": ((1, 1), (1, 1), (7, 1), (7, 1), (16, 1)),
  48. "((1, 7, 7), 'NCHW', 'float32', 'float16')": ((1, 1), (7, 1), (7, 1), (16, 1)),
  49. "((1, 1, I2, I3), 'NCHW', 'float32', 'float16')": ((1, 1), (1, 1), (1, 1), (129, 1), (2048, 1)),
  50. }
  51. def four2five_set_dim_func(data, format_, dst_type):
  52. """set dim info for attr."""
  53. shape = get_shape(data)
  54. if format_ == 'NCHW':
  55. n, _, h, w = shape
  56. else:
  57. n, h, w, _ = shape
  58. shape[0] = 1
  59. if h != 1 and w != 1:
  60. if format_ == 'NCHW' and shape[1] > 16:
  61. shape[1] = 1
  62. if format_ == 'NHWC' and shape[-1] > 16:
  63. shape[-1] = 1
  64. if n == 1:
  65. shape.remove(shape[0])
  66. hash_key = str((tuple(shape), format_, data.dtype, dst_type))
  67. return ct_util.set_dims_by_key(hash_key, four2five_set_dim_map), hash_key
  68. def four2five_tiling_strategy(tensor, input_format, expansion=None):
  69. """Custom tiling strategy for four2five op."""
  70. strategy = ct_util.create_template(tensor=tensor,
  71. template=ct_util.TileTemplate.NC1HWC0)
  72. if input_format == "NHWC" or expansion:
  73. priority_map = {4: 0, 1: 1, 3: 2, 2: 3, 0: 4} # tile in C0->C1->W->H->N sequence
  74. for pos, priority in priority_map.items():
  75. strategy.append(ct_util.create_constraint_on_tensor(tensor=tensor,
  76. values=priority,
  77. constraints=ct_util.TileConstraint.SET_PRIORITY,
  78. tensor_pos=pos)[0])
  79. if expansion:
  80. strategy.append(ct_util.create_constraint_on_tensor(tensor=tensor,
  81. values=expansion,
  82. constraints=ct_util.TileConstraint.SET_EXPANSION)[0])
  83. return strategy
  84. def four2five_tiling_strategy_dynamic(tensor, input_format):
  85. """Custom tiling strategy for four2five op."""
  86. strategy = list()
  87. if input_format == "NCHW":
  88. shape = get_shape(tensor)
  89. if shape[1] == 1:
  90. strategy.append(ct_util.create_constraint_on_tensor(tensor, 1, ct_util.TileConstraint.FACTOR, 0)[0])
  91. strategy.append(ct_util.create_constraint_on_tensor(tensor, 1, ct_util.TileConstraint.FACTOR, 1)[0])
  92. strategy.append(ct_util.create_constraint_on_tensor(tensor, 1, ct_util.TileConstraint.FACTOR, 2)[0])
  93. strategy.append(ct_util.create_constraint_on_tensor(tensor, 112, ct_util.TileConstraint.FACTOR, 3)[0])
  94. strategy.append(ct_util.create_constraint_on_tensor(tensor, 16, ct_util.TileConstraint.FACTOR, 4)[0])
  95. elif shape[1] == 128:
  96. strategy.append(ct_util.create_constraint_on_tensor(tensor, 1, ct_util.TileConstraint.FACTOR, 0)[0])
  97. strategy.append(ct_util.create_constraint_on_tensor(tensor, 1, ct_util.TileConstraint.FACTOR, 1)[0])
  98. strategy.append(ct_util.create_constraint_on_tensor(tensor, 1, ct_util.TileConstraint.FACTOR, 2)[0])
  99. strategy.append(ct_util.create_constraint_on_tensor(tensor, "FULL", ct_util.TileConstraint.MAX, 3)[0])
  100. strategy.append(ct_util.create_constraint_on_tensor(tensor, 16, ct_util.TileConstraint.FACTOR, 4)[0])
  101. return strategy
  102. @vc_util.check_input_type(akg.tvm.tensor.Tensor, str, str, bool)
  103. def four2five(data, format_, dst_dtype='float16', need_custom_tiling=True):
  104. """
  105. Convert 4-dims "data" to 5-dims,the format of "data" is defined in "format_"
  106. Args:
  107. data (tvm.tensor.Tensor): 4-dims tensor of type float16, float32
  108. format_ (str): a str defined the format of "data"
  109. dst_dtype (str): a str defined the type of output, could be float16 or float32
  110. Returns:
  111. 5-dims tvm.tensor.Tensor,type is defined by dst_dtype,
  112. which shape is [N, ceil(C / 16), H, W, 16] and attr about tiling args
  113. Raises:
  114. ValueError: If the type of format_ is invalid.
  115. """
  116. # Check dtype
  117. vc_util.ops_dtype_check(data.dtype, vc_util.DtypeForDavinci.ALL_FLOAT)
  118. # Check shape
  119. shape = get_shape(data)
  120. vc_util.davinci_format_check(shape, format_, dim=4)
  121. # Check format
  122. if format_ not in ['NCHW', 'NHWC']:
  123. raise ValueError("{} format is not support, four2five only support NCHW and NHWC format input"
  124. .format(format_))
  125. last_channel = 16
  126. if format_ == "NCHW":
  127. bs, c, h, w = get_shape(data)
  128. else:
  129. bs, h, w, c = get_shape(data)
  130. pad_c = c
  131. if c % last_channel != 0:
  132. pad_c = (c + 15) // last_channel * last_channel
  133. c1 = pad_c // last_channel
  134. c0 = last_channel
  135. is_dynamic = ds.shape_is_dynamic(data)
  136. if not is_dynamic:
  137. attrs = get_attrs()
  138. else:
  139. attrs = get_dynamic_attrs()
  140. # Check size c when casting happens
  141. if data.dtype != dst_dtype and c0 * c1 >= C_LIMIT_FOR_CAST:
  142. raise ValueError("When input and output data type is not matched, shape of 'c' axis should not exceed {}, "
  143. "while currently set is {}".format(C_LIMIT_FOR_CAST, c0 * c1))
  144. @script(capture=locals())
  145. def nchw_to_nc1hwc0_step(inputs, bs, c1, h, w, c0):
  146. output = allocate((bs, c1, h, c0, w), inputs.dtype, "local")
  147. for n_i in range(bs):
  148. for c_i in range(c1):
  149. for h_i in range(h):
  150. for w_i in range(w):
  151. for c_i0 in range(c0):
  152. output[n_i, c_i, h_i, c_i0, w_i] = inputs[n_i, c_i * last_channel + c_i0, h_i, w_i]
  153. output1 = allocate((bs, c1, h, w, c0), inputs.dtype, "local")
  154. for n_i in range(bs):
  155. for c_i in range(c1):
  156. for h_i in range(h):
  157. for w_i in range(w):
  158. for c_i0 in range(c0):
  159. output1[n_i, c_i, h_i, w_i, c_i0] = output[n_i, c_i, h_i, c_i0, w_i]
  160. return output1
  161. @script(capture=locals())
  162. def nchw_to_nc1hwc0(inputs, bs, c1, h, w, c0):
  163. output = allocate((bs, c1, h, w, c0), inputs.dtype, "local")
  164. for n_i in range(bs):
  165. for c_i in range(c1):
  166. for h_i in range(h):
  167. for w_i in range(w):
  168. for c_i0 in range(c0):
  169. output[n_i, c_i, h_i, w_i, c_i0] = inputs[n_i, c_i * last_channel + c_i0, h_i, w_i]
  170. return output
  171. @script(capture=locals())
  172. def nhwc_to_nc1hwc0(inputs, zero, bs, c1, h, w, c0):
  173. output = allocate((bs, c1, h, w, c0), inputs.dtype, "local")
  174. for n_i in range(bs):
  175. for c_i in range(c1):
  176. for h_i in range(h):
  177. for w_i in range(w):
  178. for c_i0 in range(c0):
  179. if c_i * last_channel + c_i0 < c:
  180. output[n_i, c_i, h_i, w_i, c_i0] = inputs[n_i, h_i, w_i, c_i * last_channel + c_i0]
  181. else:
  182. output[n_i, c_i, h_i, w_i, c_i0] = zero
  183. return output
  184. cast_data = data
  185. need_cast = data.dtype == 'float32' and dst_dtype == 'float16'
  186. if c % last_channel != 0 or need_cast:
  187. expansion = int(ct_util.BLOCK_SIZE / get_bytes(data.dtype))
  188. else:
  189. expansion = None
  190. # float32 -> float16, need to cast before transform
  191. if need_cast:
  192. cast_data = akg.lang.cce.cast_to(data, dst_dtype)
  193. zero_ = akg.tvm.const(0.0, cast_data.dtype)
  194. if format_ == "NCHW":
  195. if c % last_channel != 0:
  196. pad_shape = [bs, pad_c, h, w]
  197. if h == 1 and w == 1:
  198. # if h and w both are 1, it is pad last dim case
  199. output_shape = [bs, pad_c // last_channel, h, w, last_channel]
  200. output = akg.tvm.compute(output_shape,
  201. lambda i, c1, k, l, c0: akg.tvm.expr.Select(
  202. c0 < c - c1 * last_channel, cast_data[i, c1 * last_channel + c0, k, l],
  203. akg.tvm.const(0, cast_data.dtype)),
  204. name="output")
  205. else:
  206. # if need to pad c dim, separate transpose to two steps
  207. # first is nchw -> nc1hc0w, second is nc1hc0w -> nc1hwc0
  208. pad_data = akg.tvm.compute(pad_shape,
  209. lambda i, j, k, l: akg.tvm.expr.Select(j < c, cast_data[i, j, k, l], zero_),
  210. name="pad_data")
  211. output = nchw_to_nc1hwc0_step(
  212. pad_data,
  213. to_tvm_const(bs),
  214. to_tvm_const(c1),
  215. to_tvm_const(h),
  216. to_tvm_const(w),
  217. to_tvm_const(c0))
  218. else:
  219. if not is_dynamic and data.dtype == "float16" and h * w % last_channel == 0 and h * w < 3600:
  220. output_shape = [bs, c1, h, w, c0]
  221. output = akg.tvm.compute(output_shape, lambda n, c1, h, w, c0:
  222. akg.lang.cce.four2five_nchw(cast_data[n, c1 * last_channel + c0, h, w]),
  223. name="output")
  224. else:
  225. output = nchw_to_nc1hwc0(
  226. cast_data,
  227. to_tvm_const(bs),
  228. to_tvm_const(c1),
  229. to_tvm_const(h),
  230. to_tvm_const(w),
  231. to_tvm_const(c0))
  232. else:
  233. if not is_dynamic and c < last_channel:
  234. rank = 5 # (n, c1, h, w, c0)
  235. pad_before = []
  236. pad_after = []
  237. for _ in range(rank):
  238. pad_before.append(0)
  239. pad_after.append(0)
  240. pad_after[-1] = last_channel - c
  241. # As c < last_channel, c1 is 1
  242. output = akg.tvm.compute((bs, c1, h, w, c), lambda bs_i, _, h_i, w_i, c_i: cast_data[
  243. bs_i, h_i, w_i, c_i], name="output")
  244. output = tvm_pad(output, pad_before, pad_after=pad_after, name='pad_output')
  245. else:
  246. output = nhwc_to_nc1hwc0(
  247. cast_data,
  248. zero_,
  249. to_tvm_const(bs),
  250. to_tvm_const(c1),
  251. to_tvm_const(h),
  252. to_tvm_const(w),
  253. to_tvm_const(c0))
  254. # float16 -> float32, need to cast after transform
  255. if data.dtype == 'float16' and dst_dtype == 'float32':
  256. output = akg.lang.cce.cast_to(output, dst_dtype)
  257. vc_util.davinci_format_check(output.shape, "NC1HWC0", dim=5)
  258. if not is_dynamic:
  259. dim_info, _ = four2five_set_dim_func(data, format_, dst_dtype)
  260. if dim_info != "":
  261. attrs["dim"] = dim_info
  262. if need_custom_tiling:
  263. attrs["custom_tiling"] = four2five_tiling_strategy(output, format_, expansion)
  264. elif need_custom_tiling:
  265. attrs["custom_tiling"] = four2five_tiling_strategy_dynamic(output, format_)
  266. if is_dynamic:
  267. attrs["enable_feature_library_pre_poly"] = True
  268. return output, attrs