| @@ -67,6 +67,7 @@ from .nn import ( | |||
| interpolate, | |||
| leaky_relu, | |||
| linear, | |||
| local_conv2d, | |||
| matrix_mul, | |||
| max_pool2d, | |||
| one_hot, | |||
| @@ -171,6 +171,34 @@ def conv_transpose2d( | |||
| return res | |||
| @wrap_io_tensor | |||
| def local_conv2d( | |||
| inp: Tensor, | |||
| weight: Tensor, | |||
| stride: Union[int, Tuple[int, int]] = 1, | |||
| padding: Union[int, Tuple[int, int]] = 0, | |||
| dilation: Union[int, Tuple[int, int]] = 1, | |||
| conv_mode="CROSS_CORRELATION", | |||
| ) -> Tensor: | |||
| """Applies spatial 2D convolution over an image with untied kernels. | |||
| Refer to :class:`~.LocalConv2d` for more information. | |||
| """ | |||
| ret = mgb.opr.group_local( | |||
| inp, | |||
| weight, | |||
| pad_h=padding[0], | |||
| pad_w=padding[1], | |||
| stride_h=stride[0], | |||
| stride_w=stride[1], | |||
| dilate_h=dilation[0], | |||
| dilate_w=dilation[1], | |||
| format="NCHW", | |||
| mode=conv_mode, | |||
| ) | |||
| return ret | |||
| @wrap_io_tensor | |||
| def max_pool2d( | |||
| inp: Tensor, | |||
| @@ -9,7 +9,7 @@ | |||
| from .activation import LeakyReLU, PReLU, ReLU, Sigmoid, Softmax | |||
| from .batchnorm import BatchNorm1d, BatchNorm2d | |||
| from .concat import Concat | |||
| from .conv import Conv2d, ConvTranspose2d | |||
| from .conv import Conv2d, ConvTranspose2d, LocalConv2d | |||
| from .conv_bn_relu import ConvBn2d, ConvBnRelu2d | |||
| from .dropout import Dropout | |||
| from .elemwise import Elemwise | |||
| @@ -14,7 +14,7 @@ import numpy as np | |||
| import megengine._internal as mgb | |||
| from ..core import Parameter | |||
| from ..functional import conv2d, conv_transpose2d | |||
| from ..functional import conv2d, conv_transpose2d, local_conv2d | |||
| from ..utils.types import _pair, _pair_nonzero | |||
| from . import init | |||
| from .module import Module | |||
| @@ -224,7 +224,7 @@ class ConvTranspose2d(_ConvNd): | |||
| ``in_channels`` and ``out_channels`` must be divisible by ``groups``, | |||
| and there would be an extra dimension at the beginning of the weight's | |||
| shape. Specifically, the shape of weight would be ``(groups, | |||
| out_channel // groups, in_channels // groups, *kernel_size)``. Default: 1 | |||
| out_channels // groups, in_channels // groups, *kernel_size)``. Default: 1 | |||
| :param bias: wether to add a bias onto the result of convolution. Default: | |||
| True | |||
| :param conv_mode: Supports `CROSS_CORRELATION` or `CONVOLUTION`. Default: | |||
| @@ -306,3 +306,77 @@ class ConvTranspose2d(_ConvNd): | |||
| self.conv_mode, | |||
| self.compute_mode, | |||
| ) | |||
| class LocalConv2d(Conv2d): | |||
| r"""Applies a spatial convolution with untied kernels over an input 4D tensor. | |||
| It is also known as the locally connected layer. | |||
| :param in_channels: number of input channels. | |||
| :param out_channels: number of output channels. | |||
| :param input_height: the height of the input images. | |||
| :param input_width: the width of the input images. | |||
| :param kernel_size: size of weight on spatial dimensions. If ``kernel_size`` is | |||
| an :class:`int`, the actual kernel size would be | |||
| ``(kernel_size, kernel_size)``. Default: 1 | |||
| :param stride: stride of the 2D convolution operation. Default: 1 | |||
| :param padding: size of the paddings added to the input on both sides of its | |||
| spatial dimensions. Only zero-padding is supported. Default: 0 | |||
| :param groups: number of groups to divide input and output channels into, | |||
| so as to perform a "grouped convolution". When ``groups`` is not 1, | |||
| ``in_channels`` and ``out_channels`` must be divisible by ``groups``. | |||
| The shape of weight is ``(groups, output_height, output_width, | |||
| in_channels // groups, *kernel_size, out_channels // groups)``. | |||
| """ | |||
| _conv_mode_type = mgb.opr_param_defs.Convolution.Mode | |||
| def __init__( | |||
| self, | |||
| in_channels: int, | |||
| out_channels: int, | |||
| input_height: int, | |||
| input_width: int, | |||
| kernel_size: Union[int, Tuple[int, int]], | |||
| stride: Union[int, Tuple[int, int]] = 1, | |||
| padding: Union[int, Tuple[int, int]] = 0, | |||
| dilation: Union[int, Tuple[int, int]] = 1, | |||
| groups: int = 1, | |||
| conv_mode: str = "CROSS_CORRELATION", | |||
| ): | |||
| self.input_height = input_height | |||
| self.input_width = input_width | |||
| super().__init__( | |||
| in_channels, | |||
| out_channels, | |||
| kernel_size, | |||
| stride, | |||
| padding, | |||
| dilation, | |||
| groups, | |||
| bias=False, | |||
| ) | |||
| def _infer_weight_shape(self): | |||
| group = self.groups | |||
| output_height = ( | |||
| self.input_height + self.padding[0] * 2 - self.kernel_size[0] | |||
| ) // self.stride[0] + 1 | |||
| output_width = ( | |||
| self.input_width + self.padding[1] * 2 - self.kernel_size[1] | |||
| ) // self.stride[1] + 1 | |||
| # Assume format is NCHW | |||
| return ( | |||
| group, | |||
| output_height, | |||
| output_width, | |||
| self.in_channels // group, | |||
| self.kernel_size[0], | |||
| self.kernel_size[1], | |||
| self.out_channels // group, | |||
| ) | |||
| def forward(self, inp): | |||
| return local_conv2d( | |||
| inp, self.weight, self.stride, self.padding, self.dilation, self.conv_mode | |||
| ) | |||
| @@ -11,7 +11,7 @@ import itertools | |||
| import numpy as np | |||
| from megengine import Parameter, tensor | |||
| from megengine.module import ConvTranspose2d | |||
| from megengine.module import ConvTranspose2d, LocalConv2d | |||
| from megengine.test import assertTensorClose | |||
| @@ -50,3 +50,61 @@ def test_conv_transpose2d(): | |||
| y = conv_transpose2d(tensor(inp)) | |||
| assertTensorClose(out, y.numpy(), max_err=2e-6) | |||
| def test_local_conv2d(): | |||
| batch_size = 10 | |||
| in_channels = 4 | |||
| out_channels = 8 | |||
| input_height = 8 | |||
| input_width = 8 | |||
| kernel_size = 3 | |||
| stride = 1 | |||
| padding = 1 | |||
| dilation = 1 | |||
| groups = 1 | |||
| local_conv2d = LocalConv2d( | |||
| in_channels=in_channels, | |||
| out_channels=out_channels, | |||
| input_height=input_height, | |||
| input_width=input_width, | |||
| kernel_size=kernel_size, | |||
| stride=stride, | |||
| padding=padding, | |||
| dilation=dilation, | |||
| groups=groups, | |||
| ) | |||
| inputs = np.random.normal( | |||
| size=(batch_size, in_channels, input_height, input_width) | |||
| ).astype(np.float32) | |||
| output_height = (input_height + padding * 2 - kernel_size) // stride + 1 | |||
| output_width = (input_width + padding * 2 - kernel_size) // stride + 1 | |||
| weights = np.random.normal( | |||
| size=( | |||
| groups, | |||
| output_height, | |||
| output_width, | |||
| in_channels // groups, | |||
| kernel_size, | |||
| kernel_size, | |||
| out_channels // groups, | |||
| ) | |||
| ).astype(np.float32) | |||
| local_conv2d.weight = Parameter(weights) | |||
| outputs = local_conv2d(tensor(inputs)) | |||
| # naive calculation use numpy | |||
| # only test output_height == input_height, output_width == input_width, group == 1 | |||
| inputs = np.pad(inputs, ((0, 0), (0, 0), (1, 1), (1, 1))) | |||
| expected = np.zeros( | |||
| (batch_size, out_channels, output_height, output_width), dtype=np.float32, | |||
| ) | |||
| for n, oc, oh, ow in itertools.product( | |||
| *map(range, [batch_size, out_channels, output_height, output_width]) | |||
| ): | |||
| ih, iw = oh * stride, ow * stride | |||
| expected[n, oc, ih, iw] = np.sum( | |||
| inputs[n, :, ih : ih + kernel_size, iw : iw + kernel_size] | |||
| * weights[0, oh, ow, :, :, :, oc] | |||
| ) | |||
| assertTensorClose(outputs.numpy(), expected, max_err=1e-5) | |||
| @@ -112,9 +112,10 @@ decl_opr('GroupLocal', | |||
| 'convolution kernel in ' | |||
| '(group, out row, out col, in channel / group, ' | |||
| 'kern row, kern col, out channel / group) format')], | |||
| params='ConvolutionV0', | |||
| params=[('param', 'Convolution')], | |||
| desc='batched convolution on groupped channeled 2D images, but ' | |||
| 'kernels are not shared across different output positions') | |||
| 'kernels are not shared across different output positions', | |||
| version=1) | |||
| decl_opr('LRN', | |||
| inputs=['src'], | |||