You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

mean.py 5.5 kB

5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140
  1. #!/usr/bin/env python3
  2. # coding: utf-8
  3. # Copyright 2019 Huawei Technologies Co., Ltd
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. """operator dsl function: mean"""
  17. import akg.topi
  18. import akg.tvm
  19. from akg.ops.math import sum
  20. from akg.utils import format_transform as ft_util
  21. from akg.utils import validation_check as vc_util
  22. from akg.utils import custom_tiling as ct_util
  23. from akg.utils.dynamic_shape import shape_is_dynamic
  24. INT16_MAX = 65536
  25. def get_attrs(tensor):
  26. """generate default attrs."""
  27. if shape_is_dynamic(tensor):
  28. return {"enable_double_buffer": 0, "enable_divide_var": 1}
  29. return {}
  30. def mean_dynamic_tiling_strategy(tensor, axis):
  31. """custom tiling for mean with dynamic shape"""
  32. strategy = list()
  33. inner_most_to_full = True
  34. resnet_inner_most_axis_pos = 4
  35. reduce_axis_to_1 = True
  36. reduce_axis_to_no_iso = False
  37. multicore_axis_to_1 = True
  38. resnet_outer_most_axis_pos = 0
  39. if inner_most_to_full:
  40. strategy += ct_util.create_constraint_on_tensor(tensor=tensor,
  41. values="FULL",
  42. constraints=ct_util.TileConstraint.MAX,
  43. tensor_pos=resnet_inner_most_axis_pos)
  44. if reduce_axis_to_1:
  45. strategy += ct_util.create_constraint_on_tensor(tensor=tensor,
  46. values=[1 for _ in axis],
  47. constraints=ct_util.TileConstraint.FACTOR,
  48. tensor_pos=axis)
  49. elif reduce_axis_to_no_iso:
  50. strategy += ct_util.create_constraint_on_tensor(tensor=tensor,
  51. values=[1 for _ in axis],
  52. constraints=ct_util.TileConstraint.FORBID_ISOLATE,
  53. tensor_pos=axis)
  54. if multicore_axis_to_1:
  55. strategy += ct_util.create_constraint_on_tensor(tensor=tensor,
  56. values=1,
  57. constraints=ct_util.TileConstraint.FACTOR,
  58. tensor_pos=resnet_outer_most_axis_pos)
  59. return strategy
  60. @vc_util.check_input_type(akg.tvm.tensor.Tensor, (list, tuple, int, type(None)), (bool, type(None)))
  61. def mean(data, axis=None, keepdims=False):
  62. """
  63. Computes the mean of the values of a Tensor over the whole dataset.
  64. Note:
  65. If the tuple's elements are unsorted, this function will call preprocess_axis firstly to let these elements
  66. sorted. if tuple is empty, this function will compute all elements' sum.
  67. if the data type is folat 16 and the whole dim not less than 65536, this function will compute the mean by
  68. divide 65535 first to avoid whole dim too large.
  69. Args:
  70. data (tvm.tensor.Tensor): Tensor of type float16, float32.
  71. axis (Union[list, tuple, int, None]): If the tuple is empty, the axis equal to None.
  72. keepdims (bool): If keepdims equal to True, the result shape length is same to input shape length.
  73. Returns:
  74. tvm.tensor.Tensor, has the same type as data. If keepdims equal to True, all reduced dimensions are
  75. retained with length 1. else these reduced axis will be eliminate.
  76. """
  77. # Check types
  78. vc_util.ops_dtype_check(data.dtype, vc_util.DtypeForDavinci.ALL_FLOAT)
  79. # Check shape
  80. shape = ft_util.get_shape(data)
  81. vc_util.reduce_axis_check(shape, axis)
  82. axis = ft_util.refine_reduce_axis(data, axis)
  83. count = 1
  84. for i in axis:
  85. count *= shape[i]
  86. output, _ = sum.sum_value(data, axis, keepdims)
  87. if shape_is_dynamic(data):
  88. res = akg.tvm.compute(output.shape, lambda *i: akg.lang.cce.divide_var(output(*i), count), name="res")
  89. else:
  90. res = akg.topi.divide(output, count)
  91. attrs = get_attrs(data)
  92. if shape_is_dynamic(data):
  93. attrs["custom_tiling"] = mean_dynamic_tiling_strategy(data, axis)
  94. return res, attrs
  95. @vc_util.check_input_type(akg.tvm.tensor.Tensor, (list, tuple, int, type(None)), (bool, type(None)))
  96. def mean_v2(data, axis=None, keepdims=False):
  97. """Simple implementation of mean."""
  98. # Check types
  99. vc_util.ops_dtype_check(data.dtype, vc_util.DtypeForDavinci.ALL_FLOAT)
  100. # Check shape
  101. shape = [x.value for x in data.shape]
  102. vc_util.reduce_axis_check(shape, axis)
  103. axis = ft_util.refine_reduce_axis(data, axis)
  104. dtype = data.dtype
  105. count = 1
  106. for i in axis:
  107. count *= shape[i]
  108. count_rec = 1 / count
  109. output, _ = sum.sum_v2(data, axis, keepdims)
  110. res = output * akg.tvm.const(count_rec, dtype)
  111. attrs = get_attrs(data)
  112. if shape_is_dynamic(data):
  113. attrs["custom_tiling"] = mean_dynamic_tiling_strategy(data, axis)
  114. return res, attrs