GitOrigin-RevId: 9ad87a4ea9
tags/v0.3.2
| @@ -53,6 +53,7 @@ from .nn import ( | |||
| batch_norm2d, | |||
| batched_matrix_mul, | |||
| conv2d, | |||
| conv_transpose2d, | |||
| dropout, | |||
| embedding, | |||
| eye, | |||
| @@ -100,6 +100,69 @@ def conv2d( | |||
| return res | |||
| @wrap_io_tensor | |||
| def conv_transpose2d( | |||
| inp: Tensor, | |||
| weight: Tensor, | |||
| bias: Optional[Tensor] = None, | |||
| stride: Union[int, Tuple[int, int]] = 1, | |||
| padding: Union[int, Tuple[int, int]] = 0, | |||
| dilation: Union[int, Tuple[int, int]] = 1, | |||
| groups: int = 1, | |||
| conv_mode="CROSS_CORRELATION", | |||
| compute_mode="DEFAULT", | |||
| ) -> Tensor: | |||
| """2D transposed convolution operation. | |||
| :param inp: The feature map of the convolution operation | |||
| :param weight: The convolution kernel | |||
| :param bias: The bias added to the result of convolution (if given) | |||
| :param stride: Stride of the 2D convolution operation. Default: 1 | |||
| :param padding: Size of the paddings added to the input on both sides of its | |||
| spatial dimensions. Only zero-padding is supported. Default: 0 | |||
| :param dilation: Dilation of the 2D convolution operation. Default: 1 | |||
| :param groups: number of groups to divide input and output channels into, | |||
| so as to perform a "grouped convolution". When ``groups`` is not 1, | |||
| ``in_channels`` and ``out_channels`` must be divisible by ``groups``, | |||
| and the shape of weight should be ``(groups, out_channel // groups, | |||
| in_channels // groups, height, width)``. Default: 1 | |||
| :type conv_mode: string or :class:`mgb.opr_param_defs.Convolution.Mode` | |||
| :param conv_mode: Supports 'CROSS_CORRELATION' or 'CONVOLUTION'. Default: | |||
| 'CROSS_CORRELATION'. | |||
| :type compute_mode: string or | |||
| :class:`mgb.opr_param_defs.Convolution.ComputeMode` | |||
| :param compute_mode: When set to 'DEFAULT', no special requirements will be | |||
| placed on the precision of intermediate results. When set to 'FLOAT32', | |||
| Float32 would be used for accumulator and intermediate result, but only | |||
| effective when input and output are of Float16 dtype. | |||
| Refer to :class:`~.ConvTranspose2d` for more information. | |||
| """ | |||
| ph, pw = _pair(padding) | |||
| sh, sw = _pair_nonzero(stride) | |||
| dh, dw = _pair_nonzero(dilation) | |||
| Sparse = mgb.opr_param_defs.Convolution.Sparse | |||
| sparse_type = Sparse.DENSE if groups == 1 else Sparse.GROUP | |||
| res = mgb.opr.deconvolution( | |||
| inp, | |||
| weight, | |||
| pad_h=ph, | |||
| pad_w=pw, | |||
| stride_h=sh, | |||
| stride_w=sw, | |||
| dilate_h=dh, | |||
| dilate_w=dw, | |||
| format="NCHW", | |||
| strategy=get_conv_execution_strategy(), | |||
| mode=conv_mode, | |||
| compute_mode=compute_mode, | |||
| sparse=sparse_type, | |||
| ) | |||
| if bias is not None: | |||
| res += bias | |||
| return res | |||
| @wrap_io_tensor | |||
| def max_pool2d( | |||
| inp: Tensor, | |||
| @@ -8,7 +8,7 @@ | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| from .activation import LeakyReLU, PReLU, ReLU, Sigmoid, Softmax | |||
| from .batchnorm import BatchNorm1d, BatchNorm2d | |||
| from .conv import Conv2d | |||
| from .conv import Conv2d, ConvTranspose2d | |||
| from .dropout import Dropout | |||
| from .embedding import Embedding | |||
| from .identity import Identity | |||
| @@ -14,7 +14,7 @@ import numpy as np | |||
| import megengine._internal as mgb | |||
| from ..core import Parameter | |||
| from ..functional import conv2d | |||
| from ..functional import conv2d, conv_transpose2d | |||
| from ..utils.types import _pair, _pair_nonzero | |||
| from . import init | |||
| from .module import Module | |||
| @@ -31,7 +31,6 @@ class _ConvNd(Module): | |||
| stride: Union[int, Tuple[int, int]], | |||
| padding: Union[int, Tuple[int, int]], | |||
| dilation: Union[int, Tuple[int, int]], | |||
| output_padding: Union[int, Tuple[int, int]], | |||
| groups: int, | |||
| bias: bool = True, | |||
| ): | |||
| @@ -46,7 +45,6 @@ class _ConvNd(Module): | |||
| self.stride = stride | |||
| self.padding = padding | |||
| self.dilation = dilation | |||
| self.output_padding = output_padding | |||
| self.groups = groups | |||
| self.weight = Parameter(np.zeros(self._infer_weight_shape(), dtype=np.float32)) | |||
| @@ -154,7 +152,6 @@ class Conv2d(_ConvNd): | |||
| stride, | |||
| padding, | |||
| dilation, | |||
| (0, 0), | |||
| groups, | |||
| bias, | |||
| ) | |||
| @@ -197,3 +194,112 @@ class Conv2d(_ConvNd): | |||
| self.conv_mode, | |||
| self.compute_mode, | |||
| ) | |||
| class ConvTranspose2d(_ConvNd): | |||
| r"""Applies a 2D transposed convolution over an input tensor. | |||
| This module is also known as a deconvolution or a fractionally-strided convolution. | |||
| :class:`ConvTranspose2d` can ben seen as the gradient of :class:`Conv2d` operation | |||
| with respect to its input. | |||
| Convolution usually reduces the size of input, while transposed convolution works | |||
| the other way, transforming a smaller input to a larger output while preserving the | |||
| connectivity pattern. | |||
| :param in_channels: number of input channels. | |||
| :param out_channels: number of output channels. | |||
| :param kernel_size: size of weight on spatial dimensions. If ``kernel_size`` is | |||
| an :class:`int`, the actual kernel size would be | |||
| ``(kernel_size, kernel_size)``. Default: 1 | |||
| :param stride: stride of the 2D convolution operation. Default: 1 | |||
| :param padding: size of the paddings added to the input on both sides of its | |||
| spatial dimensions. Only zero-padding is supported. Default: 0 | |||
| :param dilation: dilation of the 2D convolution operation. Default: 1 | |||
| :param groups: number of groups to divide input and output channels into, | |||
| so as to perform a "grouped convolution". When ``groups`` is not 1, | |||
| ``in_channels`` and ``out_channels`` must be divisible by ``groups``, | |||
| and there would be an extra dimension at the beginning of the weight's | |||
| shape. Specifically, the shape of weight would be ``(groups, | |||
| out_channel // groups, in_channels // groups, *kernel_size)``. Default: 1 | |||
| :param bias: wether to add a bias onto the result of convolution. Default: | |||
| True | |||
| :param conv_mode: Supports `CROSS_CORRELATION` or `CONVOLUTION`. Default: | |||
| `CROSS_CORRELATION`. | |||
| :param compute_mode: When set to `DEFAULT`, no special requirements will be | |||
| placed on the precision of intermediate results. When set to `FLOAT32`, | |||
| float32 would be used for accumulator and intermediate result, but only | |||
| effective when input and output are of float16 dtype. | |||
| """ | |||
| _conv_mode_type = mgb.opr_param_defs.Convolution.Mode | |||
| _compute_mode_type = mgb.opr_param_defs.Convolution.ComputeMode | |||
| def __init__( | |||
| self, | |||
| in_channels: int, | |||
| out_channels: int, | |||
| kernel_size: Union[int, Tuple[int, int]], | |||
| stride: Union[int, Tuple[int, int]] = 1, | |||
| padding: Union[int, Tuple[int, int]] = 0, | |||
| dilation: Union[int, Tuple[int, int]] = 1, | |||
| groups: int = 1, | |||
| bias: bool = True, | |||
| conv_mode: str = "CROSS_CORRELATION", | |||
| compute_mode: str = "DEFAULT", | |||
| ): | |||
| kernel_size = _pair_nonzero(kernel_size) | |||
| stride = _pair_nonzero(stride) | |||
| padding = _pair(padding) | |||
| dilation = _pair_nonzero(dilation) | |||
| self.conv_mode = self._conv_mode_type.convert(conv_mode) | |||
| self.compute_mode = self._compute_mode_type.convert(compute_mode) | |||
| super().__init__( | |||
| in_channels, | |||
| out_channels, | |||
| kernel_size, | |||
| stride, | |||
| padding, | |||
| dilation, | |||
| groups, | |||
| bias, | |||
| ) | |||
| def _get_fanin(self): | |||
| kh, kw = self.kernel_size | |||
| oc = self.out_channels | |||
| return kh * kw * oc | |||
| def _infer_weight_shape(self): | |||
| group = self.groups | |||
| ichl = self.in_channels | |||
| ochl = self.out_channels | |||
| kh, kw = self.kernel_size | |||
| if group == 1: | |||
| # Assume format is NCHW | |||
| return (ichl, ochl, kh, kw) | |||
| assert ( | |||
| ichl % group == 0 and ochl % group == 0 | |||
| ), "invalid config: input_channels={} output_channels={} group={}".format( | |||
| ichl, ochl, group | |||
| ) | |||
| # Assume format is NCHW | |||
| return (group, ichl // group, ochl // group, kh, kw) | |||
| def _infer_bias_shape(self): | |||
| # Assume format is NCHW | |||
| return (1, self.out_channels, 1, 1) | |||
| def forward(self, inp): | |||
| return conv_transpose2d( | |||
| inp, | |||
| self.weight, | |||
| self.bias, | |||
| self.stride, | |||
| self.padding, | |||
| self.dilation, | |||
| self.groups, | |||
| self.conv_mode, | |||
| self.compute_mode, | |||
| ) | |||
| @@ -0,0 +1,64 @@ | |||
| # -*- coding: utf-8 -*- | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| import itertools | |||
| import numpy as np | |||
| import pytest | |||
| import torch | |||
| import megengine as mge | |||
| from megengine import Parameter, tensor | |||
| from megengine.module import Conv2d, ConvTranspose2d | |||
| from megengine.test import assertTensorClose | |||
| def test_conv_transpose2d(): | |||
| SH, SW = 3, 1 | |||
| PH, PW = 2, 0 | |||
| N, IC, IH, IW = 4, 5, 8, 6 | |||
| KH, KW = 3, 4 | |||
| OC = 3 | |||
| BIAS = True | |||
| def getsize(inp, kern, stride): | |||
| return (inp - 1) * stride + kern | |||
| OH = getsize(IH, KH, SH) | |||
| OW = getsize(IW, KW, SW) | |||
| inp = np.random.normal(size=(N, IC, IH, IW)).astype(np.float32) | |||
| out = np.zeros((N, OC, OH, OW), dtype=np.float32) | |||
| weight = np.random.normal(size=(IC, OC, KH, KW)).astype(np.float32) | |||
| bias = np.random.normal(size=(1, OC, 1, 1)).astype(np.float32) | |||
| for n, ic, ih, iw in itertools.product(*map(range, [N, IC, IH, IW])): | |||
| oh, ow = ih * SH, iw * SW | |||
| out[n, :, oh : oh + KH, ow : ow + KW] += inp[n, ic, ih, iw] * weight[ic] | |||
| out = out[:, :, PH : OH - PH, PW : OW - PW] | |||
| if BIAS: | |||
| out += bias | |||
| conv_transpose2d = ConvTranspose2d(IC, OC, (KH, KW), (SH, SW), (PH, PW), bias=BIAS) | |||
| conv_transpose2d.weight = Parameter(weight, dtype=np.float32) | |||
| if BIAS: | |||
| conv_transpose2d.bias = Parameter(bias, dtype=np.float32) | |||
| y = conv_transpose2d(tensor(inp)) | |||
| assertTensorClose(out, y.numpy(), max_err=2e-6) | |||
| torch_conv_transpose2d = torch.nn.ConvTranspose2d( | |||
| IC, OC, (KH, KW), stride=(SH, SW), padding=(PH, PW), bias=BIAS | |||
| ) | |||
| torch_conv_transpose2d.weight = torch.nn.parameter.Parameter(torch.Tensor(weight)) | |||
| if BIAS: | |||
| torch_conv_transpose2d.bias = torch.nn.parameter.Parameter( | |||
| torch.Tensor(bias).reshape(OC) | |||
| ) | |||
| torch_y = torch_conv_transpose2d(torch.Tensor(inp)) | |||
| assertTensorClose(torch_y.detach().numpy(), y.numpy(), max_err=2e-6) | |||