You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

mean.py 1.8 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647
  1. # Copyright 2019 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """operator dsl function: mean"""
  15. import _akg.topi
  16. import _akg.tvm
  17. from _akg.utils import format_transform as ft_util
  18. from _akg.utils import validation_check as vc_util
  19. from _akg.ops.math import sum
  20. @vc_util.check_input_type(_akg.tvm.tensor.Tensor, (list, tuple, int, type(None)), (bool, type(None)))
  21. def mean(data, axis=None, keepdims=False):
  22. """
  23. Computes the mean of the values of a Tensor over the whole dataset.
  24. Args:
  25. data (tvm.tensor.Tensor): Tensor.
  26. axis (Union[list, tuple, int, None]): If the tuple is empty, the axis equal to None.
  27. keepdims (bool): If keepdims equal to True, the result shape length is same to input shape length.
  28. Returns:
  29. tvm.tensor.Tensor, has the same type as data. If keepdims equal to True, all reduced dimensions are
  30. retained with length 1. else these reduced axis will be eliminate.
  31. """
  32. shape = [x.value for x in data.shape]
  33. vc_util.reduce_axis_check(shape, axis)
  34. axis = ft_util.refine_reduce_axis(data, axis)
  35. count = 1
  36. for i in axis:
  37. count *= shape[i]
  38. output, _ = sum.sum_value(data, axis, keepdims)
  39. res = _akg.topi.divide(output, count)
  40. return res