You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

one_hot.py 12 kB

5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322
  1. #!/usr/bin/env python3
  2. # coding: utf-8
  3. # Copyright 2019 Huawei Technologies Co., Ltd
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. """operator dsl function:one hot"""
  17. import akg.tvm
  18. from akg.tvm.hybrid import script
  19. from akg.utils import custom_tiling as ct_util
  20. from akg.utils.validation_check import ops_dtype_check, check_shape, check_input_type, DtypeForDavinci
  21. def onehot_tiling_strategy(tensor, axis):
  22. """Custom tiling strategy for onehot op."""
  23. tot_axis = ct_util.create_constraint_on_tensor(tensor=tensor,
  24. values=0,
  25. constraints=ct_util.TileConstraint.SET_PRIORITY,
  26. tensor_pos=axis)
  27. return tot_axis
  28. @check_input_type(akg.tvm.tensor.Tensor, int, str, (int, float, type(None)),
  29. (int, float, type(None)), (int, type(None)))
  30. def one_hot(indices, depth, dtype, on_value=None, off_value=None, axis=None):
  31. """
  32. generate the one-hot code for input indices
  33. Args:
  34. indices (tvm.tensor.Tensor): defining the input data.
  35. depth (int): defining the depth of the one hot dimension.
  36. dtype (String): "float16" or "float32" or "int" or "int32".
  37. on_value (Scalar): optional. defining the value to fill in the output if indices[i] == j. default 1.
  38. off_value (Scalar): optional. defining the value to fill in the output if indices[i] != j. default 0.
  39. axis (int): optional. The axis to fill. default -1, that means a new inner-most axis.
  40. attrs (dict): optional. Dictionary provide tiling information for poly.
  41. kernel_name (String): optional. the name of the kernel that will be generated.
  42. Returns:
  43. akg.tvm.module. A module that combines both host and device code.
  44. """
  45. ops_dtype_check([indices.dtype, dtype], DtypeForDavinci.INT32.value + DtypeForDavinci.ALL_FLOAT.value)
  46. shape = [x.value for x in indices.shape]
  47. check_shape(shape)
  48. # Tensor of tensor do not support tensor with more than 3 dimensions for now
  49. if len(shape) > 3:
  50. raise RuntimeError("one_hot do not support input shape %d dimensions which is more than 3" % len(shape))
  51. on_value_const = akg.tvm.const(1, dtype) if on_value is None else akg.tvm.const(on_value, dtype)
  52. off_value_const = akg.tvm.const(0, dtype) if off_value is None else akg.tvm.const(off_value, dtype)
  53. if axis is None:
  54. axis = -1
  55. if axis == -1:
  56. axis = len(shape)
  57. if axis <= -2 or axis > len(shape):
  58. raise RuntimeError("axis(%s) is not an valid index" % axis)
  59. in_shape = [x for x in indices.shape]
  60. in_shape.insert(axis, depth)
  61. out_shape = tuple(in_shape)
  62. @script
  63. def one_hot_hybrid_1(indices_in, on_value_const_in, off_value_const_in):
  64. out = output_tensor(out_shape, on_value_const_in.dtype)
  65. m, n = out_shape
  66. for i in range(m):
  67. for j in range(n):
  68. out[i, j] = off_value_const_in
  69. if axis == 0:
  70. for i in range(n):
  71. if indices_in[i] >= 0:
  72. out[indices_in[i], i] = on_value_const_in
  73. else:
  74. for i in range(m):
  75. if indices_in[i] >= 0:
  76. out[i, indices_in[i]] = on_value_const_in
  77. return out
  78. @script
  79. def one_hot_hybrid_2(indices_in, on_value_const_in, off_value_const_in):
  80. out = output_tensor(out_shape, on_value_const_in.dtype)
  81. m, n, k = out.shape
  82. for x in range(m):
  83. for y in range(n):
  84. for z in range(k):
  85. out[x, y, z] = off_value_const_in
  86. if axis == 0:
  87. for i in range(n):
  88. for j in range(k):
  89. if indices_in[i, j] >= 0:
  90. out[indices_in[i, j], i, j] = on_value_const_in
  91. elif axis == 1:
  92. for i in range(m):
  93. for j in range(k):
  94. if indices_in[i, j] >= 0:
  95. out[i, indices_in[i, j], j] = on_value_const_in
  96. else:
  97. for i in range(m):
  98. for j in range(n):
  99. if indices_in[i, j] >= 0:
  100. out[i, j, indices_in[i, j]] = on_value_const_in
  101. return out
  102. @script
  103. def one_hot_hybrid_3(indices_in, on_value_const_in, off_value_const_in):
  104. out = output_tensor(out_shape, on_value_const_in.dtype)
  105. m, n, k, t = out.shape
  106. for x in range(m):
  107. for y in range(n):
  108. for z in range(k):
  109. for u in range(t):
  110. out[x, y, z, u] = off_value_const_in
  111. if axis == 0:
  112. for i in range(n):
  113. for j in range(k):
  114. for c in range(t):
  115. if indices_in[i, j, c] >= 0:
  116. out[indices_in[i, j, c], i, j, c] = on_value_const_in
  117. elif axis == 1:
  118. for i in range(m):
  119. for j in range(k):
  120. for c in range(t):
  121. if indices_in[i, j, c] >= 0:
  122. out[i, indices_in[i, j, c], j, c] = on_value_const_in
  123. elif axis == 2:
  124. for i in range(m):
  125. for j in range(n):
  126. for c in range(t):
  127. if indices_in[i, j, c] >= 0:
  128. out[i, j, indices_in[i, j, c], c] = on_value_const_in
  129. else:
  130. for i in range(m):
  131. for j in range(n):
  132. for c in range(k):
  133. if indices_in[i, j, c] >= 0:
  134. out[i, j, c, indices_in[i, j, c]] = on_value_const_in
  135. return out
  136. if len(shape) == 1:
  137. out = one_hot_hybrid_1(indices, on_value_const, off_value_const)
  138. elif len(shape) == 2:
  139. out = one_hot_hybrid_2(indices, on_value_const, off_value_const)
  140. elif len(shape) == 3:
  141. out = one_hot_hybrid_3(indices, on_value_const, off_value_const)
  142. strategy = onehot_tiling_strategy(out, axis)
  143. attr_map = {"RewriteVarTensorIdx": True}
  144. if strategy:
  145. attr_map["custom_tiling"] = strategy
  146. return out, attr_map
  147. @check_input_type(akg.tvm.tensor.Tensor, akg.tvm.tensor.Tensor, akg.tvm.tensor.Tensor,
  148. int, (int, type(None)))
  149. def one_hot_v2(indices, on_value, off_value, depth, axis=None):
  150. """
  151. generate the one-hot code for input indices
  152. Args:
  153. indices (akg.tvm.tensor.Tensor): defining the input data.
  154. on_value (akg.tvm.tensor.Tensor): defining the value to fill in the output if indices[i] == j.
  155. off_value (akg.tvm.tensor.Tensor): defining the value to fill in the output if indices[i] != j.
  156. depth (int): defining the depth of the one hot dimension.
  157. axis (int): optional. The axis to fill. default -1, that means a new inner-most axis.
  158. attrs (dict): optional. Dictionary provide tiling information for poly.
  159. kernel_name (String): optional. the name of the kernel that will be generated.
  160. Returns:
  161. akg.tvm.module. A module that combines both host and device code.
  162. """
  163. ops_dtype_check(indices.dtype, DtypeForDavinci.INT32)
  164. ops_dtype_check([on_value.dtype, off_value.dtype], [DtypeForDavinci.INT32, DtypeForDavinci.ALL_FLOAT])
  165. shape = [x.value for x in indices.shape]
  166. check_shape(shape)
  167. # OneHot do not support tensor with more than 3 dimensions for now
  168. if len(shape) > 3:
  169. raise RuntimeError("one_hot do not support input shape %d dimensions which is more than 3" % len(shape))
  170. if axis is None:
  171. axis = -1
  172. if axis == -1:
  173. axis = len(shape)
  174. if axis <= -2 or axis > len(shape):
  175. raise RuntimeError("axis(%s) is not an valid index" % axis)
  176. in_shape = [x for x in indices.shape]
  177. in_shape.insert(axis, depth)
  178. out_shape = tuple(in_shape)
  179. @script
  180. def one_hot_hybrid_1(indices_in, on_value_const_in, off_value_const_in):
  181. out = output_tensor(out_shape, on_value_const_in.dtype)
  182. m, n = out_shape
  183. for i in range(m):
  184. for j in range(n):
  185. out[i, j] = off_value_const_in[0]
  186. if axis == 0:
  187. for i in range(n):
  188. if indices_in[i] >= 0:
  189. out[indices_in[i], i] = on_value_const_in[0]
  190. else:
  191. for i in range(m):
  192. if indices_in[i] >= 0:
  193. out[i, indices_in[i]] = on_value_const_in[0]
  194. return out
  195. @script
  196. def one_hot_hybrid_2(indices_in, on_value_const_in, off_value_const_in):
  197. out = output_tensor(out_shape, on_value_const_in.dtype)
  198. m, n, k = out.shape
  199. for x in range(m):
  200. for y in range(n):
  201. for z in range(k):
  202. out[x, y, z] = off_value_const_in[0]
  203. if axis == 0:
  204. for i in range(n):
  205. for j in range(k):
  206. if indices_in[i, j] >= 0:
  207. out[indices_in[i, j], i, j] = on_value_const_in[0]
  208. elif axis == 1:
  209. for i in range(m):
  210. for j in range(k):
  211. if indices_in[i, j] >= 0:
  212. out[i, indices_in[i, j], j] = on_value_const_in[0]
  213. else:
  214. for i in range(m):
  215. for j in range(n):
  216. if indices_in[i, j] >= 0:
  217. out[i, j, indices_in[i, j]] = on_value_const_in[0]
  218. return out
  219. @script
  220. def one_hot_hybrid_3(indices_in, on_value_const_in, off_value_const_in):
  221. out = output_tensor(out_shape, on_value_const_in.dtype)
  222. m, n, k, t = out.shape
  223. for x in range(m):
  224. for y in range(n):
  225. for z in range(k):
  226. for u in range(t):
  227. out[x, y, z, u] = off_value_const_in[0]
  228. if axis == 0:
  229. for i in range(n):
  230. for j in range(k):
  231. for c in range(t):
  232. if indices_in[i, j, c] >= 0:
  233. out[indices_in[i, j, c], i, j, c] = on_value_const_in[0]
  234. elif axis == 1:
  235. for i in range(m):
  236. for j in range(k):
  237. for c in range(t):
  238. if indices_in[i, j, c] >= 0:
  239. out[i, indices_in[i, j, c], j, c] = on_value_const_in[0]
  240. elif axis == 2:
  241. for i in range(m):
  242. for j in range(n):
  243. for c in range(t):
  244. if indices_in[i, j, c] >= 0:
  245. out[i, j, indices_in[i, j, c], c] = on_value_const_in[0]
  246. else:
  247. for i in range(m):
  248. for j in range(n):
  249. for c in range(k):
  250. if indices_in[i, j, c] >= 0:
  251. out[i, j, c, indices_in[i, j, c]] = on_value_const_in[0]
  252. return out
  253. if len(shape) == 1:
  254. out = one_hot_hybrid_1(indices, on_value, off_value)
  255. elif len(shape) == 2:
  256. out = one_hot_hybrid_2(indices, on_value, off_value)
  257. elif len(shape) == 3:
  258. out = one_hot_hybrid_3(indices, on_value, off_value)
  259. strategy = onehot_tiling_strategy(out, axis)
  260. attr_map = {"RewriteVarTensorIdx": True}
  261. if strategy:
  262. attr_map["custom_tiling"] = strategy
  263. return out, attr_map