GitOrigin-RevId: 61097b8713
tags/v1.5.0
| @@ -48,6 +48,7 @@ __all__ = [ | |||||
| "conv2d", | "conv2d", | ||||
| "conv3d", | "conv3d", | ||||
| "conv_transpose2d", | "conv_transpose2d", | ||||
| "conv_transpose3d", | |||||
| "deformable_conv2d", | "deformable_conv2d", | ||||
| "deformable_psroi_pooling", | "deformable_psroi_pooling", | ||||
| "dropout", | "dropout", | ||||
| @@ -488,6 +489,54 @@ def local_conv2d( | |||||
| return output | return output | ||||
| def conv_transpose3d( | |||||
| inp: Tensor, | |||||
| weight: Tensor, | |||||
| bias: Optional[Tensor] = None, | |||||
| stride: Union[int, Tuple[int, int, int]] = 1, | |||||
| padding: Union[int, Tuple[int, int, int]] = 0, | |||||
| dilation: Union[int, Tuple[int, int, int]] = 1, | |||||
| ) -> Tensor: | |||||
| """ | |||||
| 3D transposed convolution operation. Only support the case that group = 1 | |||||
| and conv_mode = "cross_correlation". | |||||
| Refer to :class:`~.ConvTranspose3d` for more information. | |||||
| :param inp: feature map of the convolution operation. | |||||
| :param weight: convolution kernel. | |||||
| :param bias: bias added to the result of convolution (if given). | |||||
| :param stride: stride of the 3D convolution operation. Default: 1 | |||||
| :param padding: size of the paddings added to the input on all sides of its | |||||
| spatial dimensions. Only zero-padding is supported. Default: 0 | |||||
| :param dilation: dilation of the 3D convolution operation. Default: 1 | |||||
| :return: output tensor. | |||||
| """ | |||||
| D, H, W = 0, 1, 2 | |||||
| pad = _triple(padding) | |||||
| stride = _triple_nonzero(stride) | |||||
| dilate = _triple_nonzero(dilation) | |||||
| op = builtin.Convolution3DBackwardData( | |||||
| pad_d=pad[D], | |||||
| pad_h=pad[H], | |||||
| pad_w=pad[W], | |||||
| stride_d=stride[D], | |||||
| stride_h=stride[H], | |||||
| stride_w=stride[W], | |||||
| dilate_d=dilate[D], | |||||
| dilate_h=dilate[H], | |||||
| dilate_w=dilate[W], | |||||
| strategy=get_execution_strategy(), | |||||
| ) | |||||
| weight, inp = utils.convert_inputs(weight, inp) | |||||
| (output,) = apply(op, weight, inp) | |||||
| if bias is not None: | |||||
| output += bias | |||||
| return output | |||||
| def max_pool2d( | def max_pool2d( | ||||
| inp: Tensor, | inp: Tensor, | ||||
| kernel_size: Union[int, Tuple[int, int]], | kernel_size: Union[int, Tuple[int, int]], | ||||
| @@ -18,6 +18,7 @@ from .conv import ( | |||||
| Conv3d, | Conv3d, | ||||
| ConvRelu2d, | ConvRelu2d, | ||||
| ConvTranspose2d, | ConvTranspose2d, | ||||
| ConvTranspose3d, | |||||
| DeformableConv2d, | DeformableConv2d, | ||||
| LocalConv2d, | LocalConv2d, | ||||
| ) | ) | ||||
| @@ -15,6 +15,7 @@ from ..functional import ( | |||||
| conv2d, | conv2d, | ||||
| conv3d, | conv3d, | ||||
| conv_transpose2d, | conv_transpose2d, | ||||
| conv_transpose3d, | |||||
| deformable_conv2d, | deformable_conv2d, | ||||
| local_conv2d, | local_conv2d, | ||||
| relu, | relu, | ||||
| @@ -842,3 +843,75 @@ class DeformableConv2d(_ConvNd): | |||||
| def forward(self, inp, offset, mask): | def forward(self, inp, offset, mask): | ||||
| return self.calc_conv(inp, self.weight, offset, mask, self.bias) | return self.calc_conv(inp, self.weight, offset, mask, self.bias) | ||||
| class ConvTranspose3d(_ConvNd): | |||||
| r""" | |||||
| Applies a 3D transposed convolution over an input tensor. | |||||
| Only support the case that group = 1 and conv_mode = "cross_correlation". | |||||
| :class:`ConvTranspose3d` can be seen as the gradient of :class:`Conv3d` operation | |||||
| with respect to its input. | |||||
| Convolution3D usually reduces the size of input, while transposed convolution3d | |||||
| works the opposite way, transforming a smaller input to a larger output while | |||||
| preserving the connectivity pattern. | |||||
| :param in_channels: number of input channels. | |||||
| :param out_channels: number of output channels. | |||||
| :param kernel_size: size of weight on spatial dimensions. If ``kernel_size`` is | |||||
| an :class:`int`, the actual kernel size would be | |||||
| ``(kernel_size, kernel_size, kernel_size)``. Default: 1 | |||||
| :param stride: stride of the 3D convolution operation. Default: 1 | |||||
| :param padding: size of the paddings added to the input on all sides of its | |||||
| spatial dimensions. Only zero-padding is supported. Default: 0 | |||||
| :param dilation: dilation of the 3D convolution operation. Default: 1 | |||||
| :param bias: wether to add a bias onto the result of convolution. Default: | |||||
| True | |||||
| """ | |||||
| def __init__( | |||||
| self, | |||||
| in_channels: int, | |||||
| out_channels: int, | |||||
| kernel_size: Union[int, Tuple[int, int, int]], | |||||
| stride: Union[int, Tuple[int, int, int]] = 1, | |||||
| padding: Union[int, Tuple[int, int, int]] = 0, | |||||
| dilation: Union[int, Tuple[int, int, int]] = 1, | |||||
| bias: bool = True, | |||||
| ): | |||||
| kernel_size = _triple_nonzero(kernel_size) | |||||
| stride = _triple_nonzero(stride) | |||||
| padding = _triple(padding) | |||||
| dilation = _triple_nonzero(dilation) | |||||
| super().__init__( | |||||
| in_channels=in_channels, | |||||
| out_channels=out_channels, | |||||
| kernel_size=kernel_size, | |||||
| stride=stride, | |||||
| padding=padding, | |||||
| dilation=dilation, | |||||
| groups=1, | |||||
| bias=bias, | |||||
| ) | |||||
| def _get_fanin(self): | |||||
| kt, kh, kw = self.kernel_size | |||||
| ic = self.in_channels | |||||
| return kt * kh * kw * ic | |||||
| def _infer_weight_shape(self): | |||||
| ichl = self.in_channels | |||||
| ochl = self.out_channels | |||||
| kt, kh, kw = self.kernel_size | |||||
| return (ochl, ichl, kt, kh, kw) | |||||
| def _infer_bias_shape(self): | |||||
| # Assume format is NCTHW | |||||
| return (1, self.out_channels, 1, 1, 1) | |||||
| def forward(self, inp): | |||||
| return conv_transpose3d( | |||||
| inp, self.weight, self.bias, self.stride, self.padding, self.dilation, | |||||
| ) | |||||
| @@ -11,7 +11,7 @@ import itertools | |||||
| import numpy as np | import numpy as np | ||||
| from megengine import Parameter, tensor | from megengine import Parameter, tensor | ||||
| from megengine.module import ConvTranspose2d, LocalConv2d | |||||
| from megengine.module import ConvTranspose2d, ConvTranspose3d, LocalConv2d | |||||
| def test_conv_transpose2d(): | def test_conv_transpose2d(): | ||||
| @@ -120,3 +120,64 @@ def test_local_conv2d(): | |||||
| test_func(10, 4, 4, 5, 5, 3, 1, 1, 1, 1) | test_func(10, 4, 4, 5, 5, 3, 1, 1, 1, 1) | ||||
| test_func(10, 32, 32, 8, 8, 3, 1, 1, 1, 2) | test_func(10, 32, 32, 8, 8, 3, 1, 1, 1, 2) | ||||
| test_func(10, 32, 32, 8, 8, 3, 1, 1, 1, 4) | test_func(10, 32, 32, 8, 8, 3, 1, 1, 1, 4) | ||||
| def test_conv_transpose3d(): | |||||
| def getsize(inp, kernel, stride, dilate): | |||||
| return (inp - 1) * stride + kernel * dilate - dilate + 1 | |||||
| def test_func( | |||||
| N, | |||||
| IC, | |||||
| ID, | |||||
| IH, | |||||
| IW, | |||||
| OC, | |||||
| KD, | |||||
| KH, | |||||
| KW, | |||||
| SD, | |||||
| SH, | |||||
| SW, | |||||
| PD, | |||||
| PH, | |||||
| PW, | |||||
| DD, | |||||
| DH, | |||||
| DW, | |||||
| bias=True, | |||||
| ): | |||||
| conv_transpose3d = ConvTranspose3d( | |||||
| in_channels=IC, | |||||
| out_channels=OC, | |||||
| kernel_size=(KD, KH, KW), | |||||
| stride=(SD, SH, SW), | |||||
| padding=(PD, PH, PW), | |||||
| dilation=(DD, DH, DW), | |||||
| bias=bias, | |||||
| ) | |||||
| OD = getsize(ID, KD, SD, DD) | |||||
| OH = getsize(IH, KH, SH, DH) | |||||
| OW = getsize(IW, KW, SW, DW) | |||||
| inp = np.random.normal(size=(N, IC, ID, IH, IW)) | |||||
| weight = np.random.normal(size=(IC, OC, KD, KH, KW)) | |||||
| out_np = np.zeros((N, OC, OD, OH, OW), dtype=np.float32) | |||||
| for n, ic, idepth, ih, iw in itertools.product( | |||||
| *map(range, [N, IC, ID, IH, IW]) | |||||
| ): | |||||
| od, oh, ow = idepth * SD, ih * SH, iw * SW | |||||
| out_np[n, :, od : od + KD, oh : oh + KH, ow : ow + KW] += ( | |||||
| inp[n, ic, idepth, ih, iw] * weight[ic] | |||||
| ) | |||||
| out_np = out_np[:, :, PD : OD - PD, PH : OH - PH, PW : OW - PW] | |||||
| conv_transpose3d.weight = Parameter(weight) | |||||
| out_meg = conv_transpose3d.forward(tensor(inp)) | |||||
| np.testing.assert_almost_equal(out_meg.numpy(), out_np, 1e-5) | |||||
| test_func(4, 3, 8, 16, 16, 8, 3, 3, 3, 1, 1, 1, 1, 1, 1, 1, 1, 1) | |||||
| test_func(4, 8, 16, 32, 32, 16, 1, 3, 1, 2, 1, 2, 0, 1, 0, 1, 1, 1) | |||||
| @@ -75,5 +75,20 @@ OP_TRAIT_REG(Convolution3D, Convolution3D, opr::Convolution3D) | |||||
| .fallback(); | .fallback(); | ||||
| }} // convolution3d | }} // convolution3d | ||||
| namespace { namespace convolution3d_backward_data { | |||||
| auto apply_on_var_node( | |||||
| const OpDef& def, | |||||
| const VarNodeArray& inputs) { | |||||
| auto&& conv = static_cast<const Convolution3DBackwardData&>(def); | |||||
| OperatorNodeConfig config{conv.make_name()}; | |||||
| mgb_assert(inputs.size() == 2); | |||||
| return opr::Convolution3DBackwardData::make(inputs[0], inputs[1], conv.param(), conv.policy(), config); | |||||
| } | |||||
| OP_TRAIT_REG(Convolution3DBackwardData, Convolution3DBackwardData) | |||||
| .apply_on_var_node(apply_on_var_node) | |||||
| .fallback(); | |||||
| }} // convolution3d_backward_data | |||||
| } | } | ||||
| } | } | ||||
| @@ -53,6 +53,8 @@ def ConvolutionBackwardData: MgbHashableOp<"ConvolutionBackwardData", [Convoluti | |||||
| def Convolution3D: MgbHashableOp<"Convolution3D", [Convolution3DParam, ExecutionPolicyParamBase<"policy">]>; | def Convolution3D: MgbHashableOp<"Convolution3D", [Convolution3DParam, ExecutionPolicyParamBase<"policy">]>; | ||||
| def Convolution3DBackwardData: MgbHashableOp<"Convolution3DBackwardData", [Convolution3DParam, ExecutionPolicyParamBase<"policy">]>; | |||||
| def DeformableConv : MgbHashableOp<"DeformableConv", [ConvolutionParam, ExecutionPolicyParamBase<"policy">]>; | def DeformableConv : MgbHashableOp<"DeformableConv", [ConvolutionParam, ExecutionPolicyParamBase<"policy">]>; | ||||
| def GroupLocal: MgbHashableOp<"GroupLocal", [ConvolutionParam]>; | def GroupLocal: MgbHashableOp<"GroupLocal", [ConvolutionParam]>; | ||||