GitOrigin-RevId: c0106ade08
tags/v1.11.1
| @@ -91,6 +91,7 @@ __all__ = [ | |||||
| "warp_affine", | "warp_affine", | ||||
| "warp_perspective", | "warp_perspective", | ||||
| "pixel_shuffle", | "pixel_shuffle", | ||||
| "region_restricted_conv", | |||||
| ] | ] | ||||
| @@ -1213,10 +1214,10 @@ def layer_norm( | |||||
| ): | ): | ||||
| r"""Applies layer normalization to the input. Support tensor of any shape as input. | r"""Applies layer normalization to the input. Support tensor of any shape as input. | ||||
| Reference: https://arxiv.org/pdf/1803.08494.pdf. | Reference: https://arxiv.org/pdf/1803.08494.pdf. | ||||
| Args: | Args: | ||||
| inp: input tensor. | inp: input tensor. | ||||
| normalized_shape: the shape that you want to be normalizated | |||||
| normalized_shape: the shape that you want to be normalizated | |||||
| affine: whether to use weight and bias | affine: whether to use weight and bias | ||||
| weight: must not be None when the affine is true | weight: must not be None when the affine is true | ||||
| bias: must not be None when the affine is true | bias: must not be None when the affine is true | ||||
| @@ -1974,6 +1975,61 @@ def pixel_shuffle(inp: Tensor, upscale_factor: int) -> Tensor: | |||||
| return pixel_shuffle_cpp(inp, upscale_factor, _layerPixelShuffle_traceable) | return pixel_shuffle_cpp(inp, upscale_factor, _layerPixelShuffle_traceable) | ||||
| def region_restricted_conv( | |||||
| inp: Tensor, | |||||
| weight: Tensor, | |||||
| rin: Tensor, | |||||
| rout: Tensor, | |||||
| stride: Union[int, Tuple[int, int, int]] = 1, | |||||
| padding: Union[int, Tuple[int, int, int]] = 0, | |||||
| dilation: Union[int, Tuple[int, int, int]] = 1, | |||||
| groups: int = 1, | |||||
| conv_mode: str = "cross_correlation", | |||||
| compute_mode="default", | |||||
| ) -> Tensor: | |||||
| r"""Region Restricted convolution operation. | |||||
| Refer to :class:`~.RegionRestrictedConv` for more information. | |||||
| Args: | |||||
| inp: feature map of the convolution operation. | |||||
| weight: convolution kernel. | |||||
| stride: stride of the 2D region restricted convolution operation. Default: 1 | |||||
| padding: size of the paddings added to the input on both sides of its | |||||
| spatial dimensions. Only zero-padding is supported. Default: 0 | |||||
| dilation: dilation of the 2D convolution operation. Default: 1 | |||||
| groups: number of groups into which the input and output channels are divided, | |||||
| so as to perform a ``grouped convolution``. When ``groups`` is not 1, | |||||
| ``in_channels`` and ``out_channels`` must be divisible by ``groups``, | |||||
| and the shape of weight should be ``(groups, out_channel // groups, | |||||
| in_channels // groups, depth, height, width)``. Default: 1 | |||||
| conv_mode: supports "cross_correlation". Default: "cross_correlation" | |||||
| Returns: | |||||
| output tensor. | |||||
| """ | |||||
| assert conv_mode.lower() == "cross_correlation" | |||||
| pad_h, pad_w = _expand_hw(padding) | |||||
| stride_h, stride_w = _expand_hw(stride) | |||||
| dilate_h, dilate_w = _expand_hw(dilation) | |||||
| sparse_type = "dense" if groups == 1 else "group" | |||||
| op = builtin.RegionRestrictedConvolution( | |||||
| stride_h=stride_h, | |||||
| stride_w=stride_w, | |||||
| pad_h=pad_h, | |||||
| pad_w=pad_w, | |||||
| dilate_h=dilate_h, | |||||
| dilate_w=dilate_w, | |||||
| mode=conv_mode, | |||||
| compute_mode=compute_mode, | |||||
| sparse=sparse_type, | |||||
| ) | |||||
| (output,) = apply(op, inp, weight, rin, rout) | |||||
| return output | |||||
| from .quantized import conv_bias_activation # isort:skip | from .quantized import conv_bias_activation # isort:skip | ||||
| from .loss import * # isort:skip | from .loss import * # isort:skip | ||||
| from .vision import * # isort:skip | from .vision import * # isort:skip | ||||
| @@ -14,6 +14,7 @@ from .conv import ( | |||||
| ConvTranspose3d, | ConvTranspose3d, | ||||
| DeformableConv2d, | DeformableConv2d, | ||||
| LocalConv2d, | LocalConv2d, | ||||
| RegionRestrictedConv, | |||||
| ) | ) | ||||
| from .conv_bn import ConvBn2d, ConvBnRelu2d | from .conv_bn import ConvBn2d, ConvBnRelu2d | ||||
| from .deformable_psroi_pooling import DeformablePSROIPooling | from .deformable_psroi_pooling import DeformablePSROIPooling | ||||
| @@ -12,6 +12,7 @@ from ..functional import ( | |||||
| deformable_conv2d, | deformable_conv2d, | ||||
| local_conv2d, | local_conv2d, | ||||
| pad, | pad, | ||||
| region_restricted_conv, | |||||
| relu, | relu, | ||||
| ) | ) | ||||
| from ..tensor import Parameter | from ..tensor import Parameter | ||||
| @@ -982,3 +983,174 @@ class ConvTranspose3d(_ConvNd): | |||||
| self.output_padding, | self.output_padding, | ||||
| self.dilation, | self.dilation, | ||||
| ) | ) | ||||
| class RegionRestrictedConv(_ConvNd): | |||||
| r"""Applies a 2D RegionRestricted Convolution over an input tensor. | |||||
| For instance, given an input of the size :math:`(N, C_{\text{in}}, H, W)`, | |||||
| this layer generates an output of the size | |||||
| :math:`(N, C_{\text{out}}, H_{\text{out}}, W_{\text{out}})` through the | |||||
| process described as below: | |||||
| .. math:: | |||||
| \text{out}(N_i, C_{\text{out}_j}) = | |||||
| \sum_{k = 0}^{C_{\text{in}} - 1} \text{weight}(C_{\text{out}_j}, k) \star \text{input}(N_i, k) | |||||
| where :math:`\star` is the valid 2D cross-correlation operator, | |||||
| :math:`N` is batch size, :math:`C` denotes number of channels, | |||||
| :math:`H` is height of input planes in pixels, and :math:`W` is | |||||
| width in pixels. | |||||
| In general, output feature maps' shapes can be inferred as follows: | |||||
| input: :math:`(N, C_{\text{in}}, H_{\text{in}}, W_{\text{in}})` | |||||
| output: :math:`(N, C_{\text{out}}, H_{\text{out}}, W_{\text{out}})` where | |||||
| .. math:: | |||||
| \text{H}_{out} = \lfloor \frac{\text{H}_{in} + 2 * \text{padding[0]} - | |||||
| \text{dilation[0]} * (\text{kernel_size[0]} - 1) - 1}{\text{stride[0]}} + 1 \rfloor | |||||
| .. math:: | |||||
| \text{W}_{out} = \lfloor \frac{\text{W}_{in} + 2 * \text{padding[1]} - | |||||
| \text{dilation[1]} * (\text{kernel_size[1]} - 1) - 1}{\text{stride[1]}} + 1 \rfloor | |||||
| When `groups == in_channels` and `out_channels == K * in_channels`, | |||||
| where K is a positive integer, this operation is also known as depthwise | |||||
| convolution. | |||||
| In other words, for an input of size :math:`(N, C_{in}, H_{in}, W_{in})`, | |||||
| a depthwise convolution with a depthwise multiplier `K`, can be constructed | |||||
| by arguments :math:`(in\_channels=C_{in}, out\_channels=C_{in} \times K, ..., groups=C_{in})`. | |||||
| Args: | |||||
| in_channels: number of input channels. | |||||
| out_channels: number of output channels. | |||||
| kernel_size: size of weight on spatial dimensions. If kernel_size is | |||||
| an :class:`int`, the actual kernel size would be | |||||
| ``(kernel_size, kernel_size)``. | |||||
| stride: stride of the 2D convolution operation. Default: 1 | |||||
| padding: size of the paddings added to the input on both sides of its | |||||
| spatial dimensions. Default: 0 | |||||
| dilation: dilation of the 2D convolution operation. Default: 1 | |||||
| groups: number of groups into which the input and output channels are divided, | |||||
| so as to perform a ``grouped convolution``. When ``groups`` is not 1, | |||||
| ``in_channels`` and ``out_channels`` must be divisible by ``groups``, | |||||
| and the shape of weight should be ``(groups, out_channel // groups, | |||||
| in_channels // groups, height, width)``. Default: 1 | |||||
| conv_mode: Supports `cross_correlation`. Default: `cross_correlation` | |||||
| compute_mode: When set to "default", no special requirements will be | |||||
| placed on the precision of intermediate results. When set to "float32", | |||||
| "float32" would be used for accumulator and intermediate result, but only | |||||
| effective when input and output are of float16 dtype. | |||||
| padding_mode: "zeros", "reflect" or "replicate". Default: "zeros". | |||||
| Refer to :class:`~.module.padding.Pad` for more information. | |||||
| Note: | |||||
| * ``weight`` usually has shape ``(out_channels, in_channels, height, width)`` , | |||||
| if groups is not 1, shape will be ``(groups, out_channels // groups, in_channels // groups, height, width)`` | |||||
| Examples: | |||||
| >>> import numpy as np | |||||
| >>> import megengine as mge | |||||
| >>> import megengine.module as M | |||||
| >>> rrconv = M.RegionRestrictedConv(in_channels=2, out_channels=2, kernel_size=2, groups=2) | |||||
| >>> inp = mge.tensor(np.random.randn(1, 2, 2, 2).astype(np.float32)) | |||||
| >>> rin = mge.tensor(np.random.randn(1, 2, 2).astype(np.int32)) | |||||
| >>> rout = mge.tensor(np.random.randn(1, 1, 1).astype(np.int32)) | |||||
| >>> oup = rrconv(inp, rin, rout) | |||||
| >>> oup.numpy().shape | |||||
| (1, 2, 1, 1) | |||||
| """ | |||||
| def __init__( | |||||
| self, | |||||
| in_channels: int, | |||||
| out_channels: int, | |||||
| kernel_size: Union[int, Tuple[int, int]], | |||||
| groups: int, | |||||
| stride: Union[int, Tuple[int, int]] = 1, | |||||
| padding: Union[int, Tuple[int, int]] = 0, | |||||
| dilation: Union[int, Tuple[int, int]] = 1, | |||||
| conv_mode: str = "cross_correlation", | |||||
| compute_mode: str = "default", | |||||
| padding_mode: str = "zeros", | |||||
| **kwargs | |||||
| ): | |||||
| kernel_size = _pair_nonzero(kernel_size) | |||||
| stride = _pair_nonzero(stride) | |||||
| padding = _pair(padding) | |||||
| dilation = _pair_nonzero(dilation) | |||||
| self.conv_mode = conv_mode | |||||
| self.compute_mode = compute_mode | |||||
| self.padding_mode = padding_mode | |||||
| super().__init__( | |||||
| in_channels, | |||||
| out_channels, | |||||
| kernel_size, | |||||
| stride, | |||||
| padding, | |||||
| 0, | |||||
| dilation, | |||||
| groups, | |||||
| False, | |||||
| **kwargs, | |||||
| ) | |||||
| def _get_fanin(self): | |||||
| kh, kw = self.kernel_size | |||||
| ic = self.in_channels | |||||
| return kh * kw * ic | |||||
| def _infer_weight_shape(self): | |||||
| group = self.groups | |||||
| ichl = self.in_channels | |||||
| ochl = self.out_channels | |||||
| kh, kw = self.kernel_size | |||||
| if group == 1: | |||||
| # Assume format is NCHW | |||||
| return (ochl, ichl, kh, kw) | |||||
| assert ( | |||||
| ichl % group == 0 and ochl % group == 0 | |||||
| ), "invalid config: in_channels={} out_channels={} group={}".format( | |||||
| ichl, ochl, group | |||||
| ) | |||||
| # Assume format is NCHW | |||||
| return (group, ochl // group, ichl // group, kh, kw) | |||||
| def _infer_bias_shape(self): | |||||
| # Assume format is NCHW | |||||
| return (1, self.out_channels, 1, 1) | |||||
| def get_pad_width(self): | |||||
| return ( | |||||
| (0, 0), | |||||
| (0, 0), | |||||
| (self.padding[0], self.padding[0]), | |||||
| (self.padding[1], self.padding[1]), | |||||
| ) | |||||
| def calc_conv(self, inp, weight, rin, rout): | |||||
| assert self.padding_mode in [ | |||||
| "zeros", | |||||
| "reflect", | |||||
| "replicate", | |||||
| ] | |||||
| return region_restricted_conv( | |||||
| inp, | |||||
| weight, | |||||
| rin, | |||||
| rout, | |||||
| self.stride, | |||||
| self.padding, | |||||
| self.dilation, | |||||
| self.groups, | |||||
| self.conv_mode, | |||||
| self.compute_mode, | |||||
| ) | |||||
| def forward(self, inp, rin, rout): | |||||
| return self.calc_conv(inp, self.weight, rin, rout) | |||||
| @@ -930,6 +930,179 @@ def test_batch_conv_bias(): | |||||
| run(1, 4, 4, 5, 5, 3, 3, 0, 0, 1, 1, True) | run(1, 4, 4, 5, 5, 3, 3, 0, 0, 1, 1, True) | ||||
| def test_region_restricted_conv_forward_backward_naive(): | |||||
| import megengine as mge | |||||
| import megengine.module as M | |||||
| from megengine.autodiff import GradManager | |||||
| handle = "cpu0" | |||||
| src_1 = np.arange(8).reshape(1, 2, 2, 2).astype(np.float32) | |||||
| filter_1 = np.arange(8).reshape(2, 1, 1, 2, 2).astype(np.float32) | |||||
| rin_1 = np.array([1, 1, 1, 1]).reshape(1, 2, 2).astype(np.int32) | |||||
| rout_1 = np.array([1]).reshape(1, 1, 1).astype(np.int32) | |||||
| cpu_src = tensor(src_1, device=handle) | |||||
| cpu_filter = tensor(filter_1, device=handle) | |||||
| gm = GradManager().attach([cpu_src, cpu_filter]) | |||||
| with gm: | |||||
| cpu_out = F.region_restricted_conv( | |||||
| cpu_src, | |||||
| cpu_filter, | |||||
| tensor(rin_1, device=handle), | |||||
| tensor(rout_1, device=handle), | |||||
| groups=2, | |||||
| ) | |||||
| gm.backward(cpu_out, tensor(np.ones((1, 2, 1, 1)), device=handle)) | |||||
| np.testing.assert_allclose( | |||||
| cpu_src.grad, np.array([0, 1, 2, 3, 4, 5, 6, 7]).reshape(1, 2, 2, 2) | |||||
| ) | |||||
| np.testing.assert_allclose( | |||||
| cpu_filter.grad, np.array([0, 1, 2, 3, 4, 5, 6, 7]).reshape(2, 1, 1, 2, 2) | |||||
| ) | |||||
| @pytest.mark.skipif( | |||||
| not is_cuda_available(), reason="rrconv cuda kernel requires cuda available" | |||||
| ) | |||||
| def test_region_restricted_conv_forward_backward_cuda(): | |||||
| import megengine as mge | |||||
| import megengine.module as M | |||||
| from megengine.autodiff import GradManager | |||||
| import megengine.distributed as dist | |||||
| # params | |||||
| handle = "gpu0" | |||||
| N = 1 | |||||
| GROUP = 3 | |||||
| FH = FW = 2 | |||||
| IH = IW = 2 | |||||
| OH = OW = 1 | |||||
| ICPG = OCPG = 1 | |||||
| grad_shape = (N, GROUP * ICPG, IH, IW) | |||||
| src_shape = grad_shape | |||||
| filter_shape = (GROUP, OCPG, ICPG, FH, FW) | |||||
| diff_shape = (N, GROUP * OCPG, OH, OW) | |||||
| rin_shape = (N, IH, IW) | |||||
| rout_shape = (N, OH, OW) | |||||
| def reduce(shape): | |||||
| mul = 1 | |||||
| for x in shape: | |||||
| mul *= x | |||||
| return mul | |||||
| def get_groundtruth(): | |||||
| src = tensor( | |||||
| np.arange(reduce(src_shape)).reshape(src_shape).astype(np.float32), | |||||
| device="cpu0", | |||||
| ) | |||||
| filter = tensor(np.ones(filter_shape).astype(np.float32), device="cpu0") | |||||
| rin = tensor(np.ones(rin_shape).astype(np.int32), device="cpu0") | |||||
| rout = tensor(np.ones(rout_shape).astype(np.int32), device="cpu0") | |||||
| gm = GradManager().attach([src, filter]) | |||||
| with gm: | |||||
| expected_out = F.region_restricted_conv( | |||||
| src, filter, rin, rout, groups=GROUP | |||||
| ) | |||||
| gm.backward( | |||||
| expected_out, | |||||
| tensor(np.ones(diff_shape, dtype=np.float32), device="cpu0"), | |||||
| ) | |||||
| return src, filter | |||||
| expected_src, expected_filter = get_groundtruth() | |||||
| src = tensor( | |||||
| np.arange(reduce(src_shape)).reshape(src_shape).astype(np.float32), | |||||
| device=handle, | |||||
| ) | |||||
| filter = tensor(np.ones(filter_shape).astype(np.float32), device=handle) | |||||
| rin = tensor(np.ones(rin_shape).astype(np.int32), device=handle) | |||||
| rout = tensor(np.ones(rout_shape).astype(np.int32), device=handle) | |||||
| gm = GradManager().attach([src, filter]) | |||||
| with gm: | |||||
| gpu_out = F.region_restricted_conv(src, filter, rin, rout, groups=GROUP) | |||||
| gm.backward(gpu_out, tensor(np.ones(diff_shape), device=handle)) | |||||
| np.testing.assert_allclose(src.grad, expected_src.grad) | |||||
| np.testing.assert_allclose(filter.grad, expected_filter.grad) | |||||
| @pytest.mark.skipif( | |||||
| not is_cuda_available(), reason="rrconv cuda kernel requires cuda available" | |||||
| ) | |||||
| def test_region_restricted_conv_forward_backward_uint8(): | |||||
| import megengine as mge | |||||
| import megengine.module as M | |||||
| from megengine.autodiff import GradManager | |||||
| # params | |||||
| handle = "gpu0" | |||||
| N = 1 | |||||
| GROUP = 2 | |||||
| FH = FW = 1 | |||||
| IH = IW = 4 | |||||
| OH = OW = 4 | |||||
| ICPG = OCPG = 1 | |||||
| grad_shape = (N, GROUP * ICPG, IH, IW) | |||||
| src_shape = grad_shape | |||||
| filter_shape = (GROUP, OCPG, ICPG, FH, FW) | |||||
| diff_shape = (N, GROUP * OCPG, OH, OW) | |||||
| rin_shape = (N, IH, IW) | |||||
| rout_shape = (N, OH, OW) | |||||
| def reduce(shape): | |||||
| mul = 1 | |||||
| for x in shape: | |||||
| mul *= x | |||||
| return mul | |||||
| def get_groundtruth(): | |||||
| src = tensor( | |||||
| np.arange(reduce(src_shape)).reshape(src_shape).astype(np.float32), | |||||
| device="cpu0", | |||||
| ) | |||||
| filter = tensor(np.ones(filter_shape).astype(np.float32), device="cpu0") | |||||
| rin = tensor(np.ones(rin_shape).astype(np.int32), device="cpu0") | |||||
| rout = tensor(np.ones(rout_shape).astype(np.int32), device="cpu0") | |||||
| gm = GradManager().attach([src, filter]) | |||||
| with gm: | |||||
| expected_out = F.region_restricted_conv( | |||||
| src, filter, rin, rout, groups=GROUP | |||||
| ) | |||||
| gm.backward( | |||||
| expected_out, | |||||
| tensor(np.ones(diff_shape, dtype=np.float32), device="cpu0"), | |||||
| ) | |||||
| return src, filter | |||||
| expected_src, expected_filter = get_groundtruth() | |||||
| # forward and dgrad/wgrad | |||||
| src = tensor( | |||||
| np.arange(reduce(src_shape)).reshape(src_shape).astype(np.float32), | |||||
| device=handle, | |||||
| ) | |||||
| filter = tensor(np.ones(filter_shape).astype(np.float32), device=handle) | |||||
| rin = tensor(np.ones(rin_shape).astype(np.uint8), device=handle) | |||||
| rout = tensor(np.ones(rout_shape).astype(np.uint8), device=handle) | |||||
| gm = GradManager().attach([src, filter]) | |||||
| with gm: | |||||
| gpu_out = F.region_restricted_conv(src, filter, rin, rout, groups=GROUP) | |||||
| gm.backward( | |||||
| gpu_out, tensor(np.ones(diff_shape, dtype=np.float32), device=handle) | |||||
| ) | |||||
| # assert uint8 gpu result close to cpu result | |||||
| np.testing.assert_allclose(src.grad, expected_src.grad) | |||||
| np.testing.assert_allclose(filter.grad, expected_filter.grad) | |||||
| def test_region_restricted_conv(): | |||||
| test_region_restricted_conv_forward_backward_naive() | |||||
| if is_cuda_available(): | |||||
| test_region_restricted_conv_forward_backward_cuda() | |||||
| test_region_restricted_conv_forward_backward_uint8() | |||||
| def test_conv2d_autocast(): | def test_conv2d_autocast(): | ||||
| """check amp's result is equal to manually converted result""" | """check amp's result is equal to manually converted result""" | ||||
| amp.enabled = True | amp.enabled = True | ||||
| @@ -3,9 +3,11 @@ | |||||
| #include "../blob_manager_impl.h" | #include "../blob_manager_impl.h" | ||||
| #include "../dnn_op_helper.h" | #include "../dnn_op_helper.h" | ||||
| #include "../op_trait.h" | #include "../op_trait.h" | ||||
| #include "megbrain/common.h" | |||||
| #include "megbrain/imperative/ops/autogen.h" | #include "megbrain/imperative/ops/autogen.h" | ||||
| #include "megbrain/opr/internal/megdnn_opr_wrapper.h" | #include "megbrain/opr/internal/megdnn_opr_wrapper.h" | ||||
| #include "megbrain/opr/tensor_gen.h" | #include "megbrain/opr/tensor_gen.h" | ||||
| #include "megdnn/oprs/nn.h" | |||||
| namespace mgb { | namespace mgb { | ||||
| namespace imperative { | namespace imperative { | ||||
| @@ -356,5 +358,174 @@ OP_TRAIT_REG(Convolution3DBackwardData, Convolution3DBackwardData) | |||||
| } // namespace convolution3d_backward_data | } // namespace convolution3d_backward_data | ||||
| } // namespace | } // namespace | ||||
| namespace { | |||||
| namespace region_restricted_conv { | |||||
| std::shared_ptr<OpDef> make_from_op_node(cg::OperatorNodeBase* node_) { | |||||
| auto* node = &node_->cast_final_safe<opr::RegionRestrictedConvolution>(); | |||||
| return RegionRestrictedConvolution::make(node->param()); | |||||
| } | |||||
| auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | |||||
| auto&& conv = static_cast<const RegionRestrictedConvolution&>(def); | |||||
| OperatorNodeConfig config{conv.make_name()}; | |||||
| return opr::RegionRestrictedConvolution::make( | |||||
| inputs[0], inputs[1], inputs[2], inputs[3], conv.param(), config); | |||||
| } | |||||
| std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( | |||||
| const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs) { | |||||
| auto&& region_restricted_conv = | |||||
| def.cast_final_safe<mgb::imperative::RegionRestrictedConvolution>(); | |||||
| DnnOprHelper<megdnn::RegionRestrictedConvolutionForward> dnn_opr( | |||||
| region_restricted_conv.param()); | |||||
| auto&& src = inputs[0].layout; | |||||
| auto&& filter = inputs[1].layout; | |||||
| auto&& rin = inputs[2].layout; | |||||
| auto&& rout = inputs[3].layout; | |||||
| TensorLayout output_layout{src.dtype}; | |||||
| if (src.ndim && filter.ndim) { | |||||
| dnn_opr.opr().deduce_layout(src, filter, rin, rout, output_layout); | |||||
| } | |||||
| return {{{output_layout, inputs[0].comp_node}}, output_layout.ndim != 0}; | |||||
| } | |||||
| SmallVector<TensorPtr> apply_on_physical_tensor( | |||||
| const OpDef& def, const SmallVector<TensorPtr>& inputs, | |||||
| SmallVector<LogicalTensorDesc>& output_descs, const bool& validated) { | |||||
| // create megdnn opr | |||||
| auto&& region_restricted_conv = def.cast_final_safe<RegionRestrictedConvolution>(); | |||||
| CompNode cn = inputs[0]->comp_node(); | |||||
| auto&& param = region_restricted_conv.param(); | |||||
| DnnOprCaller<megdnn::RegionRestrictedConvolutionForward> dnn_opr(cn, param); | |||||
| auto srclo = inputs[0]->layout(); | |||||
| auto filterlo = inputs[1]->layout(); | |||||
| auto rinlo = inputs[2]->layout(); | |||||
| auto routlo = inputs[3]->layout(); | |||||
| auto out_layout = [&] { | |||||
| if (validated) { | |||||
| return output_descs[0].layout; | |||||
| } else { | |||||
| TensorLayout out_layout{inputs[0]->dtype()}; | |||||
| dnn_opr.op()->deduce_layout(srclo, filterlo, rinlo, routlo, out_layout); | |||||
| return out_layout; | |||||
| } | |||||
| }(); | |||||
| auto out = Tensor::make(out_layout, cn); | |||||
| dnn_opr.exec_with_ws(inputs[0], inputs[1], inputs[2], inputs[3], out); | |||||
| return {out}; | |||||
| } | |||||
| OP_TRAIT_REG( | |||||
| RegionRestrictedConvolution, RegionRestrictedConvolution, | |||||
| opr::RegionRestrictedConvolution) | |||||
| .make_from_op_node(make_from_op_node) | |||||
| .apply_on_var_node(apply_on_var_node) | |||||
| .infer_output_attrs_fallible(infer_output_attrs_fallible) | |||||
| .apply_on_physical_tensor(apply_on_physical_tensor) | |||||
| .fallback(); | |||||
| } // namespace region_restricted_conv | |||||
| } // namespace | |||||
| namespace { | |||||
| namespace region_restricted_conv_backward_data { | |||||
| std::shared_ptr<OpDef> make_from_op_node(cg::OperatorNodeBase* node_) { | |||||
| auto* node = | |||||
| &node_->cast_final_safe<opr::RegionRestrictedConvolutionBackwardData>(); | |||||
| return RegionRestrictedConvolutionBackwardData::make(node->param()); | |||||
| } | |||||
| auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | |||||
| auto&& conv = static_cast<const RegionRestrictedConvolutionBackwardData&>(def); | |||||
| OperatorNodeConfig config{conv.make_name()}; | |||||
| // output_dtype may infered from input within rrconv bwd data(deduce_dtype api) | |||||
| CompNode cn = inputs[0]->comp_node(); | |||||
| DType output_dtype; | |||||
| DnnOprCaller<megdnn::RegionRestrictedConvolutionBackwardData> dnn_opr(cn); | |||||
| dnn_opr.op()->deduce_dtype( | |||||
| inputs[0]->dtype(), inputs[1]->dtype(), inputs[2]->dtype(), | |||||
| inputs[3]->dtype(), output_dtype); | |||||
| if (output_dtype.valid()) | |||||
| config.output_dtype(output_dtype); | |||||
| if (inputs.size() == 4) { | |||||
| return opr::RegionRestrictedConvolutionBackwardData::make( | |||||
| inputs[0], inputs[1], inputs[2], inputs[3], conv.param(), config); | |||||
| } else if (inputs.size() == 5) { | |||||
| return opr::RegionRestrictedConvolutionBackwardData::make( | |||||
| inputs[0], inputs[1], inputs[2], inputs[3], inputs[4], conv.param(), | |||||
| config); | |||||
| } | |||||
| mgb_assert(0); | |||||
| } | |||||
| std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( | |||||
| const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs) { | |||||
| auto&& convbwd = def.cast_final_safe< | |||||
| mgb::imperative::RegionRestrictedConvolutionBackwardData>(); | |||||
| DnnOprHelper<megdnn::RegionRestrictedConvolutionBackwardData> dnn_opr( | |||||
| convbwd.param()); | |||||
| TensorLayout filter = inputs[0].layout; | |||||
| TensorLayout diff = inputs[1].layout; | |||||
| TensorLayout rin = inputs[2].layout; | |||||
| TensorLayout rout = inputs[3].layout; | |||||
| DType output_dtype; | |||||
| dnn_opr.opr().deduce_dtype( | |||||
| inputs[0].layout.dtype, inputs[1].layout.dtype, inputs[2].layout.dtype, | |||||
| inputs[3].layout.dtype, output_dtype); | |||||
| TensorLayout output_layout{output_dtype}; | |||||
| if (diff.ndim && filter.ndim) { | |||||
| dnn_opr.opr().deduce_layout(filter, diff, rin, rout, output_layout); | |||||
| } | |||||
| return {{{output_layout, inputs[0].comp_node}}, output_layout.ndim != 0}; | |||||
| } | |||||
| SmallVector<TensorPtr> apply_on_physical_tensor( | |||||
| const OpDef& def, const SmallVector<TensorPtr>& inputs, | |||||
| SmallVector<LogicalTensorDesc>& output_descs, const bool& validated) { | |||||
| auto&& convbwd = def.cast_final_safe<RegionRestrictedConvolutionBackwardData>(); | |||||
| CompNode cn = inputs[0]->comp_node(); | |||||
| DnnOprCaller<megdnn::RegionRestrictedConvolutionBackwardData> dnn_opr( | |||||
| cn, convbwd.param()); | |||||
| auto filterlo = inputs[0]->layout(); | |||||
| auto difflo = inputs[1]->layout(); | |||||
| auto rinlo = inputs[2]->layout(); | |||||
| auto routlo = inputs[3]->layout(); | |||||
| auto out_layout = [&] { | |||||
| if (validated) { | |||||
| return output_descs[0].layout; | |||||
| } else { | |||||
| TensorLayout out_layout{inputs[0]->dtype()}; | |||||
| dnn_opr.op()->deduce_layout(filterlo, difflo, rinlo, routlo, out_layout); | |||||
| return out_layout; | |||||
| } | |||||
| }(); | |||||
| auto out = Tensor::make(out_layout, cn); | |||||
| dnn_opr.exec_with_ws(inputs[0], inputs[1], inputs[2], inputs[3], out); | |||||
| return {out}; | |||||
| } | |||||
| OP_TRAIT_REG( | |||||
| RegionRestrictedConvolutionBackwardData, | |||||
| RegionRestrictedConvolutionBackwardData, | |||||
| opr::RegionRestrictedConvolutionBackwardData) | |||||
| .make_from_op_node(make_from_op_node) | |||||
| .apply_on_var_node(apply_on_var_node) | |||||
| .infer_output_attrs_fallible(infer_output_attrs_fallible) | |||||
| .apply_on_physical_tensor(apply_on_physical_tensor) | |||||
| .fallback(); | |||||
| } // namespace region_restricted_conv_backward_data | |||||
| } // namespace | |||||
| } // namespace imperative | } // namespace imperative | ||||
| } // namespace mgb | } // namespace mgb | ||||
| @@ -1,7 +1,7 @@ | |||||
| 905bdf78e5413b06873be64b4ba55db9 ../../dnn/scripts/opr_param_defs.py | 905bdf78e5413b06873be64b4ba55db9 ../../dnn/scripts/opr_param_defs.py | ||||
| 40708c56b1f05fdb7d06cc097a300330 ../../src/core/include/megbrain/ir/ops.td | |||||
| 9f3af118c7fe8d0c9db433825d5ad77b generated/opdef.h.inl | |||||
| 4041e44a8ba3cca3b3affa1ed9ed44a2 generated/opdef.cpp.inl | |||||
| 319e1d170c989fe793a4e9c45decefc4 generated/opdef.py.inl | |||||
| 26a18a7593566128ecce76e8f74dcc5d generated/opdef.cpy.inl | |||||
| da03ffe2a15411f902cd88920d3d47ec ../../src/core/include/megbrain/ir/ops.td | |||||
| 5756619f37e4dc130e1b049d7706d4eb generated/opdef.h.inl | |||||
| 98d1291eed73970ee087f898b6241358 generated/opdef.cpp.inl | |||||
| b1a9c7569392942294c2168d40939eb5 generated/opdef.py.inl | |||||
| 3d88d5358d15a39219957f5257e32f5b generated/opdef.cpy.inl | |||||
| 71e1462bf4d882e2615c3c632cb671cc generated/enum_macro.h | 71e1462bf4d882e2615c3c632cb671cc generated/enum_macro.h | ||||
| @@ -5694,6 +5694,310 @@ OP_TRAIT_REG(Reduce, Reduce) | |||||
| .props(Reduce_props_impl) | .props(Reduce_props_impl) | ||||
| .make_name(Reduce_make_name_impl); | .make_name(Reduce_make_name_impl); | ||||
| MGB_DYN_TYPE_OBJ_FINAL_IMPL(RegionRestrictedConvolution); | |||||
| namespace { | |||||
| size_t RegionRestrictedConvolution_hash_impl(const OpDef& def_) { | |||||
| auto&& op_ = def_.cast_final_safe<RegionRestrictedConvolution>(); | |||||
| static_cast<void>(op_); | |||||
| size_t val = mgb::hash(op_.dyn_typeinfo()); | |||||
| val = mgb::hash_pair_combine(val, mgb::enumhash()(op_.mode)); | |||||
| val = mgb::hash_pair_combine(val, mgb::hash(op_.pad_h)); | |||||
| val = mgb::hash_pair_combine(val, mgb::hash(op_.pad_w)); | |||||
| val = mgb::hash_pair_combine(val, mgb::hash(op_.stride_h)); | |||||
| val = mgb::hash_pair_combine(val, mgb::hash(op_.stride_w)); | |||||
| val = mgb::hash_pair_combine(val, mgb::hash(op_.dilate_h)); | |||||
| val = mgb::hash_pair_combine(val, mgb::hash(op_.dilate_w)); | |||||
| val = mgb::hash_pair_combine(val, mgb::enumhash()(op_.sparse)); | |||||
| val = mgb::hash_pair_combine(val, mgb::enumhash()(op_.format)); | |||||
| val = mgb::hash_pair_combine(val, mgb::enumhash()(op_.compute_mode)); | |||||
| return val; | |||||
| } | |||||
| bool RegionRestrictedConvolution_is_same_st_impl(const OpDef& lhs_, const OpDef& rhs_) { | |||||
| auto &&a_ = lhs_.cast_final_safe<RegionRestrictedConvolution>(), | |||||
| &&b_ = rhs_.cast_final_safe<RegionRestrictedConvolution>(); | |||||
| static_cast<void>(a_); | |||||
| static_cast<void>(b_); | |||||
| if (a_.mode != b_.mode) return false; | |||||
| if (a_.pad_h != b_.pad_h) return false; | |||||
| if (a_.pad_w != b_.pad_w) return false; | |||||
| if (a_.stride_h != b_.stride_h) return false; | |||||
| if (a_.stride_w != b_.stride_w) return false; | |||||
| if (a_.dilate_h != b_.dilate_h) return false; | |||||
| if (a_.dilate_w != b_.dilate_w) return false; | |||||
| if (a_.sparse != b_.sparse) return false; | |||||
| if (a_.format != b_.format) return false; | |||||
| if (a_.compute_mode != b_.compute_mode) return false; | |||||
| return true; | |||||
| } | |||||
| std::vector<std::pair<const char*, std::string>> RegionRestrictedConvolution_props_impl(const OpDef& def_) { | |||||
| auto&& op_ = def_.cast_final_safe<RegionRestrictedConvolution>(); | |||||
| static_cast<void>(op_); | |||||
| std::vector<std::pair<const char*, std::string>> props_; | |||||
| switch (op_.mode){ | |||||
| case RegionRestrictedConvolution::Mode::CROSS_CORRELATION: | |||||
| props_.emplace_back("mode", "CROSS_CORRELATION"); | |||||
| break; | |||||
| case RegionRestrictedConvolution::Mode::CONVOLUTION: | |||||
| props_.emplace_back("mode", "CONVOLUTION"); | |||||
| break; | |||||
| default: | |||||
| props_.emplace_back("mode", "INVALID"); | |||||
| break; | |||||
| } | |||||
| props_.emplace_back("pad_h", std::to_string(op_.pad_h)); | |||||
| props_.emplace_back("pad_w", std::to_string(op_.pad_w)); | |||||
| props_.emplace_back("stride_h", std::to_string(op_.stride_h)); | |||||
| props_.emplace_back("stride_w", std::to_string(op_.stride_w)); | |||||
| props_.emplace_back("dilate_h", std::to_string(op_.dilate_h)); | |||||
| props_.emplace_back("dilate_w", std::to_string(op_.dilate_w)); | |||||
| switch (op_.sparse){ | |||||
| case RegionRestrictedConvolution::Sparse::DENSE: | |||||
| props_.emplace_back("sparse", "DENSE"); | |||||
| break; | |||||
| case RegionRestrictedConvolution::Sparse::GROUP: | |||||
| props_.emplace_back("sparse", "GROUP"); | |||||
| break; | |||||
| default: | |||||
| props_.emplace_back("sparse", "INVALID"); | |||||
| break; | |||||
| } | |||||
| switch (op_.format){ | |||||
| case RegionRestrictedConvolution::Format::NCHW: | |||||
| props_.emplace_back("format", "NCHW"); | |||||
| break; | |||||
| case RegionRestrictedConvolution::Format::NHWC: | |||||
| props_.emplace_back("format", "NHWC"); | |||||
| break; | |||||
| case RegionRestrictedConvolution::Format::NHWCD4: | |||||
| props_.emplace_back("format", "NHWCD4"); | |||||
| break; | |||||
| case RegionRestrictedConvolution::Format::NCHW4: | |||||
| props_.emplace_back("format", "NCHW4"); | |||||
| break; | |||||
| case RegionRestrictedConvolution::Format::NCHW8: | |||||
| props_.emplace_back("format", "NCHW8"); | |||||
| break; | |||||
| case RegionRestrictedConvolution::Format::NCHW32: | |||||
| props_.emplace_back("format", "NCHW32"); | |||||
| break; | |||||
| case RegionRestrictedConvolution::Format::NCHW88: | |||||
| props_.emplace_back("format", "NCHW88"); | |||||
| break; | |||||
| case RegionRestrictedConvolution::Format::NCHW44: | |||||
| props_.emplace_back("format", "NCHW44"); | |||||
| break; | |||||
| case RegionRestrictedConvolution::Format::NCHW44_DOT: | |||||
| props_.emplace_back("format", "NCHW44_DOT"); | |||||
| break; | |||||
| case RegionRestrictedConvolution::Format::NCHW4_NCHW32: | |||||
| props_.emplace_back("format", "NCHW4_NCHW32"); | |||||
| break; | |||||
| case RegionRestrictedConvolution::Format::NCHW32_NCHW4: | |||||
| props_.emplace_back("format", "NCHW32_NCHW4"); | |||||
| break; | |||||
| case RegionRestrictedConvolution::Format::NCHW4_NCHW: | |||||
| props_.emplace_back("format", "NCHW4_NCHW"); | |||||
| break; | |||||
| case RegionRestrictedConvolution::Format::NHWC_NCHW: | |||||
| props_.emplace_back("format", "NHWC_NCHW"); | |||||
| break; | |||||
| case RegionRestrictedConvolution::Format::NHWC_NCHW4_IC_SMALL: | |||||
| props_.emplace_back("format", "NHWC_NCHW4_IC_SMALL"); | |||||
| break; | |||||
| case RegionRestrictedConvolution::Format::NCHW_NCHW4_IC_SMALL: | |||||
| props_.emplace_back("format", "NCHW_NCHW4_IC_SMALL"); | |||||
| break; | |||||
| case RegionRestrictedConvolution::Format::CHWN4: | |||||
| props_.emplace_back("format", "CHWN4"); | |||||
| break; | |||||
| case RegionRestrictedConvolution::Format::NCHW64: | |||||
| props_.emplace_back("format", "NCHW64"); | |||||
| break; | |||||
| case RegionRestrictedConvolution::Format::NCHW4_NHWC: | |||||
| props_.emplace_back("format", "NCHW4_NHWC"); | |||||
| break; | |||||
| default: | |||||
| props_.emplace_back("format", "INVALID"); | |||||
| break; | |||||
| } | |||||
| switch (op_.compute_mode){ | |||||
| case RegionRestrictedConvolution::ComputeMode::DEFAULT: | |||||
| props_.emplace_back("compute_mode", "DEFAULT"); | |||||
| break; | |||||
| case RegionRestrictedConvolution::ComputeMode::FLOAT32: | |||||
| props_.emplace_back("compute_mode", "FLOAT32"); | |||||
| break; | |||||
| default: | |||||
| props_.emplace_back("compute_mode", "INVALID"); | |||||
| break; | |||||
| } | |||||
| return props_; | |||||
| } | |||||
| std::string RegionRestrictedConvolution_make_name_impl(const OpDef& def_) { | |||||
| auto&& op_ = def_.cast_final_safe<RegionRestrictedConvolution>(); | |||||
| static_cast<void>(op_); | |||||
| return "RegionRestrictedConvolution"; | |||||
| } | |||||
| } // anonymous namespace | |||||
| OP_TRAIT_REG(RegionRestrictedConvolution, RegionRestrictedConvolution) | |||||
| .hash(RegionRestrictedConvolution_hash_impl) | |||||
| .is_same_st(RegionRestrictedConvolution_is_same_st_impl) | |||||
| .props(RegionRestrictedConvolution_props_impl) | |||||
| .make_name(RegionRestrictedConvolution_make_name_impl); | |||||
| MGB_DYN_TYPE_OBJ_FINAL_IMPL(RegionRestrictedConvolutionBackwardData); | |||||
| namespace { | |||||
| size_t RegionRestrictedConvolutionBackwardData_hash_impl(const OpDef& def_) { | |||||
| auto&& op_ = def_.cast_final_safe<RegionRestrictedConvolutionBackwardData>(); | |||||
| static_cast<void>(op_); | |||||
| size_t val = mgb::hash(op_.dyn_typeinfo()); | |||||
| val = mgb::hash_pair_combine(val, mgb::enumhash()(op_.mode)); | |||||
| val = mgb::hash_pair_combine(val, mgb::hash(op_.pad_h)); | |||||
| val = mgb::hash_pair_combine(val, mgb::hash(op_.pad_w)); | |||||
| val = mgb::hash_pair_combine(val, mgb::hash(op_.stride_h)); | |||||
| val = mgb::hash_pair_combine(val, mgb::hash(op_.stride_w)); | |||||
| val = mgb::hash_pair_combine(val, mgb::hash(op_.dilate_h)); | |||||
| val = mgb::hash_pair_combine(val, mgb::hash(op_.dilate_w)); | |||||
| val = mgb::hash_pair_combine(val, mgb::enumhash()(op_.sparse)); | |||||
| val = mgb::hash_pair_combine(val, mgb::enumhash()(op_.format)); | |||||
| val = mgb::hash_pair_combine(val, mgb::enumhash()(op_.compute_mode)); | |||||
| return val; | |||||
| } | |||||
| bool RegionRestrictedConvolutionBackwardData_is_same_st_impl(const OpDef& lhs_, const OpDef& rhs_) { | |||||
| auto &&a_ = lhs_.cast_final_safe<RegionRestrictedConvolutionBackwardData>(), | |||||
| &&b_ = rhs_.cast_final_safe<RegionRestrictedConvolutionBackwardData>(); | |||||
| static_cast<void>(a_); | |||||
| static_cast<void>(b_); | |||||
| if (a_.mode != b_.mode) return false; | |||||
| if (a_.pad_h != b_.pad_h) return false; | |||||
| if (a_.pad_w != b_.pad_w) return false; | |||||
| if (a_.stride_h != b_.stride_h) return false; | |||||
| if (a_.stride_w != b_.stride_w) return false; | |||||
| if (a_.dilate_h != b_.dilate_h) return false; | |||||
| if (a_.dilate_w != b_.dilate_w) return false; | |||||
| if (a_.sparse != b_.sparse) return false; | |||||
| if (a_.format != b_.format) return false; | |||||
| if (a_.compute_mode != b_.compute_mode) return false; | |||||
| return true; | |||||
| } | |||||
| std::vector<std::pair<const char*, std::string>> RegionRestrictedConvolutionBackwardData_props_impl(const OpDef& def_) { | |||||
| auto&& op_ = def_.cast_final_safe<RegionRestrictedConvolutionBackwardData>(); | |||||
| static_cast<void>(op_); | |||||
| std::vector<std::pair<const char*, std::string>> props_; | |||||
| switch (op_.mode){ | |||||
| case RegionRestrictedConvolutionBackwardData::Mode::CROSS_CORRELATION: | |||||
| props_.emplace_back("mode", "CROSS_CORRELATION"); | |||||
| break; | |||||
| case RegionRestrictedConvolutionBackwardData::Mode::CONVOLUTION: | |||||
| props_.emplace_back("mode", "CONVOLUTION"); | |||||
| break; | |||||
| default: | |||||
| props_.emplace_back("mode", "INVALID"); | |||||
| break; | |||||
| } | |||||
| props_.emplace_back("pad_h", std::to_string(op_.pad_h)); | |||||
| props_.emplace_back("pad_w", std::to_string(op_.pad_w)); | |||||
| props_.emplace_back("stride_h", std::to_string(op_.stride_h)); | |||||
| props_.emplace_back("stride_w", std::to_string(op_.stride_w)); | |||||
| props_.emplace_back("dilate_h", std::to_string(op_.dilate_h)); | |||||
| props_.emplace_back("dilate_w", std::to_string(op_.dilate_w)); | |||||
| switch (op_.sparse){ | |||||
| case RegionRestrictedConvolutionBackwardData::Sparse::DENSE: | |||||
| props_.emplace_back("sparse", "DENSE"); | |||||
| break; | |||||
| case RegionRestrictedConvolutionBackwardData::Sparse::GROUP: | |||||
| props_.emplace_back("sparse", "GROUP"); | |||||
| break; | |||||
| default: | |||||
| props_.emplace_back("sparse", "INVALID"); | |||||
| break; | |||||
| } | |||||
| switch (op_.format){ | |||||
| case RegionRestrictedConvolutionBackwardData::Format::NCHW: | |||||
| props_.emplace_back("format", "NCHW"); | |||||
| break; | |||||
| case RegionRestrictedConvolutionBackwardData::Format::NHWC: | |||||
| props_.emplace_back("format", "NHWC"); | |||||
| break; | |||||
| case RegionRestrictedConvolutionBackwardData::Format::NHWCD4: | |||||
| props_.emplace_back("format", "NHWCD4"); | |||||
| break; | |||||
| case RegionRestrictedConvolutionBackwardData::Format::NCHW4: | |||||
| props_.emplace_back("format", "NCHW4"); | |||||
| break; | |||||
| case RegionRestrictedConvolutionBackwardData::Format::NCHW8: | |||||
| props_.emplace_back("format", "NCHW8"); | |||||
| break; | |||||
| case RegionRestrictedConvolutionBackwardData::Format::NCHW32: | |||||
| props_.emplace_back("format", "NCHW32"); | |||||
| break; | |||||
| case RegionRestrictedConvolutionBackwardData::Format::NCHW88: | |||||
| props_.emplace_back("format", "NCHW88"); | |||||
| break; | |||||
| case RegionRestrictedConvolutionBackwardData::Format::NCHW44: | |||||
| props_.emplace_back("format", "NCHW44"); | |||||
| break; | |||||
| case RegionRestrictedConvolutionBackwardData::Format::NCHW44_DOT: | |||||
| props_.emplace_back("format", "NCHW44_DOT"); | |||||
| break; | |||||
| case RegionRestrictedConvolutionBackwardData::Format::NCHW4_NCHW32: | |||||
| props_.emplace_back("format", "NCHW4_NCHW32"); | |||||
| break; | |||||
| case RegionRestrictedConvolutionBackwardData::Format::NCHW32_NCHW4: | |||||
| props_.emplace_back("format", "NCHW32_NCHW4"); | |||||
| break; | |||||
| case RegionRestrictedConvolutionBackwardData::Format::NCHW4_NCHW: | |||||
| props_.emplace_back("format", "NCHW4_NCHW"); | |||||
| break; | |||||
| case RegionRestrictedConvolutionBackwardData::Format::NHWC_NCHW: | |||||
| props_.emplace_back("format", "NHWC_NCHW"); | |||||
| break; | |||||
| case RegionRestrictedConvolutionBackwardData::Format::NHWC_NCHW4_IC_SMALL: | |||||
| props_.emplace_back("format", "NHWC_NCHW4_IC_SMALL"); | |||||
| break; | |||||
| case RegionRestrictedConvolutionBackwardData::Format::NCHW_NCHW4_IC_SMALL: | |||||
| props_.emplace_back("format", "NCHW_NCHW4_IC_SMALL"); | |||||
| break; | |||||
| case RegionRestrictedConvolutionBackwardData::Format::CHWN4: | |||||
| props_.emplace_back("format", "CHWN4"); | |||||
| break; | |||||
| case RegionRestrictedConvolutionBackwardData::Format::NCHW64: | |||||
| props_.emplace_back("format", "NCHW64"); | |||||
| break; | |||||
| case RegionRestrictedConvolutionBackwardData::Format::NCHW4_NHWC: | |||||
| props_.emplace_back("format", "NCHW4_NHWC"); | |||||
| break; | |||||
| default: | |||||
| props_.emplace_back("format", "INVALID"); | |||||
| break; | |||||
| } | |||||
| switch (op_.compute_mode){ | |||||
| case RegionRestrictedConvolutionBackwardData::ComputeMode::DEFAULT: | |||||
| props_.emplace_back("compute_mode", "DEFAULT"); | |||||
| break; | |||||
| case RegionRestrictedConvolutionBackwardData::ComputeMode::FLOAT32: | |||||
| props_.emplace_back("compute_mode", "FLOAT32"); | |||||
| break; | |||||
| default: | |||||
| props_.emplace_back("compute_mode", "INVALID"); | |||||
| break; | |||||
| } | |||||
| return props_; | |||||
| } | |||||
| std::string RegionRestrictedConvolutionBackwardData_make_name_impl(const OpDef& def_) { | |||||
| auto&& op_ = def_.cast_final_safe<RegionRestrictedConvolutionBackwardData>(); | |||||
| static_cast<void>(op_); | |||||
| return "RegionRestrictedConvolutionBackwardData"; | |||||
| } | |||||
| } // anonymous namespace | |||||
| OP_TRAIT_REG(RegionRestrictedConvolutionBackwardData, RegionRestrictedConvolutionBackwardData) | |||||
| .hash(RegionRestrictedConvolutionBackwardData_hash_impl) | |||||
| .is_same_st(RegionRestrictedConvolutionBackwardData_is_same_st_impl) | |||||
| .props(RegionRestrictedConvolutionBackwardData_props_impl) | |||||
| .make_name(RegionRestrictedConvolutionBackwardData_make_name_impl); | |||||
| MGB_DYN_TYPE_OBJ_FINAL_IMPL(Remap); | MGB_DYN_TYPE_OBJ_FINAL_IMPL(Remap); | ||||
| namespace { | namespace { | ||||
| @@ -15368,6 +15368,580 @@ void _init_py_Reduce(py::module m) { | |||||
| mgb_assert(PyOp(OpDef)::ctype2pytype.emplace(Reduce::typeinfo(), &py_type).second); | mgb_assert(PyOp(OpDef)::ctype2pytype.emplace(Reduce::typeinfo(), &py_type).second); | ||||
| } | } | ||||
| void _init_py_RegionRestrictedConvolution_Mode(PyTypeObject& py_type) { | |||||
| auto& e_type = EnumWrapper<RegionRestrictedConvolution::Mode>::type; | |||||
| Py_INCREF(e_type); | |||||
| mgb_assert(PyDict_SetItemString( | |||||
| py_type.tp_dict, "Mode", reinterpret_cast<PyObject*>(e_type)) >= 0); | |||||
| } | |||||
| void _init_py_RegionRestrictedConvolution_Sparse(PyTypeObject& py_type) { | |||||
| auto& e_type = EnumWrapper<RegionRestrictedConvolution::Sparse>::type; | |||||
| Py_INCREF(e_type); | |||||
| mgb_assert(PyDict_SetItemString( | |||||
| py_type.tp_dict, "Sparse", reinterpret_cast<PyObject*>(e_type)) >= 0); | |||||
| } | |||||
| void _init_py_RegionRestrictedConvolution_Format(PyTypeObject& py_type) { | |||||
| auto& e_type = EnumWrapper<RegionRestrictedConvolution::Format>::type; | |||||
| Py_INCREF(e_type); | |||||
| mgb_assert(PyDict_SetItemString( | |||||
| py_type.tp_dict, "Format", reinterpret_cast<PyObject*>(e_type)) >= 0); | |||||
| } | |||||
| void _init_py_RegionRestrictedConvolution_ComputeMode(PyTypeObject& py_type) { | |||||
| auto& e_type = EnumWrapper<RegionRestrictedConvolution::ComputeMode>::type; | |||||
| Py_INCREF(e_type); | |||||
| mgb_assert(PyDict_SetItemString( | |||||
| py_type.tp_dict, "ComputeMode", reinterpret_cast<PyObject*>(e_type)) >= 0); | |||||
| } | |||||
| PyOpDefBegin(RegionRestrictedConvolution) // { | |||||
| static PyGetSetDef py_getsetters[]; | |||||
| static PyMethodDef tp_methods[]; | |||||
| static PyObject* getstate(PyObject* self, PyObject*) { | |||||
| auto& opdef = reinterpret_cast<PyOp(RegionRestrictedConvolution)*>(self)->inst(); | |||||
| static_cast<void>(opdef); | |||||
| std::unordered_map<std::string, py::object> state { | |||||
| {"mode", serialization<decltype(opdef.mode)>::dump(opdef.mode)}, | |||||
| {"pad_h", serialization<decltype(opdef.pad_h)>::dump(opdef.pad_h)}, | |||||
| {"pad_w", serialization<decltype(opdef.pad_w)>::dump(opdef.pad_w)}, | |||||
| {"stride_h", serialization<decltype(opdef.stride_h)>::dump(opdef.stride_h)}, | |||||
| {"stride_w", serialization<decltype(opdef.stride_w)>::dump(opdef.stride_w)}, | |||||
| {"dilate_h", serialization<decltype(opdef.dilate_h)>::dump(opdef.dilate_h)}, | |||||
| {"dilate_w", serialization<decltype(opdef.dilate_w)>::dump(opdef.dilate_w)}, | |||||
| {"sparse", serialization<decltype(opdef.sparse)>::dump(opdef.sparse)}, | |||||
| {"format", serialization<decltype(opdef.format)>::dump(opdef.format)}, | |||||
| {"compute_mode", serialization<decltype(opdef.compute_mode)>::dump(opdef.compute_mode)} | |||||
| }; | |||||
| return py::cast(state).release().ptr(); | |||||
| } | |||||
| static PyObject* setstate(PyObject* self, PyObject* args) { | |||||
| PyObject* dict = PyTuple_GetItem(args, 0); | |||||
| if (!dict) return NULL; | |||||
| auto state = py::cast<std::unordered_map<std::string, py::object>>(dict); | |||||
| auto& opdef = reinterpret_cast<PyOp(RegionRestrictedConvolution)*>(self)->inst(); | |||||
| static_cast<void>(opdef); | |||||
| { | |||||
| auto&& iter = state.find("mode"); | |||||
| if (iter != state.end()) { | |||||
| opdef.mode = serialization<decltype(opdef.mode)>::load(iter->second); | |||||
| } | |||||
| } | |||||
| { | |||||
| auto&& iter = state.find("pad_h"); | |||||
| if (iter != state.end()) { | |||||
| opdef.pad_h = serialization<decltype(opdef.pad_h)>::load(iter->second); | |||||
| } | |||||
| } | |||||
| { | |||||
| auto&& iter = state.find("pad_w"); | |||||
| if (iter != state.end()) { | |||||
| opdef.pad_w = serialization<decltype(opdef.pad_w)>::load(iter->second); | |||||
| } | |||||
| } | |||||
| { | |||||
| auto&& iter = state.find("stride_h"); | |||||
| if (iter != state.end()) { | |||||
| opdef.stride_h = serialization<decltype(opdef.stride_h)>::load(iter->second); | |||||
| } | |||||
| } | |||||
| { | |||||
| auto&& iter = state.find("stride_w"); | |||||
| if (iter != state.end()) { | |||||
| opdef.stride_w = serialization<decltype(opdef.stride_w)>::load(iter->second); | |||||
| } | |||||
| } | |||||
| { | |||||
| auto&& iter = state.find("dilate_h"); | |||||
| if (iter != state.end()) { | |||||
| opdef.dilate_h = serialization<decltype(opdef.dilate_h)>::load(iter->second); | |||||
| } | |||||
| } | |||||
| { | |||||
| auto&& iter = state.find("dilate_w"); | |||||
| if (iter != state.end()) { | |||||
| opdef.dilate_w = serialization<decltype(opdef.dilate_w)>::load(iter->second); | |||||
| } | |||||
| } | |||||
| { | |||||
| auto&& iter = state.find("sparse"); | |||||
| if (iter != state.end()) { | |||||
| opdef.sparse = serialization<decltype(opdef.sparse)>::load(iter->second); | |||||
| } | |||||
| } | |||||
| { | |||||
| auto&& iter = state.find("format"); | |||||
| if (iter != state.end()) { | |||||
| opdef.format = serialization<decltype(opdef.format)>::load(iter->second); | |||||
| } | |||||
| } | |||||
| { | |||||
| auto&& iter = state.find("compute_mode"); | |||||
| if (iter != state.end()) { | |||||
| opdef.compute_mode = serialization<decltype(opdef.compute_mode)>::load(iter->second); | |||||
| } | |||||
| } | |||||
| Py_RETURN_NONE; | |||||
| } | |||||
| static int py_init(PyObject *self, PyObject *args, PyObject *kwds); | |||||
| // }; | |||||
| PyOpDefEnd(RegionRestrictedConvolution) | |||||
| int PyOp(RegionRestrictedConvolution)::py_init(PyObject *self, PyObject *args, PyObject *kwds) { | |||||
| static const char* kwlist[] = {"mode", "pad_h", "pad_w", "stride_h", "stride_w", "dilate_h", "dilate_w", "sparse", "format", "compute_mode", "scope", NULL}; | |||||
| PyObject *mode = NULL, *pad_h = NULL, *pad_w = NULL, *stride_h = NULL, *stride_w = NULL, *dilate_h = NULL, *dilate_w = NULL, *sparse = NULL, *format = NULL, *compute_mode = NULL, *scope = NULL; | |||||
| if (!PyArg_ParseTupleAndKeywords(args, kwds, "|OOOOOOOOOOO", const_cast<char**>(kwlist), &mode, &pad_h, &pad_w, &stride_h, &stride_w, &dilate_h, &dilate_w, &sparse, &format, &compute_mode, &scope)) | |||||
| return -1; | |||||
| if (mode) { | |||||
| try { | |||||
| // TODO: remove this guard which is used for pybind11 implicit conversion | |||||
| py::detail::loader_life_support guard{}; | |||||
| reinterpret_cast<PyOp(RegionRestrictedConvolution)*>(self)->inst().mode = | |||||
| py::cast<decltype(RegionRestrictedConvolution::mode)>(py::handle(mode)); | |||||
| } CATCH_ALL(-1) | |||||
| } | |||||
| if (pad_h) { | |||||
| try { | |||||
| // TODO: remove this guard which is used for pybind11 implicit conversion | |||||
| py::detail::loader_life_support guard{}; | |||||
| reinterpret_cast<PyOp(RegionRestrictedConvolution)*>(self)->inst().pad_h = | |||||
| py::cast<decltype(RegionRestrictedConvolution::pad_h)>(py::handle(pad_h)); | |||||
| } CATCH_ALL(-1) | |||||
| } | |||||
| if (pad_w) { | |||||
| try { | |||||
| // TODO: remove this guard which is used for pybind11 implicit conversion | |||||
| py::detail::loader_life_support guard{}; | |||||
| reinterpret_cast<PyOp(RegionRestrictedConvolution)*>(self)->inst().pad_w = | |||||
| py::cast<decltype(RegionRestrictedConvolution::pad_w)>(py::handle(pad_w)); | |||||
| } CATCH_ALL(-1) | |||||
| } | |||||
| if (stride_h) { | |||||
| try { | |||||
| // TODO: remove this guard which is used for pybind11 implicit conversion | |||||
| py::detail::loader_life_support guard{}; | |||||
| reinterpret_cast<PyOp(RegionRestrictedConvolution)*>(self)->inst().stride_h = | |||||
| py::cast<decltype(RegionRestrictedConvolution::stride_h)>(py::handle(stride_h)); | |||||
| } CATCH_ALL(-1) | |||||
| } | |||||
| if (stride_w) { | |||||
| try { | |||||
| // TODO: remove this guard which is used for pybind11 implicit conversion | |||||
| py::detail::loader_life_support guard{}; | |||||
| reinterpret_cast<PyOp(RegionRestrictedConvolution)*>(self)->inst().stride_w = | |||||
| py::cast<decltype(RegionRestrictedConvolution::stride_w)>(py::handle(stride_w)); | |||||
| } CATCH_ALL(-1) | |||||
| } | |||||
| if (dilate_h) { | |||||
| try { | |||||
| // TODO: remove this guard which is used for pybind11 implicit conversion | |||||
| py::detail::loader_life_support guard{}; | |||||
| reinterpret_cast<PyOp(RegionRestrictedConvolution)*>(self)->inst().dilate_h = | |||||
| py::cast<decltype(RegionRestrictedConvolution::dilate_h)>(py::handle(dilate_h)); | |||||
| } CATCH_ALL(-1) | |||||
| } | |||||
| if (dilate_w) { | |||||
| try { | |||||
| // TODO: remove this guard which is used for pybind11 implicit conversion | |||||
| py::detail::loader_life_support guard{}; | |||||
| reinterpret_cast<PyOp(RegionRestrictedConvolution)*>(self)->inst().dilate_w = | |||||
| py::cast<decltype(RegionRestrictedConvolution::dilate_w)>(py::handle(dilate_w)); | |||||
| } CATCH_ALL(-1) | |||||
| } | |||||
| if (sparse) { | |||||
| try { | |||||
| // TODO: remove this guard which is used for pybind11 implicit conversion | |||||
| py::detail::loader_life_support guard{}; | |||||
| reinterpret_cast<PyOp(RegionRestrictedConvolution)*>(self)->inst().sparse = | |||||
| py::cast<decltype(RegionRestrictedConvolution::sparse)>(py::handle(sparse)); | |||||
| } CATCH_ALL(-1) | |||||
| } | |||||
| if (format) { | |||||
| try { | |||||
| // TODO: remove this guard which is used for pybind11 implicit conversion | |||||
| py::detail::loader_life_support guard{}; | |||||
| reinterpret_cast<PyOp(RegionRestrictedConvolution)*>(self)->inst().format = | |||||
| py::cast<decltype(RegionRestrictedConvolution::format)>(py::handle(format)); | |||||
| } CATCH_ALL(-1) | |||||
| } | |||||
| if (compute_mode) { | |||||
| try { | |||||
| // TODO: remove this guard which is used for pybind11 implicit conversion | |||||
| py::detail::loader_life_support guard{}; | |||||
| reinterpret_cast<PyOp(RegionRestrictedConvolution)*>(self)->inst().compute_mode = | |||||
| py::cast<decltype(RegionRestrictedConvolution::compute_mode)>(py::handle(compute_mode)); | |||||
| } CATCH_ALL(-1) | |||||
| } | |||||
| if (scope) { | |||||
| try { | |||||
| reinterpret_cast<PyOp(OpDef)*>(self)->op | |||||
| ->set_scope(py::cast<std::string>(py::handle(scope))); | |||||
| } CATCH_ALL(-1) | |||||
| } | |||||
| return 0; | |||||
| } | |||||
| PyGetSetDef PyOp(RegionRestrictedConvolution)::py_getsetters[] = { | |||||
| {const_cast<char*>("mode"), py_get_generic(RegionRestrictedConvolution, mode), py_set_generic(RegionRestrictedConvolution, mode), const_cast<char*>("mode"), NULL}, | |||||
| {const_cast<char*>("pad_h"), py_get_generic(RegionRestrictedConvolution, pad_h), py_set_generic(RegionRestrictedConvolution, pad_h), const_cast<char*>("pad_h"), NULL}, | |||||
| {const_cast<char*>("pad_w"), py_get_generic(RegionRestrictedConvolution, pad_w), py_set_generic(RegionRestrictedConvolution, pad_w), const_cast<char*>("pad_w"), NULL}, | |||||
| {const_cast<char*>("stride_h"), py_get_generic(RegionRestrictedConvolution, stride_h), py_set_generic(RegionRestrictedConvolution, stride_h), const_cast<char*>("stride_h"), NULL}, | |||||
| {const_cast<char*>("stride_w"), py_get_generic(RegionRestrictedConvolution, stride_w), py_set_generic(RegionRestrictedConvolution, stride_w), const_cast<char*>("stride_w"), NULL}, | |||||
| {const_cast<char*>("dilate_h"), py_get_generic(RegionRestrictedConvolution, dilate_h), py_set_generic(RegionRestrictedConvolution, dilate_h), const_cast<char*>("dilate_h"), NULL}, | |||||
| {const_cast<char*>("dilate_w"), py_get_generic(RegionRestrictedConvolution, dilate_w), py_set_generic(RegionRestrictedConvolution, dilate_w), const_cast<char*>("dilate_w"), NULL}, | |||||
| {const_cast<char*>("sparse"), py_get_generic(RegionRestrictedConvolution, sparse), py_set_generic(RegionRestrictedConvolution, sparse), const_cast<char*>("sparse"), NULL}, | |||||
| {const_cast<char*>("format"), py_get_generic(RegionRestrictedConvolution, format), py_set_generic(RegionRestrictedConvolution, format), const_cast<char*>("format"), NULL}, | |||||
| {const_cast<char*>("compute_mode"), py_get_generic(RegionRestrictedConvolution, compute_mode), py_set_generic(RegionRestrictedConvolution, compute_mode), const_cast<char*>("compute_mode"), NULL}, | |||||
| {NULL} /* Sentinel */ | |||||
| }; | |||||
| PyMethodDef PyOp(RegionRestrictedConvolution)::tp_methods[] = { | |||||
| {const_cast<char*>("__getstate__"), PyOp(RegionRestrictedConvolution)::getstate, METH_NOARGS, "RegionRestrictedConvolution getstate"}, | |||||
| {const_cast<char*>("__setstate__"), PyOp(RegionRestrictedConvolution)::setstate, METH_VARARGS, "RegionRestrictedConvolution setstate"}, | |||||
| {NULL} /* Sentinel */ | |||||
| }; | |||||
| void _init_py_RegionRestrictedConvolution(py::module m) { | |||||
| using py_op = PyOp(RegionRestrictedConvolution); | |||||
| auto& py_type = PyOpType(RegionRestrictedConvolution); | |||||
| py_type = {PyVarObject_HEAD_INIT(NULL, 0)}; | |||||
| py_type.tp_name = "megengine.core._imperative_rt.ops.RegionRestrictedConvolution"; | |||||
| py_type.tp_basicsize = sizeof(PyOp(RegionRestrictedConvolution)); | |||||
| py_type.tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE; | |||||
| py_type.tp_doc = "RegionRestrictedConvolution"; | |||||
| py_type.tp_base = &PyOpType(OpDef); | |||||
| py_type.tp_dealloc = py_dealloc_generic<py_op>; | |||||
| py_type.tp_new = py_new_generic<py_op>; | |||||
| py_type.tp_init = py_op::py_init; | |||||
| py_type.tp_methods = py_op::tp_methods; | |||||
| py_type.tp_getset = py_op::py_getsetters; | |||||
| mgb_assert(PyType_Ready(&py_type) >= 0); | |||||
| _init_py_RegionRestrictedConvolution_Mode(py_type); | |||||
| _init_py_RegionRestrictedConvolution_Sparse(py_type); | |||||
| _init_py_RegionRestrictedConvolution_Format(py_type); | |||||
| _init_py_RegionRestrictedConvolution_ComputeMode(py_type); | |||||
| PyType_Modified(&py_type); | |||||
| m.add_object("RegionRestrictedConvolution", reinterpret_cast<PyObject*>(&py_type)); | |||||
| mgb_assert(PyOp(OpDef)::ctype2pytype.emplace(RegionRestrictedConvolution::typeinfo(), &py_type).second); | |||||
| } | |||||
| void _init_py_RegionRestrictedConvolutionBackwardData_Mode(PyTypeObject& py_type) { | |||||
| auto& e_type = EnumWrapper<RegionRestrictedConvolutionBackwardData::Mode>::type; | |||||
| Py_INCREF(e_type); | |||||
| mgb_assert(PyDict_SetItemString( | |||||
| py_type.tp_dict, "Mode", reinterpret_cast<PyObject*>(e_type)) >= 0); | |||||
| } | |||||
| void _init_py_RegionRestrictedConvolutionBackwardData_Sparse(PyTypeObject& py_type) { | |||||
| auto& e_type = EnumWrapper<RegionRestrictedConvolutionBackwardData::Sparse>::type; | |||||
| Py_INCREF(e_type); | |||||
| mgb_assert(PyDict_SetItemString( | |||||
| py_type.tp_dict, "Sparse", reinterpret_cast<PyObject*>(e_type)) >= 0); | |||||
| } | |||||
| void _init_py_RegionRestrictedConvolutionBackwardData_Format(PyTypeObject& py_type) { | |||||
| auto& e_type = EnumWrapper<RegionRestrictedConvolutionBackwardData::Format>::type; | |||||
| Py_INCREF(e_type); | |||||
| mgb_assert(PyDict_SetItemString( | |||||
| py_type.tp_dict, "Format", reinterpret_cast<PyObject*>(e_type)) >= 0); | |||||
| } | |||||
| void _init_py_RegionRestrictedConvolutionBackwardData_ComputeMode(PyTypeObject& py_type) { | |||||
| auto& e_type = EnumWrapper<RegionRestrictedConvolutionBackwardData::ComputeMode>::type; | |||||
| Py_INCREF(e_type); | |||||
| mgb_assert(PyDict_SetItemString( | |||||
| py_type.tp_dict, "ComputeMode", reinterpret_cast<PyObject*>(e_type)) >= 0); | |||||
| } | |||||
| PyOpDefBegin(RegionRestrictedConvolutionBackwardData) // { | |||||
| static PyGetSetDef py_getsetters[]; | |||||
| static PyMethodDef tp_methods[]; | |||||
| static PyObject* getstate(PyObject* self, PyObject*) { | |||||
| auto& opdef = reinterpret_cast<PyOp(RegionRestrictedConvolutionBackwardData)*>(self)->inst(); | |||||
| static_cast<void>(opdef); | |||||
| std::unordered_map<std::string, py::object> state { | |||||
| {"mode", serialization<decltype(opdef.mode)>::dump(opdef.mode)}, | |||||
| {"pad_h", serialization<decltype(opdef.pad_h)>::dump(opdef.pad_h)}, | |||||
| {"pad_w", serialization<decltype(opdef.pad_w)>::dump(opdef.pad_w)}, | |||||
| {"stride_h", serialization<decltype(opdef.stride_h)>::dump(opdef.stride_h)}, | |||||
| {"stride_w", serialization<decltype(opdef.stride_w)>::dump(opdef.stride_w)}, | |||||
| {"dilate_h", serialization<decltype(opdef.dilate_h)>::dump(opdef.dilate_h)}, | |||||
| {"dilate_w", serialization<decltype(opdef.dilate_w)>::dump(opdef.dilate_w)}, | |||||
| {"sparse", serialization<decltype(opdef.sparse)>::dump(opdef.sparse)}, | |||||
| {"format", serialization<decltype(opdef.format)>::dump(opdef.format)}, | |||||
| {"compute_mode", serialization<decltype(opdef.compute_mode)>::dump(opdef.compute_mode)} | |||||
| }; | |||||
| return py::cast(state).release().ptr(); | |||||
| } | |||||
| static PyObject* setstate(PyObject* self, PyObject* args) { | |||||
| PyObject* dict = PyTuple_GetItem(args, 0); | |||||
| if (!dict) return NULL; | |||||
| auto state = py::cast<std::unordered_map<std::string, py::object>>(dict); | |||||
| auto& opdef = reinterpret_cast<PyOp(RegionRestrictedConvolutionBackwardData)*>(self)->inst(); | |||||
| static_cast<void>(opdef); | |||||
| { | |||||
| auto&& iter = state.find("mode"); | |||||
| if (iter != state.end()) { | |||||
| opdef.mode = serialization<decltype(opdef.mode)>::load(iter->second); | |||||
| } | |||||
| } | |||||
| { | |||||
| auto&& iter = state.find("pad_h"); | |||||
| if (iter != state.end()) { | |||||
| opdef.pad_h = serialization<decltype(opdef.pad_h)>::load(iter->second); | |||||
| } | |||||
| } | |||||
| { | |||||
| auto&& iter = state.find("pad_w"); | |||||
| if (iter != state.end()) { | |||||
| opdef.pad_w = serialization<decltype(opdef.pad_w)>::load(iter->second); | |||||
| } | |||||
| } | |||||
| { | |||||
| auto&& iter = state.find("stride_h"); | |||||
| if (iter != state.end()) { | |||||
| opdef.stride_h = serialization<decltype(opdef.stride_h)>::load(iter->second); | |||||
| } | |||||
| } | |||||
| { | |||||
| auto&& iter = state.find("stride_w"); | |||||
| if (iter != state.end()) { | |||||
| opdef.stride_w = serialization<decltype(opdef.stride_w)>::load(iter->second); | |||||
| } | |||||
| } | |||||
| { | |||||
| auto&& iter = state.find("dilate_h"); | |||||
| if (iter != state.end()) { | |||||
| opdef.dilate_h = serialization<decltype(opdef.dilate_h)>::load(iter->second); | |||||
| } | |||||
| } | |||||
| { | |||||
| auto&& iter = state.find("dilate_w"); | |||||
| if (iter != state.end()) { | |||||
| opdef.dilate_w = serialization<decltype(opdef.dilate_w)>::load(iter->second); | |||||
| } | |||||
| } | |||||
| { | |||||
| auto&& iter = state.find("sparse"); | |||||
| if (iter != state.end()) { | |||||
| opdef.sparse = serialization<decltype(opdef.sparse)>::load(iter->second); | |||||
| } | |||||
| } | |||||
| { | |||||
| auto&& iter = state.find("format"); | |||||
| if (iter != state.end()) { | |||||
| opdef.format = serialization<decltype(opdef.format)>::load(iter->second); | |||||
| } | |||||
| } | |||||
| { | |||||
| auto&& iter = state.find("compute_mode"); | |||||
| if (iter != state.end()) { | |||||
| opdef.compute_mode = serialization<decltype(opdef.compute_mode)>::load(iter->second); | |||||
| } | |||||
| } | |||||
| Py_RETURN_NONE; | |||||
| } | |||||
| static int py_init(PyObject *self, PyObject *args, PyObject *kwds); | |||||
| // }; | |||||
| PyOpDefEnd(RegionRestrictedConvolutionBackwardData) | |||||
| int PyOp(RegionRestrictedConvolutionBackwardData)::py_init(PyObject *self, PyObject *args, PyObject *kwds) { | |||||
| static const char* kwlist[] = {"mode", "pad_h", "pad_w", "stride_h", "stride_w", "dilate_h", "dilate_w", "sparse", "format", "compute_mode", "scope", NULL}; | |||||
| PyObject *mode = NULL, *pad_h = NULL, *pad_w = NULL, *stride_h = NULL, *stride_w = NULL, *dilate_h = NULL, *dilate_w = NULL, *sparse = NULL, *format = NULL, *compute_mode = NULL, *scope = NULL; | |||||
| if (!PyArg_ParseTupleAndKeywords(args, kwds, "|OOOOOOOOOOO", const_cast<char**>(kwlist), &mode, &pad_h, &pad_w, &stride_h, &stride_w, &dilate_h, &dilate_w, &sparse, &format, &compute_mode, &scope)) | |||||
| return -1; | |||||
| if (mode) { | |||||
| try { | |||||
| // TODO: remove this guard which is used for pybind11 implicit conversion | |||||
| py::detail::loader_life_support guard{}; | |||||
| reinterpret_cast<PyOp(RegionRestrictedConvolutionBackwardData)*>(self)->inst().mode = | |||||
| py::cast<decltype(RegionRestrictedConvolutionBackwardData::mode)>(py::handle(mode)); | |||||
| } CATCH_ALL(-1) | |||||
| } | |||||
| if (pad_h) { | |||||
| try { | |||||
| // TODO: remove this guard which is used for pybind11 implicit conversion | |||||
| py::detail::loader_life_support guard{}; | |||||
| reinterpret_cast<PyOp(RegionRestrictedConvolutionBackwardData)*>(self)->inst().pad_h = | |||||
| py::cast<decltype(RegionRestrictedConvolutionBackwardData::pad_h)>(py::handle(pad_h)); | |||||
| } CATCH_ALL(-1) | |||||
| } | |||||
| if (pad_w) { | |||||
| try { | |||||
| // TODO: remove this guard which is used for pybind11 implicit conversion | |||||
| py::detail::loader_life_support guard{}; | |||||
| reinterpret_cast<PyOp(RegionRestrictedConvolutionBackwardData)*>(self)->inst().pad_w = | |||||
| py::cast<decltype(RegionRestrictedConvolutionBackwardData::pad_w)>(py::handle(pad_w)); | |||||
| } CATCH_ALL(-1) | |||||
| } | |||||
| if (stride_h) { | |||||
| try { | |||||
| // TODO: remove this guard which is used for pybind11 implicit conversion | |||||
| py::detail::loader_life_support guard{}; | |||||
| reinterpret_cast<PyOp(RegionRestrictedConvolutionBackwardData)*>(self)->inst().stride_h = | |||||
| py::cast<decltype(RegionRestrictedConvolutionBackwardData::stride_h)>(py::handle(stride_h)); | |||||
| } CATCH_ALL(-1) | |||||
| } | |||||
| if (stride_w) { | |||||
| try { | |||||
| // TODO: remove this guard which is used for pybind11 implicit conversion | |||||
| py::detail::loader_life_support guard{}; | |||||
| reinterpret_cast<PyOp(RegionRestrictedConvolutionBackwardData)*>(self)->inst().stride_w = | |||||
| py::cast<decltype(RegionRestrictedConvolutionBackwardData::stride_w)>(py::handle(stride_w)); | |||||
| } CATCH_ALL(-1) | |||||
| } | |||||
| if (dilate_h) { | |||||
| try { | |||||
| // TODO: remove this guard which is used for pybind11 implicit conversion | |||||
| py::detail::loader_life_support guard{}; | |||||
| reinterpret_cast<PyOp(RegionRestrictedConvolutionBackwardData)*>(self)->inst().dilate_h = | |||||
| py::cast<decltype(RegionRestrictedConvolutionBackwardData::dilate_h)>(py::handle(dilate_h)); | |||||
| } CATCH_ALL(-1) | |||||
| } | |||||
| if (dilate_w) { | |||||
| try { | |||||
| // TODO: remove this guard which is used for pybind11 implicit conversion | |||||
| py::detail::loader_life_support guard{}; | |||||
| reinterpret_cast<PyOp(RegionRestrictedConvolutionBackwardData)*>(self)->inst().dilate_w = | |||||
| py::cast<decltype(RegionRestrictedConvolutionBackwardData::dilate_w)>(py::handle(dilate_w)); | |||||
| } CATCH_ALL(-1) | |||||
| } | |||||
| if (sparse) { | |||||
| try { | |||||
| // TODO: remove this guard which is used for pybind11 implicit conversion | |||||
| py::detail::loader_life_support guard{}; | |||||
| reinterpret_cast<PyOp(RegionRestrictedConvolutionBackwardData)*>(self)->inst().sparse = | |||||
| py::cast<decltype(RegionRestrictedConvolutionBackwardData::sparse)>(py::handle(sparse)); | |||||
| } CATCH_ALL(-1) | |||||
| } | |||||
| if (format) { | |||||
| try { | |||||
| // TODO: remove this guard which is used for pybind11 implicit conversion | |||||
| py::detail::loader_life_support guard{}; | |||||
| reinterpret_cast<PyOp(RegionRestrictedConvolutionBackwardData)*>(self)->inst().format = | |||||
| py::cast<decltype(RegionRestrictedConvolutionBackwardData::format)>(py::handle(format)); | |||||
| } CATCH_ALL(-1) | |||||
| } | |||||
| if (compute_mode) { | |||||
| try { | |||||
| // TODO: remove this guard which is used for pybind11 implicit conversion | |||||
| py::detail::loader_life_support guard{}; | |||||
| reinterpret_cast<PyOp(RegionRestrictedConvolutionBackwardData)*>(self)->inst().compute_mode = | |||||
| py::cast<decltype(RegionRestrictedConvolutionBackwardData::compute_mode)>(py::handle(compute_mode)); | |||||
| } CATCH_ALL(-1) | |||||
| } | |||||
| if (scope) { | |||||
| try { | |||||
| reinterpret_cast<PyOp(OpDef)*>(self)->op | |||||
| ->set_scope(py::cast<std::string>(py::handle(scope))); | |||||
| } CATCH_ALL(-1) | |||||
| } | |||||
| return 0; | |||||
| } | |||||
| PyGetSetDef PyOp(RegionRestrictedConvolutionBackwardData)::py_getsetters[] = { | |||||
| {const_cast<char*>("mode"), py_get_generic(RegionRestrictedConvolutionBackwardData, mode), py_set_generic(RegionRestrictedConvolutionBackwardData, mode), const_cast<char*>("mode"), NULL}, | |||||
| {const_cast<char*>("pad_h"), py_get_generic(RegionRestrictedConvolutionBackwardData, pad_h), py_set_generic(RegionRestrictedConvolutionBackwardData, pad_h), const_cast<char*>("pad_h"), NULL}, | |||||
| {const_cast<char*>("pad_w"), py_get_generic(RegionRestrictedConvolutionBackwardData, pad_w), py_set_generic(RegionRestrictedConvolutionBackwardData, pad_w), const_cast<char*>("pad_w"), NULL}, | |||||
| {const_cast<char*>("stride_h"), py_get_generic(RegionRestrictedConvolutionBackwardData, stride_h), py_set_generic(RegionRestrictedConvolutionBackwardData, stride_h), const_cast<char*>("stride_h"), NULL}, | |||||
| {const_cast<char*>("stride_w"), py_get_generic(RegionRestrictedConvolutionBackwardData, stride_w), py_set_generic(RegionRestrictedConvolutionBackwardData, stride_w), const_cast<char*>("stride_w"), NULL}, | |||||
| {const_cast<char*>("dilate_h"), py_get_generic(RegionRestrictedConvolutionBackwardData, dilate_h), py_set_generic(RegionRestrictedConvolutionBackwardData, dilate_h), const_cast<char*>("dilate_h"), NULL}, | |||||
| {const_cast<char*>("dilate_w"), py_get_generic(RegionRestrictedConvolutionBackwardData, dilate_w), py_set_generic(RegionRestrictedConvolutionBackwardData, dilate_w), const_cast<char*>("dilate_w"), NULL}, | |||||
| {const_cast<char*>("sparse"), py_get_generic(RegionRestrictedConvolutionBackwardData, sparse), py_set_generic(RegionRestrictedConvolutionBackwardData, sparse), const_cast<char*>("sparse"), NULL}, | |||||
| {const_cast<char*>("format"), py_get_generic(RegionRestrictedConvolutionBackwardData, format), py_set_generic(RegionRestrictedConvolutionBackwardData, format), const_cast<char*>("format"), NULL}, | |||||
| {const_cast<char*>("compute_mode"), py_get_generic(RegionRestrictedConvolutionBackwardData, compute_mode), py_set_generic(RegionRestrictedConvolutionBackwardData, compute_mode), const_cast<char*>("compute_mode"), NULL}, | |||||
| {NULL} /* Sentinel */ | |||||
| }; | |||||
| PyMethodDef PyOp(RegionRestrictedConvolutionBackwardData)::tp_methods[] = { | |||||
| {const_cast<char*>("__getstate__"), PyOp(RegionRestrictedConvolutionBackwardData)::getstate, METH_NOARGS, "RegionRestrictedConvolutionBackwardData getstate"}, | |||||
| {const_cast<char*>("__setstate__"), PyOp(RegionRestrictedConvolutionBackwardData)::setstate, METH_VARARGS, "RegionRestrictedConvolutionBackwardData setstate"}, | |||||
| {NULL} /* Sentinel */ | |||||
| }; | |||||
| void _init_py_RegionRestrictedConvolutionBackwardData(py::module m) { | |||||
| using py_op = PyOp(RegionRestrictedConvolutionBackwardData); | |||||
| auto& py_type = PyOpType(RegionRestrictedConvolutionBackwardData); | |||||
| py_type = {PyVarObject_HEAD_INIT(NULL, 0)}; | |||||
| py_type.tp_name = "megengine.core._imperative_rt.ops.RegionRestrictedConvolutionBackwardData"; | |||||
| py_type.tp_basicsize = sizeof(PyOp(RegionRestrictedConvolutionBackwardData)); | |||||
| py_type.tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE; | |||||
| py_type.tp_doc = "RegionRestrictedConvolutionBackwardData"; | |||||
| py_type.tp_base = &PyOpType(OpDef); | |||||
| py_type.tp_dealloc = py_dealloc_generic<py_op>; | |||||
| py_type.tp_new = py_new_generic<py_op>; | |||||
| py_type.tp_init = py_op::py_init; | |||||
| py_type.tp_methods = py_op::tp_methods; | |||||
| py_type.tp_getset = py_op::py_getsetters; | |||||
| mgb_assert(PyType_Ready(&py_type) >= 0); | |||||
| _init_py_RegionRestrictedConvolutionBackwardData_Mode(py_type); | |||||
| _init_py_RegionRestrictedConvolutionBackwardData_Sparse(py_type); | |||||
| _init_py_RegionRestrictedConvolutionBackwardData_Format(py_type); | |||||
| _init_py_RegionRestrictedConvolutionBackwardData_ComputeMode(py_type); | |||||
| PyType_Modified(&py_type); | |||||
| m.add_object("RegionRestrictedConvolutionBackwardData", reinterpret_cast<PyObject*>(&py_type)); | |||||
| mgb_assert(PyOp(OpDef)::ctype2pytype.emplace(RegionRestrictedConvolutionBackwardData::typeinfo(), &py_type).second); | |||||
| } | |||||
| template<> struct EnumTrait<Remap::InterpolationMode> { | template<> struct EnumTrait<Remap::InterpolationMode> { | ||||
| static constexpr const char *name = "Remap.InterpolationMode"; | static constexpr const char *name = "Remap.InterpolationMode"; | ||||
| static constexpr std::underlying_type_t<Remap::InterpolationMode> max = 5 - 1; | static constexpr std::underlying_type_t<Remap::InterpolationMode> max = 5 - 1; | ||||
| @@ -18700,6 +19274,8 @@ void _init_py_WarpPerspectiveBackwardMat(py::module m) { | |||||
| _init_py_ROIAlign(m); \ | _init_py_ROIAlign(m); \ | ||||
| _init_py_ROIPooling(m); \ | _init_py_ROIPooling(m); \ | ||||
| _init_py_Reduce(m); \ | _init_py_Reduce(m); \ | ||||
| _init_py_RegionRestrictedConvolution(m); \ | |||||
| _init_py_RegionRestrictedConvolutionBackwardData(m); \ | |||||
| _init_py_Remap(m); \ | _init_py_Remap(m); \ | ||||
| _init_py_RemoteRecv(m); \ | _init_py_RemoteRecv(m); \ | ||||
| _init_py_RemoteSend(m); \ | _init_py_RemoteSend(m); \ | ||||
| @@ -1517,6 +1517,58 @@ public: | |||||
| } | } | ||||
| }; | }; | ||||
| class RegionRestrictedConvolution : public OpDefImplBase<RegionRestrictedConvolution> { | |||||
| MGB_DYN_TYPE_OBJ_FINAL_DECL; | |||||
| public: | |||||
| using Mode = ::megdnn::param::Convolution::Mode; | |||||
| using Sparse = ::megdnn::param::Convolution::Sparse; | |||||
| using Format = ::megdnn::param::Convolution::Format; | |||||
| using ComputeMode = ::megdnn::param::Convolution::ComputeMode; | |||||
| Mode mode = ::megdnn::param::Convolution::Mode::CROSS_CORRELATION; | |||||
| uint32_t pad_h = 0; | |||||
| uint32_t pad_w = 0; | |||||
| uint32_t stride_h = 1; | |||||
| uint32_t stride_w = 1; | |||||
| uint32_t dilate_h = 1; | |||||
| uint32_t dilate_w = 1; | |||||
| Sparse sparse = ::megdnn::param::Convolution::Sparse::DENSE; | |||||
| Format format = ::megdnn::param::Convolution::Format::NCHW; | |||||
| ComputeMode compute_mode = ::megdnn::param::Convolution::ComputeMode::DEFAULT; | |||||
| RegionRestrictedConvolution() = default; | |||||
| RegionRestrictedConvolution(Mode mode_, uint32_t pad_h_, uint32_t pad_w_, uint32_t stride_h_, uint32_t stride_w_, uint32_t dilate_h_, uint32_t dilate_w_, Sparse sparse_, Format format_, ComputeMode compute_mode_, std::string scope_ = {}): mode(mode_), pad_h(pad_h_), pad_w(pad_w_), stride_h(stride_h_), stride_w(stride_w_), dilate_h(dilate_h_), dilate_w(dilate_w_), sparse(sparse_), format(format_), compute_mode(compute_mode_) { set_scope(scope_); } | |||||
| RegionRestrictedConvolution(::megdnn::param::Convolution packed_param_0): mode(packed_param_0.mode), pad_h(packed_param_0.pad_h), pad_w(packed_param_0.pad_w), stride_h(packed_param_0.stride_h), stride_w(packed_param_0.stride_w), dilate_h(packed_param_0.dilate_h), dilate_w(packed_param_0.dilate_w), sparse(packed_param_0.sparse), format(packed_param_0.format), compute_mode(packed_param_0.compute_mode) {} | |||||
| ::megdnn::param::Convolution param() const { | |||||
| return {mode, pad_h, pad_w, stride_h, stride_w, dilate_h, dilate_w, sparse, format, compute_mode}; | |||||
| } | |||||
| }; | |||||
| class RegionRestrictedConvolutionBackwardData : public OpDefImplBase<RegionRestrictedConvolutionBackwardData> { | |||||
| MGB_DYN_TYPE_OBJ_FINAL_DECL; | |||||
| public: | |||||
| using Mode = ::megdnn::param::Convolution::Mode; | |||||
| using Sparse = ::megdnn::param::Convolution::Sparse; | |||||
| using Format = ::megdnn::param::Convolution::Format; | |||||
| using ComputeMode = ::megdnn::param::Convolution::ComputeMode; | |||||
| Mode mode = ::megdnn::param::Convolution::Mode::CROSS_CORRELATION; | |||||
| uint32_t pad_h = 0; | |||||
| uint32_t pad_w = 0; | |||||
| uint32_t stride_h = 1; | |||||
| uint32_t stride_w = 1; | |||||
| uint32_t dilate_h = 1; | |||||
| uint32_t dilate_w = 1; | |||||
| Sparse sparse = ::megdnn::param::Convolution::Sparse::DENSE; | |||||
| Format format = ::megdnn::param::Convolution::Format::NCHW; | |||||
| ComputeMode compute_mode = ::megdnn::param::Convolution::ComputeMode::DEFAULT; | |||||
| RegionRestrictedConvolutionBackwardData() = default; | |||||
| RegionRestrictedConvolutionBackwardData(Mode mode_, uint32_t pad_h_, uint32_t pad_w_, uint32_t stride_h_, uint32_t stride_w_, uint32_t dilate_h_, uint32_t dilate_w_, Sparse sparse_, Format format_, ComputeMode compute_mode_, std::string scope_ = {}): mode(mode_), pad_h(pad_h_), pad_w(pad_w_), stride_h(stride_h_), stride_w(stride_w_), dilate_h(dilate_h_), dilate_w(dilate_w_), sparse(sparse_), format(format_), compute_mode(compute_mode_) { set_scope(scope_); } | |||||
| RegionRestrictedConvolutionBackwardData(::megdnn::param::Convolution packed_param_0): mode(packed_param_0.mode), pad_h(packed_param_0.pad_h), pad_w(packed_param_0.pad_w), stride_h(packed_param_0.stride_h), stride_w(packed_param_0.stride_w), dilate_h(packed_param_0.dilate_h), dilate_w(packed_param_0.dilate_w), sparse(packed_param_0.sparse), format(packed_param_0.format), compute_mode(packed_param_0.compute_mode) {} | |||||
| ::megdnn::param::Convolution param() const { | |||||
| return {mode, pad_h, pad_w, stride_h, stride_w, dilate_h, dilate_w, sparse, format, compute_mode}; | |||||
| } | |||||
| }; | |||||
| class Remap : public OpDefImplBase<Remap> { | class Remap : public OpDefImplBase<Remap> { | ||||
| MGB_DYN_TYPE_OBJ_FINAL_DECL; | MGB_DYN_TYPE_OBJ_FINAL_DECL; | ||||
| @@ -1620,6 +1620,52 @@ ReduceInst | |||||
| .def_readwrite("data_type", &Reduce::data_type) | .def_readwrite("data_type", &Reduce::data_type) | ||||
| .def_readwrite("keepdim", &Reduce::keepdim); | .def_readwrite("keepdim", &Reduce::keepdim); | ||||
| py::class_<RegionRestrictedConvolution, std::shared_ptr<RegionRestrictedConvolution>, OpDef> RegionRestrictedConvolutionInst(m, "RegionRestrictedConvolution"); | |||||
| RegionRestrictedConvolutionInst.attr("Mode") = BatchConvBiasInst.attr("Mode"); | |||||
| RegionRestrictedConvolutionInst.attr("Sparse") = BatchConvBiasInst.attr("Sparse"); | |||||
| RegionRestrictedConvolutionInst.attr("Format") = AdaptivePoolingInst.attr("Format"); | |||||
| RegionRestrictedConvolutionInst.attr("ComputeMode") = BatchConvBiasInst.attr("ComputeMode"); | |||||
| RegionRestrictedConvolutionInst | |||||
| .def(py::init<::megdnn::param::Convolution::Mode, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, ::megdnn::param::Convolution::Sparse, ::megdnn::param::Convolution::Format, ::megdnn::param::Convolution::ComputeMode, std::string>(), py::arg("mode") = ::megdnn::param::Convolution::Mode::CROSS_CORRELATION, py::arg("pad_h") = 0, py::arg("pad_w") = 0, py::arg("stride_h") = 1, py::arg("stride_w") = 1, py::arg("dilate_h") = 1, py::arg("dilate_w") = 1, py::arg("sparse") = ::megdnn::param::Convolution::Sparse::DENSE, py::arg("format") = ::megdnn::param::Convolution::Format::NCHW, py::arg("compute_mode") = ::megdnn::param::Convolution::ComputeMode::DEFAULT, py::arg("scope") = {}) | |||||
| .def_readwrite("mode", &RegionRestrictedConvolution::mode) | |||||
| .def_readwrite("pad_h", &RegionRestrictedConvolution::pad_h) | |||||
| .def_readwrite("pad_w", &RegionRestrictedConvolution::pad_w) | |||||
| .def_readwrite("stride_h", &RegionRestrictedConvolution::stride_h) | |||||
| .def_readwrite("stride_w", &RegionRestrictedConvolution::stride_w) | |||||
| .def_readwrite("dilate_h", &RegionRestrictedConvolution::dilate_h) | |||||
| .def_readwrite("dilate_w", &RegionRestrictedConvolution::dilate_w) | |||||
| .def_readwrite("sparse", &RegionRestrictedConvolution::sparse) | |||||
| .def_readwrite("format", &RegionRestrictedConvolution::format) | |||||
| .def_readwrite("compute_mode", &RegionRestrictedConvolution::compute_mode); | |||||
| py::class_<RegionRestrictedConvolutionBackwardData, std::shared_ptr<RegionRestrictedConvolutionBackwardData>, OpDef> RegionRestrictedConvolutionBackwardDataInst(m, "RegionRestrictedConvolutionBackwardData"); | |||||
| RegionRestrictedConvolutionBackwardDataInst.attr("Mode") = BatchConvBiasInst.attr("Mode"); | |||||
| RegionRestrictedConvolutionBackwardDataInst.attr("Sparse") = BatchConvBiasInst.attr("Sparse"); | |||||
| RegionRestrictedConvolutionBackwardDataInst.attr("Format") = AdaptivePoolingInst.attr("Format"); | |||||
| RegionRestrictedConvolutionBackwardDataInst.attr("ComputeMode") = BatchConvBiasInst.attr("ComputeMode"); | |||||
| RegionRestrictedConvolutionBackwardDataInst | |||||
| .def(py::init<::megdnn::param::Convolution::Mode, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, ::megdnn::param::Convolution::Sparse, ::megdnn::param::Convolution::Format, ::megdnn::param::Convolution::ComputeMode, std::string>(), py::arg("mode") = ::megdnn::param::Convolution::Mode::CROSS_CORRELATION, py::arg("pad_h") = 0, py::arg("pad_w") = 0, py::arg("stride_h") = 1, py::arg("stride_w") = 1, py::arg("dilate_h") = 1, py::arg("dilate_w") = 1, py::arg("sparse") = ::megdnn::param::Convolution::Sparse::DENSE, py::arg("format") = ::megdnn::param::Convolution::Format::NCHW, py::arg("compute_mode") = ::megdnn::param::Convolution::ComputeMode::DEFAULT, py::arg("scope") = {}) | |||||
| .def_readwrite("mode", &RegionRestrictedConvolutionBackwardData::mode) | |||||
| .def_readwrite("pad_h", &RegionRestrictedConvolutionBackwardData::pad_h) | |||||
| .def_readwrite("pad_w", &RegionRestrictedConvolutionBackwardData::pad_w) | |||||
| .def_readwrite("stride_h", &RegionRestrictedConvolutionBackwardData::stride_h) | |||||
| .def_readwrite("stride_w", &RegionRestrictedConvolutionBackwardData::stride_w) | |||||
| .def_readwrite("dilate_h", &RegionRestrictedConvolutionBackwardData::dilate_h) | |||||
| .def_readwrite("dilate_w", &RegionRestrictedConvolutionBackwardData::dilate_w) | |||||
| .def_readwrite("sparse", &RegionRestrictedConvolutionBackwardData::sparse) | |||||
| .def_readwrite("format", &RegionRestrictedConvolutionBackwardData::format) | |||||
| .def_readwrite("compute_mode", &RegionRestrictedConvolutionBackwardData::compute_mode); | |||||
| py::class_<Remap, std::shared_ptr<Remap>, OpDef> RemapInst(m, "Remap"); | py::class_<Remap, std::shared_ptr<Remap>, OpDef> RemapInst(m, "Remap"); | ||||
| py::enum_<Remap::InterpolationMode>(RemapInst, "InterpolationMode") | py::enum_<Remap::InterpolationMode>(RemapInst, "InterpolationMode") | ||||
| @@ -520,4 +520,9 @@ def MeshGrid: MgbHashableOp<"MeshGrid"> { | |||||
| MgbStringAttr:$indexing | MgbStringAttr:$indexing | ||||
| ); | ); | ||||
| } | } | ||||
| def RegionRestrictedConvolution: MgbHashableOp<"RegionRestrictedConvolution", [ConvolutionParam]>; | |||||
| def RegionRestrictedConvolutionBackwardData: MgbHashableOp<"RegionRestrictedConvolutionBackwardData", [ConvolutionParam]>; | |||||
| #endif // MGB_OPS | #endif // MGB_OPS | ||||
| @@ -25,6 +25,58 @@ using namespace cg::static_infer; | |||||
| using intl::WorkspaceLimitGetter; | using intl::WorkspaceLimitGetter; | ||||
| /* ==================== misc impl ==================== */ | /* ==================== misc impl ==================== */ | ||||
| template <typename MGBOPR, typename DNNOPR> | |||||
| void mixin::RegionConvBackwardDataMixin::init_output_static_infer_desc_for_bwd_data( | |||||
| cg::OperatorNodeBase* self) { | |||||
| using namespace cg::static_infer; | |||||
| auto&& mgr = self->owner_graph()->static_infer_manager(); | |||||
| DepVal inp_deps; | |||||
| inp_deps.reserve(6); | |||||
| for (int i = 0; i < 4; i++) { | |||||
| inp_deps.push_back({self->input(i), DepType::SHAPE}); | |||||
| } | |||||
| auto infer_shp = [self](TensorShape& dest, const InpVal& inp) { | |||||
| TensorLayout ol{self->output(0)->dtype()}; | |||||
| mgb_assert( | |||||
| self->input(0)->dtype().category() == DTypeCategory::FLOAT && | |||||
| self->input(1)->dtype().category() == DTypeCategory::FLOAT && | |||||
| self->input(2)->dtype().category() == DTypeCategory::INT && | |||||
| self->input(3)->dtype().category() == DTypeCategory::INT, | |||||
| "region conv dtype assert error!"); | |||||
| static_cast<MGBOPR*>(self)->megdnn_opr()->deduce_layout( | |||||
| {inp.val.at(0).shape(), self->input(0)->dtype()}, // filter | |||||
| {inp.val.at(1).shape(), self->input(1)->dtype()}, // diff | |||||
| {inp.val.at(2).shape(), self->input(2)->dtype()}, // rin | |||||
| {inp.val.at(3).shape(), self->input(3)->dtype()}, // rout | |||||
| ol // grad | |||||
| ); | |||||
| dest = ol; | |||||
| return true; | |||||
| }; | |||||
| mgr.register_shape_infer(self->output(0), {SourceType::DEP, inp_deps, infer_shp}); | |||||
| // workspace size | |||||
| auto infer_wk = [self](TensorShape& dest, const InpVal& inp) { | |||||
| TensorLayout ol{self->output(0)->dtype()}; | |||||
| dest.ndim = 1; | |||||
| dest.shape[0] = | |||||
| static_cast<MGBOPR*>(self)->megdnn_opr()->get_workspace_in_bytes( | |||||
| {self->input(0)->shape(), self->input(0)->dtype()}, // filter | |||||
| {self->input(1)->shape(), self->input(1)->dtype()}, // diff | |||||
| {self->input(2)->shape(), self->input(2)->dtype()}, // rin | |||||
| {self->input(3)->shape(), self->input(3)->dtype()}, // rout | |||||
| ol); | |||||
| return true; | |||||
| }; | |||||
| inp_deps.push_back({self->output(0), DepType::SHAPE}); | |||||
| auto workspace_dep_var = | |||||
| intl::WorkspaceLimitGetter::register_to_graph(self->owner_graph()); | |||||
| if (workspace_dep_var) | |||||
| inp_deps.push_back({workspace_dep_var, DepType::VALUE}); | |||||
| mgr.register_shape_infer(self->output(1), {SourceType::DEP, inp_deps, infer_wk}); | |||||
| } | |||||
| template <class MgbOpr, class MegDNNOpr> | template <class MgbOpr, class MegDNNOpr> | ||||
| void mixin::ConvolutionBackwardDataMixin::init_output_static_infer_desc_for_bwd_data( | void mixin::ConvolutionBackwardDataMixin::init_output_static_infer_desc_for_bwd_data( | ||||
| @@ -1535,6 +1587,226 @@ void BatchConvBiasForward::init_output_format() { | |||||
| output(0)->format(input(0)->format()); | output(0)->format(input(0)->format()); | ||||
| } | } | ||||
| /* ========================== RegionRestrictedConvolutionForward | |||||
| * ========================== */ | |||||
| IMPL_CONV(RegionRestrictedConvolutionForward); | |||||
| RegionRestrictedConvolutionForward::RegionRestrictedConvolutionForward( | |||||
| VarNode* src, VarNode* filter, VarNode* region_in, VarNode* region_out, | |||||
| const Param& param, const OperatorNodeConfig& config) | |||||
| : Super(src->owner_graph(), config, "region_restricted_conv_fwd", | |||||
| {src, filter, region_in, region_out}) { | |||||
| init_megdnn_opr(*this, param); | |||||
| add_input({src, filter, region_in, region_out}); | |||||
| } | |||||
| SymbolVar RegionRestrictedConvolutionForward::make( | |||||
| SymbolVar src, SymbolVar filter, SymbolVar region_in, SymbolVar region_out, | |||||
| const Param& param, const OperatorNodeConfig& config) { | |||||
| return src.insert_single_output_opr<RegionRestrictedConvolutionForward>( | |||||
| src.node(), filter.node(), region_in.node(), region_out.node(), param, | |||||
| config); | |||||
| } | |||||
| void RegionRestrictedConvolutionForward::init_output_dtype() { | |||||
| mgb_assert( | |||||
| input(0)->dtype().category() == DTypeCategory::FLOAT, | |||||
| "input dtype only support FLOAT, \ | |||||
| but got input dtype: %s", | |||||
| input(0)->dtype().name()); | |||||
| output(0)->dtype(input(0)->dtype()); | |||||
| return; | |||||
| } | |||||
| size_t RegionRestrictedConvolutionForward::get_workspace_size_bytes( | |||||
| const TensorShapeArray& input_shapes, | |||||
| const TensorShapeArray& output_shapes) const { | |||||
| return megdnn_opr()->get_workspace_in_bytes( | |||||
| {input_shapes[0], input(0)->dtype(), input(0)->format()}, | |||||
| {input_shapes[1], input(1)->dtype(), input(1)->format()}, | |||||
| {input_shapes[2], input(2)->dtype(), input(2)->format()}, | |||||
| {input_shapes[3], input(3)->dtype(), input(3)->format()}, | |||||
| {output_shapes[0], output(0)->dtype(), output(0)->format()}); | |||||
| } | |||||
| #if MGB_ENABLE_GRAD | |||||
| MGB_IMPL_OPR_GRAD(RegionRestrictedConvolutionForward) { | |||||
| mgb_assert( | |||||
| opr.input(0)->dtype().category() == DTypeCategory::FLOAT && | |||||
| opr.input(1)->dtype().category() == DTypeCategory::FLOAT && | |||||
| opr.input(2)->dtype().category() == DTypeCategory::INT && | |||||
| opr.input(3)->dtype().category() == DTypeCategory::INT, | |||||
| "only float data type supported for grad"); | |||||
| if (wrt_idx == 0) { // src | |||||
| SymbolVar grad = RegionRestrictedConvolutionBackwardData::make( | |||||
| opr.input(1), // filter | |||||
| out_grad[0], // diff | |||||
| opr.input(2), // rin | |||||
| opr.input(3), // rout | |||||
| opr.input(0), // src | |||||
| opr.param()); | |||||
| return grad.node(); | |||||
| } | |||||
| // TODO: CUDA WGRAD UNIMPLEMENTED! | |||||
| if (wrt_idx == 1) { // filter | |||||
| SymbolVar grad = RegionRestrictedConvolutionBackwardFilter::make( | |||||
| opr.input(0), // src | |||||
| out_grad[0], // diff | |||||
| opr.input(2), // rin | |||||
| opr.input(3), // rout | |||||
| opr.input(1), // filter | |||||
| opr.param()); | |||||
| return grad.node(); | |||||
| } | |||||
| return nullptr; | |||||
| } | |||||
| #endif | |||||
| /* ========================== RegionRestrictedConvolutionBackwardData | |||||
| * ========================== */ | |||||
| IMPL_CONV(RegionRestrictedConvolutionBackwardData); | |||||
| RegionRestrictedConvolutionBackwardData::RegionRestrictedConvolutionBackwardData( | |||||
| VarNode* filter, VarNode* diff, VarNode* region_in, VarNode* region_out, | |||||
| VarNode* src, const Param& param, const OperatorNodeConfig& config) | |||||
| : Super{filter->owner_graph(), | |||||
| config, | |||||
| "region_restricted_conv_bwd_data", | |||||
| {filter, diff, region_in, region_out}} { | |||||
| init_megdnn_opr(*this, param); | |||||
| add_input({filter, diff, region_in, region_out}); | |||||
| if (src) | |||||
| add_input({src}); | |||||
| } | |||||
| SymbolVar RegionRestrictedConvolutionBackwardData::make( | |||||
| SymbolVar filter, SymbolVar diff, SymbolVar region_in, SymbolVar region_out, | |||||
| SymbolVar src, const Param& param, const OperatorNodeConfig& config) { | |||||
| return filter.insert_single_output_opr<RegionRestrictedConvolutionBackwardData>( | |||||
| filter.node(), diff.node(), region_in.node(), region_out.node(), src.node(), | |||||
| param, config); | |||||
| } | |||||
| SymbolVar RegionRestrictedConvolutionBackwardData::make( | |||||
| SymbolVar filter, SymbolVar diff, SymbolVar region_in, SymbolVar region_out, | |||||
| const Param& param, const OperatorNodeConfig& config) { | |||||
| return make(filter, diff, region_in, region_out, {}, param, config); | |||||
| } | |||||
| void RegionRestrictedConvolutionBackwardData::init_output_static_infer_desc() { | |||||
| init_output_static_infer_desc_for_bwd_data< | |||||
| RegionRestrictedConvolutionBackwardData, | |||||
| megdnn::RegionRestrictedConvolutionBackwardData>(this); | |||||
| } | |||||
| void RegionRestrictedConvolutionBackwardData::init_output_dtype() { | |||||
| output(0)->dtype(input(0)->dtype()); | |||||
| } | |||||
| void RegionRestrictedConvolutionBackwardData::scn_do_execute() { | |||||
| megdnn_opr()->exec( | |||||
| input(0)->dev_tensor().as_megdnn(), // filter | |||||
| input(1)->dev_tensor().as_megdnn(), // diff | |||||
| input(2)->dev_tensor().as_megdnn(), // rin | |||||
| input(3)->dev_tensor().as_megdnn(), // rout | |||||
| output(0)->dev_tensor().as_megdnn(), | |||||
| intl::get_megdnn_workspace_from_var(output().back())); | |||||
| } | |||||
| cg::OperatorNodeBase::NodeProp* RegionRestrictedConvolutionBackwardData:: | |||||
| do_make_node_prop() const { | |||||
| auto prop = Super::Super::do_make_node_prop(); | |||||
| if (input().size() == 5) { | |||||
| using D = NodeProp::DepType; | |||||
| prop->reset_dep_type( | |||||
| input(), | |||||
| {D::DEV_VALUE, D::DEV_VALUE, D::DEV_VALUE, D::DEV_VALUE, D::SHAPE}); | |||||
| } | |||||
| return prop; | |||||
| } | |||||
| #if MGB_ENABLE_GRAD | |||||
| MGB_IMPL_OPR_GRAD(RegionRestrictedConvolutionBackwardData) { | |||||
| if (wrt_idx == 0) { // filter | |||||
| return RegionRestrictedConvolutionBackwardFilter::make( | |||||
| out_grad[0], opr.input(1), opr.input(2), opr.input(3), | |||||
| opr.input(0), opr.param()) | |||||
| .node(); | |||||
| } | |||||
| if (wrt_idx == 1) { // diff | |||||
| return RegionRestrictedConvolution::make( | |||||
| out_grad[0], opr.input(0), opr.input(2), opr.input(3), | |||||
| opr.param()) | |||||
| .node(); | |||||
| } | |||||
| return nullptr; | |||||
| } | |||||
| #endif | |||||
| /* ========================== RegionRestrictedConvolutionBackwardFilter | |||||
| * ========================== */ | |||||
| IMPL_CONV(RegionRestrictedConvolutionBackwardFilter); | |||||
| RegionRestrictedConvolutionBackwardFilter::RegionRestrictedConvolutionBackwardFilter( | |||||
| VarNode* src, VarNode* diff, VarNode* region_in, VarNode* region_out, | |||||
| VarNode* filter, const Param& param, const OperatorNodeConfig& config) | |||||
| : Super({src->owner_graph(), | |||||
| config, | |||||
| "region_restricted_conv_bwd_filter", | |||||
| {src, diff, region_in, region_out, filter}}, | |||||
| 4, false) { | |||||
| init_megdnn_opr(*this, param); | |||||
| add_input({src, diff, region_in, region_out, filter}); | |||||
| } | |||||
| SymbolVar RegionRestrictedConvolutionBackwardFilter::make( | |||||
| SymbolVar src, SymbolVar diff, SymbolVar region_in, SymbolVar region_out, | |||||
| SymbolVar filter, const Param& param, const OperatorNodeConfig& config) { | |||||
| return src.insert_single_output_opr<RegionRestrictedConvolutionBackwardFilter>( | |||||
| src.node(), diff.node(), region_in.node(), region_out.node(), filter.node(), | |||||
| param, config); | |||||
| } | |||||
| size_t RegionRestrictedConvolutionBackwardFilter::get_workspace_size_bytes( | |||||
| const TensorShapeArray& input_shapes, | |||||
| const TensorShapeArray& output_shapes) const { | |||||
| return megdnn_opr()->get_workspace_in_bytes( | |||||
| {input_shapes[0], input(0)->dtype(), input(0)->format()}, | |||||
| {input_shapes[1], input(1)->dtype(), input(1)->format()}, | |||||
| {input_shapes[2], input(2)->dtype(), input(2)->format()}, | |||||
| {input_shapes[3], input(3)->dtype(), input(3)->format()}, | |||||
| {output_shapes[0], output(0)->dtype(), output(0)->format()}); | |||||
| } | |||||
| void RegionRestrictedConvolutionBackwardFilter::scn_do_execute() { | |||||
| megdnn_opr()->exec( | |||||
| input(0)->dev_tensor().as_megdnn(), // src | |||||
| input(1)->dev_tensor().as_megdnn(), // diff | |||||
| input(2)->dev_tensor().as_megdnn(), // rin | |||||
| input(3)->dev_tensor().as_megdnn(), // rout | |||||
| output(0)->dev_tensor().as_megdnn(), | |||||
| intl::get_megdnn_workspace_from_var(output().back())); | |||||
| } | |||||
| #if MGB_ENABLE_GRAD | |||||
| MGB_IMPL_OPR_GRAD(RegionRestrictedConvolutionBackwardFilter) { | |||||
| if (wrt_idx == 0) { | |||||
| return RegionRestrictedConvolutionBackwardData::make( | |||||
| out_grad[0] /*filter*/, opr.input(1) /*diff*/, | |||||
| opr.input(2) /*rin*/, opr.input(3) /*rout*/, | |||||
| opr.input(0) /*src*/, opr.param()) | |||||
| .node(); | |||||
| } | |||||
| if (wrt_idx == 1) { | |||||
| return RegionRestrictedConvolution::make( | |||||
| opr.input(0) /*src*/, out_grad[0] /*filter*/, | |||||
| opr.input(2) /*rin*/, opr.input(3) /*rout*/, opr.param()) | |||||
| .node(); | |||||
| } | |||||
| return nullptr; | |||||
| } | |||||
| #endif | |||||
| #undef IMPL_CONV | #undef IMPL_CONV | ||||
| // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | ||||
| @@ -431,6 +431,7 @@ struct OprLoadDumpImpl<opr::Convolution3DBackwardFilter, 0> | |||||
| MakeConvCallerEmpty<megdnn::Convolution3D>, | MakeConvCallerEmpty<megdnn::Convolution3D>, | ||||
| MakeConvCallerEmpty<megdnn::Convolution3D>, | MakeConvCallerEmpty<megdnn::Convolution3D>, | ||||
| megdnn::param::Convolution3D> {}; | megdnn::param::Convolution3D> {}; | ||||
| template <> | template <> | ||||
| struct OprLoadDumpImpl<opr::ConvBiasForward, 0> | struct OprLoadDumpImpl<opr::ConvBiasForward, 0> | ||||
| : public ConvLoadDumpImpl< | : public ConvLoadDumpImpl< | ||||
| @@ -194,6 +194,30 @@ struct OprLoadDumpImplV2<opr::DeformableConvBackwardFilter, 0> | |||||
| MakeConvCaller5<megdnn::DeformableConvBackwardFilter>, | MakeConvCaller5<megdnn::DeformableConvBackwardFilter>, | ||||
| megdnn::Convolution> {}; | megdnn::Convolution> {}; | ||||
| template <> | |||||
| struct OprMaker<opr::RegionRestrictedConvolutionBackwardData, 0> { | |||||
| using Opr = opr::RegionRestrictedConvolutionBackwardData; | |||||
| using Param = Opr::Param; | |||||
| static cg::OperatorNodeBase* make( | |||||
| const Param& param, const cg::VarNodeArray& inputs, ComputingGraph& graph, | |||||
| const OperatorNodeConfig& config) { | |||||
| MGB_MARK_USED_VAR(graph); | |||||
| if (inputs.size() == 4) { // deconv mode | |||||
| return Opr::make(inputs[0], inputs[1], inputs[2], inputs[3], param, config) | |||||
| .node() | |||||
| ->owner_opr(); | |||||
| } else if (inputs.size() == 5) { // dgrad mode | |||||
| return Opr::make( | |||||
| inputs[0], inputs[1], inputs[2], inputs[3], inputs[4], param, | |||||
| config) | |||||
| .node() | |||||
| ->owner_opr(); | |||||
| } else { | |||||
| return nullptr; | |||||
| } | |||||
| } | |||||
| }; | |||||
| } // namespace serialization | } // namespace serialization | ||||
| namespace opr { | namespace opr { | ||||
| @@ -220,6 +244,10 @@ SERGE_OPR_V2_NO_CONVERTER(Convolution3D, 0); | |||||
| SERGE_OPR_V2_NO_CONVERTER(Convolution3DBackwardData, 0); | SERGE_OPR_V2_NO_CONVERTER(Convolution3DBackwardData, 0); | ||||
| SERGE_OPR_V2_NO_CONVERTER(Convolution3DBackwardFilter, 0); | SERGE_OPR_V2_NO_CONVERTER(Convolution3DBackwardFilter, 0); | ||||
| MGB_SEREG_OPR(RegionRestrictedConvolutionBackwardData, 0); | |||||
| MGB_SEREG_OPR(RegionRestrictedConvolution, 4); | |||||
| MGB_SEREG_OPR(RegionRestrictedConvolutionBackwardFilter, 5); | |||||
| SERGE_OPR_V2_NO_CONVERTER(LocalShareForward, 0); | SERGE_OPR_V2_NO_CONVERTER(LocalShareForward, 0); | ||||
| SERGE_OPR_V2_NO_CONVERTER(LocalShareBackwardData, 0); | SERGE_OPR_V2_NO_CONVERTER(LocalShareBackwardData, 0); | ||||
| SERGE_OPR_V2_NO_CONVERTER(LocalShareBackwardFilter, 0); | SERGE_OPR_V2_NO_CONVERTER(LocalShareBackwardFilter, 0); | ||||
| @@ -18,6 +18,12 @@ protected: | |||||
| static void init_output_static_infer_desc_for_bwd_data(cg::OperatorNodeBase* self); | static void init_output_static_infer_desc_for_bwd_data(cg::OperatorNodeBase* self); | ||||
| }; | }; | ||||
| class RegionConvBackwardDataMixin : public cg::OperatorNodeMixinBase { | |||||
| protected: | |||||
| template <typename MGBOPR, typename DNNOPR> | |||||
| static void init_output_static_infer_desc_for_bwd_data(cg::OperatorNodeBase* self); | |||||
| }; | |||||
| class WeightPreprocessExecutor : public cg::OperatorNodeMixinBase { | class WeightPreprocessExecutor : public cg::OperatorNodeMixinBase { | ||||
| class PreprocessedFilterExecDep; | class PreprocessedFilterExecDep; | ||||
| @@ -83,6 +89,80 @@ class ConvolutionTestingPeer; | |||||
| } // namespace testing | } // namespace testing | ||||
| /* ==================== RegionRestrictedConvolutionForward ==================== */ | |||||
| MGB_DEFINE_OPR_CLASS_WITH_EXPORT( | |||||
| RegionRestrictedConvolutionForward, | |||||
| intl::MegDNNOprWrapperFwd<megdnn::RegionRestrictedConvolutionForward>) // { | |||||
| size_t get_workspace_size_bytes( | |||||
| const TensorShapeArray& input_shapes, | |||||
| const TensorShapeArray& output_shapes) const override; | |||||
| void init_output_dtype() override; | |||||
| public: | |||||
| MGE_WIN_DECLSPEC_FUC RegionRestrictedConvolutionForward( | |||||
| VarNode* src, VarNode* filter, VarNode* region_in, VarNode* region_out, | |||||
| const Param& param, const OperatorNodeConfig& config); | |||||
| MGE_WIN_DECLSPEC_FUC static SymbolVar make( | |||||
| SymbolVar src, SymbolVar filter, SymbolVar region_in, SymbolVar region_out, | |||||
| const Param& param, const OperatorNodeConfig& config = {}); | |||||
| }; | |||||
| using RegionRestrictedConvolution = RegionRestrictedConvolutionForward; | |||||
| /* ==================== RegionRestrictedConvolutionBackwardData ==================== */ | |||||
| MGB_DEFINE_OPR_CLASS_WITH_EXPORT( | |||||
| RegionRestrictedConvolutionBackwardData, | |||||
| cg::SingleCNOperatorNodeBaseT<mixin::MegDNNOprHolderImpl< | |||||
| megdnn::RegionRestrictedConvolutionBackwardData>>, | |||||
| public mixin::RegionConvBackwardDataMixin) // { | |||||
| void scn_do_execute() override; | |||||
| void init_output_static_infer_desc() override; | |||||
| NodeProp* do_make_node_prop() const override; | |||||
| void init_output_dtype() override; | |||||
| public: | |||||
| MGE_WIN_DECLSPEC_FUC RegionRestrictedConvolutionBackwardData( | |||||
| VarNode* filter, VarNode* diff, VarNode* region_in, VarNode* region_out, | |||||
| VarNode* src, const Param& param, const OperatorNodeConfig& config); | |||||
| // grad mode | |||||
| MGE_WIN_DECLSPEC_FUC static SymbolVar make( | |||||
| SymbolVar filter, SymbolVar diff, SymbolVar region_in, SymbolVar region_out, | |||||
| SymbolVar src, const Param& param, const OperatorNodeConfig& config = {}); | |||||
| // sereg for deconv mode | |||||
| MGE_WIN_DECLSPEC_FUC static SymbolVar make( | |||||
| SymbolVar filter, SymbolVar diff, SymbolVar region_in, SymbolVar region_out, | |||||
| const Param& param, const OperatorNodeConfig& config = {}); | |||||
| // user interface for deconv | |||||
| MGE_WIN_DECLSPEC_FUC static SymbolVar make_deconv( | |||||
| SymbolVar data, SymbolVar filter, SymbolVar region_in, SymbolVar region_out, | |||||
| const Param& param = {}, const OperatorNodeConfig& config = {}) { | |||||
| return make(filter, data, region_in, region_out, param, config); | |||||
| } | |||||
| }; | |||||
| /* ==================== RegionRestrictedConvolutionBackwardFilter ==================== */ | |||||
| MGB_DEFINE_OPR_CLASS_WITH_EXPORT( | |||||
| RegionRestrictedConvolutionBackwardFilter, | |||||
| intl::MegDNNOprWrapperBwd<megdnn::RegionRestrictedConvolutionBackwardFilter>) // { | |||||
| size_t get_workspace_size_bytes( | |||||
| const TensorShapeArray& input_shapes, | |||||
| const TensorShapeArray& output_shapes) const override; | |||||
| void scn_do_execute() override; | |||||
| public: | |||||
| MGE_WIN_DECLSPEC_FUC RegionRestrictedConvolutionBackwardFilter( | |||||
| VarNode* src, VarNode* diff, VarNode* region_in, VarNode* region_out, | |||||
| VarNode* filter, const Param& param, const OperatorNodeConfig& config); | |||||
| MGE_WIN_DECLSPEC_FUC static SymbolVar make( | |||||
| SymbolVar src, SymbolVar diff, SymbolVar region_in, SymbolVar region_out, | |||||
| SymbolVar filter, const Param& param, | |||||
| const OperatorNodeConfig& config = {}); | |||||
| }; | |||||
| MGB_DEFINE_OPR_CLASS_WITH_EXPORT( | MGB_DEFINE_OPR_CLASS_WITH_EXPORT( | ||||
| ConvolutionForward, intl::ConvolutionForwardBase, | ConvolutionForward, intl::ConvolutionForwardBase, | ||||
| public mixin::AlgoChooserHelper) // { | public mixin::AlgoChooserHelper) // { | ||||
| @@ -0,0 +1,196 @@ | |||||
| #include "./legacy_checker.h" | |||||
| #include "megbrain/comp_node_env.h" | |||||
| #include "megbrain/gopt/inference.h" | |||||
| #include "megbrain/opr/basic_arith.h" | |||||
| #include "megbrain/opr/dnn/convolution.h" | |||||
| #include "megbrain/opr/tensor_manip.h" | |||||
| #include "megbrain/serialization/serializer.h" | |||||
| #include "megbrain/test/autocheck.h" | |||||
| #include "megbrain/test/helper.h" | |||||
| #include "megbrain/test/megdnn_helper.h" | |||||
| #include "megdnn/algorithm_cache.h" | |||||
| #include "megdnn/dtype.h" | |||||
| #include "megdnn/oprs/base.h" | |||||
| #include <gmock/gmock.h> | |||||
| #include <cmath> | |||||
| #include <memory> | |||||
| #include <random> | |||||
| using namespace mgb; | |||||
| TEST(TestOprDNN, REGIONCONV_FWD_CPU_WRAPPER) { | |||||
| using Checker = AutoOprChecker<4, 1>; | |||||
| megdnn::RegionRestrictedConvolution::Param param; | |||||
| param.sparse = opr::RegionRestrictedConvolution::Param::Sparse::DENSE; | |||||
| auto make_graph = [&](const Checker::SymInpArray& inputs) -> Checker::SymOutArray { | |||||
| return {opr::RegionRestrictedConvolutionForward::make( | |||||
| inputs[0], inputs[1], inputs[2], inputs[3], param)}; | |||||
| }; | |||||
| Checker::RunOptions option; | |||||
| option.numdiff_eps = 0.1; | |||||
| option.numdiff_max_err = 1e-2; | |||||
| auto mask_gen = [&](HostTensorND& src) { | |||||
| HostTensorGenerator<dtype::Int32, RandomDistribution::CONSTANT> gen(1); | |||||
| src = *gen(src.shape(), src.comp_node()); | |||||
| }; | |||||
| auto float_gen = [&](HostTensorND& src) { | |||||
| HostTensorGenerator<dtype::Float32, RandomDistribution::GAUSSIAN> gen; | |||||
| src = *gen(src.shape(), src.comp_node()); | |||||
| }; | |||||
| auto fwd = [&](Checker::NumOutArray& dest, Checker::NumInpArray inp) { | |||||
| auto opr = | |||||
| megdnn_naive_handle() | |||||
| ->create_operator<megdnn::RegionRestrictedConvolutionForward>(); | |||||
| opr->param() = param; | |||||
| TensorLayout dest_layout; | |||||
| opr->deduce_layout( | |||||
| inp[0]->layout(), inp[1]->layout(), inp[2]->layout(), inp[3]->layout(), | |||||
| dest_layout); | |||||
| std::vector<dt_byte> workspace(opr->get_workspace_in_bytes( | |||||
| inp[0]->layout(), inp[1]->layout(), inp[2]->layout(), inp[3]->layout(), | |||||
| dest_layout)); | |||||
| dest[0].dtype(inp[0]->dtype()) | |||||
| .comp_node(inp[0]->comp_node()) | |||||
| .resize(dest_layout); | |||||
| opr->exec( | |||||
| inp[0]->as_megdnn(), inp[1]->as_megdnn(), inp[2]->as_megdnn(), | |||||
| inp[3]->as_megdnn(), dest[0].as_megdnn(), | |||||
| {workspace.data(), workspace.size()}); | |||||
| }; | |||||
| Checker(make_graph, fwd, CompNode::load("cpu0")) | |||||
| .set_input_dtype(0, dtype::Float32()) | |||||
| .set_input_dtype(1, dtype::Float32()) | |||||
| .set_input_dtype(2, dtype::Int32()) | |||||
| .set_input_dtype(3, dtype::Int32()) | |||||
| .set_input_generator(0, float_gen) | |||||
| .set_input_generator(1, float_gen) | |||||
| .set_input_generator(2, mask_gen) | |||||
| .set_input_generator(3, mask_gen) | |||||
| .set_input_allow_grad(2, false) | |||||
| .set_input_allow_grad(3, false) | |||||
| // {n,ic,ih,iw}, {oc,ic,fh,fw}, {n,ih,iw}, {n,oh,ow} | |||||
| .run({TensorShape{1, 2, 2, 2}, TensorShape{1, 2, 2, 2}, | |||||
| TensorShape{1, 2, 2}, TensorShape{1, 1, 1}}, | |||||
| option) | |||||
| .run({TensorShape{1, 2, 3, 3}, TensorShape{1, 2, 3, 3}, | |||||
| TensorShape{1, 3, 3}, TensorShape{1, 1, 1}}, | |||||
| option) | |||||
| .run({TensorShape{1, 1, 4, 4}, TensorShape{1, 1, 2, 2}, | |||||
| TensorShape{1, 4, 4}, TensorShape{1, 3, 3}}, | |||||
| option) | |||||
| .run({TensorShape{2, 2, 8, 8}, TensorShape{4, 2, 2, 2}, | |||||
| TensorShape{2, 8, 8}, TensorShape{2, 7, 7}}, | |||||
| option) | |||||
| .run({TensorShape{4, 4, 8, 8}, TensorShape{4, 4, 2, 2}, | |||||
| TensorShape{4, 8, 8}, TensorShape{4, 7, 7}}, | |||||
| option); | |||||
| } | |||||
| #if MGB_CUDA | |||||
| TEST(TestOprDNN, REGIONCONV_FWD_GPU_WRAPPER) { | |||||
| using Checker = AutoOprChecker<4, 1>; | |||||
| megdnn::RegionRestrictedConvolution::Param param; | |||||
| param.sparse = opr::RegionRestrictedConvolution::Param::Sparse::GROUP; | |||||
| auto make_graph = [&](const Checker::SymInpArray& inputs) -> Checker::SymOutArray { | |||||
| return {opr::RegionRestrictedConvolutionForward::make( | |||||
| inputs[0], inputs[1], inputs[2], inputs[3], param)}; | |||||
| }; | |||||
| Checker::RunOptions option; | |||||
| option.numdiff_eps = 0.1; | |||||
| option.numdiff_max_err = 1e-2; | |||||
| auto mask_gen = [&](HostTensorND& src) { | |||||
| HostTensorGenerator<dtype::Int32, RandomDistribution::CONSTANT> gen(1); | |||||
| src = *gen(src.shape(), src.comp_node()); | |||||
| }; | |||||
| auto uint8_mask_gen = [&](HostTensorND& src) { | |||||
| HostTensorGenerator<dtype::Uint8, RandomDistribution::CONSTANT> gen(1); | |||||
| src = *gen(src.shape(), src.comp_node()); | |||||
| }; | |||||
| auto float_gen = [&](HostTensorND& src) { | |||||
| HostTensorGenerator<dtype::Float32, RandomDistribution::GAUSSIAN> gen; | |||||
| src = *gen(src.shape(), src.comp_node()); | |||||
| }; | |||||
| auto fwd = [&](Checker::NumOutArray& dest, Checker::NumInpArray inp) { | |||||
| auto opr = | |||||
| megdnn_naive_handle() | |||||
| ->create_operator<megdnn::RegionRestrictedConvolutionForward>(); | |||||
| opr->param() = param; | |||||
| TensorLayout dest_layout; | |||||
| opr->deduce_layout( | |||||
| inp[0]->layout(), inp[1]->layout(), inp[2]->layout(), inp[3]->layout(), | |||||
| dest_layout); | |||||
| std::vector<dt_byte> workspace(opr->get_workspace_in_bytes( | |||||
| inp[0]->layout(), inp[1]->layout(), inp[2]->layout(), inp[3]->layout(), | |||||
| dest_layout)); | |||||
| dest[0].dtype(inp[0]->dtype()) | |||||
| .comp_node(inp[0]->comp_node()) | |||||
| .resize(dest_layout); | |||||
| opr->exec( | |||||
| inp[0]->as_megdnn(), inp[1]->as_megdnn(), inp[2]->as_megdnn(), | |||||
| inp[3]->as_megdnn(), dest[0].as_megdnn(), | |||||
| {workspace.data(), workspace.size()}); | |||||
| }; | |||||
| Checker(make_graph, fwd, CompNode::load("gpu0")) | |||||
| .set_input_dtype(0, dtype::Float32()) | |||||
| .set_input_dtype(1, dtype::Float32()) | |||||
| .set_input_dtype(2, dtype::Int32()) | |||||
| .set_input_dtype(3, dtype::Int32()) | |||||
| .set_input_generator(0, float_gen) | |||||
| .set_input_generator(1, float_gen) | |||||
| .set_input_generator(2, mask_gen) | |||||
| .set_input_generator(3, mask_gen) | |||||
| .set_input_allow_grad(2, false) | |||||
| .set_input_allow_grad(3, false) | |||||
| // {n,ic,ih,iw}, {oc,ic,fh,fw}, {n,ih,iw}, {n,oh,ow} | |||||
| .run({TensorShape{1, 2, 2, 2}, TensorShape{2, 1, 1, 2, 2}, | |||||
| TensorShape{1, 2, 2}, TensorShape{1, 1, 1}}, | |||||
| option) | |||||
| .run({TensorShape{1, 2, 3, 3}, TensorShape{2, 1, 1, 3, 3}, | |||||
| TensorShape{1, 3, 3}, TensorShape{1, 1, 1}}, | |||||
| option) | |||||
| .run({TensorShape{1, 4, 4, 4}, TensorShape{4, 1, 1, 2, 2}, | |||||
| TensorShape{1, 4, 4}, TensorShape{1, 3, 3}}, | |||||
| option) | |||||
| .run({TensorShape{2, 4, 8, 8}, TensorShape{4, 1, 1, 2, 2}, | |||||
| TensorShape{2, 8, 8}, TensorShape{2, 7, 7}}, | |||||
| option); | |||||
| Checker(make_graph, fwd, CompNode::load("gpu0")) | |||||
| .set_input_dtype(0, dtype::Float32()) | |||||
| .set_input_dtype(1, dtype::Float32()) | |||||
| .set_input_dtype(2, dtype::Uint8()) | |||||
| .set_input_dtype(3, dtype::Uint8()) | |||||
| .set_input_generator(0, float_gen) | |||||
| .set_input_generator(1, float_gen) | |||||
| .set_input_generator(2, uint8_mask_gen) | |||||
| .set_input_generator(3, uint8_mask_gen) | |||||
| .set_input_allow_grad(2, false) | |||||
| .set_input_allow_grad(3, false) | |||||
| // {n,ic,ih,iw}, {oc,ic,fh,fw}, {n,ih,iw}, {n,oh,ow} | |||||
| .run({TensorShape{1, 2, 4, 4}, TensorShape{2, 1, 1, 1, 1}, | |||||
| TensorShape{1, 4, 4}, TensorShape{1, 4, 4}}, | |||||
| option) | |||||
| .run({TensorShape{1, 2, 8, 8}, TensorShape{2, 1, 1, 1, 1}, | |||||
| TensorShape{1, 8, 8}, TensorShape{1, 8, 8}}, | |||||
| option) | |||||
| .run({TensorShape{1, 4, 8, 8}, TensorShape{4, 1, 1, 5, 5}, | |||||
| TensorShape{1, 8, 8}, TensorShape{1, 4, 4}}, | |||||
| option) | |||||
| .run({TensorShape{2, 4, 8, 8}, TensorShape{4, 1, 1, 1, 1}, | |||||
| TensorShape{2, 8, 8}, TensorShape{2, 8, 8}}, | |||||
| option); | |||||
| } | |||||
| #endif | |||||