| @@ -62,6 +62,7 @@ __all__ = [ | |||||
| "softplus", | "softplus", | ||||
| "svd", | "svd", | ||||
| "warp_perspective", | "warp_perspective", | ||||
| "conv1d", | |||||
| ] | ] | ||||
| @@ -121,7 +122,7 @@ def conv2d( | |||||
| and the shape of weight should be `(groups, out_channel // groups, | and the shape of weight should be `(groups, out_channel // groups, | ||||
| in_channels // groups, height, width)`. | in_channels // groups, height, width)`. | ||||
| :type conv_mode: string or :class:`P.Convolution.Mode` | :type conv_mode: string or :class:`P.Convolution.Mode` | ||||
| :param conv_mode: supports "CROSS_CORRELATION" or "CONVOLUTION". Default: | |||||
| :param conv_mode: supports "CROSS_CORRELATION". Default: | |||||
| "CROSS_CORRELATION" | "CROSS_CORRELATION" | ||||
| :type compute_mode: string or | :type compute_mode: string or | ||||
| :class:`P.Convolution.ComputeMode` | :class:`P.Convolution.ComputeMode` | ||||
| @@ -187,7 +188,7 @@ def conv_transpose2d( | |||||
| and the shape of weight should be `(groups, out_channel // groups, | and the shape of weight should be `(groups, out_channel // groups, | ||||
| in_channels // groups, height, width)`. Default: 1 | in_channels // groups, height, width)`. Default: 1 | ||||
| :type conv_mode: string or :class:`P.Convolution.Mode` | :type conv_mode: string or :class:`P.Convolution.Mode` | ||||
| :param conv_mode: supports "CROSS_CORRELATION" or "CONVOLUTION". Default: | |||||
| :param conv_mode: supports "CROSS_CORRELATION". Default: | |||||
| "CROSS_CORRELATION" | "CROSS_CORRELATION" | ||||
| :type compute_mode: string or | :type compute_mode: string or | ||||
| :class:`P.Convolution.ComputeMode` | :class:`P.Convolution.ComputeMode` | ||||
| @@ -232,9 +233,7 @@ def local_conv2d( | |||||
| dilation: Union[int, Tuple[int, int]] = 1, | dilation: Union[int, Tuple[int, int]] = 1, | ||||
| conv_mode="CROSS_CORRELATION", | conv_mode="CROSS_CORRELATION", | ||||
| ): | ): | ||||
| """ | |||||
| Applies spatial 2D convolution over an groupped channeled image with untied kernels. | |||||
| """ | |||||
| """Applies spatial 2D convolution over an groupped channeled image with untied kernels.""" | |||||
| assert conv_mode == "CROSS_CORRELATION" or conv_mode.name == "CROSS_CORRELATION" | assert conv_mode == "CROSS_CORRELATION" or conv_mode.name == "CROSS_CORRELATION" | ||||
| stride_h, stride_w = expand_hw(stride) | stride_h, stride_w = expand_hw(stride) | ||||
| @@ -1585,6 +1584,82 @@ def indexing_one_hot( | |||||
| return result | return result | ||||
| def conv1d( | |||||
| inp: Tensor, | |||||
| weight: Tensor, | |||||
| bias: Optional[Tensor] = None, | |||||
| stride: int = 1, | |||||
| padding: int = 0, | |||||
| dilation: int = 1, | |||||
| groups: int = 1, | |||||
| conv_mode="CROSS_CORRELATION", | |||||
| compute_mode="DEFAULT", | |||||
| ) -> Tensor: | |||||
| """1D convolution operation. | |||||
| Refer to :class:`~.Conv1d` for more information. | |||||
| :param inp: The feature map of the convolution operation | |||||
| :param weight: The convolution kernel | |||||
| :param bias: The bias added to the result of convolution (if given) | |||||
| :param stride: Stride of the 1D convolution operation. Default: 1 | |||||
| :param padding: Size of the paddings added to the input on both sides of its | |||||
| spatial dimensions. Only zero-padding is supported. Default: 0 | |||||
| :param dilation: Dilation of the 1D convolution operation. Default: 1 | |||||
| :param groups: number of groups to divide input and output channels into, | |||||
| so as to perform a "grouped convolution". When ``groups`` is not 1, | |||||
| ``in_channels`` and ``out_channels`` must be divisible by ``groups``, | |||||
| and the shape of weight should be ``(groups, out_channel // groups, | |||||
| in_channels // groups, height, width)``. | |||||
| :type conv_mode: string or :class:`mgb.opr_param_defs.Convolution.Mode` | |||||
| :param conv_mode: Supports 'CROSS_CORRELATION'. Default: | |||||
| 'CROSS_CORRELATION'. | |||||
| :type compute_mode: string or | |||||
| :class:`mgb.opr_param_defs.Convolution.ComputeMode` | |||||
| :param compute_mode: When set to 'DEFAULT', no special requirements will be | |||||
| placed on the precision of intermediate results. When set to 'FLOAT32', | |||||
| Float32 would be used for accumulator and intermediate result, but only | |||||
| effective when input and output are of Float16 dtype. | |||||
| """ | |||||
| assert conv_mode == "CROSS_CORRELATION" or conv_mode.name == "CROSS_CORRELATION" | |||||
| assert compute_mode == "DEFAULT" or compute_mode.name == "DEFAULT" | |||||
| assert inp.ndim == 3, "the input dimension of conv1d should be 3" | |||||
| assert weight.ndim == 3, "the weight dimension of conv1d should be 3" | |||||
| inp = expand_dims(inp, 3) | |||||
| weight = expand_dims(weight, 3) | |||||
| if bias is not None: | |||||
| assert bias.ndim == 3, "the bias dimension of conv1d should be 3" | |||||
| bias = expand_dims(bias, 3) | |||||
| stride_h = stride | |||||
| pad_h = padding | |||||
| dilate_h = dilation | |||||
| Sparse = P.Convolution.Sparse | |||||
| sparse_type = Sparse.DENSE if groups == 1 else Sparse.GROUP | |||||
| op = builtin.Convolution( | |||||
| stride_h=stride_h, | |||||
| stride_w=1, | |||||
| pad_h=pad_h, | |||||
| pad_w=0, | |||||
| dilate_h=dilate_h, | |||||
| dilate_w=1, | |||||
| strategy=get_conv_execution_strategy(), | |||||
| mode=conv_mode, | |||||
| compute_mode=compute_mode, | |||||
| sparse=sparse_type, | |||||
| ) | |||||
| inp, weight = utils.convert_inputs(inp, weight) | |||||
| (output,) = apply(op, inp, weight) | |||||
| if bias is not None: | |||||
| output += bias | |||||
| output = squeeze(output, 3) | |||||
| return output | |||||
| def nms( | def nms( | ||||
| boxes: Tensor, scores: Tensor, iou_thresh: float, max_output: Optional[int] = None | boxes: Tensor, scores: Tensor, iou_thresh: float, max_output: Optional[int] = None | ||||
| ) -> Tensor: | ) -> Tensor: | ||||
| @@ -11,7 +11,7 @@ from .activation import LeakyReLU, PReLU, ReLU, Sigmoid, Softmax | |||||
| from .adaptive_pooling import AdaptiveAvgPool2d, AdaptiveMaxPool2d | from .adaptive_pooling import AdaptiveAvgPool2d, AdaptiveMaxPool2d | ||||
| from .batchnorm import BatchNorm1d, BatchNorm2d, SyncBatchNorm | from .batchnorm import BatchNorm1d, BatchNorm2d, SyncBatchNorm | ||||
| from .concat import Concat | from .concat import Concat | ||||
| from .conv import Conv2d, ConvRelu2d, ConvTranspose2d, LocalConv2d | |||||
| from .conv import Conv1d, Conv2d, ConvRelu2d, ConvTranspose2d, LocalConv2d | |||||
| from .conv_bn import ConvBn2d, ConvBnRelu2d | from .conv_bn import ConvBn2d, ConvBnRelu2d | ||||
| from .dropout import Dropout | from .dropout import Dropout | ||||
| from .elemwise import Elemwise | from .elemwise import Elemwise | ||||
| @@ -11,7 +11,7 @@ from typing import Tuple, Union | |||||
| import numpy as np | import numpy as np | ||||
| from ..core.ops._internal import param_defs as P | from ..core.ops._internal import param_defs as P | ||||
| from ..functional import conv2d, conv_transpose2d, local_conv2d, relu | |||||
| from ..functional import conv1d, conv2d, conv_transpose2d, local_conv2d, relu | |||||
| from ..functional.types import _pair, _pair_nonzero | from ..functional.types import _pair, _pair_nonzero | ||||
| from ..tensor import Parameter | from ..tensor import Parameter | ||||
| from . import init | from . import init | ||||
| @@ -86,6 +86,152 @@ class _ConvNd(Module): | |||||
| return s.format(**self.__dict__) | return s.format(**self.__dict__) | ||||
| class Conv1d(_ConvNd): | |||||
| r""" | |||||
| Applies a 1D convolution over an input tensor. | |||||
| For instance, given an input of the size :math:`(N, C_{\text{in}}, H)`, | |||||
| this layer generates an output of the size | |||||
| :math:`(N, C_{\text{out}}, H_{\text{out}}})` through the | |||||
| process described as below: | |||||
| .. math:: | |||||
| \text{out}(N_i, C_{\text{out}_j}) = \text{bias}(C_{\text{out}_j}) + | |||||
| \sum_{k = 0}^{C_{\text{in}} - 1} \text{weight}(C_{\text{out}_j}, k) \star \text{input}(N_i, k) | |||||
| where :math:`\star` is the valid 1D cross-correlation operator, | |||||
| :math:`N` is batch size, :math:`C` denotes number of channels, and | |||||
| :math:`H` is length of 1D data element. | |||||
| When `groups == in_channels` and `out_channels == K * in_channels`, | |||||
| where K is a positive integer, this operation is also known as depthwise | |||||
| convolution. | |||||
| In other words, for an input of size :math:`(N, C_{in}, H_{in})`, | |||||
| a depthwise convolution with a depthwise multiplier `K`, can be constructed | |||||
| by arguments :math:`(in\_channels=C_{in}, out\_channels=C_{in} \times K, ..., groups=C_{in})`. | |||||
| :param in_channels: number of input channels. | |||||
| :param out_channels: number of output channels. | |||||
| :param kernel_size: size of weight on spatial dimensions. If kernel_size is | |||||
| an :class:`int`, the actual kernel size would be | |||||
| `(kernel_size, kernel_size)`. Default: 1 | |||||
| :param stride: stride of the 1D convolution operation. Default: 1 | |||||
| :param padding: size of the paddings added to the input on both sides of its | |||||
| spatial dimensions. Only zero-padding is supported. Default: 0 | |||||
| :param dilation: dilation of the 1D convolution operation. Default: 1 | |||||
| :param groups: number of groups into which the input and output channels are divided, so as to perform a "grouped convolution". When ``groups`` is not 1, | |||||
| ``in_channels`` and ``out_channels`` must be divisible by ``groups``, | |||||
| and there would be an extra dimension at the beginning of the weight's | |||||
| shape. Specifically, the shape of weight would be `(groups, | |||||
| out_channel // groups, in_channels // groups, *kernel_size)`. | |||||
| :param bias: whether to add a bias onto the result of convolution. Default: | |||||
| True | |||||
| :param conv_mode: Supports `CROSS_CORRELATION`. Default: | |||||
| `CROSS_CORRELATION` | |||||
| :param compute_mode: When set to "DEFAULT", no special requirements will be | |||||
| placed on the precision of intermediate results. When set to "FLOAT32", | |||||
| "Float32" would be used for accumulator and intermediate result, but only | |||||
| effective when input and output are of float16 dtype. | |||||
| Examples: | |||||
| .. testcode:: | |||||
| import numpy as np | |||||
| import megengine as mge | |||||
| import megengine.module as M | |||||
| m = M.Conv1d(in_channels=3, out_channels=1, kernel_size=3) | |||||
| inp = mge.tensor(np.arange(0, 24).astype("float32").reshape(2, 3, 4)) | |||||
| oup = m(inp) | |||||
| print(oup.numpy().shape) | |||||
| Outputs: | |||||
| .. testoutput:: | |||||
| (2, 1, 2) | |||||
| """ | |||||
| _conv_mode_type = P.Convolution.Mode | |||||
| _compute_mode_type = P.Convolution.ComputeMode | |||||
| def __init__( | |||||
| self, | |||||
| in_channels: int, | |||||
| out_channels: int, | |||||
| kernel_size: int, | |||||
| stride: int = 1, | |||||
| padding: int = 0, | |||||
| dilation: int = 1, | |||||
| groups: int = 1, | |||||
| bias: bool = True, | |||||
| conv_mode: str = "CROSS_CORRELATION", | |||||
| compute_mode: str = "DEFAULT", | |||||
| ): | |||||
| kernel_size = kernel_size | |||||
| stride = stride | |||||
| padding = padding | |||||
| dilation = dilation | |||||
| self.conv_mode = self._conv_mode_type.convert(conv_mode) | |||||
| self.compute_mode = self._compute_mode_type.convert(compute_mode) | |||||
| super().__init__( | |||||
| in_channels, | |||||
| out_channels, | |||||
| kernel_size, | |||||
| stride, | |||||
| padding, | |||||
| dilation, | |||||
| groups, | |||||
| bias, | |||||
| ) | |||||
| def _get_fanin(self): | |||||
| kh = self.kernel_size | |||||
| ic = self.in_channels | |||||
| return kh * ic | |||||
| def _infer_weight_shape(self): | |||||
| group = self.groups | |||||
| ichl = self.in_channels | |||||
| ochl = self.out_channels | |||||
| kh = self.kernel_size | |||||
| if group == 1: | |||||
| # Assume format is NCH(W=1) | |||||
| return (ochl, ichl, kh) | |||||
| assert ( | |||||
| ichl % group == 0 and ochl % group == 0 | |||||
| ), "invalid config: input_channels={} output_channels={} group={}".format( | |||||
| ichl, ochl, group | |||||
| ) | |||||
| # Assume format is NCH(W=1) | |||||
| return (group, ochl // group, ichl // group, kh) | |||||
| def _infer_bias_shape(self): | |||||
| # Assume format is NCH(W=1) | |||||
| return (1, self.out_channels, 1) | |||||
| def calc_conv(self, inp, weight, bias): | |||||
| return conv1d( | |||||
| inp, | |||||
| weight, | |||||
| bias, | |||||
| self.stride, | |||||
| self.padding, | |||||
| self.dilation, | |||||
| self.groups, | |||||
| self.conv_mode, | |||||
| self.compute_mode, | |||||
| ) | |||||
| def forward(self, inp): | |||||
| return self.calc_conv(inp, self.weight, self.bias) | |||||
| class Conv2d(_ConvNd): | class Conv2d(_ConvNd): | ||||
| r""" | r""" | ||||
| Applies a 2D convolution over an input tensor. | Applies a 2D convolution over an input tensor. | ||||
| @@ -128,7 +274,7 @@ class Conv2d(_ConvNd): | |||||
| out_channel // groups, in_channels // groups, *kernel_size)`. | out_channel // groups, in_channels // groups, *kernel_size)`. | ||||
| :param bias: whether to add a bias onto the result of convolution. Default: | :param bias: whether to add a bias onto the result of convolution. Default: | ||||
| True | True | ||||
| :param conv_mode: Supports `CROSS_CORRELATION` or `CONVOLUTION`. Default: | |||||
| :param conv_mode: Supports `CROSS_CORRELATION`. Default: | |||||
| `CROSS_CORRELATION` | `CROSS_CORRELATION` | ||||
| :param compute_mode: When set to "DEFAULT", no special requirements will be | :param compute_mode: When set to "DEFAULT", no special requirements will be | ||||
| placed on the precision of intermediate results. When set to "FLOAT32", | placed on the precision of intermediate results. When set to "FLOAT32", | ||||
| @@ -260,7 +406,7 @@ class ConvTranspose2d(_ConvNd): | |||||
| out_channels // groups, in_channels // groups, *kernel_size)``. Default: 1 | out_channels // groups, in_channels // groups, *kernel_size)``. Default: 1 | ||||
| :param bias: wether to add a bias onto the result of convolution. Default: | :param bias: wether to add a bias onto the result of convolution. Default: | ||||
| True | True | ||||
| :param conv_mode: Supports `CROSS_CORRELATION` or `CONVOLUTION`. Default: | |||||
| :param conv_mode: Supports `CROSS_CORRELATION`. Default: | |||||
| `CROSS_CORRELATION` | `CROSS_CORRELATION` | ||||
| :param compute_mode: When set to "DEFAULT", no special requirements will be | :param compute_mode: When set to "DEFAULT", no special requirements will be | ||||
| placed on the precision of intermediate results. When set to "FLOAT32", | placed on the precision of intermediate results. When set to "FLOAT32", | ||||
| @@ -531,6 +531,18 @@ def test_zero_stride_numpy_array(): | |||||
| out = F.conv2d(inp, weight, None, (2, 2), (3, 3), (1, 1), 1) | out = F.conv2d(inp, weight, None, (2, 2), (3, 3), (1, 1), 1) | ||||
| def test_conv1d(): | |||||
| inp = tensor(np.ones((16,), dtype=np.float32).reshape(2, 2, 4)) | |||||
| weight = tensor(np.ones((12,), dtype=np.float32).reshape(3, 2, 2)) | |||||
| out = F.conv1d(inp, weight, None, 2, 0, 1, 1) | |||||
| np.testing.assert_equal( | |||||
| out.numpy(), | |||||
| np.array( | |||||
| [[[4, 4], [4, 4], [4, 4]], [[4, 4], [4, 4], [4, 4]]], dtype=np.float32 | |||||
| ), | |||||
| ) | |||||
| def test_condtake(): | def test_condtake(): | ||||
| x = np.array([[1, 2, 3], [4, 5, 6]]) | x = np.array([[1, 2, 3], [4, 5, 6]]) | ||||
| y = np.array([[True, False, True], [False, True, True]]) | y = np.array([[True, False, True], [False, True, True]]) | ||||
| @@ -20,6 +20,7 @@ from megengine import Parameter, Tensor, tensor | |||||
| from megengine.module import ( | from megengine.module import ( | ||||
| BatchNorm1d, | BatchNorm1d, | ||||
| BatchNorm2d, | BatchNorm2d, | ||||
| Conv1d, | |||||
| Conv2d, | Conv2d, | ||||
| Dropout, | Dropout, | ||||
| Linear, | Linear, | ||||
| @@ -541,6 +542,43 @@ def test_shared_param(): | |||||
| np.testing.assert_allclose(conv0(data).numpy(), conv1(data).numpy()) | np.testing.assert_allclose(conv0(data).numpy(), conv1(data).numpy()) | ||||
| class Simple2(Module): | |||||
| def __init__(self): | |||||
| super().__init__() | |||||
| self.conv1 = Conv1d(1, 1, kernel_size=3, bias=False) | |||||
| self.conv0 = Conv1d(1, 1, kernel_size=3, bias=False) | |||||
| self.conv1.weight = self.conv0.weight | |||||
| def forward(self, inputs): | |||||
| pass | |||||
| def test_shared_param_1d(): | |||||
| net = Simple2() | |||||
| assert net.conv0.weight is net.conv1.weight | |||||
| data = tensor(np.random.random((1, 1, 8)).astype(np.float32)) | |||||
| np.testing.assert_allclose(net.conv0(data).numpy(), net.conv1(data).numpy()) | |||||
| with BytesIO() as f: | |||||
| mge.save(net, f) | |||||
| f.seek(0) | |||||
| net1 = mge.load(f) | |||||
| assert net1.conv0.weight is net1.conv1.weight | |||||
| np.testing.assert_allclose(net1.conv0(data).numpy(), net1.conv1(data).numpy()) | |||||
| with BytesIO() as f: | |||||
| mge.save(net.conv0, f) | |||||
| f.seek(0) | |||||
| conv0 = mge.load(f) | |||||
| with BytesIO() as f: | |||||
| mge.save(net.conv1, f) | |||||
| f.seek(0) | |||||
| conv1 = mge.load(f) | |||||
| assert conv0.weight is not conv1.weight | |||||
| np.testing.assert_allclose(conv0(data).numpy(), conv1(data).numpy()) | |||||
| def test_pickle_module(): | def test_pickle_module(): | ||||
| data_shape = (2, 28) | data_shape = (2, 28) | ||||
| data = tensor(np.random.random(data_shape)) | data = tensor(np.random.random(data_shape)) | ||||