From 03ec134e87105511ce60fca4613bccadea473236 Mon Sep 17 00:00:00 2001 From: wangrao Date: Mon, 22 Feb 2021 15:29:55 +0800 Subject: [PATCH] fix np.average --- mindspore/numpy/math_ops.py | 42 +++++++++++++++++++++++++------------ 1 file changed, 29 insertions(+), 13 deletions(-) diff --git a/mindspore/numpy/math_ops.py b/mindspore/numpy/math_ops.py index 4158dc409d..ea7f16a50b 100644 --- a/mindspore/numpy/math_ops.py +++ b/mindspore/numpy/math_ops.py @@ -25,14 +25,14 @@ from ..common import Tensor from .dtypes import nan, pi -from .array_creations import asarray_const, ones, zeros, empty, full +from .array_creations import asarray_const, ones, zeros, empty, full, full_like from .array_ops import where as where_ from .array_ops import ravel, expand_dims from .utils_const import _infer_out_shape, _check_axis_valid, _get_device, \ _check_shape_aligned, _raise_type_error, _check_same_type, _check_is_float, \ _raise_value_error, _check_matmul_shapes, _promote, _check_axis_type, _canonicalize_axis, \ - _max, _is_shape_empty, _check_is_int + _max, _is_shape_empty, _check_is_int, _expanded_shape from .utils import _is_scalar, _expand, _broadcast_to, _broadcast_to_shape, _get_size, \ _check_input_tensor @@ -1200,22 +1200,38 @@ def average(x, axis=None, weights=None, returned=False): _check_axis_type(axis, True, True, False) axis = _canonicalize_axis(axis, x.ndim) - if weights is None: - return mean(x, axis) - x_avg = full((), nan, F.dtype(x)) sum_of_weights = None - if x.shape == weights.shape: - x_avg, sum_of_weights = comput_avg(x, axis, weights) - elif F.rank(weights) == 1: - if not isinstance(axis, int): - _raise_type_error("Axis must be specified when shapes of x and weights differ.") - weights = _broadcast_to_shape(weights, x.shape) - x_avg, sum_of_weights = comput_avg(x, axis, weights) + if weights is None: + x_avg = mean(x, axis) + if axis is None: + sum_of_weights = full((), x.size, F.dtype(x)) + else: + fill_value = 1 + if isinstance(axis, int) or isinstance(axis, tuple) and F.tuple_len(axis) == 1: + fill_value = x.shape[axis] + elif axis is None or axis == (): + for sh in x.shape: + fill_value *= sh + else: + for ax in axis: + fill_value *= x.shape[ax] + sum_of_weights = full_like(x_avg, fill_value, F.dtype(x)) else: - _raise_type_error("Weights should be None, 1-D or the same as input x, but got shape of", weights) + if x.shape == weights.shape: + x_avg, sum_of_weights = comput_avg(x, axis, weights) + elif F.rank(weights) == 1: + if not isinstance(axis, int): + _raise_type_error("Axis must be specified when shapes of x and weights differ.") + perm = _expanded_shape(x.ndim, weights.shape[0], axis) + weights = weights.reshape(perm) + x_avg, sum_of_weights = comput_avg(x, axis, weights) + else: + _raise_type_error("Weights should be None, 1-D or the same shape as input x.") if returned: + if x_avg.shape != sum_of_weights.shape: + sum_of_weights = _broadcast_to(sum_of_weights, sum_of_weights.shape, x_avg.shape, x_avg.ndim) return (x_avg, sum_of_weights) return x_avg