GitOrigin-RevId: 8ed95e0fb8
tags/v1.0.0-rc1
| @@ -134,15 +134,54 @@ def _logical_binary_elwise(mode, rev=False): | |||
| return f | |||
| def _remove_axis(inp: Tensor, axis) -> Tensor: | |||
| Param = builtin.AxisAddRemove.Param | |||
| def get_axes(): | |||
| if axis is None: | |||
| return [i for i, s in enumerate(inp.shape) if s == 1] | |||
| try: | |||
| return [int(axis)] | |||
| except (TypeError, ValueError): | |||
| pass | |||
| return list(map(int, axis)) | |||
| axis = get_axes() | |||
| axis = sorted(i + inp.ndim if i < 0 else i for i in axis) | |||
| axis = [a - i for i, a in enumerate(axis)] | |||
| param = Param(*map(builtin.AxisAddRemove.AxisDesc.make_remove, axis)) | |||
| op = builtin.AxisAddRemove(param=param) | |||
| (result,) = apply(op, inp) | |||
| return result | |||
| def _reduce(mode): | |||
| def f(self, axis=None): | |||
| inp = self | |||
| def f(self, axis=None, keepdims: bool = False): | |||
| data = self | |||
| (data,) = utils.convert_inputs(data) | |||
| if axis is None: | |||
| inp = self.flatten() | |||
| axis = 0 | |||
| op = builtin.Reduce(mode=mode, axis=axis) | |||
| (result,) = utils.convert_inputs(inp) | |||
| (result,) = apply(op, result) | |||
| data = data.reshape(-1) | |||
| assert not keepdims, "can not set axis=None and keepdims=True" | |||
| op = builtin.Reduce(mode=mode, axis=0) | |||
| (result,) = apply(op, data) | |||
| elif isinstance(axis, collections.Iterable): | |||
| axis = list(axis) | |||
| axis.sort(reverse=True) | |||
| for ai in axis: | |||
| op = builtin.Reduce(mode=mode, axis=ai) | |||
| (data,) = apply(op, data) | |||
| if not keepdims: | |||
| data = _remove_axis(data, ai) | |||
| result = data | |||
| else: | |||
| op = builtin.Reduce(mode=mode, axis=axis) | |||
| (result,) = apply(op, data) | |||
| if not keepdims: | |||
| result = _remove_axis(result, axis) | |||
| return result | |||
| return f | |||
| @@ -176,15 +176,15 @@ def cross_entropy_with_softmax( | |||
| num_classes = pred.shape[axis] | |||
| # Denominator of the softmax | |||
| offset = pred.max(axis=axis).detach() | |||
| offset = pred.max(axis=axis, keepdims=True).detach() | |||
| pred = pred - offset | |||
| down = exp(pred).sum(axis=axis) | |||
| down = exp(pred).sum(axis=axis, keepdims=True) | |||
| up = indexing_one_hot(pred, label, axis) | |||
| if label_smooth != 0: | |||
| factor = label_smooth / num_classes | |||
| up = up * (1 - label_smooth) + pred.sum(axis=axis) * factor | |||
| up = up * (1 - label_smooth) + pred.sum(axis=axis, keepdims=True) * factor | |||
| return (log(down) - up).mean() | |||
| @@ -117,40 +117,6 @@ def sign(inp: Tensor): | |||
| raise NotImplementedError | |||
| def _reduce( | |||
| data, | |||
| *, | |||
| mode, | |||
| axis: Optional[Union[int, Sequence[int]]] = None, | |||
| keepdims: bool = False | |||
| ): | |||
| (data,) = utils.convert_inputs(data) | |||
| if axis is None: | |||
| data = data.reshape(-1) | |||
| assert not keepdims, "can not set axis=None and keepdims=True" | |||
| op = builtin.Reduce(mode=mode, axis=0) | |||
| (result,) = apply(op, data) | |||
| elif isinstance(axis, collections.Iterable): | |||
| axis = list(axis) | |||
| axis.sort(reverse=True) | |||
| for ai in axis: | |||
| op = builtin.Reduce(mode=mode, axis=ai) | |||
| (data,) = apply(op, data) | |||
| if not keepdims: | |||
| data = remove_axis(data, ai) | |||
| result = data | |||
| else: | |||
| op = builtin.Reduce(mode=mode, axis=axis) | |||
| (result,) = apply(op, data) | |||
| if not keepdims: | |||
| result = remove_axis(result, axis) | |||
| return result | |||
| def sum( | |||
| inp: Tensor, | |||
| axis: Optional[Union[int, Sequence[int]]] = None, | |||
| @@ -182,7 +148,7 @@ def sum( | |||
| [21] | |||
| """ | |||
| return _reduce(inp, mode="SUM", axis=axis, keepdims=keepdims) | |||
| return inp.sum(axis=axis, keepdims=keepdims) | |||
| def prod( | |||
| @@ -215,7 +181,7 @@ def prod( | |||
| [720] | |||
| """ | |||
| return _reduce(inp, mode="PRODUCT", axis=axis, keepdims=keepdims) | |||
| return inp.prod(axis=axis, keepdims=keepdims) | |||
| def mean( | |||
| @@ -248,7 +214,7 @@ def mean( | |||
| [3.5] | |||
| """ | |||
| return _reduce(inp, mode="MEAN", axis=axis, keepdims=keepdims) | |||
| return inp.astype("float32").mean(axis=axis, keepdims=keepdims) | |||
| def median( | |||
| @@ -362,7 +328,7 @@ def min( | |||
| [1] | |||
| """ | |||
| return _reduce(inp, mode="MIN", axis=axis, keepdims=keepdims) | |||
| return inp.min(axis=axis, keepdims=keepdims) | |||
| def max( | |||
| @@ -394,7 +360,7 @@ def max( | |||
| [6] | |||
| """ | |||
| return _reduce(inp, mode="MAX", axis=axis, keepdims=keepdims) | |||
| return inp.max(axis=axis, keepdims=keepdims) | |||
| def norm( | |||
| @@ -580,7 +580,7 @@ def softmax(inp: Tensor, axis: Optional[int] = None) -> Tensor: | |||
| """ | |||
| if axis is None: | |||
| axis = _get_softmax_axis(len(inp.shape)) | |||
| offset = inp.max(axis=axis).detach() | |||
| offset = inp.max(axis=axis, keepdims=True).detach() | |||
| cached = exp(inp - offset) | |||
| down = sum(cached, axis=axis, keepdims=True) | |||
| return cached / down | |||
| @@ -38,7 +38,7 @@ def test_reduce(): | |||
| for m in ["sum", "prod", "min", "max", "mean"]: | |||
| x_np = np.random.rand(10).astype("float32") | |||
| x = TensorWrapper(x_np) | |||
| y = getattr(x, m)(-1) | |||
| y = getattr(x, m)(axis=-1, keepdims=True) | |||
| np.testing.assert_almost_equal(y.numpy(), getattr(x_np, m)(-1), decimal=6) | |||