You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

add.py 4.7 kB

5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113
  1. #!/usr/bin/env python3
  2. # coding: utf-8
  3. # Copyright 2019 Huawei Technologies Co., Ltd
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. """operator dsl function: add"""
  17. import akg.topi
  18. import akg.tvm
  19. from akg.lang.cce import vadd, vmuls
  20. from akg.utils import validation_check as vc_util
  21. from akg.utils.dsl_create import produce_shapes
  22. from akg.utils.format_transform import get_shape
  23. from akg.utils.dynamic_shape import shape_is_dynamic
  24. @vc_util.check_input_type(akg.tvm.tensor.Tensor, akg.tvm.tensor.Tensor,
  25. (int, float, type(None)), (bool, type(None)), (dict, type(None)))
  26. def add(first_input, second_input, scale=1.0, polyhedral=True, attrs=None):
  27. """
  28. Computes first_input + second_input * scale elementwise.
  29. Args:
  30. first_input (tvm.tensor.Tensor): Tensor of type float16, float32, int32.
  31. second_input (tvm.tensor.Tensor): Tensor with same type as first_input.
  32. Broadcast will happen if shapes of input tensors are different.
  33. scale (float): scale factor applied on second_input, default value is 1.0.
  34. polyhedral (bool): If True, use auto-schedule, else use manual-schedule, default value is True.
  35. attrs (dict): Specifies parameters used in manual-schedule.
  36. Returns:
  37. tvm.tensor.Tensor of same type as input tensor with shape the broadcast shape of input tensors.
  38. """
  39. vc_util.check_shape(first_input.shape)
  40. vc_util.check_shape(second_input.shape)
  41. attr_map = {}
  42. first_input_shape = get_shape(first_input)
  43. second_input_shape = get_shape(second_input)
  44. if shape_is_dynamic([first_input, second_input]):
  45. if first_input_shape != second_input_shape:
  46. raise RuntimeError("Input tensors have different shapes, broadcast is not supported for dynamic.")
  47. first_broadcast = first_input
  48. second_broadcast = second_input
  49. else:
  50. if first_input_shape != second_input_shape:
  51. _, _, out_shape = produce_shapes(first_input_shape, second_input_shape)
  52. else:
  53. out_shape = first_input_shape
  54. first_broadcast = akg.topi.broadcast_to(first_input, out_shape)
  55. second_broadcast = akg.topi.broadcast_to(second_input, out_shape)
  56. first_input_type = first_input.dtype
  57. second_input_type = second_input.dtype
  58. if first_input_type != second_input_type:
  59. raise TypeError("Input tensors have different data types.")
  60. vc_util.ops_dtype_check(first_input_type, vc_util.DtypeForDavinci.ALL_TYPES)
  61. temp = vmuls(second_broadcast, scale)
  62. res = vadd(first_broadcast, temp)
  63. res_cast = res.astype(first_input_type)
  64. if polyhedral:
  65. return res_cast, attr_map
  66. def comp_func(s):
  67. first_ub = s.cache_read(first_input, "local.UB", [first_broadcast])
  68. second_ub = s.cache_read(second_input, "local.UB", [second_broadcast])
  69. res_cast_ub = s.cache_write(res_cast, "local.UB")
  70. s[first_broadcast].set_scope("local.UB")
  71. s[second_broadcast].set_scope("local.UB")
  72. s[temp].set_scope("local.UB")
  73. s[res].set_scope("local.UB")
  74. split_axis = []
  75. for i in range(len(attrs["tile"])):
  76. outer, inner = s[res_cast].split(res_cast.op.axis[i], attrs["tile"][i])
  77. axis_dict = {"outer": outer, "inner": inner}
  78. split_axis.append(axis_dict)
  79. s[first_ub].compute_at(s[res], res.op.axis[0])
  80. s[second_ub].compute_at(s[res], res.op.axis[0])
  81. s[first_broadcast].compute_at(s[res], res.op.axis[0])
  82. s[second_broadcast].compute_at(s[res], res.op.axis[0])
  83. s[temp].compute_at(s[res], res.op.axis[0])
  84. s[res].compute_at(s[res_cast_ub], res_cast_ub.op.axis[0])
  85. s[res_cast_ub].compute_at(s[res_cast], split_axis[-1]['outer'])
  86. # no scaling nedeed
  87. if scale == 1:
  88. s[temp].compute_inline()
  89. # no broadcast needed
  90. if first_input_shape == second_input_shape:
  91. s[first_broadcast].compute_inline()
  92. s[second_broadcast].compute_inline()
  93. return res_cast, comp_func, attr_map