GitOrigin-RevId: b75b792fb4
tags/v1.7.0
| @@ -372,6 +372,7 @@ def conv_transpose2d( | |||||
| Args: | Args: | ||||
| inp: feature map of the convolution operation. | inp: feature map of the convolution operation. | ||||
| weight: convolution kernel. | weight: convolution kernel. | ||||
| weight usually has shape ``(in_channels, out_channels, height, width)``. | |||||
| bias: bias added to the result of convolution (if given). | bias: bias added to the result of convolution (if given). | ||||
| stride: stride of the 2D convolution operation. Default: 1 | stride: stride of the 2D convolution operation. Default: 1 | ||||
| padding: size of the paddings added to the input on both sides of its | padding: size of the paddings added to the input on both sides of its | ||||
| @@ -405,14 +406,12 @@ def conv_transpose2d( | |||||
| if weight.dtype != dtype: | if weight.dtype != dtype: | ||||
| weight = weight.astype(dtype) | weight = weight.astype(dtype) | ||||
| if groups != 1: | |||||
| raise NotImplementedError("group transposed conv2d is not supported yet.") | |||||
| stride_h, stride_w = expand_hw(stride) | stride_h, stride_w = expand_hw(stride) | ||||
| pad_h, pad_w = expand_hw(padding) | pad_h, pad_w = expand_hw(padding) | ||||
| dilate_h, dilate_w = expand_hw(dilation) | dilate_h, dilate_w = expand_hw(dilation) | ||||
| compute_mode = _config._get_actual_op_param(compute_mode, _config.__compute_mode) | |||||
| compute_mode = _config._get_actual_op_param(compute_mode, _config.__compute_mode) | |||||
| sparse_type = "dense" if groups == 1 else "group" | |||||
| op = builtin.ConvolutionBackwardData( | op = builtin.ConvolutionBackwardData( | ||||
| stride_h=stride_h, | stride_h=stride_h, | ||||
| stride_w=stride_w, | stride_w=stride_w, | ||||
| @@ -422,6 +421,7 @@ def conv_transpose2d( | |||||
| dilate_w=dilate_w, | dilate_w=dilate_w, | ||||
| strategy=get_execution_strategy(), | strategy=get_execution_strategy(), | ||||
| compute_mode=compute_mode, | compute_mode=compute_mode, | ||||
| sparse=sparse_type, | |||||
| ) | ) | ||||
| (output,) = apply(op, weight, inp) | (output,) = apply(op, weight, inp) | ||||
| if bias is not None: | if bias is not None: | ||||
| @@ -447,6 +447,7 @@ def deformable_conv2d( | |||||
| Args: | Args: | ||||
| inp: input feature map. | inp: input feature map. | ||||
| weight: convolution kernel. | weight: convolution kernel. | ||||
| weight usually has shape ``(out_channels, in_channels, height, width)``. | |||||
| offset: input offset to kernel, channel of this tensor should match the deformable settings. | offset: input offset to kernel, channel of this tensor should match the deformable settings. | ||||
| mask: input mask to kernel, channel of this tensor should match the deformable settings. | mask: input mask to kernel, channel of this tensor should match the deformable settings. | ||||
| bias: bias added to the result of convolution (if given). | bias: bias added to the result of convolution (if given). | ||||
| @@ -551,6 +552,7 @@ def conv_transpose3d( | |||||
| stride: Union[int, Tuple[int, int, int]] = 1, | stride: Union[int, Tuple[int, int, int]] = 1, | ||||
| padding: Union[int, Tuple[int, int, int]] = 0, | padding: Union[int, Tuple[int, int, int]] = 0, | ||||
| dilation: Union[int, Tuple[int, int, int]] = 1, | dilation: Union[int, Tuple[int, int, int]] = 1, | ||||
| groups: int = 1, | |||||
| ) -> Tensor: | ) -> Tensor: | ||||
| r"""3D transposed convolution operation. Only support the case that groups = 1 | r"""3D transposed convolution operation. Only support the case that groups = 1 | ||||
| and conv_mode = "cross_correlation". | and conv_mode = "cross_correlation". | ||||
| @@ -581,6 +583,7 @@ def conv_transpose3d( | |||||
| if weight.dtype != dtype: | if weight.dtype != dtype: | ||||
| weight = weight.astype(dtype) | weight = weight.astype(dtype) | ||||
| sparse_type = "dense" if groups == 1 else "group" | |||||
| op = builtin.Convolution3DBackwardData( | op = builtin.Convolution3DBackwardData( | ||||
| pad_d=pad[D], | pad_d=pad[D], | ||||
| pad_h=pad[H], | pad_h=pad[H], | ||||
| @@ -592,6 +595,7 @@ def conv_transpose3d( | |||||
| dilate_h=dilate[H], | dilate_h=dilate[H], | ||||
| dilate_w=dilate[W], | dilate_w=dilate[W], | ||||
| strategy=get_execution_strategy(), | strategy=get_execution_strategy(), | ||||
| sparse=sparse_type, | |||||
| ) | ) | ||||
| (output,) = apply(op, weight, inp) | (output,) = apply(op, weight, inp) | ||||
| if bias is not None: | if bias is not None: | ||||
| @@ -891,6 +891,7 @@ class ConvTranspose3d(_ConvNd): | |||||
| padding: Union[int, Tuple[int, int, int]] = 0, | padding: Union[int, Tuple[int, int, int]] = 0, | ||||
| dilation: Union[int, Tuple[int, int, int]] = 1, | dilation: Union[int, Tuple[int, int, int]] = 1, | ||||
| bias: bool = True, | bias: bool = True, | ||||
| groups: int = 1, | |||||
| ): | ): | ||||
| kernel_size = _triple_nonzero(kernel_size) | kernel_size = _triple_nonzero(kernel_size) | ||||
| stride = _triple_nonzero(stride) | stride = _triple_nonzero(stride) | ||||
| @@ -903,7 +904,7 @@ class ConvTranspose3d(_ConvNd): | |||||
| stride=stride, | stride=stride, | ||||
| padding=padding, | padding=padding, | ||||
| dilation=dilation, | dilation=dilation, | ||||
| groups=1, | |||||
| groups=groups, | |||||
| bias=bias, | bias=bias, | ||||
| ) | ) | ||||
| @@ -913,10 +914,21 @@ class ConvTranspose3d(_ConvNd): | |||||
| return kt * kh * kw * ic | return kt * kh * kw * ic | ||||
| def _infer_weight_shape(self): | def _infer_weight_shape(self): | ||||
| group = self.groups | |||||
| ichl = self.in_channels | ichl = self.in_channels | ||||
| ochl = self.out_channels | ochl = self.out_channels | ||||
| kt, kh, kw = self.kernel_size | kt, kh, kw = self.kernel_size | ||||
| return (ichl, ochl, kt, kh, kw) | |||||
| if group == 1: | |||||
| # Assume format is NCHW | |||||
| return (ichl, ochl, kt, kh, kw) | |||||
| assert ( | |||||
| ichl % group == 0 and ochl % group == 0 | |||||
| ), "invalid config: in_channels={} out_channels={} group={}".format( | |||||
| ichl, ochl, group | |||||
| ) | |||||
| # Assume format is NCHW | |||||
| return (group, ichl // group, ochl // group, kt, kh, kw) | |||||
| def _infer_bias_shape(self): | def _infer_bias_shape(self): | ||||
| # Assume format is NCTHW | # Assume format is NCTHW | ||||
| @@ -1290,3 +1290,62 @@ def test_set_warp_perspective_config(): | |||||
| expected = F.vision.warp_perspective(inp, M, (2, 2), format="NHWC") | expected = F.vision.warp_perspective(inp, M, (2, 2), format="NHWC") | ||||
| np.testing.assert_allclose(config_out.numpy(), expected.numpy()) | np.testing.assert_allclose(config_out.numpy(), expected.numpy()) | ||||
| np.testing.assert_allclose(context_out.numpy(), expected.numpy()) | np.testing.assert_allclose(context_out.numpy(), expected.numpy()) | ||||
| @pytest.mark.parametrize("stride", [(1, 1)]) | |||||
| @pytest.mark.parametrize("padding", [(1, 1)]) | |||||
| @pytest.mark.parametrize("dilation", [(1, 1)]) | |||||
| @pytest.mark.parametrize("ksize", [(3, 3)]) | |||||
| @pytest.mark.parametrize("groups", [1, 2]) | |||||
| def test_local_conv2d(stride, padding, dilation, ksize, groups): | |||||
| batch_size, in_channels, out_channels = 2, 4, 8 | |||||
| input_height, input_width = 10, 10 | |||||
| output_height = (input_height + padding[0] * 2 - ksize[0]) // stride[0] + 1 | |||||
| output_width = (input_width + padding[1] * 2 - ksize[1]) // stride[1] + 1 | |||||
| def local_conv2d_np(data, weight, stride, padding, dialtion): | |||||
| # naive calculation use numpy | |||||
| # only test output_height == input_height, output_width == input_width | |||||
| data = np.pad(data, ((0, 0), (0, 0), (1, 1), (1, 1))) | |||||
| expected = np.zeros( | |||||
| (batch_size, out_channels, output_height, output_width), dtype=np.float32, | |||||
| ) | |||||
| ic_group_size = in_channels // groups | |||||
| oc_group_size = out_channels // groups | |||||
| for n, oc, oh, ow in itertools.product( | |||||
| *map(range, [batch_size, out_channels, output_height, output_width]) | |||||
| ): | |||||
| ih, iw = oh * stride[0], ow * stride[1] | |||||
| g_id = oc // oc_group_size | |||||
| expected[n, oc, ih, iw] = np.sum( | |||||
| data[ | |||||
| n, | |||||
| g_id * ic_group_size : (g_id + 1) * ic_group_size, | |||||
| ih : ih + ksize[0], | |||||
| iw : iw + ksize[1], | |||||
| ] | |||||
| * weight[g_id, oh, ow, :, :, :, oc % oc_group_size] | |||||
| ) | |||||
| return expected | |||||
| data = np.random.rand(batch_size, in_channels, input_height, input_width).astype( | |||||
| "float32" | |||||
| ) | |||||
| weight = np.random.rand( | |||||
| groups, | |||||
| output_height, | |||||
| output_width, | |||||
| in_channels // groups, | |||||
| *ksize, | |||||
| out_channels // groups, | |||||
| ).astype("float32") | |||||
| output = F.local_conv2d( | |||||
| tensor(data), | |||||
| tensor(weight), | |||||
| None, | |||||
| stride=stride, | |||||
| padding=padding, | |||||
| dilation=dilation, | |||||
| ) | |||||
| ref = local_conv2d_np(data, weight, stride, padding, dilation) | |||||
| np.testing.assert_almost_equal(output.numpy(), ref, 5) | |||||
| @@ -42,162 +42,3 @@ def test_conv_dtype_promotion(name, reproducible): | |||||
| m = getattr(M, name)(Ci, Co, K) | m = getattr(M, name)(Ci, Co, K) | ||||
| x = tensor(np.random.random(size=(N, Ci) + S).astype("float16")) | x = tensor(np.random.random(size=(N, Ci) + S).astype("float16")) | ||||
| np.testing.assert_equal(m(x).numpy(), m(x.astype("float32")).numpy()) | np.testing.assert_equal(m(x).numpy(), m(x.astype("float32")).numpy()) | ||||
| def test_conv_transpose2d(): | |||||
| SH, SW = 3, 1 | |||||
| PH, PW = 2, 0 | |||||
| N, IC, IH, IW = 4, 5, 8, 6 | |||||
| KH, KW = 3, 4 | |||||
| OC = 3 | |||||
| BIAS = False | |||||
| def getsize(inp, kern, stride): | |||||
| return (inp - 1) * stride + kern | |||||
| OH = getsize(IH, KH, SH) | |||||
| OW = getsize(IW, KW, SW) | |||||
| inp = np.random.normal(size=(N, IC, IH, IW)).astype(np.float32) | |||||
| out = np.zeros((N, OC, OH, OW), dtype=np.float32) | |||||
| weight = np.random.normal(size=(IC, OC, KH, KW)).astype(np.float32) | |||||
| bias = np.random.normal(size=(1, OC, 1, 1)).astype(np.float32) | |||||
| # naive calculation use numpy | |||||
| for n, ic, ih, iw in itertools.product(*map(range, [N, IC, IH, IW])): | |||||
| oh, ow = ih * SH, iw * SW | |||||
| out[n, :, oh : oh + KH, ow : ow + KW] += inp[n, ic, ih, iw] * weight[ic] | |||||
| out = out[:, :, PH : OH - PH, PW : OW - PW] | |||||
| if BIAS: | |||||
| out += bias | |||||
| # megengine conv_transpose2d calculation | |||||
| conv_transpose2d = ConvTranspose2d(IC, OC, (KH, KW), (SH, SW), (PH, PW), bias=BIAS) | |||||
| conv_transpose2d.weight = Parameter(weight, dtype=np.float32) | |||||
| if BIAS: | |||||
| conv_transpose2d.bias = Parameter(bias, dtype=np.float32) | |||||
| y = conv_transpose2d(tensor(inp)) | |||||
| np.testing.assert_almost_equal(out, y.numpy(), 2e-6) | |||||
| def test_local_conv2d(): | |||||
| def test_func( | |||||
| batch_size, | |||||
| in_channels, | |||||
| out_channels, | |||||
| input_height, | |||||
| input_width, | |||||
| kernel_size, | |||||
| stride, | |||||
| padding, | |||||
| dilation, | |||||
| groups, | |||||
| ): | |||||
| local_conv2d = LocalConv2d( | |||||
| in_channels=in_channels, | |||||
| out_channels=out_channels, | |||||
| input_height=input_height, | |||||
| input_width=input_width, | |||||
| kernel_size=kernel_size, | |||||
| stride=stride, | |||||
| padding=padding, | |||||
| dilation=dilation, | |||||
| groups=groups, | |||||
| ) | |||||
| inputs = np.random.normal( | |||||
| size=(batch_size, in_channels, input_height, input_width) | |||||
| ).astype(np.float32) | |||||
| output_height = (input_height + padding * 2 - kernel_size) // stride + 1 | |||||
| output_width = (input_width + padding * 2 - kernel_size) // stride + 1 | |||||
| weights = local_conv2d.weight.numpy() | |||||
| outputs = local_conv2d(tensor(inputs)) | |||||
| # naive calculation use numpy | |||||
| # only test output_height == input_height, output_width == input_width | |||||
| inputs = np.pad(inputs, ((0, 0), (0, 0), (1, 1), (1, 1))) | |||||
| expected = np.zeros( | |||||
| (batch_size, out_channels, output_height, output_width), dtype=np.float32, | |||||
| ) | |||||
| ic_group_size = in_channels // groups | |||||
| oc_group_size = out_channels // groups | |||||
| for n, oc, oh, ow in itertools.product( | |||||
| *map(range, [batch_size, out_channels, output_height, output_width]) | |||||
| ): | |||||
| ih, iw = oh * stride, ow * stride | |||||
| g_id = oc // oc_group_size | |||||
| expected[n, oc, ih, iw] = np.sum( | |||||
| inputs[ | |||||
| n, | |||||
| g_id * ic_group_size : (g_id + 1) * ic_group_size, | |||||
| ih : ih + kernel_size, | |||||
| iw : iw + kernel_size, | |||||
| ] | |||||
| * weights[g_id, oh, ow, :, :, :, oc % oc_group_size] | |||||
| ) | |||||
| np.testing.assert_almost_equal(outputs.numpy(), expected, 1e-5) | |||||
| test_func(10, 4, 4, 5, 5, 3, 1, 1, 1, 1) | |||||
| test_func(10, 32, 32, 8, 8, 3, 1, 1, 1, 2) | |||||
| test_func(10, 32, 32, 8, 8, 3, 1, 1, 1, 4) | |||||
| def test_conv_transpose3d(): | |||||
| def getsize(inp, kernel, stride, dilate): | |||||
| return (inp - 1) * stride + kernel * dilate - dilate + 1 | |||||
| def test_func( | |||||
| N, | |||||
| IC, | |||||
| ID, | |||||
| IH, | |||||
| IW, | |||||
| OC, | |||||
| KD, | |||||
| KH, | |||||
| KW, | |||||
| SD, | |||||
| SH, | |||||
| SW, | |||||
| PD, | |||||
| PH, | |||||
| PW, | |||||
| DD, | |||||
| DH, | |||||
| DW, | |||||
| bias=True, | |||||
| ): | |||||
| conv_transpose3d = ConvTranspose3d( | |||||
| in_channels=IC, | |||||
| out_channels=OC, | |||||
| kernel_size=(KD, KH, KW), | |||||
| stride=(SD, SH, SW), | |||||
| padding=(PD, PH, PW), | |||||
| dilation=(DD, DH, DW), | |||||
| bias=bias, | |||||
| ) | |||||
| OD = getsize(ID, KD, SD, DD) | |||||
| OH = getsize(IH, KH, SH, DH) | |||||
| OW = getsize(IW, KW, SW, DW) | |||||
| inp = np.random.normal(size=(N, IC, ID, IH, IW)) | |||||
| weight = np.random.normal(size=(IC, OC, KD, KH, KW)) | |||||
| out_np = np.zeros((N, OC, OD, OH, OW), dtype=np.float32) | |||||
| for n, ic, idepth, ih, iw in itertools.product( | |||||
| *map(range, [N, IC, ID, IH, IW]) | |||||
| ): | |||||
| od, oh, ow = idepth * SD, ih * SH, iw * SW | |||||
| out_np[n, :, od : od + KD, oh : oh + KH, ow : ow + KW] += ( | |||||
| inp[n, ic, idepth, ih, iw] * weight[ic] | |||||
| ) | |||||
| out_np = out_np[:, :, PD : OD - PD, PH : OH - PH, PW : OW - PW] | |||||
| assert conv_transpose3d.weight.numpy().shape == weight.shape | |||||
| conv_transpose3d.weight = Parameter(weight) | |||||
| out_meg = conv_transpose3d.forward(tensor(inp)) | |||||
| np.testing.assert_almost_equal(out_meg.numpy(), out_np, 1e-5) | |||||
| test_func(4, 3, 8, 16, 16, 8, 3, 3, 3, 1, 1, 1, 1, 1, 1, 1, 1, 1) | |||||
| test_func(4, 8, 16, 32, 32, 16, 1, 3, 1, 2, 1, 2, 0, 1, 0, 1, 1, 1) | |||||