GitOrigin-RevId: 9ad87a4ea9
tags/v0.3.2
| @@ -53,6 +53,7 @@ from .nn import ( | |||||
| batch_norm2d, | batch_norm2d, | ||||
| batched_matrix_mul, | batched_matrix_mul, | ||||
| conv2d, | conv2d, | ||||
| conv_transpose2d, | |||||
| dropout, | dropout, | ||||
| embedding, | embedding, | ||||
| eye, | eye, | ||||
| @@ -100,6 +100,69 @@ def conv2d( | |||||
| return res | return res | ||||
| @wrap_io_tensor | |||||
| def conv_transpose2d( | |||||
| inp: Tensor, | |||||
| weight: Tensor, | |||||
| bias: Optional[Tensor] = None, | |||||
| stride: Union[int, Tuple[int, int]] = 1, | |||||
| padding: Union[int, Tuple[int, int]] = 0, | |||||
| dilation: Union[int, Tuple[int, int]] = 1, | |||||
| groups: int = 1, | |||||
| conv_mode="CROSS_CORRELATION", | |||||
| compute_mode="DEFAULT", | |||||
| ) -> Tensor: | |||||
| """2D transposed convolution operation. | |||||
| :param inp: The feature map of the convolution operation | |||||
| :param weight: The convolution kernel | |||||
| :param bias: The bias added to the result of convolution (if given) | |||||
| :param stride: Stride of the 2D convolution operation. Default: 1 | |||||
| :param padding: Size of the paddings added to the input on both sides of its | |||||
| spatial dimensions. Only zero-padding is supported. Default: 0 | |||||
| :param dilation: Dilation of the 2D convolution operation. Default: 1 | |||||
| :param groups: number of groups to divide input and output channels into, | |||||
| so as to perform a "grouped convolution". When ``groups`` is not 1, | |||||
| ``in_channels`` and ``out_channels`` must be divisible by ``groups``, | |||||
| and the shape of weight should be ``(groups, out_channel // groups, | |||||
| in_channels // groups, height, width)``. Default: 1 | |||||
| :type conv_mode: string or :class:`mgb.opr_param_defs.Convolution.Mode` | |||||
| :param conv_mode: Supports 'CROSS_CORRELATION' or 'CONVOLUTION'. Default: | |||||
| 'CROSS_CORRELATION'. | |||||
| :type compute_mode: string or | |||||
| :class:`mgb.opr_param_defs.Convolution.ComputeMode` | |||||
| :param compute_mode: When set to 'DEFAULT', no special requirements will be | |||||
| placed on the precision of intermediate results. When set to 'FLOAT32', | |||||
| Float32 would be used for accumulator and intermediate result, but only | |||||
| effective when input and output are of Float16 dtype. | |||||
| Refer to :class:`~.ConvTranspose2d` for more information. | |||||
| """ | |||||
| ph, pw = _pair(padding) | |||||
| sh, sw = _pair_nonzero(stride) | |||||
| dh, dw = _pair_nonzero(dilation) | |||||
| Sparse = mgb.opr_param_defs.Convolution.Sparse | |||||
| sparse_type = Sparse.DENSE if groups == 1 else Sparse.GROUP | |||||
| res = mgb.opr.deconvolution( | |||||
| inp, | |||||
| weight, | |||||
| pad_h=ph, | |||||
| pad_w=pw, | |||||
| stride_h=sh, | |||||
| stride_w=sw, | |||||
| dilate_h=dh, | |||||
| dilate_w=dw, | |||||
| format="NCHW", | |||||
| strategy=get_conv_execution_strategy(), | |||||
| mode=conv_mode, | |||||
| compute_mode=compute_mode, | |||||
| sparse=sparse_type, | |||||
| ) | |||||
| if bias is not None: | |||||
| res += bias | |||||
| return res | |||||
| @wrap_io_tensor | @wrap_io_tensor | ||||
| def max_pool2d( | def max_pool2d( | ||||
| inp: Tensor, | inp: Tensor, | ||||
| @@ -8,7 +8,7 @@ | |||||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| from .activation import LeakyReLU, PReLU, ReLU, Sigmoid, Softmax | from .activation import LeakyReLU, PReLU, ReLU, Sigmoid, Softmax | ||||
| from .batchnorm import BatchNorm1d, BatchNorm2d | from .batchnorm import BatchNorm1d, BatchNorm2d | ||||
| from .conv import Conv2d | |||||
| from .conv import Conv2d, ConvTranspose2d | |||||
| from .dropout import Dropout | from .dropout import Dropout | ||||
| from .embedding import Embedding | from .embedding import Embedding | ||||
| from .identity import Identity | from .identity import Identity | ||||
| @@ -14,7 +14,7 @@ import numpy as np | |||||
| import megengine._internal as mgb | import megengine._internal as mgb | ||||
| from ..core import Parameter | from ..core import Parameter | ||||
| from ..functional import conv2d | |||||
| from ..functional import conv2d, conv_transpose2d | |||||
| from ..utils.types import _pair, _pair_nonzero | from ..utils.types import _pair, _pair_nonzero | ||||
| from . import init | from . import init | ||||
| from .module import Module | from .module import Module | ||||
| @@ -31,7 +31,6 @@ class _ConvNd(Module): | |||||
| stride: Union[int, Tuple[int, int]], | stride: Union[int, Tuple[int, int]], | ||||
| padding: Union[int, Tuple[int, int]], | padding: Union[int, Tuple[int, int]], | ||||
| dilation: Union[int, Tuple[int, int]], | dilation: Union[int, Tuple[int, int]], | ||||
| output_padding: Union[int, Tuple[int, int]], | |||||
| groups: int, | groups: int, | ||||
| bias: bool = True, | bias: bool = True, | ||||
| ): | ): | ||||
| @@ -46,7 +45,6 @@ class _ConvNd(Module): | |||||
| self.stride = stride | self.stride = stride | ||||
| self.padding = padding | self.padding = padding | ||||
| self.dilation = dilation | self.dilation = dilation | ||||
| self.output_padding = output_padding | |||||
| self.groups = groups | self.groups = groups | ||||
| self.weight = Parameter(np.zeros(self._infer_weight_shape(), dtype=np.float32)) | self.weight = Parameter(np.zeros(self._infer_weight_shape(), dtype=np.float32)) | ||||
| @@ -154,7 +152,6 @@ class Conv2d(_ConvNd): | |||||
| stride, | stride, | ||||
| padding, | padding, | ||||
| dilation, | dilation, | ||||
| (0, 0), | |||||
| groups, | groups, | ||||
| bias, | bias, | ||||
| ) | ) | ||||
| @@ -197,3 +194,112 @@ class Conv2d(_ConvNd): | |||||
| self.conv_mode, | self.conv_mode, | ||||
| self.compute_mode, | self.compute_mode, | ||||
| ) | ) | ||||
| class ConvTranspose2d(_ConvNd): | |||||
| r"""Applies a 2D transposed convolution over an input tensor. | |||||
| This module is also known as a deconvolution or a fractionally-strided convolution. | |||||
| :class:`ConvTranspose2d` can ben seen as the gradient of :class:`Conv2d` operation | |||||
| with respect to its input. | |||||
| Convolution usually reduces the size of input, while transposed convolution works | |||||
| the other way, transforming a smaller input to a larger output while preserving the | |||||
| connectivity pattern. | |||||
| :param in_channels: number of input channels. | |||||
| :param out_channels: number of output channels. | |||||
| :param kernel_size: size of weight on spatial dimensions. If ``kernel_size`` is | |||||
| an :class:`int`, the actual kernel size would be | |||||
| ``(kernel_size, kernel_size)``. Default: 1 | |||||
| :param stride: stride of the 2D convolution operation. Default: 1 | |||||
| :param padding: size of the paddings added to the input on both sides of its | |||||
| spatial dimensions. Only zero-padding is supported. Default: 0 | |||||
| :param dilation: dilation of the 2D convolution operation. Default: 1 | |||||
| :param groups: number of groups to divide input and output channels into, | |||||
| so as to perform a "grouped convolution". When ``groups`` is not 1, | |||||
| ``in_channels`` and ``out_channels`` must be divisible by ``groups``, | |||||
| and there would be an extra dimension at the beginning of the weight's | |||||
| shape. Specifically, the shape of weight would be ``(groups, | |||||
| out_channel // groups, in_channels // groups, *kernel_size)``. Default: 1 | |||||
| :param bias: wether to add a bias onto the result of convolution. Default: | |||||
| True | |||||
| :param conv_mode: Supports `CROSS_CORRELATION` or `CONVOLUTION`. Default: | |||||
| `CROSS_CORRELATION`. | |||||
| :param compute_mode: When set to `DEFAULT`, no special requirements will be | |||||
| placed on the precision of intermediate results. When set to `FLOAT32`, | |||||
| float32 would be used for accumulator and intermediate result, but only | |||||
| effective when input and output are of float16 dtype. | |||||
| """ | |||||
| _conv_mode_type = mgb.opr_param_defs.Convolution.Mode | |||||
| _compute_mode_type = mgb.opr_param_defs.Convolution.ComputeMode | |||||
| def __init__( | |||||
| self, | |||||
| in_channels: int, | |||||
| out_channels: int, | |||||
| kernel_size: Union[int, Tuple[int, int]], | |||||
| stride: Union[int, Tuple[int, int]] = 1, | |||||
| padding: Union[int, Tuple[int, int]] = 0, | |||||
| dilation: Union[int, Tuple[int, int]] = 1, | |||||
| groups: int = 1, | |||||
| bias: bool = True, | |||||
| conv_mode: str = "CROSS_CORRELATION", | |||||
| compute_mode: str = "DEFAULT", | |||||
| ): | |||||
| kernel_size = _pair_nonzero(kernel_size) | |||||
| stride = _pair_nonzero(stride) | |||||
| padding = _pair(padding) | |||||
| dilation = _pair_nonzero(dilation) | |||||
| self.conv_mode = self._conv_mode_type.convert(conv_mode) | |||||
| self.compute_mode = self._compute_mode_type.convert(compute_mode) | |||||
| super().__init__( | |||||
| in_channels, | |||||
| out_channels, | |||||
| kernel_size, | |||||
| stride, | |||||
| padding, | |||||
| dilation, | |||||
| groups, | |||||
| bias, | |||||
| ) | |||||
| def _get_fanin(self): | |||||
| kh, kw = self.kernel_size | |||||
| oc = self.out_channels | |||||
| return kh * kw * oc | |||||
| def _infer_weight_shape(self): | |||||
| group = self.groups | |||||
| ichl = self.in_channels | |||||
| ochl = self.out_channels | |||||
| kh, kw = self.kernel_size | |||||
| if group == 1: | |||||
| # Assume format is NCHW | |||||
| return (ichl, ochl, kh, kw) | |||||
| assert ( | |||||
| ichl % group == 0 and ochl % group == 0 | |||||
| ), "invalid config: input_channels={} output_channels={} group={}".format( | |||||
| ichl, ochl, group | |||||
| ) | |||||
| # Assume format is NCHW | |||||
| return (group, ichl // group, ochl // group, kh, kw) | |||||
| def _infer_bias_shape(self): | |||||
| # Assume format is NCHW | |||||
| return (1, self.out_channels, 1, 1) | |||||
| def forward(self, inp): | |||||
| return conv_transpose2d( | |||||
| inp, | |||||
| self.weight, | |||||
| self.bias, | |||||
| self.stride, | |||||
| self.padding, | |||||
| self.dilation, | |||||
| self.groups, | |||||
| self.conv_mode, | |||||
| self.compute_mode, | |||||
| ) | |||||
| @@ -0,0 +1,64 @@ | |||||
| # -*- coding: utf-8 -*- | |||||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| # | |||||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, | |||||
| # software distributed under the License is distributed on an | |||||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| import itertools | |||||
| import numpy as np | |||||
| import pytest | |||||
| import torch | |||||
| import megengine as mge | |||||
| from megengine import Parameter, tensor | |||||
| from megengine.module import Conv2d, ConvTranspose2d | |||||
| from megengine.test import assertTensorClose | |||||
| def test_conv_transpose2d(): | |||||
| SH, SW = 3, 1 | |||||
| PH, PW = 2, 0 | |||||
| N, IC, IH, IW = 4, 5, 8, 6 | |||||
| KH, KW = 3, 4 | |||||
| OC = 3 | |||||
| BIAS = True | |||||
| def getsize(inp, kern, stride): | |||||
| return (inp - 1) * stride + kern | |||||
| OH = getsize(IH, KH, SH) | |||||
| OW = getsize(IW, KW, SW) | |||||
| inp = np.random.normal(size=(N, IC, IH, IW)).astype(np.float32) | |||||
| out = np.zeros((N, OC, OH, OW), dtype=np.float32) | |||||
| weight = np.random.normal(size=(IC, OC, KH, KW)).astype(np.float32) | |||||
| bias = np.random.normal(size=(1, OC, 1, 1)).astype(np.float32) | |||||
| for n, ic, ih, iw in itertools.product(*map(range, [N, IC, IH, IW])): | |||||
| oh, ow = ih * SH, iw * SW | |||||
| out[n, :, oh : oh + KH, ow : ow + KW] += inp[n, ic, ih, iw] * weight[ic] | |||||
| out = out[:, :, PH : OH - PH, PW : OW - PW] | |||||
| if BIAS: | |||||
| out += bias | |||||
| conv_transpose2d = ConvTranspose2d(IC, OC, (KH, KW), (SH, SW), (PH, PW), bias=BIAS) | |||||
| conv_transpose2d.weight = Parameter(weight, dtype=np.float32) | |||||
| if BIAS: | |||||
| conv_transpose2d.bias = Parameter(bias, dtype=np.float32) | |||||
| y = conv_transpose2d(tensor(inp)) | |||||
| assertTensorClose(out, y.numpy(), max_err=2e-6) | |||||
| torch_conv_transpose2d = torch.nn.ConvTranspose2d( | |||||
| IC, OC, (KH, KW), stride=(SH, SW), padding=(PH, PW), bias=BIAS | |||||
| ) | |||||
| torch_conv_transpose2d.weight = torch.nn.parameter.Parameter(torch.Tensor(weight)) | |||||
| if BIAS: | |||||
| torch_conv_transpose2d.bias = torch.nn.parameter.Parameter( | |||||
| torch.Tensor(bias).reshape(OC) | |||||
| ) | |||||
| torch_y = torch_conv_transpose2d(torch.Tensor(inp)) | |||||
| assertTensorClose(torch_y.detach().numpy(), y.numpy(), max_err=2e-6) | |||||