GitOrigin-RevId: 789f1511ec
tags/v1.0.0
| @@ -13,7 +13,7 @@ from ..core._imperative_rt import CompNode | |||||
| from ..core.ops import builtin | from ..core.ops import builtin | ||||
| from ..core.ops._internal import param_defs as P | from ..core.ops._internal import param_defs as P | ||||
| from ..core.ops.special import Const | from ..core.ops.special import Const | ||||
| from ..core.tensor import utils | |||||
| from ..core.tensor import megbrain_graph, utils | |||||
| from ..core.tensor.core import TensorBase, TensorWrapperBase, apply | from ..core.tensor.core import TensorBase, TensorWrapperBase, apply | ||||
| from ..core.tensor.utils import astensor1d | from ..core.tensor.utils import astensor1d | ||||
| from ..distributed import WORLD, is_distributed | from ..distributed import WORLD, is_distributed | ||||
| @@ -27,6 +27,8 @@ from .tensor import add_axis, broadcast, concat, full, ones, remove_axis, reshap | |||||
| from .types import _pair, _pair_nonzero | from .types import _pair, _pair_nonzero | ||||
| __all__ = [ | __all__ = [ | ||||
| "adaptive_avg_pool2d", | |||||
| "adaptive_max_pool2d", | |||||
| "avg_pool2d", | "avg_pool2d", | ||||
| "batched_nms", | "batched_nms", | ||||
| "batch_norm2d", | "batch_norm2d", | ||||
| @@ -324,6 +326,48 @@ def avg_pool2d( | |||||
| return output | return output | ||||
| def adaptive_max_pool2d( | |||||
| inp: Tensor, oshp: Union[Tuple[int, int], int, Tensor], | |||||
| ) -> Tensor: | |||||
| """Applies a 2D max adaptive pooling over an input. | |||||
| Refer to :class:`~.MaxAdaptivePool2d` for more information. | |||||
| :param inp: The input tensor. | |||||
| :param oshp: (OH, OW) size of the output shape. | |||||
| :return: output tensor. | |||||
| """ | |||||
| assert isinstance(inp, (Tensor, megbrain_graph.VarNode)), "inp must be Tensor type" | |||||
| if isinstance(oshp, int): | |||||
| oshp = (oshp, oshp) | |||||
| op = builtin.AdaptivePooling(mode="MAX", format="NCHW",) | |||||
| oshp = astensor1d(oshp, inp, dtype="int32", device=inp.device) | |||||
| (output,) = apply(op, inp, oshp) | |||||
| return output | |||||
| def adaptive_avg_pool2d( | |||||
| inp: Tensor, oshp: Union[Tuple[int, int], int, Tensor], | |||||
| ) -> Tensor: | |||||
| """Applies a 2D average adaptive pooling over an input. | |||||
| Refer to :class:`~.AvgAdaptivePool2d` for more information. | |||||
| :param inp: The input tensor. | |||||
| :param oshp: (OH, OW) size of the output shape. | |||||
| :return: output tensor. | |||||
| """ | |||||
| assert isinstance(inp, (Tensor, megbrain_graph.VarNode)), "inp must be Tensor type" | |||||
| if isinstance(oshp, int): | |||||
| oshp = (oshp, oshp) | |||||
| op = builtin.AdaptivePooling(mode="AVERAGE", format="NCHW",) | |||||
| oshp = astensor1d(oshp, inp, dtype="int32", device=inp.device) | |||||
| (output,) = apply(op, inp, oshp) | |||||
| return output | |||||
| def prelu(inp: Tensor, weight: Tensor) -> Tensor: | def prelu(inp: Tensor, weight: Tensor) -> Tensor: | ||||
| r""" | r""" | ||||
| Applies the element-wise PReLU function. | Applies the element-wise PReLU function. | ||||
| @@ -8,6 +8,7 @@ | |||||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| from .activation import LeakyReLU, PReLU, ReLU, Sigmoid, Softmax | from .activation import LeakyReLU, PReLU, ReLU, Sigmoid, Softmax | ||||
| from .adaptive_pooling import AdaptiveAvgPool2d, AdaptiveMaxPool2d | |||||
| from .batchnorm import BatchNorm1d, BatchNorm2d, SyncBatchNorm | from .batchnorm import BatchNorm1d, BatchNorm2d, SyncBatchNorm | ||||
| from .concat import Concat | from .concat import Concat | ||||
| from .conv import Conv2d, ConvRelu2d, ConvTranspose2d, LocalConv2d | from .conv import Conv2d, ConvRelu2d, ConvTranspose2d, LocalConv2d | ||||
| @@ -0,0 +1,114 @@ | |||||
| # -*- coding: utf-8 -*- | |||||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| # | |||||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, | |||||
| # software distributed under the License is distributed on an | |||||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| from abc import abstractmethod | |||||
| from typing import Tuple, Union | |||||
| from ..functional import adaptive_avg_pool2d, adaptive_max_pool2d | |||||
| from ..tensor import Parameter, Tensor | |||||
| from .module import Module | |||||
| class _AdaptivePoolNd(Module): | |||||
| def __init__( | |||||
| self, oshp: Union[Tuple[int, int], int, Tensor], | |||||
| ): | |||||
| super(_AdaptivePoolNd, self).__init__() | |||||
| self.oshp = oshp | |||||
| @abstractmethod | |||||
| def forward(self, inp): | |||||
| pass | |||||
| class AdaptiveMaxPool2d(_AdaptivePoolNd): | |||||
| r"""Applies a 2D max adaptive pooling over an input. | |||||
| For instance, given an input of the size :math:`(N, C, H, W)` and | |||||
| an output shape :math:`(OH, OW)`, this layer generates the output of | |||||
| the size :math:`(N, C, OH, OW)` through a process described as: | |||||
| .. math:: | |||||
| \begin{aligned} | |||||
| out(N_i, C_j, h, w) ={} & \max_{m=0, \ldots, kH-1} \max_{n=0, \ldots, kW-1} | |||||
| \text{input}(N_i, C_j, \text{stride[0]} \times h + m, | |||||
| \text{stride[1]} \times w + n) | |||||
| \end{aligned} | |||||
| Kernel_size and stride can be inferred from input shape and out shape: | |||||
| padding: (0, 0) | |||||
| stride: (floor(IH / OH), floor(IW / OW)) | |||||
| kernel_size: (IH - (OH - 1) * stride_h, IW - (OW - 1) * stride_w) | |||||
| Examples: | |||||
| .. testcode:: | |||||
| import numpy as np | |||||
| import megengine as mge | |||||
| import megengine.module as M | |||||
| m = M.AdaptiveMaxPool2d((2, 2)) | |||||
| inp = mge.tensor(np.arange(0, 16).astype("float32").reshape(1, 1, 4, 4)) | |||||
| oup = m(inp) | |||||
| print(oup.numpy()) | |||||
| Outputs: | |||||
| .. testoutput:: | |||||
| [[[[5. 7.] | |||||
| [13. 15.]]]] | |||||
| """ | |||||
| def forward(self, inp): | |||||
| return adaptive_max_pool2d(inp, self.oshp) | |||||
| class AdaptiveAvgPool2d(_AdaptivePoolNd): | |||||
| r"""Applies a 2D average pooling over an input. | |||||
| For instance, given an input of the size :math:`(N, C, H, W)` and | |||||
| an output shape :math:`(OH, OW)`, this layer generates the output of | |||||
| the size :math:`(N, C, OH, OW)` through a process described as: | |||||
| .. math:: | |||||
| out(N_i, C_j, h, w) = \frac{1}{kH * kW} \sum_{m=0}^{kH-1} \sum_{n=0}^{kW-1} | |||||
| input(N_i, C_j, stride[0] \times h + m, stride[1] \times w + n) | |||||
| Kernel_size and stride can be inferred from input shape and out shape: | |||||
| padding: (0, 0) | |||||
| stride: (floor(IH / OH), floor(IW / OW)) | |||||
| kernel_size: (IH - (OH - 1) * stride_h, IW - (OW - 1) * stride_w) | |||||
| Examples: | |||||
| .. testcode:: | |||||
| import numpy as np | |||||
| import megengine as mge | |||||
| import megengine.module as M | |||||
| m = M.AdaptiveAvgPool2d((2, 2)) | |||||
| inp = mge.tensor(np.arange(0, 16).astype("float32").reshape(1, 1, 4, 4)) | |||||
| oup = m(inp) | |||||
| print(oup.numpy()) | |||||
| Outputs: | |||||
| .. testoutput:: | |||||
| [[[[2.5 4.5] | |||||
| [10.5 12.5]]]] | |||||
| """ | |||||
| def forward(self, inp): | |||||
| return adaptive_avg_pool2d(inp, self.oshp) | |||||
| @@ -206,6 +206,66 @@ def test_roi_pooling(): | |||||
| assert make_shape_tuple(inp_feat.grad.shape) == make_shape_tuple(inp_feat.shape) | assert make_shape_tuple(inp_feat.grad.shape) == make_shape_tuple(inp_feat.shape) | ||||
| def test_adaptive_avg_pool2d(): | |||||
| inp = tensor(np.arange(0, 16, dtype=np.float32).reshape(1, 1, 4, 4)) | |||||
| oshp = (2, 2) | |||||
| grad = Grad().wrt(inp, callback=_save_to(inp)) | |||||
| outp = F.adaptive_avg_pool2d(inp, oshp,) | |||||
| assert make_shape_tuple(outp.shape) == (inp.shape[0], inp.shape[1], *oshp,) | |||||
| np.testing.assert_equal( | |||||
| outp.numpy(), np.array([[[[2.5, 4.5], [10.5, 12.5]]]], dtype=np.float32) | |||||
| ) | |||||
| grad(outp, tensor(F.ones_like(outp))) | |||||
| assert make_shape_tuple(inp.grad.shape) == make_shape_tuple(inp.shape) | |||||
| np.testing.assert_equal( | |||||
| inp.grad.numpy(), | |||||
| np.array( | |||||
| [ | |||||
| [ | |||||
| [ | |||||
| [0.25, 0.25, 0.25, 0.25], | |||||
| [0.25, 0.25, 0.25, 0.25], | |||||
| [0.25, 0.25, 0.25, 0.25], | |||||
| [0.25, 0.25, 0.25, 0.25], | |||||
| ] | |||||
| ] | |||||
| ], | |||||
| dtype=np.float32, | |||||
| ), | |||||
| ) | |||||
| def test_adaptive_max_pool2d(): | |||||
| inp = tensor(np.arange(0, 16, dtype=np.float32).reshape(1, 1, 4, 4)) | |||||
| oshp = (2, 2) | |||||
| grad = Grad().wrt(inp, callback=_save_to(inp)) | |||||
| outp = F.adaptive_max_pool2d(inp, oshp,) | |||||
| assert make_shape_tuple(outp.shape) == (inp.shape[0], inp.shape[1], *oshp,) | |||||
| np.testing.assert_equal( | |||||
| outp.numpy(), np.array([[[[5, 7], [13, 15]]]], dtype=np.float32) | |||||
| ) | |||||
| grad(outp, tensor(F.ones_like(outp))) | |||||
| assert make_shape_tuple(inp.grad.shape) == make_shape_tuple(inp.shape) | |||||
| np.testing.assert_equal( | |||||
| inp.grad.numpy(), | |||||
| np.array( | |||||
| [ | |||||
| [ | |||||
| [ | |||||
| [0.0, 0.0, 0.0, 0.0], | |||||
| [0.0, 1.0, 0.0, 1.0], | |||||
| [0.0, 0.0, 0.0, 0.0], | |||||
| [0.0, 1.0, 0.0, 1.0], | |||||
| ] | |||||
| ] | |||||
| ], | |||||
| dtype=np.float32, | |||||
| ), | |||||
| ) | |||||
| def test_one_hot(): | def test_one_hot(): | ||||
| def onehot_low_dimension(): | def onehot_low_dimension(): | ||||
| inp = tensor(np.arange(1, 4, dtype=np.int32)) | inp = tensor(np.arange(1, 4, dtype=np.int32)) | ||||