| @@ -22,8 +22,8 @@ from ..device import get_default_device | |||||
| from ..distributed import WORLD, is_distributed | from ..distributed import WORLD, is_distributed | ||||
| from ..random import uniform | from ..random import uniform | ||||
| from ..tensor import Tensor | from ..tensor import Tensor | ||||
| from ..utils.tuple_function import _pair, _pair_nonzero | |||||
| from .debug_param import get_conv_execution_strategy, get_execution_strategy | |||||
| from ..utils.tuple_function import _pair, _pair_nonzero, _triple, _triple_nonzero | |||||
| from .debug_param import get_execution_strategy | |||||
| from .distributed import all_reduce_sum | from .distributed import all_reduce_sum | ||||
| from .elemwise import exp, floor, log, log1p, maximum, minimum | from .elemwise import exp, floor, log, log1p, maximum, minimum | ||||
| from .math import argsort, matmul, max, prod, sum | from .math import argsort, matmul, max, prod, sum | ||||
| @@ -43,7 +43,9 @@ __all__ = [ | |||||
| "adaptive_max_pool2d", | "adaptive_max_pool2d", | ||||
| "avg_pool2d", | "avg_pool2d", | ||||
| "batch_norm", | "batch_norm", | ||||
| "conv1d", | |||||
| "conv2d", | "conv2d", | ||||
| "conv3d", | |||||
| "conv_transpose2d", | "conv_transpose2d", | ||||
| "deformable_conv2d", | "deformable_conv2d", | ||||
| "deformable_psroi_pooling", | "deformable_psroi_pooling", | ||||
| @@ -166,6 +168,66 @@ def conv2d( | |||||
| return output | return output | ||||
| def conv3d( | |||||
| inp: Tensor, | |||||
| weight: Tensor, | |||||
| bias: Optional[Tensor] = None, | |||||
| stride: Union[int, Tuple[int, int, int]] = 1, | |||||
| padding: Union[int, Tuple[int, int, int]] = 0, | |||||
| dilation: Union[int, Tuple[int, int, int]] = 1, | |||||
| groups: int = 1, | |||||
| conv_mode: str = "CROSS_CORRELATION", | |||||
| ) -> Tensor: | |||||
| """ | |||||
| 3D convolution operation. | |||||
| Refer to :class:`~.Conv3d` for more information. | |||||
| :param inp: feature map of the convolution operation. | |||||
| :param weight: convolution kernel. | |||||
| :param bias: bias added to the result of convolution (if given). | |||||
| :param stride: stride of the 3D convolution operation. Default: 1 | |||||
| :param padding: size of the paddings added to the input on both sides of its | |||||
| spatial dimensions. Only zero-padding is supported. Default: 0 | |||||
| :param dilation: dilation of the 3D convolution operation. Default: 1 | |||||
| :param groups: number of groups into which the input and output channels are divided, so as to perform a ``grouped convolution``. When ``groups`` is not 1, | |||||
| ``in_channels`` and ``out_channels`` must be divisible by ``groups``, | |||||
| and the shape of weight should be `(groups, out_channel // groups, | |||||
| in_channels // groups, t, height, width)`. | |||||
| :param conv_mode: supports "CROSS_CORRELATION". Default: | |||||
| "CROSS_CORRELATION" | |||||
| :return: output tensor. | |||||
| """ | |||||
| assert conv_mode == "CROSS_CORRELATION" | |||||
| D, H, W = 0, 1, 2 | |||||
| pad = _triple(padding) | |||||
| stride = _triple_nonzero(stride) | |||||
| dilate = _triple_nonzero(dilation) | |||||
| sparse_type = "DENSE" if groups == 1 else "GROUP" | |||||
| op = builtin.Convolution3D( | |||||
| pad_d=pad[D], | |||||
| pad_h=pad[H], | |||||
| pad_w=pad[W], | |||||
| stride_d=stride[D], | |||||
| stride_h=stride[H], | |||||
| stride_w=stride[W], | |||||
| dilate_d=dilate[D], | |||||
| dilate_h=dilate[H], | |||||
| dilate_w=dilate[W], | |||||
| strategy=get_execution_strategy(), | |||||
| mode=conv_mode, | |||||
| sparse=sparse_type, | |||||
| ) | |||||
| inp, weight = utils.convert_inputs(inp, weight) | |||||
| (output,) = apply(op, inp, weight) | |||||
| if bias is not None: | |||||
| output += bias | |||||
| return output | |||||
| def conv_transpose2d( | def conv_transpose2d( | ||||
| inp: Tensor, | inp: Tensor, | ||||
| weight: Tensor, | weight: Tensor, | ||||
| @@ -1094,7 +1156,7 @@ def matmul( | |||||
| transposeB=transpose_b, | transposeB=transpose_b, | ||||
| compute_mode=compute_mode, | compute_mode=compute_mode, | ||||
| format=format, | format=format, | ||||
| strategy=get_conv_execution_strategy(), | |||||
| strategy=get_execution_strategy(), | |||||
| ) | ) | ||||
| else: | else: | ||||
| op = builtin.MatrixMul( | op = builtin.MatrixMul( | ||||
| @@ -1102,7 +1164,7 @@ def matmul( | |||||
| transposeB=transpose_b, | transposeB=transpose_b, | ||||
| compute_mode=compute_mode, | compute_mode=compute_mode, | ||||
| format=format, | format=format, | ||||
| strategy=get_conv_execution_strategy(), | |||||
| strategy=get_execution_strategy(), | |||||
| ) | ) | ||||
| (result,) = apply(op, inp1, inp2) | (result,) = apply(op, inp1, inp2) | ||||
| @@ -15,6 +15,7 @@ from .concat import Concat | |||||
| from .conv import ( | from .conv import ( | ||||
| Conv1d, | Conv1d, | ||||
| Conv2d, | Conv2d, | ||||
| Conv3d, | |||||
| ConvRelu2d, | ConvRelu2d, | ||||
| ConvTranspose2d, | ConvTranspose2d, | ||||
| DeformableConv2d, | DeformableConv2d, | ||||
| @@ -13,13 +13,14 @@ import numpy as np | |||||
| from ..functional import ( | from ..functional import ( | ||||
| conv1d, | conv1d, | ||||
| conv2d, | conv2d, | ||||
| conv3d, | |||||
| conv_transpose2d, | conv_transpose2d, | ||||
| deformable_conv2d, | deformable_conv2d, | ||||
| local_conv2d, | local_conv2d, | ||||
| relu, | relu, | ||||
| ) | ) | ||||
| from ..tensor import Parameter | from ..tensor import Parameter | ||||
| from ..utils.tuple_function import _pair, _pair_nonzero | |||||
| from ..utils.tuple_function import _pair, _pair_nonzero, _triple, _triple_nonzero | |||||
| from . import init | from . import init | ||||
| from .module import Module | from .module import Module | ||||
| @@ -400,6 +401,142 @@ class Conv2d(_ConvNd): | |||||
| return self.calc_conv(inp, self.weight, self.bias) | return self.calc_conv(inp, self.weight, self.bias) | ||||
| class Conv3d(_ConvNd): | |||||
| r""" | |||||
| Applies a 3D convolution over an input tensor. | |||||
| For instance, given an input of the size :math:`(N, C_{\text{in}}, T, H, W)`, | |||||
| this layer generates an output of the size | |||||
| :math:`(N, C_{\text{out}}, T_{\text{out}}}, H_{\text{out}}}, W_{\text{out}}})` through the | |||||
| process described as below: | |||||
| .. math:: | |||||
| \text{out}(N_i, C_{\text{out}_j}) = \text{bias}(C_{\text{out}_j}) + | |||||
| \sum_{k = 0}^{C_{\text{in}} - 1} \text{weight}(C_{\text{out}_j}, k) \star \text{input}(N_i, k) | |||||
| where :math:`\star` is the valid 3D cross-correlation operator, | |||||
| :math:`N` is batch size, :math:`C` denotes number of channels | |||||
| When `groups == in_channels` and `out_channels == K * in_channels`, | |||||
| where K is a positive integer, this operation is also known as depthwise | |||||
| convolution. | |||||
| In other words, for an input of size :math:`(N, C_{in}, T_{int}, H_{in}, W_{in})`, | |||||
| a depthwise convolution with a depthwise multiplier `K`, can be constructed | |||||
| by arguments :math:`(in\_channels=C_{in}, out\_channels=C_{in} \times K, ..., groups=C_{in})`. | |||||
| :param in_channels: number of input channels. | |||||
| :param out_channels: number of output channels. | |||||
| :param kernel_size: size of weight on spatial dimensions. If kernel_size is | |||||
| an :class:`int`, the actual kernel size would be | |||||
| `(kernel_size, kernel_size, kernel_size)`. Default: 1 | |||||
| :param stride: stride of the 3D convolution operation. Default: 1 | |||||
| :param padding: size of the paddings added to the input on both sides of its | |||||
| spatial dimensions. Only zero-padding is supported. Default: 0 | |||||
| :param dilation: dilation of the 3D convolution operation. Default: 1 | |||||
| :param groups: number of groups into which the input and output channels are divided, so as to perform a "grouped convolution". When ``groups`` is not 1, | |||||
| ``in_channels`` and ``out_channels`` must be divisible by ``groups``, | |||||
| and there would be an extra dimension at the beginning of the weight's | |||||
| shape. Specifically, the shape of weight would be `(groups, | |||||
| out_channel // groups, in_channels // groups, *kernel_size)`. | |||||
| :param bias: whether to add a bias onto the result of convolution. Default: | |||||
| True | |||||
| :param conv_mode: Supports `CROSS_CORRELATION`. Default: | |||||
| `CROSS_CORRELATION` | |||||
| Examples: | |||||
| .. testcode:: | |||||
| import numpy as np | |||||
| import megengine as mge | |||||
| import megengine.module as M | |||||
| m = M.Conv3d(in_channels=3, out_channels=1, kernel_size=3) | |||||
| inp = mge.tensor(np.arange(0, 384).astype("float32").reshape(2, 3, 4, 4, 4)) | |||||
| oup = m(inp) | |||||
| print(oup.numpy().shape) | |||||
| Outputs: | |||||
| .. testoutput:: | |||||
| (2, 1, 2, 2, 2) | |||||
| """ | |||||
| def __init__( | |||||
| self, | |||||
| in_channels: int, | |||||
| out_channels: int, | |||||
| kernel_size: Union[int, Tuple[int, int, int]], | |||||
| stride: Union[int, Tuple[int, int, int]] = 1, | |||||
| padding: Union[int, Tuple[int, int, int]] = 0, | |||||
| dilation: Union[int, Tuple[int, int, int]] = 1, | |||||
| groups: int = 1, | |||||
| bias: bool = True, | |||||
| conv_mode: str = "CROSS_CORRELATION", | |||||
| ): | |||||
| kernel_size = _triple_nonzero(kernel_size) | |||||
| stride = _triple_nonzero(stride) | |||||
| padding = _triple(padding) | |||||
| dilation = _triple_nonzero(dilation) | |||||
| self.conv_mode = conv_mode | |||||
| super().__init__( | |||||
| in_channels, | |||||
| out_channels, | |||||
| kernel_size, | |||||
| stride, | |||||
| padding, | |||||
| dilation, | |||||
| groups, | |||||
| bias, | |||||
| ) | |||||
| def _get_fanin(self): | |||||
| kt, kh, kw = self.kernel_size | |||||
| ic = self.in_channels | |||||
| return kt * kh * kw * ic | |||||
| def _infer_weight_shape(self): | |||||
| group = self.groups | |||||
| ichl = self.in_channels | |||||
| ochl = self.out_channels | |||||
| kt, kh, kw = self.kernel_size | |||||
| if group == 1: | |||||
| # Assume format is NCTHW | |||||
| return (ochl, ichl, kt, kh, kw) | |||||
| assert ( | |||||
| ichl % group == 0 and ochl % group == 0 | |||||
| ), "invalid config: input_channels={} output_channels={} group={}".format( | |||||
| ichl, ochl, group | |||||
| ) | |||||
| # Assume format is NCTHW | |||||
| return (group, ochl // group, ichl // group, kt, kh, kw) | |||||
| def _infer_bias_shape(self): | |||||
| # Assume format is NCTHW | |||||
| return (1, self.out_channels, 1, 1, 1) | |||||
| def calc_conv(self, inp, weight, bias): | |||||
| return conv3d( | |||||
| inp, | |||||
| weight, | |||||
| bias, | |||||
| self.stride, | |||||
| self.padding, | |||||
| self.dilation, | |||||
| self.groups, | |||||
| self.conv_mode, | |||||
| ) | |||||
| def forward(self, inp): | |||||
| return self.calc_conv(inp, self.weight, self.bias) | |||||
| class ConvTranspose2d(_ConvNd): | class ConvTranspose2d(_ConvNd): | ||||
| r""" | r""" | ||||
| Applies a 2D transposed convolution over an input tensor. | Applies a 2D transposed convolution over an input tensor. | ||||
| @@ -35,4 +35,5 @@ _single = functools.partial(get_ndtuple, n=1, allow_zero=True) | |||||
| _pair = functools.partial(get_ndtuple, n=2, allow_zero=True) | _pair = functools.partial(get_ndtuple, n=2, allow_zero=True) | ||||
| _pair_nonzero = functools.partial(get_ndtuple, n=2, allow_zero=False) | _pair_nonzero = functools.partial(get_ndtuple, n=2, allow_zero=False) | ||||
| _triple = functools.partial(get_ndtuple, n=3, allow_zero=True) | _triple = functools.partial(get_ndtuple, n=3, allow_zero=True) | ||||
| _triple_nonzero = functools.partial(get_ndtuple, n=3, allow_zero=False) | |||||
| _quadruple = functools.partial(get_ndtuple, n=4, allow_zero=True) | _quadruple = functools.partial(get_ndtuple, n=4, allow_zero=True) | ||||
| @@ -637,7 +637,7 @@ def test_batch_conv_bias(): | |||||
| run(1, 4, 4, 5, 5, 3, 3, 0, 0, 1, 1, True) | run(1, 4, 4, 5, 5, 3, 3, 0, 0, 1, 1, True) | ||||
| def test_zero_stride_numpy_array(): | |||||
| def test_conv2d_zero_stride_numpy_array(): | |||||
| inp = np.random.randn(3, 224, 224).astype(np.float32) | inp = np.random.randn(3, 224, 224).astype(np.float32) | ||||
| inp = inp[np.newaxis, :] | inp = inp[np.newaxis, :] | ||||
| @@ -646,6 +646,16 @@ def test_zero_stride_numpy_array(): | |||||
| out = F.conv2d(inp, weight, None, (2, 2), (3, 3), (1, 1), 1) | out = F.conv2d(inp, weight, None, (2, 2), (3, 3), (1, 1), 1) | ||||
| def test_conv3d_zero_stride_numpy_array(): | |||||
| inp = np.random.randn(3, 224, 224, 224).astype(np.float32) | |||||
| inp = inp[np.newaxis, :] | |||||
| inp = tensor(inp, dtype=np.float32) | |||||
| weight = tensor(np.random.randn(16, 3, 3, 3, 3), dtype=np.float32) | |||||
| out = F.conv3d(inp, weight, None, (2, 2, 2), (3, 3, 3), (1, 1, 1), 1) | |||||
| out.numpy() | |||||
| def test_conv1d(): | def test_conv1d(): | ||||
| inp = tensor(np.ones((16,), dtype=np.float32).reshape(2, 2, 4)) | inp = tensor(np.ones((16,), dtype=np.float32).reshape(2, 2, 4)) | ||||
| weight = tensor(np.ones((12,), dtype=np.float32).reshape(3, 2, 2)) | weight = tensor(np.ones((12,), dtype=np.float32).reshape(3, 2, 2)) | ||||
| @@ -658,6 +668,16 @@ def test_conv1d(): | |||||
| ) | ) | ||||
| def test_conv3d(): | |||||
| inp = tensor(np.ones((256,), dtype=np.float32).reshape(2, 2, 4, 4, 4)) | |||||
| weight = tensor(np.ones((48,), dtype=np.float32).reshape(3, 2, 2, 2, 2)) | |||||
| out = F.conv3d(inp, weight, None, 2, 0, 1, 1) | |||||
| print(out.numpy().shape) | |||||
| np.testing.assert_equal( | |||||
| out.numpy(), np.ones((2, 3, 2, 2, 2), dtype=np.float32) * 16 | |||||
| ) | |||||
| def test_condtake(): | def test_condtake(): | ||||
| x = np.array([[1, 2, 3], [4, 5, 6]]) | x = np.array([[1, 2, 3], [4, 5, 6]]) | ||||
| y = np.array([[True, False, True], [False, True, True]]) | y = np.array([[True, False, True], [False, True, True]]) | ||||
| @@ -0,0 +1,79 @@ | |||||
| /** | |||||
| * \file imperative/src/impl/ops/dnn/convolution.cpp | |||||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| * | |||||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, | |||||
| * software distributed under the License is distributed on an | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| */ | |||||
| #include "megbrain/imperative/ops/autogen.h" | |||||
| #include "megbrain/opr/dnn/convolution.h" | |||||
| #include "../op_trait.h" | |||||
| namespace mgb { | |||||
| namespace imperative { | |||||
| namespace { namespace convolution { | |||||
| std::shared_ptr<OpDef> make_from_op_node(cg::OperatorNodeBase* node_) { | |||||
| auto* node = &node_->cast_final_safe<opr::Convolution>(); | |||||
| return Convolution::make(node->param(), node->execution_policy()); | |||||
| } | |||||
| auto apply_on_var_node( | |||||
| const OpDef& def, | |||||
| const VarNodeArray& inputs) { | |||||
| auto&& conv = static_cast<const Convolution&>(def); | |||||
| OperatorNodeConfig config{conv.make_name()}; | |||||
| return opr::Convolution::make(inputs[0], inputs[1], conv.param(), conv.policy(), config); | |||||
| } | |||||
| OP_TRAIT_REG(Convolution, Convolution, opr::Convolution) | |||||
| .make_from_op_node(make_from_op_node) | |||||
| .apply_on_var_node(apply_on_var_node) | |||||
| .fallback(); | |||||
| }} // convolution | |||||
| namespace { namespace convolution_backward_data { | |||||
| auto apply_on_var_node( | |||||
| const OpDef& def, | |||||
| const VarNodeArray& inputs) { | |||||
| auto&& conv = static_cast<const ConvolutionBackwardData&>(def); | |||||
| OperatorNodeConfig config{conv.make_name()}; | |||||
| if (inputs.size() == 2) { | |||||
| return opr::ConvolutionBackwardData::make(inputs[0], inputs[1], conv.param(), conv.policy(), config); | |||||
| } else { | |||||
| mgb_assert(inputs.size() == 3); | |||||
| return opr::ConvolutionBackwardData::make(inputs[0], inputs[1], inputs[2], conv.param(), conv.policy(), config); | |||||
| } | |||||
| } | |||||
| OP_TRAIT_REG(ConvolutionBackwardData, ConvolutionBackwardData) | |||||
| .apply_on_var_node(apply_on_var_node) | |||||
| .fallback(); | |||||
| }} // convolution_backward_data | |||||
| namespace { namespace convolution3d { | |||||
| std::shared_ptr<OpDef> make_from_op_node(cg::OperatorNodeBase* node_) { | |||||
| auto* node = &node_->cast_final_safe<opr::Convolution3D>(); | |||||
| return Convolution3D::make(node->param(), node->execution_policy()); | |||||
| } | |||||
| auto apply_on_var_node( | |||||
| const OpDef& def, | |||||
| const VarNodeArray& inputs) { | |||||
| auto&& conv = static_cast<const Convolution3D&>(def); | |||||
| return opr::Convolution3D::make(inputs[0], inputs[1], conv.param(), conv.policy()); | |||||
| } | |||||
| OP_TRAIT_REG(Convolution3D, Convolution3D, opr::Convolution3D) | |||||
| .make_from_op_node(make_from_op_node) | |||||
| .apply_on_var_node(apply_on_var_node) | |||||
| .fallback(); | |||||
| }} // convolution3d | |||||
| } | |||||
| } | |||||
| @@ -36,45 +36,6 @@ | |||||
| namespace mgb::imperative { | namespace mgb::imperative { | ||||
| namespace { namespace convolution { | |||||
| std::shared_ptr<OpDef> make_from_op_node(cg::OperatorNodeBase* node_) { | |||||
| auto* node = &node_->cast_final_safe<opr::Convolution>(); | |||||
| return Convolution::make(node->param(), node->execution_policy()); | |||||
| } | |||||
| auto apply_on_var_node( | |||||
| const OpDef& def, | |||||
| const VarNodeArray& inputs) { | |||||
| auto&& conv = static_cast<const Convolution&>(def); | |||||
| OperatorNodeConfig config{conv.make_name()}; | |||||
| return opr::Convolution::make(inputs[0], inputs[1], conv.param(), conv.policy(), config); | |||||
| } | |||||
| OP_TRAIT_REG(Convolution, Convolution, opr::Convolution) | |||||
| .make_from_op_node(make_from_op_node) | |||||
| .apply_on_var_node(apply_on_var_node) | |||||
| .fallback(); | |||||
| }} // convolution | |||||
| namespace { namespace convolution_backward_data { | |||||
| auto apply_on_var_node( | |||||
| const OpDef& def, | |||||
| const VarNodeArray& inputs) { | |||||
| auto&& conv = static_cast<const ConvolutionBackwardData&>(def); | |||||
| OperatorNodeConfig config{conv.make_name()}; | |||||
| if (inputs.size() == 2) { | |||||
| return opr::ConvolutionBackwardData::make(inputs[0], inputs[1], conv.param(), conv.policy(), config); | |||||
| } else { | |||||
| mgb_assert(inputs.size() == 3); | |||||
| return opr::ConvolutionBackwardData::make(inputs[0], inputs[1], inputs[2], conv.param(), conv.policy(), config); | |||||
| } | |||||
| } | |||||
| OP_TRAIT_REG(ConvolutionBackwardData, ConvolutionBackwardData) | |||||
| .apply_on_var_node(apply_on_var_node) | |||||
| .fallback(); | |||||
| }} // convolution_backward_data | |||||
| namespace { namespace dimshuffle { | namespace { namespace dimshuffle { | ||||
| std::shared_ptr<OpDef> make_from_op_node(cg::OperatorNodeBase* node_) { | std::shared_ptr<OpDef> make_from_op_node(cg::OperatorNodeBase* node_) { | ||||
| auto* node = &node_->cast_final_safe<opr::Dimshuffle>(); | auto* node = &node_->cast_final_safe<opr::Dimshuffle>(); | ||||
| @@ -51,6 +51,8 @@ def Convolution : MgbHashableOp<"Convolution", [ConvolutionParam, ExecutionPolic | |||||
| def ConvolutionBackwardData: MgbHashableOp<"ConvolutionBackwardData", [ConvolutionParam, ExecutionPolicyParamBase<"policy">]>; | def ConvolutionBackwardData: MgbHashableOp<"ConvolutionBackwardData", [ConvolutionParam, ExecutionPolicyParamBase<"policy">]>; | ||||
| def Convolution3D: MgbHashableOp<"Convolution3D", [Convolution3DParam, ExecutionPolicyParamBase<"policy">]>; | |||||
| def DeformableConv : MgbHashableOp<"DeformableConv", [ConvolutionParam, ExecutionPolicyParamBase<"policy">]>; | def DeformableConv : MgbHashableOp<"DeformableConv", [ConvolutionParam, ExecutionPolicyParamBase<"policy">]>; | ||||
| def GroupLocal: MgbHashableOp<"GroupLocal", [ConvolutionParam]>; | def GroupLocal: MgbHashableOp<"GroupLocal", [ConvolutionParam]>; | ||||