group and compute_mode are not used by local_conv2d
GitOrigin-RevId: 8e4f25bfd8
tags/v0.4.0
| @@ -65,6 +65,7 @@ from .nn import ( | |||||
| interpolate, | interpolate, | ||||
| leaky_relu, | leaky_relu, | ||||
| linear, | linear, | ||||
| local_conv2d, | |||||
| matrix_mul, | matrix_mul, | ||||
| max_pool2d, | max_pool2d, | ||||
| one_hot, | one_hot, | ||||
| @@ -170,6 +170,34 @@ def conv_transpose2d( | |||||
| return res | return res | ||||
| @wrap_io_tensor | |||||
| def local_conv2d( | |||||
| inp: Tensor, | |||||
| weight: Tensor, | |||||
| stride: Union[int, Tuple[int, int]] = 1, | |||||
| padding: Union[int, Tuple[int, int]] = 0, | |||||
| dilation: Union[int, Tuple[int, int]] = 1, | |||||
| conv_mode="CROSS_CORRELATION", | |||||
| ) -> Tensor: | |||||
| """Applies spatial 2D convolution over an image with untied kernels. | |||||
| Refer to :class:`~.LocalConv2d` for more information. | |||||
| """ | |||||
| ret = mgb.opr.group_local( | |||||
| inp, | |||||
| weight, | |||||
| pad_h=padding[0], | |||||
| pad_w=padding[1], | |||||
| stride_h=stride[0], | |||||
| stride_w=stride[1], | |||||
| dilate_h=dilation[0], | |||||
| dilate_w=dilation[1], | |||||
| format="NCHW", | |||||
| mode=conv_mode, | |||||
| ) | |||||
| return ret | |||||
| @wrap_io_tensor | @wrap_io_tensor | ||||
| def max_pool2d( | def max_pool2d( | ||||
| inp: Tensor, | inp: Tensor, | ||||
| @@ -9,7 +9,7 @@ | |||||
| from .activation import LeakyReLU, PReLU, ReLU, Sigmoid, Softmax | from .activation import LeakyReLU, PReLU, ReLU, Sigmoid, Softmax | ||||
| from .batchnorm import BatchNorm1d, BatchNorm2d | from .batchnorm import BatchNorm1d, BatchNorm2d | ||||
| from .concat import Concat | from .concat import Concat | ||||
| from .conv import Conv2d, ConvTranspose2d | |||||
| from .conv import Conv2d, ConvTranspose2d, LocalConv2d | |||||
| from .conv_bn_relu import ConvBn2d, ConvBnRelu2d | from .conv_bn_relu import ConvBn2d, ConvBnRelu2d | ||||
| from .dropout import Dropout | from .dropout import Dropout | ||||
| from .elemwise import Elemwise | from .elemwise import Elemwise | ||||
| @@ -14,7 +14,7 @@ import numpy as np | |||||
| import megengine._internal as mgb | import megengine._internal as mgb | ||||
| from ..core import Parameter | from ..core import Parameter | ||||
| from ..functional import conv2d, conv_transpose2d | |||||
| from ..functional import conv2d, conv_transpose2d, local_conv2d | |||||
| from ..utils.types import _pair, _pair_nonzero | from ..utils.types import _pair, _pair_nonzero | ||||
| from . import init | from . import init | ||||
| from .module import Module | from .module import Module | ||||
| @@ -224,7 +224,7 @@ class ConvTranspose2d(_ConvNd): | |||||
| ``in_channels`` and ``out_channels`` must be divisible by ``groups``, | ``in_channels`` and ``out_channels`` must be divisible by ``groups``, | ||||
| and there would be an extra dimension at the beginning of the weight's | and there would be an extra dimension at the beginning of the weight's | ||||
| shape. Specifically, the shape of weight would be ``(groups, | shape. Specifically, the shape of weight would be ``(groups, | ||||
| out_channel // groups, in_channels // groups, *kernel_size)``. Default: 1 | |||||
| out_channels // groups, in_channels // groups, *kernel_size)``. Default: 1 | |||||
| :param bias: wether to add a bias onto the result of convolution. Default: | :param bias: wether to add a bias onto the result of convolution. Default: | ||||
| True | True | ||||
| :param conv_mode: Supports `CROSS_CORRELATION` or `CONVOLUTION`. Default: | :param conv_mode: Supports `CROSS_CORRELATION` or `CONVOLUTION`. Default: | ||||
| @@ -306,3 +306,77 @@ class ConvTranspose2d(_ConvNd): | |||||
| self.conv_mode, | self.conv_mode, | ||||
| self.compute_mode, | self.compute_mode, | ||||
| ) | ) | ||||
| class LocalConv2d(Conv2d): | |||||
| r"""Applies a spatial convolution with untied kernels over an input 4D tensor. | |||||
| It is also known as the locally connected layer. | |||||
| :param in_channels: number of input channels. | |||||
| :param out_channels: number of output channels. | |||||
| :param input_height: the height of the input images. | |||||
| :param input_width: the width of the input images. | |||||
| :param kernel_size: size of weight on spatial dimensions. If ``kernel_size`` is | |||||
| an :class:`int`, the actual kernel size would be | |||||
| ``(kernel_size, kernel_size)``. Default: 1 | |||||
| :param stride: stride of the 2D convolution operation. Default: 1 | |||||
| :param padding: size of the paddings added to the input on both sides of its | |||||
| spatial dimensions. Only zero-padding is supported. Default: 0 | |||||
| :param groups: number of groups to divide input and output channels into, | |||||
| so as to perform a "grouped convolution". When ``groups`` is not 1, | |||||
| ``in_channels`` and ``out_channels`` must be divisible by ``groups``. | |||||
| The shape of weight is ``(groups, output_height, output_width, | |||||
| in_channels // groups, *kernel_size, out_channels // groups)``. | |||||
| """ | |||||
| _conv_mode_type = mgb.opr_param_defs.Convolution.Mode | |||||
| def __init__( | |||||
| self, | |||||
| in_channels: int, | |||||
| out_channels: int, | |||||
| input_height: int, | |||||
| input_width: int, | |||||
| kernel_size: Union[int, Tuple[int, int]], | |||||
| stride: Union[int, Tuple[int, int]] = 1, | |||||
| padding: Union[int, Tuple[int, int]] = 0, | |||||
| dilation: Union[int, Tuple[int, int]] = 1, | |||||
| groups: int = 1, | |||||
| conv_mode: str = "CROSS_CORRELATION", | |||||
| ): | |||||
| self.input_height = input_height | |||||
| self.input_width = input_width | |||||
| super().__init__( | |||||
| in_channels, | |||||
| out_channels, | |||||
| kernel_size, | |||||
| stride, | |||||
| padding, | |||||
| dilation, | |||||
| groups, | |||||
| bias=False, | |||||
| ) | |||||
| def _infer_weight_shape(self): | |||||
| group = self.groups | |||||
| output_height = ( | |||||
| self.input_height + self.padding[0] * 2 - self.kernel_size[0] | |||||
| ) // self.stride[0] + 1 | |||||
| output_width = ( | |||||
| self.input_width + self.padding[1] * 2 - self.kernel_size[1] | |||||
| ) // self.stride[1] + 1 | |||||
| # Assume format is NCHW | |||||
| return ( | |||||
| group, | |||||
| output_height, | |||||
| output_width, | |||||
| self.in_channels // group, | |||||
| self.kernel_size[0], | |||||
| self.kernel_size[1], | |||||
| self.out_channels // group, | |||||
| ) | |||||
| def forward(self, inp): | |||||
| return local_conv2d( | |||||
| inp, self.weight, self.stride, self.padding, self.dilation, self.conv_mode | |||||
| ) | |||||
| @@ -11,7 +11,7 @@ import itertools | |||||
| import numpy as np | import numpy as np | ||||
| from megengine import Parameter, tensor | from megengine import Parameter, tensor | ||||
| from megengine.module import ConvTranspose2d | |||||
| from megengine.module import ConvTranspose2d, LocalConv2d | |||||
| from megengine.test import assertTensorClose | from megengine.test import assertTensorClose | ||||
| @@ -50,3 +50,61 @@ def test_conv_transpose2d(): | |||||
| y = conv_transpose2d(tensor(inp)) | y = conv_transpose2d(tensor(inp)) | ||||
| assertTensorClose(out, y.numpy(), max_err=2e-6) | assertTensorClose(out, y.numpy(), max_err=2e-6) | ||||
| def test_local_conv2d(): | |||||
| batch_size = 10 | |||||
| in_channels = 4 | |||||
| out_channels = 8 | |||||
| input_height = 8 | |||||
| input_width = 8 | |||||
| kernel_size = 3 | |||||
| stride = 1 | |||||
| padding = 1 | |||||
| dilation = 1 | |||||
| groups = 1 | |||||
| local_conv2d = LocalConv2d( | |||||
| in_channels=in_channels, | |||||
| out_channels=out_channels, | |||||
| input_height=input_height, | |||||
| input_width=input_width, | |||||
| kernel_size=kernel_size, | |||||
| stride=stride, | |||||
| padding=padding, | |||||
| dilation=dilation, | |||||
| groups=groups, | |||||
| ) | |||||
| inputs = np.random.normal( | |||||
| size=(batch_size, in_channels, input_height, input_width) | |||||
| ).astype(np.float32) | |||||
| output_height = (input_height + padding * 2 - kernel_size) // stride + 1 | |||||
| output_width = (input_width + padding * 2 - kernel_size) // stride + 1 | |||||
| weights = np.random.normal( | |||||
| size=( | |||||
| groups, | |||||
| output_height, | |||||
| output_width, | |||||
| in_channels // groups, | |||||
| kernel_size, | |||||
| kernel_size, | |||||
| out_channels // groups, | |||||
| ) | |||||
| ).astype(np.float32) | |||||
| local_conv2d.weight = Parameter(weights) | |||||
| outputs = local_conv2d(tensor(inputs)) | |||||
| # naive calculation use numpy | |||||
| # only test output_height == input_height, output_width == input_width, group == 1 | |||||
| inputs = np.pad(inputs, ((0, 0), (0, 0), (1, 1), (1, 1))) | |||||
| expected = np.zeros( | |||||
| (batch_size, out_channels, output_height, output_width), dtype=np.float32, | |||||
| ) | |||||
| for n, oc, oh, ow in itertools.product( | |||||
| *map(range, [batch_size, out_channels, output_height, output_width]) | |||||
| ): | |||||
| ih, iw = oh * stride, ow * stride | |||||
| expected[n, oc, ih, iw] = np.sum( | |||||
| inputs[n, :, ih : ih + kernel_size, iw : iw + kernel_size] | |||||
| * weights[0, oh, ow, :, :, :, oc] | |||||
| ) | |||||
| assertTensorClose(outputs.numpy(), expected, max_err=1e-5) | |||||