| @@ -15,6 +15,7 @@ from ..core._imperative_rt.ops import SubgraphBuilder as _SubgraphBuilder | |||||
| from ..core.ops import builtin | from ..core.ops import builtin | ||||
| from ..core.ops.builtin import ( | from ..core.ops.builtin import ( | ||||
| BatchNorm, | BatchNorm, | ||||
| Dimshuffle, | |||||
| Elemwise, | Elemwise, | ||||
| GetVarShape, | GetVarShape, | ||||
| Identity, | Identity, | ||||
| @@ -86,6 +87,7 @@ __all__ = [ | |||||
| "sync_batch_norm", | "sync_batch_norm", | ||||
| "warp_affine", | "warp_affine", | ||||
| "warp_perspective", | "warp_perspective", | ||||
| "pixel_shuffle", | |||||
| ] | ] | ||||
| @@ -1733,6 +1735,69 @@ def pad( | |||||
| return output | return output | ||||
| @lru_cache(maxsize=None) | |||||
| def _get_layerPixelShuffle(device, dtype, dim_order): | |||||
| @subgraph("LayerPixelShuffle", dtype, device, 3) | |||||
| def layerPixelShuffle(inputs, f, c): | |||||
| inp, shape_0, shape_1 = inputs | |||||
| inp = f(Reshape(), inp, shape_0) | |||||
| inp = f(Dimshuffle(dim_order), inp) | |||||
| oup = f(Reshape(), inp, shape_1) | |||||
| return (oup,), (True,) | |||||
| return layerPixelShuffle | |||||
| def pixel_shuffle(inp: Tensor, upscale_factor: int) -> Tensor: | |||||
| """ | |||||
| Rearranges elements in a tensor of shape (*, C x r^2, H, W) to a tensor of | |||||
| shape (*, C, H x r, W x r), where r is an upscale factor, where * is zero | |||||
| or more batch dimensions. | |||||
| :param inp: input tensor. | |||||
| :param upscale_factor: upscale factor of pixel_shuffle. | |||||
| :return: output tensor. | |||||
| """ | |||||
| assert upscale_factor > 0, "upscale_factor should larger than 0" | |||||
| assert inp.ndim >= 3, "the input dimension of pixel_shuffle should be larger than 3" | |||||
| assert ( | |||||
| inp.shape[-3] % (upscale_factor ** 2) == 0 | |||||
| ), "the -3 dimension should be divided by (upscale_factor ** 2)" | |||||
| _device = inp.device | |||||
| _dtype = inp.dtype | |||||
| shape_ori = inp.shape | |||||
| high_dim = shape_ori[:-3] | |||||
| square = upscale_factor ** 2 | |||||
| n = 1 | |||||
| for item in high_dim: | |||||
| n *= item | |||||
| shape_0 = ( | |||||
| n, | |||||
| int(shape_ori[-3] / square), | |||||
| upscale_factor, | |||||
| upscale_factor, | |||||
| shape_ori[-2], | |||||
| shape_ori[-1], | |||||
| ) | |||||
| shape_1 = ( | |||||
| *high_dim, | |||||
| shape_ori[-3] / square, | |||||
| shape_ori[-2] * upscale_factor, | |||||
| shape_ori[-1] * upscale_factor, | |||||
| ) | |||||
| dim_order = (0, 1, 4, 2, 5, 3) | |||||
| layerPixelShuffle = _get_layerPixelShuffle(_device, _dtype, dim_order) | |||||
| shape_0 = convert_single_value(shape_0, dtype=inp.dtype, device=inp.device) | |||||
| shape_1 = convert_single_value(shape_1, dtype=inp.dtype, device=inp.device) | |||||
| outvar, *_ = apply(layerPixelShuffle(), inp, shape_0, shape_1) | |||||
| return outvar | |||||
| from .quantized import conv_bias_activation # isort:skip | from .quantized import conv_bias_activation # isort:skip | ||||
| from .loss import * # isort:skip | from .loss import * # isort:skip | ||||
| from .metric import * # isort:skip | from .metric import * # isort:skip | ||||
| @@ -0,0 +1,24 @@ | |||||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| # | |||||
| # Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, | |||||
| # software distributed under the License is distributed on an | |||||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| from ..functional.nn import pixel_shuffle | |||||
| from .module import Module | |||||
| class PixelShuffle(Module): | |||||
| r""" | |||||
| Rearranges elements in a tensor of shape (*, C x r^2, H, W) to a tensor of | |||||
| shape (*, C, H x r, W x r), where r is an upscale factor, where * is zero | |||||
| or more batch dimensions. | |||||
| """ | |||||
| def __init__(self, upscale_factor: int, **kwargs): | |||||
| super().__init__(**kwargs) | |||||
| self.upscale_factor = upscale_factor | |||||
| def forward(self, x): | |||||
| return pixel_shuffle(x, self.upscale_factor) | |||||
| @@ -1177,3 +1177,74 @@ def test_pad(): | |||||
| dst = np.pad(src, ((2, 2), (2, 2)), "reflect") | dst = np.pad(src, ((2, 2), (2, 2)), "reflect") | ||||
| res = F.nn.pad(tensor(src), ((2, 2), (2, 2)), "REFLECT") | res = F.nn.pad(tensor(src), ((2, 2), (2, 2)), "REFLECT") | ||||
| np.testing.assert_allclose(res, dst, atol=1e-5) | np.testing.assert_allclose(res, dst, atol=1e-5) | ||||
| def pixel_shuffle(data, r): | |||||
| high_dim = data.shape[:-3] | |||||
| data = data.reshape(-1, data.shape[-3], data.shape[-2], data.shape[-1]) | |||||
| inn, ic, ih, iw = data.shape | |||||
| res = np.zeros((inn, int(ic / (r * r)), ih * r, iw * r)) | |||||
| for n in range(inn): | |||||
| for c in range(ic): | |||||
| for h in range(ih): | |||||
| for w in range(iw): | |||||
| res[ | |||||
| n, | |||||
| int(c / r / r), | |||||
| h * r + int((c % (r * r)) / r), | |||||
| w * r + c % r, | |||||
| ] = data[n, c, h, w] | |||||
| if len(high_dim) > 0: | |||||
| res = res.reshape((*high_dim, int(ic / r / r), ih * r, iw * r)) | |||||
| else: | |||||
| res = res[0] | |||||
| return res | |||||
| def test_pixel_shuffle(): | |||||
| # ndim = 3 | |||||
| inp = np.arange(16 * 3 * 3).reshape(16, 3, 3) | |||||
| out = F.pixel_shuffle(tensor(inp), upscale_factor=4) | |||||
| golden = pixel_shuffle(inp, 4) | |||||
| np.testing.assert_equal(out.numpy(), golden) | |||||
| # ndim = 4 | |||||
| inp = np.arange(3 * 18 * 3 * 3).reshape(3, 18, 3, 3) | |||||
| out = F.pixel_shuffle(tensor(inp), upscale_factor=3) | |||||
| golden = pixel_shuffle(inp, 3) | |||||
| np.testing.assert_equal(out.numpy(), golden) | |||||
| # ndim = 5 | |||||
| inp = np.arange(5 * 3 * 20 * 3 * 4).reshape(5, 3, 20, 3, 4) | |||||
| out = F.pixel_shuffle(tensor(inp), upscale_factor=2) | |||||
| golden = pixel_shuffle(inp, 2) | |||||
| np.testing.assert_equal(out.numpy(), golden) | |||||
| # ndim = 6 | |||||
| inp = np.arange(6 * 5 * 3 * 25 * 3 * 4).reshape(6, 5, 3, 25, 3, 4) | |||||
| out = F.pixel_shuffle(tensor(inp), upscale_factor=5) | |||||
| golden = pixel_shuffle(inp, 5) | |||||
| np.testing.assert_equal(out.numpy(), golden) | |||||
| # ndim = 7 | |||||
| inp = np.arange(2 * 3 * 5 * 3 * 20 * 3 * 4).reshape(2, 3, 5, 3, 20, 3, 4) | |||||
| out = F.pixel_shuffle(tensor(inp), upscale_factor=2) | |||||
| golden = pixel_shuffle(inp, 2) | |||||
| np.testing.assert_equal(out.numpy(), golden) | |||||
| @pytest.mark.parametrize("is_symbolic", [False, True]) | |||||
| def test_pixel_shuffle_symbolic(is_symbolic): | |||||
| def fn(inp, upscale_factor): | |||||
| return F.pixel_shuffle(inp, upscale_factor=upscale_factor) | |||||
| if is_symbolic is not None: | |||||
| fn = jit.trace(symbolic=is_symbolic)(fn) | |||||
| inp = tensor(np.arange(3 * 4 * 5 * 5).reshape(3, 4, 5, 5)) | |||||
| golden = pixel_shuffle(inp, 2) | |||||
| for _ in range(3): | |||||
| out = fn(inp, 2) | |||||
| np.testing.assert_equal(out.numpy(), golden) | |||||
| if is_symbolic is None: | |||||
| break | |||||