| @@ -50,7 +50,19 @@ from .loss import ( | |||
| square_loss, | |||
| triplet_margin_loss, | |||
| ) | |||
| from .math import argmax, argmin, max, mean, min, norm, normalize, prod, sqrt, sum | |||
| from .math import ( | |||
| argmax, | |||
| argmin, | |||
| logsumexp, | |||
| max, | |||
| mean, | |||
| min, | |||
| norm, | |||
| normalize, | |||
| prod, | |||
| sqrt, | |||
| sum, | |||
| ) | |||
| from .nn import ( | |||
| assert_equal, | |||
| avg_pool2d, | |||
| @@ -6,12 +6,15 @@ | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| from typing import Optional | |||
| import math | |||
| import numbers | |||
| from typing import Optional, Sequence, Union | |||
| import megengine._internal as mgb | |||
| from ..core import Tensor, wrap_io_tensor | |||
| from .elemwise import clamp | |||
| from .elemwise import clamp, exp, isinf, log | |||
| from .tensor import remove_axis, where, zeros_like | |||
| @wrap_io_tensor | |||
| @@ -296,3 +299,35 @@ def normalize( | |||
| return inp / clamp(norm(inp, p), lower=eps) | |||
| else: | |||
| return inp / clamp(norm(inp, p, axis, keepdims=True), lower=eps) | |||
| def logsumexp(inp: Tensor, axis: Union[int, Sequence[int]], keepdims: bool = False): | |||
| r""" | |||
| Compute the log of the sum of exponentials of inputs along the given :attr:`axis`. The computation is numerically stabilized. | |||
| .. math:: | |||
| \mathsf{logsumexp}(x_1, \dots, x_n) = \log(\exp(x_1) + \cdots + \exp(x_n)) | |||
| :param inp: The input tensor. | |||
| :param axis: Axis over which the sum is taken. It can be a single axis or a list of axes. | |||
| :param keepdims: whether to retain :attr:`axis` or not for the output tensor. | |||
| """ | |||
| if isinstance(axis, numbers.Integral): | |||
| axis = (axis,) | |||
| max_value = inp | |||
| for dim in axis: | |||
| max_value = max_value.max(axis=dim, keepdims=True) | |||
| max_value = where( | |||
| isinf(max_value).astype("int32"), zeros_like(max_value), max_value | |||
| ) | |||
| x = exp(inp - max_value) | |||
| for dim in axis: | |||
| x = x.sum(axis=dim, keepdims=True) | |||
| x = max_value + log(x) | |||
| if not keepdims: | |||
| axis = sorted(axis, reverse=True) | |||
| for i in axis: | |||
| x = remove_axis(x, axis=i) | |||
| return x | |||
| @@ -9,9 +9,12 @@ | |||
| import numpy as np | |||
| def assertTensorClose(v0, v1, *, max_err=1e-6, name=None): | |||
| def assertTensorClose( | |||
| v0, v1, *, max_err: float = 1e-6, allow_special_values: bool = False, name=None | |||
| ): | |||
| """ | |||
| max_err: relative error | |||
| :param allow_special_values: whether to allow :attr:`v0` and :attr:`v1` to contain inf and nan values. | |||
| :param max_err: relative error | |||
| """ | |||
| __tracebackhide__ = True # pylint: disable=unused-variable | |||
| @@ -20,9 +23,30 @@ def assertTensorClose(v0, v1, *, max_err=1e-6, name=None): | |||
| ), "Two Tensor must have same dtype, but the inputs are {} and {}".format( | |||
| v0.dtype, v1.dtype | |||
| ) | |||
| v0 = np.ascontiguousarray(v0, dtype=np.float32) | |||
| v1 = np.ascontiguousarray(v1, dtype=np.float32) | |||
| assert np.isfinite(v0.sum()) and np.isfinite(v1.sum()), (v0, v1) | |||
| v0 = np.ascontiguousarray(v0, dtype=np.float32).copy() | |||
| v1 = np.ascontiguousarray(v1, dtype=np.float32).copy() | |||
| if allow_special_values: | |||
| # check nan and rm it | |||
| v0_nan_mask = np.isnan(v0) | |||
| if np.any(v0_nan_mask): | |||
| assert np.array_equiv(v0_nan_mask, np.isnan(v1)), (v0, v1) | |||
| v0[v0_nan_mask] = 0 | |||
| v1[v0_nan_mask] = 0 | |||
| # check inf and rm it | |||
| v0_inf_mask = v0 == float("inf") | |||
| if np.any(v0_inf_mask): | |||
| assert np.array_equiv(v0_inf_mask, v1 == float("inf")), (v0, v1) | |||
| v0[v0_inf_mask] = 0 | |||
| v1[v0_inf_mask] = 0 | |||
| # check -inf and rm it | |||
| v0_inf_mask = v0 == float("-inf") | |||
| if np.any(v0_inf_mask): | |||
| assert np.array_equiv(v0_inf_mask, v1 == float("-inf")), (v0, v1) | |||
| v0[v0_inf_mask] = 0 | |||
| v1[v0_inf_mask] = 0 | |||
| else: | |||
| assert np.isfinite(v0.sum()) and np.isfinite(v1.sum()), (v0, v1) | |||
| assert v0.shape == v1.shape, "Two tensor must have same shape({} v.s. {})".format( | |||
| v0.shape, v1.shape | |||
| ) | |||
| @@ -6,10 +6,14 @@ | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| from functools import partial | |||
| import numpy as np | |||
| from helpers import opr_test | |||
| import megengine.functional as F | |||
| from megengine.test import assertTensorClose | |||
| def common_test_reduce(opr, ref_opr): | |||
| @@ -86,7 +90,6 @@ def test_sqrt(): | |||
| def test_normalize(): | |||
| from functools import partial | |||
| cases = [ | |||
| {"input": np.random.random((2, 3, 12, 12)).astype(np.float32)} for i in range(2) | |||
| @@ -112,3 +115,54 @@ def test_normalize(): | |||
| cases[0]["input"][0, 0, 0, :] = 0 | |||
| cases[1]["input"][0, 0, 0, :] = 0 | |||
| opr_test(cases, partial(F.normalize, axis=3), ref_fn=partial(np_normalize, axis=3)) | |||
| def test_logsumexp(): | |||
| x = np.arange(10).astype(np.float32) | |||
| expected = np.log(np.sum(np.exp(x))) | |||
| cases = [{"input": x, "output": expected}] | |||
| compare_fn = partial(assertTensorClose, allow_special_values=True) | |||
| # large value check | |||
| n = 100 | |||
| x = np.full(n, 10000, dtype=np.float32) | |||
| expected = 10000 + np.log(n) | |||
| cases.append({"input": x, "output": expected.astype(np.float32)}) | |||
| opr_test(cases, F.logsumexp, axis=0, compare_fn=compare_fn) | |||
| # special value check | |||
| x = np.array([np.inf], dtype=np.float32) | |||
| expected = x | |||
| cases = [{"input": x, "output": expected}] | |||
| x = np.array([-np.inf, 0.0], dtype=np.float32) | |||
| expected = np.zeros(1).astype(np.float32) | |||
| cases.append({"input": x, "output": expected}) | |||
| opr_test(cases, F.logsumexp, axis=0, compare_fn=compare_fn) | |||
| x = np.array([np.nan], dtype=np.float32) | |||
| expected = x | |||
| cases = [{"input": x, "output": expected}] | |||
| x = np.array([-np.inf, 1], dtype=np.float32) | |||
| expected = np.array([1.0], dtype=np.float32) | |||
| cases.append({"input": x, "output": expected}) | |||
| opr_test(cases, F.logsumexp, axis=0, compare_fn=compare_fn) | |||
| # keepdims check | |||
| x = np.array([[1e10, 1e-10], [-1e10, -np.inf]], dtype=np.float32) | |||
| expected = np.array([[1e10], [-1e10]], dtype=np.float32) | |||
| cases = [{"input": x, "output": expected}] | |||
| x = np.array([[1e10, -1e-10, 1e-10], [1e10, 1e-10, np.inf]], dtype=np.float32) | |||
| expected = np.array([[1e10], [np.inf]], dtype=np.float32) | |||
| cases.append({"input": x, "output": expected}) | |||
| opr_test(cases, F.logsumexp, axis=1, keepdims=True, compare_fn=compare_fn) | |||
| # multiple axes check | |||
| x = np.array([[1e10, 1e-10], [-1e10, -np.inf]], dtype=np.float32) | |||
| expected = np.array([1e10], dtype=np.float32) | |||
| cases = [{"input": x, "output": expected}] | |||
| x = np.array([[1e10, -1e-10, 1e-10], [1e10, 1e-10, np.inf]], dtype=np.float32) | |||
| expected = np.array([np.inf], dtype=np.float32) | |||
| cases.append({"input": x, "output": expected}) | |||
| opr_test(cases, F.logsumexp, axis=(0, 1), keepdims=False, compare_fn=compare_fn) | |||