You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

argmin_argmax_common.py 6.7 kB

5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170
  1. #!/usr/bin/env python3
  2. # coding: utf-8
  3. # Copyright 2019 Huawei Technologies Co., Ltd
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. """operator dsl function:argmin_argmax_common"""
  17. import akg.tvm
  18. import akg.topi
  19. from akg.lang import cce as dav
  20. from akg.utils import custom_tiling as ct_util, validation_check as vc_util
  21. from akg.utils.dsl_create import get_reduce_out_shape
  22. from akg.utils.format_transform import refine_reduce_axis, get_shape
  23. from akg.utils.dynamic_shape import shape_is_dynamic, set_dynamic_shape_limit_for_tensor
  24. def argminmax_tiling_strategy(out_shape, axis):
  25. """Custom tiling strategy for argminmax op."""
  26. strategy = list()
  27. # when reduce axis is one, we do not need any strategy
  28. if out_shape[axis] == 1:
  29. return strategy
  30. # if reduce first axis, it will transpose to last axis
  31. # so here we adapt to this change
  32. if axis == 0:
  33. temp = out_shape[0]
  34. out_shape = out_shape[1:]
  35. out_shape.append(temp)
  36. axis = len(out_shape) - 1
  37. # eliminate single axis, which will automatically disappear in halide ir
  38. # and adjust axis if it is influenced
  39. shrink = list()
  40. for i, shp in enumerate(out_shape):
  41. if shp == 1:
  42. if i < axis:
  43. axis -= 1
  44. else:
  45. shrink.append(shp)
  46. for i, _ in enumerate(shrink):
  47. if i == axis:
  48. strategy.append(ct_util.create_constraint_on_axis(
  49. values="FULL",
  50. constraints=ct_util.TileConstraint.MAX,
  51. axis=i)[0])
  52. else:
  53. strategy.append(ct_util.create_constraint_on_axis(
  54. values=1,
  55. constraints=ct_util.TileConstraint.FACTOR,
  56. axis=i)[0])
  57. return strategy
  58. @vc_util.check_input_type(akg.tvm.tensor.Tensor, int, (str, type(None)))
  59. def common(data, axis, method="min"):
  60. """
  61. Returns the index with the max or min value across axes of a tensor.
  62. Note:
  63. method can be "max" or "min" to get argmax or argmin.
  64. Args:
  65. data (tvm.tensor.Tensor): Tensor of type float16, float32, int8, int32.
  66. axis (int): Describe the axis of input tensor.
  67. method (str): Can be "max" or "min".
  68. Returns:
  69. tvm.tensor.Tensor, has type of int32.
  70. """
  71. shape = get_shape(data)
  72. dtype = data.dtype
  73. vc_util.ops_dtype_check(data.dtype, [vc_util.DtypeForDavinci.ALL_FLOAT, vc_util.DtypeForDavinci.ALL_INT])
  74. vc_util.reduce_axis_check(shape, axis)
  75. real_axis = refine_reduce_axis(shape, axis)[0]
  76. out_shape = get_reduce_out_shape(shape, axis=axis)
  77. attr_map = {}
  78. if shape_is_dynamic(data):
  79. attr_map["dynamic_shape"] = set_dynamic_shape_limit_for_tensor(data, 4096, real_axis)
  80. if dtype != "float16":
  81. data = akg.topi.cast(data, "float16")
  82. k = akg.tvm.reduce_axis((0, data.shape[real_axis]), "k")
  83. if axis in (len(shape) - 1, -1):
  84. if method == "min":
  85. reducer = akg.tvm.comm_reducer(
  86. lambda x, y: dav.fargmin(x, y), lambda t: akg.tvm.max_value(t))
  87. elif method == "max":
  88. reducer = akg.tvm.comm_reducer(
  89. lambda x, y: dav.fargmax(x, y), lambda t: akg.tvm.min_value(t))
  90. else:
  91. raise ValueError("not support " + method)
  92. if len(data.shape) == 1:
  93. res = akg.tvm.compute((1,), lambda i: reducer(data[k], axis=k))
  94. else:
  95. res = akg.tvm.compute(out_shape,
  96. lambda *indice:
  97. reducer(data(*indice, k), axis=k))
  98. res = akg.tvm.compute(out_shape,
  99. lambda *indice: res(*indice).astype("int32"),
  100. "argred_output")
  101. elif axis in (0, -len(shape)):
  102. tmp_idx = akg.tvm.compute(shape[1:],
  103. lambda *indice: akg.tvm.const(0.0, "float16"),
  104. name='tmp_index')
  105. local_data = akg.tvm.compute(shape[1:],
  106. lambda *indice: data(0, *indice),
  107. name="tmp_data")
  108. for idx in range(shape[axis] - 1):
  109. if method == 'min':
  110. tmp_idx = akg.tvm.compute(
  111. shape[1:],
  112. lambda *indice, ite_idx=idx:
  113. akg.tvm.expr.Select(
  114. local_data(*indice) > data(ite_idx + 1, *indice),
  115. akg.tvm.const(ite_idx + 1, "float16"),
  116. tmp_idx(*indice)
  117. ))
  118. local_data = akg.tvm.compute(
  119. shape[1:],
  120. lambda *indice, ite_idx=idx:
  121. akg.tvm.expr.Select(
  122. local_data(*indice) > data(ite_idx + 1, *indice),
  123. data(ite_idx + 1, *indice),
  124. local_data(*indice)
  125. ))
  126. elif method == "max":
  127. tmp_idx = akg.tvm.compute(
  128. shape[1:],
  129. lambda *indice, ite_idx=idx:
  130. akg.tvm.expr.Select(
  131. local_data(*indice) < data(ite_idx + 1, *indice),
  132. akg.tvm.const(ite_idx + 1, "float16"),
  133. tmp_idx(*indice)
  134. ))
  135. local_data = akg.tvm.compute(
  136. shape[1:],
  137. lambda *indice, ite_idx=idx:
  138. akg.tvm.expr.Select(
  139. local_data(*indice) < data(ite_idx + 1, *indice),
  140. data(ite_idx + 1, *indice),
  141. local_data(*indice)
  142. ))
  143. else:
  144. raise ValueError("not support " + method)
  145. res = akg.tvm.compute(out_shape,
  146. lambda *indice: tmp_idx(*indice).astype("int32"),
  147. "cast1")
  148. else:
  149. raise ValueError("Argmax only support first axis and is last axis now!")
  150. lager = out_shape if len(out_shape) > len(shape) else shape
  151. strategy = argminmax_tiling_strategy(lager, real_axis)
  152. if strategy:
  153. attr_map["custom_tiling"] = strategy
  154. return res, attr_map