GitOrigin-RevId: 61097b8713
tags/v1.5.0
| @@ -48,6 +48,7 @@ __all__ = [ | |||
| "conv2d", | |||
| "conv3d", | |||
| "conv_transpose2d", | |||
| "conv_transpose3d", | |||
| "deformable_conv2d", | |||
| "deformable_psroi_pooling", | |||
| "dropout", | |||
| @@ -488,6 +489,54 @@ def local_conv2d( | |||
| return output | |||
| def conv_transpose3d( | |||
| inp: Tensor, | |||
| weight: Tensor, | |||
| bias: Optional[Tensor] = None, | |||
| stride: Union[int, Tuple[int, int, int]] = 1, | |||
| padding: Union[int, Tuple[int, int, int]] = 0, | |||
| dilation: Union[int, Tuple[int, int, int]] = 1, | |||
| ) -> Tensor: | |||
| """ | |||
| 3D transposed convolution operation. Only support the case that group = 1 | |||
| and conv_mode = "cross_correlation". | |||
| Refer to :class:`~.ConvTranspose3d` for more information. | |||
| :param inp: feature map of the convolution operation. | |||
| :param weight: convolution kernel. | |||
| :param bias: bias added to the result of convolution (if given). | |||
| :param stride: stride of the 3D convolution operation. Default: 1 | |||
| :param padding: size of the paddings added to the input on all sides of its | |||
| spatial dimensions. Only zero-padding is supported. Default: 0 | |||
| :param dilation: dilation of the 3D convolution operation. Default: 1 | |||
| :return: output tensor. | |||
| """ | |||
| D, H, W = 0, 1, 2 | |||
| pad = _triple(padding) | |||
| stride = _triple_nonzero(stride) | |||
| dilate = _triple_nonzero(dilation) | |||
| op = builtin.Convolution3DBackwardData( | |||
| pad_d=pad[D], | |||
| pad_h=pad[H], | |||
| pad_w=pad[W], | |||
| stride_d=stride[D], | |||
| stride_h=stride[H], | |||
| stride_w=stride[W], | |||
| dilate_d=dilate[D], | |||
| dilate_h=dilate[H], | |||
| dilate_w=dilate[W], | |||
| strategy=get_execution_strategy(), | |||
| ) | |||
| weight, inp = utils.convert_inputs(weight, inp) | |||
| (output,) = apply(op, weight, inp) | |||
| if bias is not None: | |||
| output += bias | |||
| return output | |||
| def max_pool2d( | |||
| inp: Tensor, | |||
| kernel_size: Union[int, Tuple[int, int]], | |||
| @@ -18,6 +18,7 @@ from .conv import ( | |||
| Conv3d, | |||
| ConvRelu2d, | |||
| ConvTranspose2d, | |||
| ConvTranspose3d, | |||
| DeformableConv2d, | |||
| LocalConv2d, | |||
| ) | |||
| @@ -15,6 +15,7 @@ from ..functional import ( | |||
| conv2d, | |||
| conv3d, | |||
| conv_transpose2d, | |||
| conv_transpose3d, | |||
| deformable_conv2d, | |||
| local_conv2d, | |||
| relu, | |||
| @@ -842,3 +843,75 @@ class DeformableConv2d(_ConvNd): | |||
| def forward(self, inp, offset, mask): | |||
| return self.calc_conv(inp, self.weight, offset, mask, self.bias) | |||
| class ConvTranspose3d(_ConvNd): | |||
| r""" | |||
| Applies a 3D transposed convolution over an input tensor. | |||
| Only support the case that group = 1 and conv_mode = "cross_correlation". | |||
| :class:`ConvTranspose3d` can be seen as the gradient of :class:`Conv3d` operation | |||
| with respect to its input. | |||
| Convolution3D usually reduces the size of input, while transposed convolution3d | |||
| works the opposite way, transforming a smaller input to a larger output while | |||
| preserving the connectivity pattern. | |||
| :param in_channels: number of input channels. | |||
| :param out_channels: number of output channels. | |||
| :param kernel_size: size of weight on spatial dimensions. If ``kernel_size`` is | |||
| an :class:`int`, the actual kernel size would be | |||
| ``(kernel_size, kernel_size, kernel_size)``. Default: 1 | |||
| :param stride: stride of the 3D convolution operation. Default: 1 | |||
| :param padding: size of the paddings added to the input on all sides of its | |||
| spatial dimensions. Only zero-padding is supported. Default: 0 | |||
| :param dilation: dilation of the 3D convolution operation. Default: 1 | |||
| :param bias: wether to add a bias onto the result of convolution. Default: | |||
| True | |||
| """ | |||
| def __init__( | |||
| self, | |||
| in_channels: int, | |||
| out_channels: int, | |||
| kernel_size: Union[int, Tuple[int, int, int]], | |||
| stride: Union[int, Tuple[int, int, int]] = 1, | |||
| padding: Union[int, Tuple[int, int, int]] = 0, | |||
| dilation: Union[int, Tuple[int, int, int]] = 1, | |||
| bias: bool = True, | |||
| ): | |||
| kernel_size = _triple_nonzero(kernel_size) | |||
| stride = _triple_nonzero(stride) | |||
| padding = _triple(padding) | |||
| dilation = _triple_nonzero(dilation) | |||
| super().__init__( | |||
| in_channels=in_channels, | |||
| out_channels=out_channels, | |||
| kernel_size=kernel_size, | |||
| stride=stride, | |||
| padding=padding, | |||
| dilation=dilation, | |||
| groups=1, | |||
| bias=bias, | |||
| ) | |||
| def _get_fanin(self): | |||
| kt, kh, kw = self.kernel_size | |||
| ic = self.in_channels | |||
| return kt * kh * kw * ic | |||
| def _infer_weight_shape(self): | |||
| ichl = self.in_channels | |||
| ochl = self.out_channels | |||
| kt, kh, kw = self.kernel_size | |||
| return (ochl, ichl, kt, kh, kw) | |||
| def _infer_bias_shape(self): | |||
| # Assume format is NCTHW | |||
| return (1, self.out_channels, 1, 1, 1) | |||
| def forward(self, inp): | |||
| return conv_transpose3d( | |||
| inp, self.weight, self.bias, self.stride, self.padding, self.dilation, | |||
| ) | |||
| @@ -11,7 +11,7 @@ import itertools | |||
| import numpy as np | |||
| from megengine import Parameter, tensor | |||
| from megengine.module import ConvTranspose2d, LocalConv2d | |||
| from megengine.module import ConvTranspose2d, ConvTranspose3d, LocalConv2d | |||
| def test_conv_transpose2d(): | |||
| @@ -120,3 +120,64 @@ def test_local_conv2d(): | |||
| test_func(10, 4, 4, 5, 5, 3, 1, 1, 1, 1) | |||
| test_func(10, 32, 32, 8, 8, 3, 1, 1, 1, 2) | |||
| test_func(10, 32, 32, 8, 8, 3, 1, 1, 1, 4) | |||
| def test_conv_transpose3d(): | |||
| def getsize(inp, kernel, stride, dilate): | |||
| return (inp - 1) * stride + kernel * dilate - dilate + 1 | |||
| def test_func( | |||
| N, | |||
| IC, | |||
| ID, | |||
| IH, | |||
| IW, | |||
| OC, | |||
| KD, | |||
| KH, | |||
| KW, | |||
| SD, | |||
| SH, | |||
| SW, | |||
| PD, | |||
| PH, | |||
| PW, | |||
| DD, | |||
| DH, | |||
| DW, | |||
| bias=True, | |||
| ): | |||
| conv_transpose3d = ConvTranspose3d( | |||
| in_channels=IC, | |||
| out_channels=OC, | |||
| kernel_size=(KD, KH, KW), | |||
| stride=(SD, SH, SW), | |||
| padding=(PD, PH, PW), | |||
| dilation=(DD, DH, DW), | |||
| bias=bias, | |||
| ) | |||
| OD = getsize(ID, KD, SD, DD) | |||
| OH = getsize(IH, KH, SH, DH) | |||
| OW = getsize(IW, KW, SW, DW) | |||
| inp = np.random.normal(size=(N, IC, ID, IH, IW)) | |||
| weight = np.random.normal(size=(IC, OC, KD, KH, KW)) | |||
| out_np = np.zeros((N, OC, OD, OH, OW), dtype=np.float32) | |||
| for n, ic, idepth, ih, iw in itertools.product( | |||
| *map(range, [N, IC, ID, IH, IW]) | |||
| ): | |||
| od, oh, ow = idepth * SD, ih * SH, iw * SW | |||
| out_np[n, :, od : od + KD, oh : oh + KH, ow : ow + KW] += ( | |||
| inp[n, ic, idepth, ih, iw] * weight[ic] | |||
| ) | |||
| out_np = out_np[:, :, PD : OD - PD, PH : OH - PH, PW : OW - PW] | |||
| conv_transpose3d.weight = Parameter(weight) | |||
| out_meg = conv_transpose3d.forward(tensor(inp)) | |||
| np.testing.assert_almost_equal(out_meg.numpy(), out_np, 1e-5) | |||
| test_func(4, 3, 8, 16, 16, 8, 3, 3, 3, 1, 1, 1, 1, 1, 1, 1, 1, 1) | |||
| test_func(4, 8, 16, 32, 32, 16, 1, 3, 1, 2, 1, 2, 0, 1, 0, 1, 1, 1) | |||
| @@ -75,5 +75,20 @@ OP_TRAIT_REG(Convolution3D, Convolution3D, opr::Convolution3D) | |||
| .fallback(); | |||
| }} // convolution3d | |||
| namespace { namespace convolution3d_backward_data { | |||
| auto apply_on_var_node( | |||
| const OpDef& def, | |||
| const VarNodeArray& inputs) { | |||
| auto&& conv = static_cast<const Convolution3DBackwardData&>(def); | |||
| OperatorNodeConfig config{conv.make_name()}; | |||
| mgb_assert(inputs.size() == 2); | |||
| return opr::Convolution3DBackwardData::make(inputs[0], inputs[1], conv.param(), conv.policy(), config); | |||
| } | |||
| OP_TRAIT_REG(Convolution3DBackwardData, Convolution3DBackwardData) | |||
| .apply_on_var_node(apply_on_var_node) | |||
| .fallback(); | |||
| }} // convolution3d_backward_data | |||
| } | |||
| } | |||
| @@ -53,6 +53,8 @@ def ConvolutionBackwardData: MgbHashableOp<"ConvolutionBackwardData", [Convoluti | |||
| def Convolution3D: MgbHashableOp<"Convolution3D", [Convolution3DParam, ExecutionPolicyParamBase<"policy">]>; | |||
| def Convolution3DBackwardData: MgbHashableOp<"Convolution3DBackwardData", [Convolution3DParam, ExecutionPolicyParamBase<"policy">]>; | |||
| def DeformableConv : MgbHashableOp<"DeformableConv", [ConvolutionParam, ExecutionPolicyParamBase<"policy">]>; | |||
| def GroupLocal: MgbHashableOp<"GroupLocal", [ConvolutionParam]>; | |||