You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

common.py 8.9 kB

5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290
  1. #!/usr/bin/env python3
  2. # coding: utf-8
  3. # Copyright 2019 Huawei Technologies Co., Ltd
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. """common"""
  17. import akg.tvm
  18. from .elewise_compute import vmuls, vadds, vmax, vmin, vabs, vrec, vmul, set_is_need_save_dtype
  19. from .cast_compute import floor, round, cast
  20. def fargmax(x, y):
  21. """
  22. Build expression for the index of maximum value among input expressions x and y.
  23. Args:
  24. x (tvm.expr.Expr): Input expression.
  25. y (tvm.expr.Expr): Input expression.
  26. Returns:
  27. tvm.expr.Expr. The call expression.
  28. Examples:
  29. >>> n = akg.tvm.var('n')
  30. >>> m = akg.tvm.var('m')
  31. >>> data = akg.tvm.placeholder((n, m), name='data')
  32. >>> k = akg.tvm.reduce_axis((0, m), "k")
  33. >>> reducer = akg.tvm.comm_reducer(lambda x,y: akg.fargmax(x, y), lambda t: akg.tvm.min_value(t), name="argmax")
  34. >>> res = akg.tvm.compute((n,), lambda *indice: reducer(data(*indice, k), axis=k), name="res")
  35. """
  36. return akg.tvm.call_pure_intrin(x.dtype, "fargmax", x, y)
  37. def fargmin(x, y):
  38. """
  39. Build expression for the index of minimum value among input expressions x and y.
  40. Args:
  41. x (tvm.expr.Expr): Input expression.
  42. y (tvm.expr.Expr): Input expression.
  43. Returns:
  44. tvm.expr.Expr. The call expression.
  45. """
  46. return akg.tvm.call_pure_intrin(x.dtype, "fargmin", x, y)
  47. def mad(x, y):
  48. """
  49. Build expression for two matrices multiplication and add.
  50. Args:
  51. x (tvm.expr.Expr): Input expression.
  52. y (tvm.expr.Expr): Input expression.
  53. Returns:
  54. tvm.expr.Expr. The call expression.
  55. Examples:
  56. >>> n = akg.tvm.var('n')
  57. >>> m = akg.tvm.var('m')
  58. >>> k = akg.tvm.var('k')
  59. >>> A = akg.tvm.placeholder((m, k), name='A')
  60. >>> B = akg.tvm.placeholder((k, n), name='B')
  61. >>> kk = akg.tvm.reduce_axis((0, k), name='kk')
  62. >>> mmad = akg.tvm.comm_reducer(lambda x, y: akg.mad(x, y), lambda t: akg.tvm.const(0, dtype=t), name="mmad")
  63. >>> C = akg.tvm.compute((m, n), lambda i, j: mmad(A[i, kk] * B[kk, j], axis=kk), name="C")
  64. """
  65. return akg.tvm.call_pure_intrin(x.dtype, "mad", x, y)
  66. mmad = akg.tvm.comm_reducer(lambda x, y: mad(x, y), lambda t: akg.tvm.const(0, dtype=t), name="mmad")
  67. def dropout(x, y):
  68. """
  69. Build expression with dropout function.
  70. Args:
  71. x (tvm.expr.Expr): Input expression.
  72. y (tvm.expr.Expr): Input expression.
  73. Returns:
  74. tvm.expr.Expr. The call expression.
  75. """
  76. return akg.tvm.call_pure_intrin(y.dtype, "dropout", x, y)
  77. def iou(x, y):
  78. """
  79. Return the intersection over union of x, y box.
  80. Args:
  81. x (tvm.expr.Expr): Input expression.
  82. y (tvm.expr.Expr): Input expression.
  83. Returns:
  84. tvm.expr.Expr. The call expression.
  85. """
  86. return akg.tvm.call_pure_intrin(x.dtype, "iou", x, y)
  87. def nms(x, y, scalar):
  88. """
  89. return nonmaximum suppresion result x, y box.
  90. Args:
  91. x (tvm.expr.Expr): Input argument of reduced tensor.
  92. y (tvm.expr.Expr): Input argument.
  93. scalar (Union[tvm.expr.Expr, float]): Score threshold of nms.
  94. Returns:
  95. z : tvm.expr.Expr. The result is store in fp16, each fp16 is a hex number indicating suppresion.
  96. """
  97. return akg.tvm.call_pure_intrin(x.dtype, "nms", x, y, scalar)
  98. def topk_sort(dst, src, topk):
  99. """
  100. sort the proposal box and return topk result, used when the sort process need partition the sorting loop.
  101. Args:
  102. dst (tvm.expr.Expr): Input argument. The destination of sort generated by common reducer.
  103. src (tvm.expr.Expr): Input argument.
  104. Strictly required that the box number can be divisible by 16 and item number is 8.
  105. topk (tvm.expr.Expr): Input argument. Constant tvm.expr.Expr indicating the required topk number.
  106. Returns:
  107. z : tvm.expr.Expr. The result.
  108. """
  109. return akg.tvm.call_pure_intrin(src.dtype, "topk_sort", dst, src, topk)
  110. def proposal_sort(dst, src, topk):
  111. """
  112. sort the proposal box and return topk result.
  113. Args:
  114. dst (tvm.expr.Expr): Input argument. The destination of sort generated by common reducer.
  115. src (tvm.expr.Expr): Input argument.
  116. Strictly required that the box number can be divisible by 16 and item number is 8.
  117. topk (tvm.expr.Expr): Input argument. Constant tvm.expr.Expr indicating the required topk number.
  118. Returns:
  119. z : tvm.expr.Expr. The result.
  120. """
  121. return akg.tvm.call_pure_intrin(src.dtype, "proposal_sort", dst, src, topk)
  122. def fnot(x):
  123. return akg.tvm.call_pure_intrin(x.dtype, "not", x)
  124. def round_to(data, max_, min_):
  125. """
  126. round data to [min,max]
  127. Args:
  128. data (Tensor): tensors need to change dtype.
  129. max_ (float): the range of res.
  130. min_ (float): the range of res.
  131. Returns:
  132. tensor : akg.tvm.tensor ,elements in tensor is in range [min,max]
  133. """
  134. data_tmp = vmuls(data, 0)
  135. data_min = vadds(data_tmp, min_)
  136. data_max = vadds(data_tmp, max_)
  137. data1 = vmax(data, data_min)
  138. data1 = vmin(data1, data_max)
  139. return data1
  140. def cast_to(data, dtype, f1628_int_flag=False):
  141. """
  142. a wrapped cast operations , cast data to the type of dtype
  143. Args:
  144. data (Tensor): akg.tvm.tensor needs to change dtype.
  145. dtype (String): dst dtype need to cast to.
  146. f1628_int_flag (bool): before fp16->int8/uint8, the data is all interger or not. default value is False.
  147. Returns:
  148. tensor : akg.tvm.tensor.
  149. """
  150. if isinstance(data, akg.tvm.tensor.Tensor):
  151. data_dtype = getattr(data, 'dtype')
  152. else:
  153. raise RuntimeError("The cast input type must be akg.tvm.tensor")
  154. if (data_dtype == "float16") and (dtype == "int32"):
  155. fp16_max = akg.tvm.const(32768, dtype="float16")
  156. fp16_min = akg.tvm.const(2 ** (-15), dtype="float16")
  157. data1 = round_to(data, 0.5, -0.5)
  158. new_data = vmuls(data1, fp16_max)
  159. tmp2 = vabs(new_data)
  160. tmp3 = vadds(tmp2, fp16_min)
  161. fp16_res = vmul(new_data, vrec(tmp3))
  162. sign_res = round(fp16_res)
  163. floor_data = floor(vabs(data))
  164. res = vmul(floor_data, sign_res)
  165. return res
  166. if data_dtype == "float16" and dtype in ("int8", "uint8") and not f1628_int_flag:
  167. fp16_half = akg.tvm.const(-0.5, dtype="float16")
  168. set_is_need_save_dtype()
  169. data = vadds(data, fp16_half)
  170. if data_dtype == dtype:
  171. return data
  172. if data_dtype == "float16":
  173. tmp = data
  174. else:
  175. tmp = cast(data, dst_dtype="float16")
  176. return cast(tmp, dst_dtype=dtype)
  177. def four2five_nchw(data):
  178. return akg.tvm.call_pure_intrin(data.dtype, "four2five_nchw", data)
  179. def load_im2col_c1_buf(data, pad_h, pad_t, pad_l, pad_r,
  180. fm_h, fm_w, stride_h, stride_w,
  181. filter_h, filter_w, dilation_h, dilation_w, repeat_mode, jmp_offset):
  182. return akg.tvm.call_pure_intrin(data.dtype, "load_im2col_c1_buf", data, pad_h, pad_t, pad_l, pad_r,
  183. fm_h, fm_w, stride_h, stride_w,
  184. filter_h, filter_w, dilation_h, dilation_w, repeat_mode, jmp_offset)
  185. def sin(data):
  186. return akg.tvm.call_pure_intrin(data.dtype, "sin", data)
  187. def cos(data):
  188. return akg.tvm.call_pure_intrin(data.dtype, "cos", data)
  189. def sinh(data):
  190. return akg.tvm.call_pure_intrin(data.dtype, "sinh", data)
  191. def cosh(data):
  192. return akg.tvm.call_pure_intrin(data.dtype, "cosh", data)
  193. def divide_var(data, divisor):
  194. return akg.tvm.call_pure_intrin(data.dtype, "divide_var", data, divisor)
  195. def vmadd(x, y, z):
  196. """
  197. Call the vmadd instruction to calculate :math:`x * y + z`.
  198. Args:
  199. x (tvm.tensor.Tensor): input x.
  200. y (tvm.tensor.Tensor): input y.
  201. z (tvm.tensor.Tensor): input z.
  202. Returns:
  203. tensor : akg.tvm.tensor.
  204. """
  205. return akg.tvm.call_pure_intrin(x.dtype, "vmadd", y, z, x)
  206. def vmla(x, y, z):
  207. """
  208. Call the vmla instruction to calculate :math:`x + y * z`.
  209. Args:
  210. x (tvm.tensor.Tensor): input x.
  211. y (tvm.tensor.Tensor): input y.
  212. z (tvm.tensor.Tensor): input z.
  213. Returns:
  214. tensor : akg.tvm.tensor.
  215. """
  216. return akg.tvm.call_pure_intrin(x.dtype, "vmla", y, z, x)