You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

select.py 3.8 kB

5 years ago
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """operator dsl function: select"""
  15. import akg.topi
  16. import akg.tvm
  17. import akg.lang.cce
  18. from akg.utils import validation_check as vc_util
  19. from akg.utils.format_transform import get_shape
  20. from akg.utils import kernel_exec as utils
  21. VALUE_ONE = 1
  22. def select_compute(condition, x1, x2):
  23. """select compute implementation"""
  24. shape = get_shape(x1)
  25. con_shape = get_shape(condition)
  26. num_dtype = x1.dtype
  27. bool_dtype = condition.dtype
  28. if num_dtype in ("int8", "uint8"):
  29. x1_dtype = "float32"
  30. ones = akg.lang.cce.broadcast(akg.tvm.const(VALUE_ONE, dtype="float32"),
  31. shape, output_dtype="float32")
  32. x1 = akg.topi.cast(x1, "float32")
  33. x2 = akg.topi.cast(x2, "float32")
  34. else:
  35. x1_dtype = num_dtype
  36. ones = akg.lang.cce.broadcast(akg.tvm.const(VALUE_ONE, dtype=num_dtype),
  37. shape, output_dtype=num_dtype)
  38. if bool_dtype == "int8":
  39. if x1_dtype == "int32":
  40. condition_dtype = akg.lang.cce.ceil(condition)
  41. else:
  42. condition_dtype = akg.topi.cast(condition, x1_dtype)
  43. else:
  44. if x1_dtype == "int32":
  45. condition_dtype = condition
  46. else:
  47. condition_dtype = akg.topi.cast(condition, x1_dtype)
  48. if list(con_shape) != list(shape):
  49. condition_dtype = akg.lang.cce.broadcast(condition_dtype, shape)
  50. vinsn_support_dtype = ("float16", "float32")
  51. if utils.product_is_mini():
  52. vinsn_support_dtype = ("float16", )
  53. if num_dtype in vinsn_support_dtype:
  54. res = akg.topi.where(condition_dtype, x1, x2)
  55. else:
  56. condition_opp = akg.lang.cce.vsub(ones, condition_dtype)
  57. temp_x = akg.lang.cce.vmul(x1, condition_dtype)
  58. temp_y = akg.lang.cce.vmul(x2, condition_opp)
  59. res = akg.lang.cce.vadd(temp_x, temp_y)
  60. if num_dtype in ("int8", "uint8"):
  61. res = akg.topi.cast(res, num_dtype)
  62. return res
  63. @vc_util.check_input_type(akg.tvm.tensor.Tensor, akg.tvm.tensor.Tensor, akg.tvm.tensor.Tensor)
  64. def select(condition, x1, x2):
  65. """
  66. Selects elements from x1 or x2, depending on condition.
  67. Note:
  68. every parmas' shape need legal, can support condition's shape broadcast.
  69. Args:
  70. condition (tvm.tensor.Tensor): Tensor of type int8, int32, must be 0 or 1.
  71. x1 (tvm.tensor.Tensor): Tensor of type float16, float32, int8, int32, uint8.
  72. x2 (tvm.tensor.Tensor): Tensor of type float16, float32, int8, int32, uint8.
  73. Returns:
  74. tvm.tensor.Tensor, has the same type and shape as x1.
  75. """
  76. shape_x1 = get_shape(x1)
  77. shape_x2 = get_shape(x2)
  78. con_shape = get_shape(condition)
  79. vc_util.elemwise_shape_check(shape_x1, shape_x2)
  80. vc_util.elemwise_dtype_check(x1.dtype, x2.dtype, [vc_util.DtypeForDavinci.ALL_FLOAT,
  81. vc_util.DtypeForDavinci.INT8, vc_util.DtypeForDavinci.INT32,
  82. vc_util.DtypeForDavinci.UINT8])
  83. vc_util.ops_dtype_check(condition.dtype, [vc_util.DtypeForDavinci.INT8, vc_util.DtypeForDavinci.INT32])
  84. vc_util.auto_broadcast_check(con_shape, shape_x1)
  85. res = select_compute(condition, x1, x2)
  86. return res