GitOrigin-RevId: 8a69608953
tags/v1.11.1
| @@ -335,6 +335,7 @@ def conv_transpose2d( | |||
| bias: Optional[Tensor] = None, | |||
| stride: Union[int, Tuple[int, int]] = 1, | |||
| padding: Union[int, Tuple[int, int]] = 0, | |||
| output_padding: Union[int, Tuple[int, int]] = 0, | |||
| dilation: Union[int, Tuple[int, int]] = 1, | |||
| groups: int = 1, | |||
| conv_mode="cross_correlation", | |||
| @@ -352,6 +353,7 @@ def conv_transpose2d( | |||
| stride: stride of the 2D convolution operation. Default: 1 | |||
| padding: size of the paddings added to the input on both sides of its | |||
| spatial dimensions. Only zero-padding is supported. Default: 0 | |||
| output_padding: size of paddings appended to output. Default: 0 | |||
| dilation: dilation of the 2D convolution operation. Default: 1 | |||
| groups: number of groups into which the input and output channels are divided, | |||
| so as to perform a ``grouped convolution``. When ``groups`` is not 1, | |||
| @@ -374,6 +376,7 @@ def conv_transpose2d( | |||
| stride_h, stride_w = expand_hw(stride) | |||
| pad_h, pad_w = expand_hw(padding) | |||
| output_pad_h, output_pad_w = expand_hw(output_padding) | |||
| dilate_h, dilate_w = expand_hw(dilation) | |||
| compute_mode = _config._get_actual_op_param(compute_mode, _config.__compute_mode) | |||
| @@ -389,7 +392,32 @@ def conv_transpose2d( | |||
| compute_mode=compute_mode, | |||
| sparse=sparse_type, | |||
| ) | |||
| (output,) = apply(op, weight, inp) | |||
| if output_pad_h != 0 or output_pad_h != 0: | |||
| assert ( | |||
| output_pad_h < stride[0] | |||
| ), "output_padding[0] shoule be less than stride[0]" | |||
| assert ( | |||
| output_pad_w < stride[1] | |||
| ), "output_padding[1] shoule be less than stride[1]" | |||
| Hout = ( | |||
| (inp.shape[2] - 1) * stride[0] | |||
| - 2 * padding[0] | |||
| + dilation[0] * (weight.shape[2] - 1) | |||
| + output_pad_h | |||
| + 1 | |||
| ) | |||
| Wout = ( | |||
| (inp.shape[3] - 1) * stride[1] | |||
| - 2 * padding[1] | |||
| + dilation[1] * (weight.shape[3] - 1) | |||
| + output_pad_w | |||
| + 1 | |||
| ) | |||
| output_shape = [inp.shape[0], weight.shape[1], Hout, Wout] | |||
| output_shape = astensor1d(output_shape) | |||
| (output,) = apply(op, weight, inp, output_shape) | |||
| else: | |||
| (output,) = apply(op, weight, inp) | |||
| if bias is not None: | |||
| if amp._enabled: | |||
| bias = cast_tensors(bias) | |||
| @@ -528,6 +556,7 @@ def conv_transpose3d( | |||
| bias: Optional[Tensor] = None, | |||
| stride: Union[int, Tuple[int, int, int]] = 1, | |||
| padding: Union[int, Tuple[int, int, int]] = 0, | |||
| output_padding: Union[int, Tuple[int, int, int]] = 0, | |||
| dilation: Union[int, Tuple[int, int, int]] = 1, | |||
| groups: int = 1, | |||
| ) -> Tensor: | |||
| @@ -544,6 +573,7 @@ def conv_transpose3d( | |||
| stride: stride of the 3D convolution operation. Default: 1 | |||
| padding: size of the paddings added to the input on all sides of its | |||
| spatial dimensions. Only zero-padding is supported. Default: 0 | |||
| output_padding: size of paddings appended to output. Default: 0 | |||
| dilation: dilation of the 3D convolution operation. Default: 1 | |||
| groups: number of groups into which the input and output channels are divided, | |||
| so as to perform a ``grouped convolution``. When ``groups`` is not 1, | |||
| @@ -558,6 +588,7 @@ def conv_transpose3d( | |||
| pad = expand_dhw(padding) | |||
| stride = expand_dhw(stride) | |||
| dilate = expand_dhw(dilation) | |||
| output_padding = expand_dhw(output_padding) | |||
| sparse_type = "dense" if groups == 1 else "group" | |||
| op = builtin.Convolution3DBackwardData( | |||
| @@ -573,7 +604,42 @@ def conv_transpose3d( | |||
| strategy=get_execution_strategy(), | |||
| sparse=sparse_type, | |||
| ) | |||
| (output,) = apply(op, weight, inp) | |||
| if output_padding[0] != 0 or output_padding[1] != 0 or output_padding[2] != 0: | |||
| assert ( | |||
| output_padding[0] < stride[0] | |||
| ), "output_padding[0] shoule be less than stride[0]" | |||
| assert ( | |||
| output_padding[1] < stride[1] | |||
| ), "output_padding[1] shoule be less than stride[1]" | |||
| assert ( | |||
| output_padding[2] < stride[2] | |||
| ), "output_padding[2] shoule be less than stride[2]" | |||
| Dout = ( | |||
| (inp.shape[2] - 1) * stride[0] | |||
| - 2 * padding[0] | |||
| + dilation[0] * (weight.shape[2] - 1) | |||
| + output_padding[0] | |||
| + 1 | |||
| ) | |||
| Hout = ( | |||
| (inp.shape[3] - 1) * stride[1] | |||
| - 2 * padding[1] | |||
| + dilation[1] * (weight.shape[3] - 1) | |||
| + output_padding[1] | |||
| + 1 | |||
| ) | |||
| Wout = ( | |||
| (inp.shape[4] - 1) * stride[2] | |||
| - 2 * padding[2] | |||
| + dilation[2] * (weight.shape[4] - 1) | |||
| + output_padding[2] | |||
| + 1 | |||
| ) | |||
| output_shape = [inp.shape[0], weight.shape[1], Dout, Hout, Wout] | |||
| output_shape = astensor1d(output_shape) | |||
| (output,) = apply(op, weight, inp, output_shape) | |||
| else: | |||
| (output,) = apply(op, weight, inp) | |||
| if bias is not None: | |||
| output += bias | |||
| return output | |||
| @@ -134,6 +134,7 @@ def conv_transpose2d( | |||
| dtype=None, | |||
| stride: Union[int, Tuple[int, int]] = 1, | |||
| padding: Union[int, Tuple[int, int]] = 0, | |||
| output_padding: Union[int, Tuple[int, int]] = 0, | |||
| dilation: Union[int, Tuple[int, int]] = 1, | |||
| groups: int = 1, | |||
| conv_mode="cross_correlation", | |||
| @@ -156,6 +157,7 @@ def conv_transpose2d( | |||
| ) | |||
| pad_h, pad_w = _pair(padding) | |||
| output_pad_h, output_pad_w = _pair(output_padding) | |||
| stride_h, stride_w = _pair_nonzero(stride) | |||
| dilate_h, dilate_w = _pair_nonzero(dilation) | |||
| compute_mode = _config._get_actual_op_param(compute_mode, _config.__compute_mode) | |||
| @@ -173,5 +175,30 @@ def conv_transpose2d( | |||
| compute_mode=compute_mode, | |||
| mode=conv_mode, | |||
| ) | |||
| (output,) = apply(op, weight, inp) | |||
| if output_pad_h != 0 or output_pad_h != 0: | |||
| assert ( | |||
| output_pad_h < stride[0] | |||
| ), "output_padding[0] shoule be less than stride[0]" | |||
| assert ( | |||
| output_pad_w < stride[1] | |||
| ), "output_padding[1] shoule be less than stride[1]" | |||
| Hout = ( | |||
| (inp.shape[2] - 1) * stride[0] | |||
| - 2 * padding[0] | |||
| + dilation[0] * (weight.shape[2] - 1) | |||
| + output_pad_h | |||
| + 1 | |||
| ) | |||
| Wout = ( | |||
| (inp.shape[3] - 1) * stride[1] | |||
| - 2 * padding[1] | |||
| + dilation[1] * (weight.shape[3] - 1) | |||
| + output_pad_w | |||
| + 1 | |||
| ) | |||
| output_shape = [inp.shape[0], weight.shape[1], Hout, Wout] | |||
| output_shape = Tensor(output_shape) | |||
| (output,) = apply(op, weight, inp, output_shape) | |||
| else: | |||
| (output,) = apply(op, weight, inp) | |||
| return output | |||
| @@ -30,6 +30,7 @@ class _ConvNd(Module): | |||
| kernel_size: Union[int, Tuple[int, int]], | |||
| stride: Union[int, Tuple[int, int]], | |||
| padding: Union[int, Tuple[int, int]], | |||
| output_padding: Union[int, Tuple[int, int]], | |||
| dilation: Union[int, Tuple[int, int]], | |||
| groups: int, | |||
| bias: bool = True, | |||
| @@ -45,6 +46,7 @@ class _ConvNd(Module): | |||
| self.kernel_size = kernel_size | |||
| self.stride = stride | |||
| self.padding = padding | |||
| self.output_padding = output_padding | |||
| self.dilation = dilation | |||
| self.groups = groups | |||
| @@ -178,6 +180,7 @@ class Conv1d(_ConvNd): | |||
| kernel_size, | |||
| stride, | |||
| padding, | |||
| 0, | |||
| dilation, | |||
| groups, | |||
| bias, | |||
| @@ -352,6 +355,7 @@ class Conv2d(_ConvNd): | |||
| kernel_size, | |||
| stride, | |||
| padding, | |||
| 0, | |||
| dilation, | |||
| groups, | |||
| bias, | |||
| @@ -505,6 +509,7 @@ class Conv3d(_ConvNd): | |||
| kernel_size, | |||
| stride, | |||
| padding, | |||
| 0, | |||
| dilation, | |||
| groups, | |||
| bias, | |||
| @@ -572,6 +577,7 @@ class ConvTranspose2d(_ConvNd): | |||
| stride: stride of the 2D convolution operation. Default: 1 | |||
| padding: size of the paddings added to the input on both sides of its | |||
| spatial dimensions. Only zero-padding is supported. Default: 0 | |||
| output_padding: size of paddings appended to output. Default: 0 | |||
| dilation: dilation of the 2D convolution operation. Default: 1 | |||
| groups: number of groups into which the input and output channels are divided, | |||
| so as to perform a ``grouped convolution``. When ``groups`` is not 1, | |||
| @@ -591,6 +597,8 @@ class ConvTranspose2d(_ConvNd): | |||
| * ``bias`` usually has shape ``(1, out_channels, *1)`` | |||
| """ | |||
| output_padding = 0 | |||
| def __init__( | |||
| self, | |||
| in_channels: int, | |||
| @@ -598,6 +606,7 @@ class ConvTranspose2d(_ConvNd): | |||
| kernel_size: Union[int, Tuple[int, int]], | |||
| stride: Union[int, Tuple[int, int]] = 1, | |||
| padding: Union[int, Tuple[int, int]] = 0, | |||
| output_padding: Union[int, Tuple[int, int]] = 0, | |||
| dilation: Union[int, Tuple[int, int]] = 1, | |||
| groups: int = 1, | |||
| bias: bool = True, | |||
| @@ -608,6 +617,7 @@ class ConvTranspose2d(_ConvNd): | |||
| kernel_size = _pair_nonzero(kernel_size) | |||
| stride = _pair_nonzero(stride) | |||
| padding = _pair(padding) | |||
| output_padding = _pair(output_padding) | |||
| dilation = _pair_nonzero(dilation) | |||
| self.conv_mode = conv_mode | |||
| self.compute_mode = compute_mode | |||
| @@ -617,6 +627,7 @@ class ConvTranspose2d(_ConvNd): | |||
| kernel_size, | |||
| stride, | |||
| padding, | |||
| output_padding, | |||
| dilation, | |||
| groups, | |||
| bias, | |||
| @@ -656,6 +667,7 @@ class ConvTranspose2d(_ConvNd): | |||
| bias, | |||
| self.stride, | |||
| self.padding, | |||
| self.output_padding, | |||
| self.dilation, | |||
| self.groups, | |||
| self.conv_mode, | |||
| @@ -817,6 +829,7 @@ class DeformableConv2d(_ConvNd): | |||
| kernel_size, | |||
| stride, | |||
| padding, | |||
| 0, | |||
| dilation, | |||
| groups, | |||
| bias, | |||
| @@ -889,6 +902,7 @@ class ConvTranspose3d(_ConvNd): | |||
| stride: stride of the 3D convolution operation. Default: 1 | |||
| padding: size of the paddings added to the input on all sides of its | |||
| spatial dimensions. Only zero-padding is supported. Default: 0 | |||
| output_padding: size of paddings appended to output. Default: 0 | |||
| dilation: dilation of the 3D convolution operation. Default: 1 | |||
| groups: number of groups into which the input and output channels are divided, | |||
| so as to perform a ``grouped convolution``. When ``groups`` is not 1, | |||
| @@ -902,6 +916,8 @@ class ConvTranspose3d(_ConvNd): | |||
| * ``bias`` usually has shape ``(1, out_channels, *1)`` | |||
| """ | |||
| output_padding = 0 | |||
| def __init__( | |||
| self, | |||
| in_channels: int, | |||
| @@ -909,6 +925,7 @@ class ConvTranspose3d(_ConvNd): | |||
| kernel_size: Union[int, Tuple[int, int, int]], | |||
| stride: Union[int, Tuple[int, int, int]] = 1, | |||
| padding: Union[int, Tuple[int, int, int]] = 0, | |||
| output_padding: Union[int, Tuple[int, int, int]] = 0, | |||
| dilation: Union[int, Tuple[int, int, int]] = 1, | |||
| groups: int = 1, | |||
| bias: bool = True, | |||
| @@ -923,6 +940,7 @@ class ConvTranspose3d(_ConvNd): | |||
| kernel_size=kernel_size, | |||
| stride=stride, | |||
| padding=padding, | |||
| output_padding=output_padding, | |||
| dilation=dilation, | |||
| groups=groups, | |||
| bias=bias, | |||
| @@ -956,5 +974,11 @@ class ConvTranspose3d(_ConvNd): | |||
| def forward(self, inp): | |||
| return conv_transpose3d( | |||
| inp, self.weight, self.bias, self.stride, self.padding, self.dilation, | |||
| inp, | |||
| self.weight, | |||
| self.bias, | |||
| self.stride, | |||
| self.padding, | |||
| self.output_padding, | |||
| self.dilation, | |||
| ) | |||
| @@ -74,6 +74,7 @@ class ConvTranspose2d(Float.ConvTranspose2d, QATModule): | |||
| float_module.kernel_size, | |||
| float_module.stride, | |||
| float_module.padding, | |||
| float_module.output_padding, | |||
| float_module.dilation, | |||
| float_module.groups, | |||
| float_module.bias is not None, | |||
| @@ -138,6 +138,8 @@ class ConvTranspose2d(Float.ConvTranspose2d, QuantizedModule): | |||
| dtype: data type of the output, should be qint8. | |||
| """ | |||
| output_padding = 0 | |||
| def __init__( | |||
| self, | |||
| in_channels: int, | |||
| @@ -145,6 +147,7 @@ class ConvTranspose2d(Float.ConvTranspose2d, QuantizedModule): | |||
| kernel_size: Union[int, Tuple[int, int]], | |||
| stride: Union[int, Tuple[int, int]] = 1, | |||
| padding: Union[int, Tuple[int, int]] = 0, | |||
| output_padding: Union[int, Tuple[int, int]] = 0, | |||
| dilation: Union[int, Tuple[int, int]] = 1, | |||
| groups: int = 1, | |||
| bias: bool = True, | |||
| @@ -159,6 +162,7 @@ class ConvTranspose2d(Float.ConvTranspose2d, QuantizedModule): | |||
| kernel_size=kernel_size, | |||
| stride=stride, | |||
| padding=padding, | |||
| output_padding=output_padding, | |||
| dilation=dilation, | |||
| groups=groups, | |||
| bias=bias, | |||
| @@ -180,6 +184,7 @@ class ConvTranspose2d(Float.ConvTranspose2d, QuantizedModule): | |||
| qat_module.kernel_size, | |||
| qat_module.stride, | |||
| qat_module.padding, | |||
| qat_module.output_padding, | |||
| qat_module.dilation, | |||
| qat_module.groups, | |||
| qat_module.bias is not None, | |||
| @@ -212,6 +217,7 @@ class ConvTranspose2d(Float.ConvTranspose2d, QuantizedModule): | |||
| dtype=self.output_dtype, | |||
| stride=self.stride, | |||
| padding=self.padding, | |||
| output_padding=self.output_padding, | |||
| dilation=self.dilation, | |||
| groups=self.groups, | |||
| conv_mode=self.conv_mode, | |||
| @@ -18,7 +18,8 @@ from megengine.core._trace_option import use_symbolic_shape | |||
| from megengine.core.autodiff.grad import Grad | |||
| from megengine.core.tensor.utils import make_shape_tuple | |||
| from megengine.device import get_device_count | |||
| from megengine.module import LayerNorm | |||
| from megengine.jit.tracing import trace | |||
| from megengine.module import ConvTranspose2d, ConvTranspose3d, LayerNorm | |||
| _assert_allclose = partial(np.testing.assert_allclose, atol=5e-6, rtol=5e-6) | |||
| @@ -1374,3 +1375,37 @@ def test_local_conv2d(stride, padding, dilation, ksize, groups): | |||
| ) | |||
| ref = local_conv2d_np(data, weight, stride, padding, dilation) | |||
| np.testing.assert_almost_equal(output.numpy(), ref, 5) | |||
| def test_conv_transpose2d(): | |||
| m = ConvTranspose2d( | |||
| 16, 33, (3, 5), output_padding=(1, 2), stride=(2, 3), padding=(4, 2) | |||
| ) | |||
| @trace(symbolic=True) | |||
| def fwd(inp: Tensor): | |||
| return m(inp) | |||
| input = Tensor(np.random.rand(20, 16, 50, 100)) | |||
| output = fwd(input) | |||
| output_shape = Tensor(output.shape) | |||
| np.testing.assert_equal( | |||
| output_shape.numpy(), np.array([20, 33, 94, 300], dtype=np.int32) | |||
| ) | |||
| def test_conv_transpose3d(): | |||
| m = ConvTranspose3d( | |||
| 16, 33, (3, 5, 2), output_padding=(2, 1, 1), stride=(3, 2, 2), padding=(0, 4, 2) | |||
| ) | |||
| @trace(symbolic=True) | |||
| def fwd(inp: Tensor): | |||
| return m(inp) | |||
| input = Tensor(np.random.rand(20, 16, 10, 50, 100)) | |||
| output = fwd(input) | |||
| output_shape = Tensor(output.shape) | |||
| np.testing.assert_equal( | |||
| output_shape.numpy(), np.array([20, 33, 32, 96, 197], dtype=np.int32) | |||
| ) | |||
| @@ -5,6 +5,7 @@ | |||
| #include "../op_trait.h" | |||
| #include "megbrain/imperative/ops/autogen.h" | |||
| #include "megbrain/opr/internal/megdnn_opr_wrapper.h" | |||
| #include "megbrain/opr/tensor_gen.h" | |||
| namespace mgb { | |||
| namespace imperative { | |||
| @@ -152,8 +153,11 @@ auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | |||
| inputs[0], inputs[1], conv.param(), conv.policy(), config); | |||
| } else { | |||
| mgb_assert(inputs.size() == 3); | |||
| auto* src_for_shape = | |||
| opr::Alloc::make(inputs[2], inputs[0]->dtype(), {}).node(); | |||
| return opr::ConvolutionBackwardData::make( | |||
| inputs[0], inputs[1], inputs[2], conv.param(), conv.policy(), config); | |||
| inputs[0], inputs[1], src_for_shape, conv.param(), conv.policy(), | |||
| config); | |||
| } | |||
| } | |||
| @@ -168,6 +172,14 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( | |||
| if (filter.ndim && diff.ndim) { | |||
| // deduce_layout won't override existing dtype | |||
| dnn_opr.opr().deduce_layout(filter, diff, output_layout); | |||
| if (inputs.size() == 3) { | |||
| if (!inputs[2].value.empty()) { | |||
| cg::copy_tensor_value_to_shape(output_layout, inputs[2].value); | |||
| output_layout.init_contiguous_stride(); | |||
| } else { | |||
| output_layout.ndim = 0; | |||
| } | |||
| } | |||
| } | |||
| return {{{output_layout, inputs[0].comp_node}}, output_layout.ndim != 0}; | |||
| } | |||
| @@ -185,8 +197,11 @@ SmallVector<TensorPtr> apply_on_physical_tensor( | |||
| return output_descs[0].layout; | |||
| } else { | |||
| TensorLayout out_layout{inputs[0]->dtype()}; | |||
| dnn_opr.op()->deduce_layout( | |||
| inputs[0]->layout(), inputs[1]->layout(), out_layout); | |||
| if (inputs.size() == 3) { | |||
| cg::copy_tensor_value_to_shape( | |||
| out_layout, inputs[2]->get_value().proxy_to_default_cpu()); | |||
| out_layout.init_contiguous_stride(); | |||
| } | |||
| return out_layout; | |||
| } | |||
| }(); | |||
| @@ -263,50 +278,74 @@ namespace convolution3d_backward_data { | |||
| std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( | |||
| const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs) { | |||
| mgb_assert( | |||
| inputs.size() == 2, | |||
| "inputs num of conv_transpose3d should be 2 but you give %zu", | |||
| inputs.size() == 2 || inputs.size() == 3, | |||
| "inputs num of conv_transpose3d should be 2 or 3 but you give %zu", | |||
| inputs.size()); | |||
| auto&& op_def = def.cast_final_safe<Convolution3DBackwardData>(); | |||
| auto&& weight = inputs[0]; | |||
| auto&& conv3dbwd = def.cast_final_safe<Convolution3DBackwardData>(); | |||
| DnnOprHelper<megdnn::Convolution3DBackwardData> dnn_opr(conv3dbwd.param()); | |||
| auto&& filter = inputs[0]; | |||
| auto&& diff = inputs[1]; | |||
| if (!(weight.layout.ndim && diff.layout.ndim)) { | |||
| return {{{TensorLayout{weight.layout.dtype}, weight.comp_node}}, false}; | |||
| if (!(filter.layout.ndim && diff.layout.ndim)) { | |||
| return {{{TensorLayout{filter.layout.dtype}, filter.comp_node}}, false}; | |||
| } | |||
| DnnOprHelper<megdnn::Convolution3DBackwardData> dnn_opr(op_def.param()); | |||
| auto oup_layout = dnn_opr.deduce_layout(weight.layout, diff.layout); | |||
| return {{{oup_layout, weight.comp_node}}, true}; | |||
| TensorLayout output_layout = dnn_opr.deduce_layout(filter.layout, diff.layout); | |||
| if (filter.layout.ndim && diff.layout.ndim) { | |||
| if (inputs.size() == 3) { | |||
| if (!inputs[2].value.empty()) { | |||
| cg::copy_tensor_value_to_shape(output_layout, inputs[2].value); | |||
| output_layout.init_contiguous_stride(); | |||
| } else { | |||
| output_layout.ndim = 0; | |||
| } | |||
| } | |||
| } | |||
| return {{{output_layout, inputs[0].comp_node}}, output_layout.ndim != 0}; | |||
| } | |||
| SmallVector<TensorPtr> apply_on_physical_tensor( | |||
| const OpDef& def, const SmallVector<TensorPtr>& inputs, | |||
| SmallVector<LogicalTensorDesc>& output_descs, const bool& validated) { | |||
| auto&& conv = def.cast_final_safe<Convolution3DBackwardData>(); | |||
| auto cn = inputs[0]->comp_node(); | |||
| auto&& wlayout = inputs[0]->layout(); | |||
| auto&& dlayout = inputs[1]->layout(); | |||
| DnnOprCaller<megdnn::Convolution3DBackwardData> dnn_op( | |||
| cn, conv.param(), conv.policy()); | |||
| auto oup_layout = [&] { | |||
| auto&& conv3dbwd = def.cast_final_safe<Convolution3DBackwardData>(); | |||
| CompNode cn = inputs[0]->comp_node(); | |||
| DnnOprCaller<megdnn::Convolution3DBackwardData> dnn_opr( | |||
| cn, conv3dbwd.param(), conv3dbwd.policy()); | |||
| auto out_layout = [&] { | |||
| if (validated) { | |||
| return output_descs[0].layout; | |||
| } else { | |||
| return dnn_op.deduce_layout(wlayout, dlayout); | |||
| TensorLayout out_layout{inputs[0]->dtype()}; | |||
| dnn_opr.op()->deduce_layout( | |||
| inputs[0]->layout(), inputs[1]->layout(), out_layout); | |||
| if (inputs.size() == 3) { | |||
| cg::copy_tensor_value_to_shape( | |||
| out_layout, inputs[2]->get_value().proxy_to_default_cpu()); | |||
| out_layout.init_contiguous_stride(); | |||
| } | |||
| return out_layout; | |||
| } | |||
| }(); | |||
| auto oup = Tensor::make(oup_layout, cn); | |||
| dnn_op.exec_fastrun(inputs[0], inputs[1], oup); | |||
| return {oup}; | |||
| auto out = Tensor::make(out_layout, cn); | |||
| dnn_opr.exec_fastrun(inputs[0], inputs[1], out); | |||
| return {out}; | |||
| } | |||
| auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | |||
| auto&& conv = static_cast<const Convolution3DBackwardData&>(def); | |||
| OperatorNodeConfig config{conv.make_name()}; | |||
| mgb_assert(inputs.size() == 2); | |||
| return opr::Convolution3DBackwardData::make( | |||
| inputs[0], inputs[1], conv.param(), conv.policy(), config); | |||
| if (inputs.size() == 2) { | |||
| return opr::Convolution3DBackwardData::make( | |||
| inputs[0], inputs[1], conv.param(), conv.policy(), config); | |||
| } else { | |||
| mgb_assert(inputs.size() == 3); | |||
| // The output shape is calculated in advance and given as input | |||
| auto* src_for_shape = | |||
| opr::Alloc::make(inputs[2], inputs[0]->dtype(), {}).node(); | |||
| return opr::Convolution3DBackwardData::make( | |||
| inputs[0], inputs[1], src_for_shape, conv.param(), conv.policy(), | |||
| config); | |||
| } | |||
| } | |||
| OP_TRAIT_REG(Convolution3DBackwardData, Convolution3DBackwardData) | |||