You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

strided_slice.py 9.5 kB

5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243
  1. #!/usr/bin/env python3
  2. # coding: utf-8
  3. # Copyright 2019 Huawei Technologies Co., Ltd
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. """operator dsl function: strided_slice"""
  17. import copy
  18. import numpy as np
  19. import akg.topi
  20. import akg.tvm
  21. from akg.utils import validation_check as vc_util
  22. def check_args(begin, end, strides, begin_mask, end_mask, ellipsis_mask, new_axis_mask, shrink_axis_mask):
  23. """check args."""
  24. if len(begin) != len(end):
  25. raise Exception("len(begin) is {}, len(end) is {}. They must be identical!".format(len(begin), len(end)))
  26. if strides is not None:
  27. if len(begin) != len(strides):
  28. raise Exception("len(begin) is {}, len(strides) is {}. They must be identical!".
  29. format(len(begin), len(strides)))
  30. for s in strides:
  31. if s == 0:
  32. raise Exception("Value in strides[{}] must not be 0!".format(strides))
  33. if begin_mask < 0 or begin_mask >= (2 ** len(begin)):
  34. raise Exception("Illegal begin_mask[{}]".format(begin_mask))
  35. if end_mask < 0 or end_mask >= (2 ** len(begin)):
  36. raise Exception("Illegal end_mask[{}]".format(end_mask))
  37. if ellipsis_mask < 0 or ellipsis_mask >= (2 ** len(begin)):
  38. raise Exception("Illegal ellipsis_mask[{}]".format(ellipsis_mask))
  39. if ellipsis_mask != 0: # ellipsis_mask must be a power of two (only one ellipsis)
  40. if ellipsis_mask & (ellipsis_mask - 1) != 0:
  41. raise Exception("ellipsis_mask[{}] is not power of two (only one ellipsis).".format(ellipsis_mask))
  42. if new_axis_mask < 0 or new_axis_mask >= (2 ** len(begin)):
  43. raise Exception("Illegal new_axis_mask[{}]".format(new_axis_mask))
  44. if shrink_axis_mask < 0 or shrink_axis_mask >= (2 ** len(begin)):
  45. raise Exception("Illegal shrink_axis_mask[{}]".format(shrink_axis_mask))
  46. def args_to_slices(begin, end, strides, begin_mask, end_mask, ellipsis_mask, new_axis_mask, shrink_axis_mask):
  47. """args to slice."""
  48. check_args(begin, end, strides, begin_mask, end_mask, ellipsis_mask, new_axis_mask, shrink_axis_mask)
  49. slices = []
  50. for dim, bgn in enumerate(begin):
  51. if (ellipsis_mask >> dim) & 1:
  52. slices.append(Ellipsis)
  53. elif (new_axis_mask >> dim) & 1:
  54. slices.append(np.newaxis)
  55. elif (shrink_axis_mask >> dim) & 1:
  56. slices.append(bgn)
  57. else:
  58. start = None if (begin_mask >> dim) & 1 else bgn
  59. stop = None if (end_mask >> dim) & 1 else end[dim]
  60. step = strides[dim]
  61. slices.append(slice(start, stop, step))
  62. return slices
  63. def slices_to_args(slices=()):
  64. """slice to args."""
  65. begin = []
  66. end = []
  67. strides = []
  68. begin_mask = 0
  69. end_mask = 0
  70. ellipsis_mask = 0
  71. new_axis_mask = 0
  72. shrink_axis_mask = 0
  73. for i, arg in enumerate(slices):
  74. if isinstance(arg, slice):
  75. begin.append(0 if arg.start is None else arg.start)
  76. if arg.start is None:
  77. begin_mask |= 1 << i
  78. end.append(0 if arg.stop is None else arg.stop)
  79. if arg.stop is None:
  80. end_mask |= 1 << i
  81. strides.append(1 if arg.step is None else arg.step)
  82. elif arg is np.newaxis:
  83. begin.append(0)
  84. end.append(0)
  85. strides.append(1)
  86. new_axis_mask |= 1 << i
  87. elif arg is Ellipsis:
  88. begin.append(0)
  89. end.append(0)
  90. strides.append(1)
  91. ellipsis_mask |= 1 << i
  92. elif isinstance(arg, int):
  93. begin.append(arg)
  94. end.append(arg + 1)
  95. strides.append(1)
  96. shrink_axis_mask |= 1 << i
  97. else:
  98. raise Exception("arg ", arg, ' is invalid')
  99. return begin, end, strides, begin_mask, end_mask, ellipsis_mask, new_axis_mask, shrink_axis_mask
  100. def complete_args(inputs_shape, begin, end, strides, begin_mask,
  101. end_mask, ellipsis_mask, new_axis_mask, shrink_axis_mask):
  102. """complete args."""
  103. check_args(begin, end, strides, begin_mask, end_mask, ellipsis_mask, new_axis_mask, shrink_axis_mask)
  104. # step0: deep copy begin, end, strides
  105. begin = copy.copy(begin)
  106. end = copy.copy(end)
  107. strides = copy.copy(strides)
  108. # step1: store all bits and calculate new_axis_count
  109. check_args(begin, end, strides, begin_mask, end_mask, ellipsis_mask, new_axis_mask, shrink_axis_mask)
  110. begin_list = [(begin_mask >> dim) & 1 for dim in range(len(begin))]
  111. end_list = [(end_mask >> dim) & 1 for dim in range(len(begin))]
  112. ellipsis_list = [(ellipsis_mask >> dim) & 1 for dim in range(len(begin))]
  113. new_axis_list = [(new_axis_mask >> dim) & 1 for dim in range(len(begin))]
  114. new_axis_count = len([dim for dim in range(len(begin)) if (new_axis_mask >> dim) & 1])
  115. shrink_list = [(shrink_axis_mask >> dim) & 1 for dim in range(len(begin))]
  116. # step2: fill the ellipsis using ellipsis_list
  117. ellipsis_idx = None
  118. for idx, x in enumerate(ellipsis_list):
  119. if x:
  120. ellipsis_idx = idx
  121. break
  122. if ellipsis_idx is not None:
  123. ellipsis_length = len(inputs_shape) - (len(begin) - 1 - new_axis_count)
  124. idx = ellipsis_idx
  125. begin.pop(idx)
  126. end.pop(idx)
  127. strides.pop(idx)
  128. begin_list.pop(idx)
  129. end_list.pop(idx)
  130. ellipsis_list.pop(idx)
  131. new_axis_list.pop(idx)
  132. shrink_list.pop(idx)
  133. for _ in range(ellipsis_length):
  134. begin.insert(idx, None)
  135. end.insert(idx, None)
  136. strides.insert(idx, 1)
  137. begin_list.insert(idx, 1)
  138. end_list.insert(idx, 1)
  139. ellipsis_list.insert(idx, 0)
  140. new_axis_list.insert(idx, 0)
  141. shrink_list.insert(idx, 0)
  142. # step3: remove new_axis using new_axis_list
  143. new_axis_index = [idx for idx, x in enumerate(new_axis_list) if x]
  144. for idx in new_axis_index[::-1]:
  145. begin.pop(idx)
  146. end.pop(idx)
  147. strides.pop(idx)
  148. begin_list.pop(idx)
  149. end_list.pop(idx)
  150. ellipsis_list.pop(idx)
  151. shrink_list.pop(idx)
  152. new_axis_list.pop(idx)
  153. # step4: update (begin, end, strides) using (shrink_list, begin_list, end_list)
  154. for dim, bgn in enumerate(begin):
  155. if shrink_list[dim]:
  156. end[dim] = bgn + 1
  157. strides[dim] = 1
  158. continue
  159. if begin_list[dim]:
  160. begin[dim] = 0
  161. if end_list[dim]:
  162. end[dim] = inputs_shape[dim]
  163. return begin, end, strides, new_axis_index, shrink_list
  164. @vc_util.check_input_type(akg.tvm.tensor.Tensor, ((list, tuple), int), ((list, tuple), int),
  165. ((list, tuple), int), int, int, int, int, int)
  166. def strided_slice(inputs, begin, end, strides,
  167. begin_mask, end_mask, ellipsis_mask, new_axis_mask, shrink_axis_mask):
  168. """
  169. Generate an array by slicing input tensor
  170. Args:
  171. inputs (tvm.tensor.Tensor): Tensor of type float16, float32.
  172. begin (Union[list, tuple, int]): The start indexes for slicing.
  173. end (Union[list, tuple, int]): The end indexes for slicing.
  174. strides (Union[list, tuple, int]): The strides for slicing.
  175. begin_mask (int): int32 mask for begin indexes.
  176. end_mask (int): int32 mask for end indexes.
  177. ellipsis_mask (int): int32 mask for inserting unspecified dimensions.
  178. new_axis_mask (int): int32 mask for new dim with length 1.
  179. shrink_axis_mask (int): int32 mask for shrinking the dims.
  180. Returns:
  181. tvm.tensor.Tensor, with the same dtype as inputs.
  182. """
  183. shape = [x.value for x in inputs.shape]
  184. # step0~4: complete begin, end, strides
  185. begin, end, strides, new_axis_index, shrink_list = complete_args(shape, begin, end, strides,
  186. begin_mask, end_mask, ellipsis_mask,
  187. new_axis_mask, shrink_axis_mask)
  188. # step5: use topi to do strided_slice using begin, end, strides
  189. if (shape == [1] and begin == end):
  190. return akg.tvm.compute(shape, lambda *i: inputs(*i), name="out")
  191. if inputs.dtype == "uint8":
  192. inputs_cast = akg.topi.cast(inputs, "int8")
  193. else:
  194. inputs_cast = inputs
  195. out1 = akg.topi.strided_slice(inputs_cast, begin, end, strides)
  196. # step6: increase out_tensor's dim using new_axis_index
  197. new_shape = list(out1.shape)
  198. for idx in new_axis_index[::-1]:
  199. new_shape.insert(idx, 1)
  200. # step7: decrease out_tensor's dim using shrink_list
  201. for idx in new_axis_index[::-1]:
  202. shrink_list.insert(idx, 0)
  203. shrink_axis_index = [idx for idx, x in enumerate(shrink_list) if x]
  204. for idx in shrink_axis_index[::-1]:
  205. new_shape.pop(idx)
  206. # step8: reshape out_tensor
  207. out2 = akg.topi.reshape(out1, tuple(new_shape))
  208. if inputs.dtype == "uint8":
  209. out2 = akg.topi.cast(out2, "uint8")
  210. return out2