You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990
  1. #!/usr/bin/env python3
  2. # coding: utf-8
  3. # Copyright 2019 Huawei Technologies Co., Ltd
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. """operator dsl function:div"""
  17. import akg.tvm
  18. import akg.topi
  19. from akg.ops.math.cast import cast
  20. from akg.ops.math.floor import floor
  21. from akg.utils import validation_check as vc_util
  22. from akg.utils.dsl_create import produce_shapes
  23. from akg.utils import kernel_exec as utils
  24. from akg.ops.math.reciprocal import reciprocal
  25. @vc_util.check_input_type(akg.tvm.tensor.Tensor, akg.tvm.tensor.Tensor)
  26. def div(data1, data2):
  27. """
  28. Calculates x/y, and returns an integer when inputs are all integers.
  29. When both arguments are integers, use integer division (also known as "floor division").
  30. When arguments are float numbers, use normal floating point division
  31. Note:
  32. div supports broadcasting.
  33. Args:
  34. data1 (tvm.tensor.Tensor): Tensor of type float16, float32, int32, int8 and uint8.
  35. data2 (tvm.tensor.Tensor): Tensor of type float16, float32, int32, int8 and uint8.
  36. Returns:
  37. tvm.tensor.Tensor, has the same type as data1 and data2.
  38. """
  39. vc_util.ops_dtype_check([data1.dtype, data2.dtype], vc_util.DtypeForDavinci.ALL_TYPES)
  40. vc_util.elemwise_dtype_check(data1.dtype, data2.dtype)
  41. dtype = data1.dtype
  42. shape1 = [x.value for x in data1.shape]
  43. shape2 = [x.value for x in data2.shape]
  44. vc_util.check_shape(shape1)
  45. vc_util.check_shape(shape2)
  46. vc_util.auto_broadcast_check(shape1, shape2)
  47. n_shape1, n_shape2, out_shape = produce_shapes(shape1, shape2)
  48. if n_shape1 != out_shape:
  49. input1_cast = akg.topi.broadcast_to(data1, out_shape)
  50. else:
  51. input1_cast = data1
  52. if n_shape2 != out_shape:
  53. input2_cast = akg.topi.broadcast_to(data2, out_shape)
  54. else:
  55. input2_cast = data2
  56. if dtype in ("int32", "int8", "uint8"):
  57. input1p = cast(input1_cast, "float16")
  58. input2p = cast(input2_cast, "float16")
  59. else:
  60. input1p = input1_cast
  61. input2p = input2_cast
  62. if utils.product_is_mini():
  63. input2p_rec = reciprocal(input2p)
  64. res = akg.topi.multiply(input1p, input2p_rec)
  65. else:
  66. res = akg.topi.divide(input1p, input2p)
  67. if dtype in ("int8", "uint8"):
  68. res = floor(res)
  69. res = cast(res, "float16")
  70. if dtype in ("int32", "int8", "uint8"):
  71. res = cast(res, dtype)
  72. return res