GitOrigin-RevId: a1bd1102a6
tags/v1.5.0
| @@ -165,7 +165,7 @@ def _remove_axis(inp: Tensor, axis) -> Tensor: | |||||
| return list(map(int, axis)) | return list(map(int, axis)) | ||||
| axis = get_axes() | axis = get_axes() | ||||
| axis = sorted(i + inp.ndim if i < 0 else i for i in axis) | |||||
| axis = utils._normalize_axis(inp.ndim, axis) | |||||
| axis = [a - i for i, a in enumerate(axis)] | axis = [a - i for i, a in enumerate(axis)] | ||||
| op = builtin.RemoveAxis(axis=axis) | op = builtin.RemoveAxis(axis=axis) | ||||
| @@ -190,8 +190,7 @@ def _reduce(mode): | |||||
| op = builtin.Reduce(mode=mode, axis=0) | op = builtin.Reduce(mode=mode, axis=0) | ||||
| (result,) = apply(op, data) | (result,) = apply(op, data) | ||||
| elif isinstance(axis, collections.abc.Iterable): | elif isinstance(axis, collections.abc.Iterable): | ||||
| axis = list(axis) | |||||
| axis.sort(reverse=True) | |||||
| axis = utils._normalize_axis(self.ndim, axis, reverse=True) | |||||
| for ai in axis: | for ai in axis: | ||||
| op = builtin.Reduce(mode=mode, axis=ai) | op = builtin.Reduce(mode=mode, axis=ai) | ||||
| (data,) = apply(op, data) | (data,) = apply(op, data) | ||||
| @@ -199,6 +198,7 @@ def _reduce(mode): | |||||
| data = _remove_axis(data, ai) | data = _remove_axis(data, ai) | ||||
| result = data | result = data | ||||
| else: | else: | ||||
| # builtin.Reduce already accept negtive axis | |||||
| op = builtin.Reduce(mode=mode, axis=axis) | op = builtin.Reduce(mode=mode, axis=axis) | ||||
| (result,) = apply(op, data) | (result,) = apply(op, data) | ||||
| @@ -178,3 +178,28 @@ def make_shape_tuple(shape): | |||||
| s = [] | s = [] | ||||
| _expand_int(s, shape) | _expand_int(s, shape) | ||||
| return tuple(s) | return tuple(s) | ||||
| def _normalize_axis( | |||||
| ndim: int, axis: Union[int, Iterable], reverse=False | |||||
| ) -> Union[int, list]: | |||||
| def convert(x): | |||||
| x_org = x | |||||
| if x < 0: | |||||
| x = ndim + x | |||||
| assert ( | |||||
| x >= 0 and x < ndim | |||||
| ), "axis {} is out of bounds for tensor of dimension {}".format(x_org, ndim) | |||||
| return x | |||||
| if isinstance(axis, int): | |||||
| return convert(axis) | |||||
| elif isinstance(axis, Iterable): | |||||
| axis_org = axis | |||||
| axis = list(sorted(map(convert, axis), reverse=reverse)) | |||||
| for i in range(len(axis) - 1): | |||||
| assert axis[i] != axis[i + 1], "axis {} contains duplicated indices".format( | |||||
| axis_org | |||||
| ) | |||||
| return axis | |||||
| raise | |||||
| @@ -466,9 +466,13 @@ def argmin( | |||||
| 0 | 0 | ||||
| """ | """ | ||||
| if axis is None: | |||||
| assert not keepdims, "can not set axis=None and keepdims=True" | |||||
| inp = inp.flatten() | |||||
| axis = 0 | |||||
| axis = utils._normalize_axis(inp.ndim, axis, reverse=True) | |||||
| if isinstance(axis, collections.abc.Iterable): | if isinstance(axis, collections.abc.Iterable): | ||||
| axis = list(axis) | |||||
| axis.sort(reverse=True) | |||||
| for ai in axis: | for ai in axis: | ||||
| op = builtin.Argmin(axis=ai) | op = builtin.Argmin(axis=ai) | ||||
| @@ -479,11 +483,6 @@ def argmin( | |||||
| return inp | return inp | ||||
| if axis is None: | |||||
| assert not keepdims, "can not set axis=None and keepdims=True" | |||||
| inp = inp.flatten() | |||||
| axis = 0 | |||||
| op = builtin.Argmin(axis=axis) | op = builtin.Argmin(axis=axis) | ||||
| (result,) = apply(op, inp) | (result,) = apply(op, inp) | ||||
| if not keepdims: | if not keepdims: | ||||
| @@ -525,9 +524,13 @@ def argmax( | |||||
| 5 | 5 | ||||
| """ | """ | ||||
| if axis is None: | |||||
| assert not keepdims, "can not set axis=None and keepdims=True" | |||||
| inp = inp.flatten() | |||||
| axis = 0 | |||||
| axis = utils._normalize_axis(inp.ndim, axis, reverse=True) | |||||
| if isinstance(axis, collections.abc.Iterable): | if isinstance(axis, collections.abc.Iterable): | ||||
| axis = list(axis) | |||||
| axis.sort(reverse=True) | |||||
| for ai in axis: | for ai in axis: | ||||
| op = builtin.Argmax(axis=ai) | op = builtin.Argmax(axis=ai) | ||||
| @@ -538,11 +541,6 @@ def argmax( | |||||
| return inp | return inp | ||||
| if axis is None: | |||||
| assert not keepdims, "can not set axis=None and keepdims=True" | |||||
| inp = inp.flatten() | |||||
| axis = 0 | |||||
| op = builtin.Argmax(axis=axis) | op = builtin.Argmax(axis=axis) | ||||
| (result,) = apply(op, inp) | (result,) = apply(op, inp) | ||||
| if not keepdims: | if not keepdims: | ||||
| @@ -811,3 +811,19 @@ def test_assert_not_equal(): | |||||
| y = F.zeros(shape, dtype=np.float32) + 1.1 | y = F.zeros(shape, dtype=np.float32) + 1.1 | ||||
| with pytest.raises(RuntimeError): | with pytest.raises(RuntimeError): | ||||
| z = F.utils._assert_equal(x, y) | z = F.utils._assert_equal(x, y) | ||||
| def test_neg_axis(): | |||||
| x = tensor(np.random.normal(0, 1, (32, 5))) | |||||
| y = F.argmax(x, axis=-1) | |||||
| yy = F.argmax(x, axis=1) | |||||
| np.testing.assert_equal(y.numpy(), yy.numpy()) | |||||
| y = F.argmax(x, axis=(-1, -2)) | |||||
| yy = F.argmax(x, axis=(0, 1)) | |||||
| np.testing.assert_equal(y.numpy(), yy.numpy()) | |||||
| y = F.argmin(x, axis=(-1, -2)) | |||||
| yy = F.argmin(x, axis=(0, 1)) | |||||
| np.testing.assert_equal(y.numpy(), yy.numpy()) | |||||
| @@ -9,6 +9,7 @@ | |||||
| from functools import partial | from functools import partial | ||||
| import numpy as np | import numpy as np | ||||
| import pytest | |||||
| from utils import opr_test | from utils import opr_test | ||||
| import megengine.functional as F | import megengine.functional as F | ||||
| @@ -48,6 +49,14 @@ def common_test_reduce(opr, ref_opr): | |||||
| ref_fn=lambda x: ref_opr(x, axis=axis).astype(np.int32), | ref_fn=lambda x: ref_opr(x, axis=axis).astype(np.int32), | ||||
| axis=axis, | axis=axis, | ||||
| ) | ) | ||||
| # test negative axis | |||||
| axis = axis - len(data1_shape) | |||||
| opr_test( | |||||
| cases, | |||||
| opr, | |||||
| ref_fn=lambda x: ref_opr(x, axis=axis).astype(np.int32), | |||||
| axis=axis, | |||||
| ) | |||||
| def test_sum(): | def test_sum(): | ||||
| @@ -137,3 +146,14 @@ def test_normalize(): | |||||
| cases[0]["input"][0, 0, 0, :] = 0 | cases[0]["input"][0, 0, 0, :] = 0 | ||||
| cases[1]["input"][0, 0, 0, :] = 0 | cases[1]["input"][0, 0, 0, :] = 0 | ||||
| opr_test(cases, partial(F.normalize, axis=3), ref_fn=partial(np_normalize, axis=3)) | opr_test(cases, partial(F.normalize, axis=3), ref_fn=partial(np_normalize, axis=3)) | ||||
| def test_sum_neg_axis(): | |||||
| shape = (2, 3) | |||||
| data = np.random.random(shape).astype(np.float32) | |||||
| for axis in (-1, -2, (-2, 1), (-1, 0)): | |||||
| get = F.sum(tensor(data), axis=axis) | |||||
| ref = np.sum(data, axis=axis) | |||||
| np.testing.assert_allclose(get.numpy(), ref, rtol=1e-6) | |||||
| with pytest.raises(AssertionError): | |||||
| F.sum(tensor(data), axis=(-1, 1)) | |||||