You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

equal.py 2.1 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354
  1. # Copyright 2019 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """operator dsl function: equal"""
  15. import _akg.tvm
  16. import _akg.topi
  17. from _akg.utils.dsl_create import produce_shapes
  18. from _akg.utils import validation_check as vc_util
  19. @vc_util.check_input_type(_akg.tvm.tensor.Tensor, _akg.tvm.tensor.Tensor)
  20. def equal(input1, input2):
  21. """
  22. check whether input1 equals to input2.
  23. Args:
  24. input1 (tvm.tensor.Tensor): Tensor.
  25. input2 (tvm.tensor.Tensor): Tensor.
  26. Returns:
  27. tvm.tensor.Tensor. If input1 equal to input2 return True, else return False.
  28. """
  29. shape1 = [x.value for x in input1.shape]
  30. shape2 = [x.value for x in input2.shape]
  31. vc_util.check_shape(shape1)
  32. vc_util.check_shape(shape2)
  33. shape1, shape2, shape = produce_shapes(shape1, shape2)
  34. vc_util.elemwise_dtype_check(input1.dtype, input2.dtype)
  35. dtype = input1.dtype
  36. # get equal compute
  37. t_value = _akg.tvm.compute(shape, lambda *indice: _akg.tvm.const(1, dtype), "T")
  38. f_value = _akg.tvm.compute(shape, lambda *indice: _akg.tvm.const(0, dtype), "F")
  39. input1_bro = _akg.topi.broadcast_to(input1, shape)
  40. input2_bro = _akg.topi.broadcast_to(input2, shape)
  41. c_out = _akg.tvm.compute(shape, lambda *indice: _akg.tvm.expr.Select(input1_bro[indice] == input2_bro[indice],
  42. t_value[indice], f_value[indice]), name="C")
  43. res = _akg.tvm.compute(shape, lambda *indice: c_out(*indice).astype("bool"), name="res")
  44. return res