| @@ -25,14 +25,14 @@ from ..common import Tensor | |||
| from .dtypes import nan, pi | |||
| from .array_creations import asarray_const, ones, zeros, empty, full | |||
| from .array_creations import asarray_const, ones, zeros, empty, full, full_like | |||
| from .array_ops import where as where_ | |||
| from .array_ops import ravel, expand_dims | |||
| from .utils_const import _infer_out_shape, _check_axis_valid, _get_device, \ | |||
| _check_shape_aligned, _raise_type_error, _check_same_type, _check_is_float, \ | |||
| _raise_value_error, _check_matmul_shapes, _promote, _check_axis_type, _canonicalize_axis, \ | |||
| _max, _is_shape_empty, _check_is_int | |||
| _max, _is_shape_empty, _check_is_int, _expanded_shape | |||
| from .utils import _is_scalar, _expand, _broadcast_to, _broadcast_to_shape, _get_size, \ | |||
| _check_input_tensor | |||
| @@ -1200,22 +1200,38 @@ def average(x, axis=None, weights=None, returned=False): | |||
| _check_axis_type(axis, True, True, False) | |||
| axis = _canonicalize_axis(axis, x.ndim) | |||
| if weights is None: | |||
| return mean(x, axis) | |||
| x_avg = full((), nan, F.dtype(x)) | |||
| sum_of_weights = None | |||
| if x.shape == weights.shape: | |||
| x_avg, sum_of_weights = comput_avg(x, axis, weights) | |||
| elif F.rank(weights) == 1: | |||
| if not isinstance(axis, int): | |||
| _raise_type_error("Axis must be specified when shapes of x and weights differ.") | |||
| weights = _broadcast_to_shape(weights, x.shape) | |||
| x_avg, sum_of_weights = comput_avg(x, axis, weights) | |||
| if weights is None: | |||
| x_avg = mean(x, axis) | |||
| if axis is None: | |||
| sum_of_weights = full((), x.size, F.dtype(x)) | |||
| else: | |||
| fill_value = 1 | |||
| if isinstance(axis, int) or isinstance(axis, tuple) and F.tuple_len(axis) == 1: | |||
| fill_value = x.shape[axis] | |||
| elif axis is None or axis == (): | |||
| for sh in x.shape: | |||
| fill_value *= sh | |||
| else: | |||
| for ax in axis: | |||
| fill_value *= x.shape[ax] | |||
| sum_of_weights = full_like(x_avg, fill_value, F.dtype(x)) | |||
| else: | |||
| _raise_type_error("Weights should be None, 1-D or the same as input x, but got shape of", weights) | |||
| if x.shape == weights.shape: | |||
| x_avg, sum_of_weights = comput_avg(x, axis, weights) | |||
| elif F.rank(weights) == 1: | |||
| if not isinstance(axis, int): | |||
| _raise_type_error("Axis must be specified when shapes of x and weights differ.") | |||
| perm = _expanded_shape(x.ndim, weights.shape[0], axis) | |||
| weights = weights.reshape(perm) | |||
| x_avg, sum_of_weights = comput_avg(x, axis, weights) | |||
| else: | |||
| _raise_type_error("Weights should be None, 1-D or the same shape as input x.") | |||
| if returned: | |||
| if x_avg.shape != sum_of_weights.shape: | |||
| sum_of_weights = _broadcast_to(sum_of_weights, sum_of_weights.shape, x_avg.shape, x_avg.ndim) | |||
| return (x_avg, sum_of_weights) | |||
| return x_avg | |||