You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

dilated_encoder.py 3.9 kB

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import torch.nn as nn
  3. from mmcv.cnn import (ConvModule, caffe2_xavier_init, constant_init, is_norm,
  4. normal_init)
  5. from torch.nn import BatchNorm2d
  6. from ..builder import NECKS
  7. class Bottleneck(nn.Module):
  8. """Bottleneck block for DilatedEncoder used in `YOLOF.
  9. <https://arxiv.org/abs/2103.09460>`.
  10. The Bottleneck contains three ConvLayers and one residual connection.
  11. Args:
  12. in_channels (int): The number of input channels.
  13. mid_channels (int): The number of middle output channels.
  14. dilation (int): Dilation rate.
  15. norm_cfg (dict): Dictionary to construct and config norm layer.
  16. """
  17. def __init__(self,
  18. in_channels,
  19. mid_channels,
  20. dilation,
  21. norm_cfg=dict(type='BN', requires_grad=True)):
  22. super(Bottleneck, self).__init__()
  23. self.conv1 = ConvModule(
  24. in_channels, mid_channels, 1, norm_cfg=norm_cfg)
  25. self.conv2 = ConvModule(
  26. mid_channels,
  27. mid_channels,
  28. 3,
  29. padding=dilation,
  30. dilation=dilation,
  31. norm_cfg=norm_cfg)
  32. self.conv3 = ConvModule(
  33. mid_channels, in_channels, 1, norm_cfg=norm_cfg)
  34. def forward(self, x):
  35. identity = x
  36. out = self.conv1(x)
  37. out = self.conv2(out)
  38. out = self.conv3(out)
  39. out = out + identity
  40. return out
  41. @NECKS.register_module()
  42. class DilatedEncoder(nn.Module):
  43. """Dilated Encoder for YOLOF <https://arxiv.org/abs/2103.09460>`.
  44. This module contains two types of components:
  45. - the original FPN lateral convolution layer and fpn convolution layer,
  46. which are 1x1 conv + 3x3 conv
  47. - the dilated residual block
  48. Args:
  49. in_channels (int): The number of input channels.
  50. out_channels (int): The number of output channels.
  51. block_mid_channels (int): The number of middle block output channels
  52. num_residual_blocks (int): The number of residual blocks.
  53. """
  54. def __init__(self, in_channels, out_channels, block_mid_channels,
  55. num_residual_blocks):
  56. super(DilatedEncoder, self).__init__()
  57. self.in_channels = in_channels
  58. self.out_channels = out_channels
  59. self.block_mid_channels = block_mid_channels
  60. self.num_residual_blocks = num_residual_blocks
  61. self.block_dilations = [2, 4, 6, 8]
  62. self._init_layers()
  63. def _init_layers(self):
  64. self.lateral_conv = nn.Conv2d(
  65. self.in_channels, self.out_channels, kernel_size=1)
  66. self.lateral_norm = BatchNorm2d(self.out_channels)
  67. self.fpn_conv = nn.Conv2d(
  68. self.out_channels, self.out_channels, kernel_size=3, padding=1)
  69. self.fpn_norm = BatchNorm2d(self.out_channels)
  70. encoder_blocks = []
  71. for i in range(self.num_residual_blocks):
  72. dilation = self.block_dilations[i]
  73. encoder_blocks.append(
  74. Bottleneck(
  75. self.out_channels,
  76. self.block_mid_channels,
  77. dilation=dilation))
  78. self.dilated_encoder_blocks = nn.Sequential(*encoder_blocks)
  79. def init_weights(self):
  80. caffe2_xavier_init(self.lateral_conv)
  81. caffe2_xavier_init(self.fpn_conv)
  82. for m in [self.lateral_norm, self.fpn_norm]:
  83. constant_init(m, 1)
  84. for m in self.dilated_encoder_blocks.modules():
  85. if isinstance(m, nn.Conv2d):
  86. normal_init(m, mean=0, std=0.01)
  87. if is_norm(m):
  88. constant_init(m, 1)
  89. def forward(self, feature):
  90. out = self.lateral_norm(self.lateral_conv(feature[-1]))
  91. out = self.fpn_norm(self.fpn_conv(out))
  92. return self.dilated_encoder_blocks(out),

No Description

Contributors (1)