|
- # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
- import torch
- from torch import nn
- from torch.nn import functional as F
- from torch.nn.modules.utils import _pair
-
- from detectron2.layers.wrappers import _NewEmptyTensorOp
-
-
- class TridentConv(nn.Module):
- def __init__(
- self,
- in_channels,
- out_channels,
- kernel_size,
- stride=1,
- paddings=0,
- dilations=1,
- groups=1,
- num_branch=1,
- test_branch_idx=-1,
- bias=False,
- norm=None,
- activation=None,
- ):
- super(TridentConv, self).__init__()
- self.in_channels = in_channels
- self.out_channels = out_channels
- self.kernel_size = _pair(kernel_size)
- self.num_branch = num_branch
- self.stride = _pair(stride)
- self.groups = groups
- self.with_bias = bias
- if isinstance(paddings, int):
- paddings = [paddings] * self.num_branch
- if isinstance(dilations, int):
- dilations = [dilations] * self.num_branch
- self.paddings = [_pair(padding) for padding in paddings]
- self.dilations = [_pair(dilation) for dilation in dilations]
- self.test_branch_idx = test_branch_idx
- self.norm = norm
- self.activation = activation
-
- assert len({self.num_branch, len(self.paddings), len(self.dilations)}) == 1
-
- self.weight = nn.Parameter(
- torch.Tensor(out_channels, in_channels // groups, *self.kernel_size)
- )
- if bias:
- self.bias = nn.Parameter(torch.Tensor(out_channels))
- else:
- self.bias = None
-
- nn.init.kaiming_uniform_(self.weight, nonlinearity="relu")
- if self.bias is not None:
- nn.init.constant_(self.bias, 0)
-
- def forward(self, inputs):
- num_branch = self.num_branch if self.training or self.test_branch_idx == -1 else 1
- assert len(inputs) == num_branch
-
- if inputs[0].numel() == 0:
- output_shape = [
- (i + 2 * p - (di * (k - 1) + 1)) // s + 1
- for i, p, di, k, s in zip(
- inputs[0].shape[-2:], self.padding, self.dilation, self.kernel_size, self.stride
- )
- ]
- output_shape = [input[0].shape[0], self.weight.shape[0]] + output_shape
- return [_NewEmptyTensorOp.apply(input, output_shape) for input in inputs]
-
- if self.training or self.test_branch_idx == -1:
- outputs = [
- F.conv2d(input, self.weight, self.bias, self.stride, padding, dilation, self.groups)
- for input, dilation, padding in zip(inputs, self.dilations, self.paddings)
- ]
- else:
- outputs = [
- F.conv2d(
- inputs[0],
- self.weight,
- self.bias,
- self.stride,
- self.paddings[self.test_branch_idx],
- self.dilations[self.test_branch_idx],
- self.groups,
- )
- ]
-
- if self.norm is not None:
- outputs = [self.norm(x) for x in outputs]
- if self.activation is not None:
- outputs = [self.activation(x) for x in outputs]
- return outputs
-
- def extra_repr(self):
- tmpstr = "in_channels=" + str(self.in_channels)
- tmpstr += ", out_channels=" + str(self.out_channels)
- tmpstr += ", kernel_size=" + str(self.kernel_size)
- tmpstr += ", num_branch=" + str(self.num_branch)
- tmpstr += ", test_branch_idx=" + str(self.test_branch_idx)
- tmpstr += ", stride=" + str(self.stride)
- tmpstr += ", paddings=" + str(self.paddings)
- tmpstr += ", dilations=" + str(self.dilations)
- tmpstr += ", groups=" + str(self.groups)
- tmpstr += ", bias=" + str(self.with_bias)
- return tmpstr
|