You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import torch
  3. import torch.nn as nn
  4. import torch.nn.functional as F
  5. from mmcv.cnn import constant_init, xavier_init
  6. from mmcv.runner import BaseModule, ModuleList
  7. from ..builder import NECKS, build_backbone
  8. from .fpn import FPN
  9. class ASPP(BaseModule):
  10. """ASPP (Atrous Spatial Pyramid Pooling)
  11. This is an implementation of the ASPP module used in DetectoRS
  12. (https://arxiv.org/pdf/2006.02334.pdf)
  13. Args:
  14. in_channels (int): Number of input channels.
  15. out_channels (int): Number of channels produced by this module
  16. dilations (tuple[int]): Dilations of the four branches.
  17. Default: (1, 3, 6, 1)
  18. init_cfg (dict or list[dict], optional): Initialization config dict.
  19. """
  20. def __init__(self,
  21. in_channels,
  22. out_channels,
  23. dilations=(1, 3, 6, 1),
  24. init_cfg=dict(type='Kaiming', layer='Conv2d')):
  25. super().__init__(init_cfg)
  26. assert dilations[-1] == 1
  27. self.aspp = nn.ModuleList()
  28. for dilation in dilations:
  29. kernel_size = 3 if dilation > 1 else 1
  30. padding = dilation if dilation > 1 else 0
  31. conv = nn.Conv2d(
  32. in_channels,
  33. out_channels,
  34. kernel_size=kernel_size,
  35. stride=1,
  36. dilation=dilation,
  37. padding=padding,
  38. bias=True)
  39. self.aspp.append(conv)
  40. self.gap = nn.AdaptiveAvgPool2d(1)
  41. def forward(self, x):
  42. avg_x = self.gap(x)
  43. out = []
  44. for aspp_idx in range(len(self.aspp)):
  45. inp = avg_x if (aspp_idx == len(self.aspp) - 1) else x
  46. out.append(F.relu_(self.aspp[aspp_idx](inp)))
  47. out[-1] = out[-1].expand_as(out[-2])
  48. out = torch.cat(out, dim=1)
  49. return out
  50. @NECKS.register_module()
  51. class RFP(FPN):
  52. """RFP (Recursive Feature Pyramid)
  53. This is an implementation of RFP in `DetectoRS
  54. <https://arxiv.org/pdf/2006.02334.pdf>`_. Different from standard FPN, the
  55. input of RFP should be multi level features along with origin input image
  56. of backbone.
  57. Args:
  58. rfp_steps (int): Number of unrolled steps of RFP.
  59. rfp_backbone (dict): Configuration of the backbone for RFP.
  60. aspp_out_channels (int): Number of output channels of ASPP module.
  61. aspp_dilations (tuple[int]): Dilation rates of four branches.
  62. Default: (1, 3, 6, 1)
  63. init_cfg (dict or list[dict], optional): Initialization config dict.
  64. Default: None
  65. """
  66. def __init__(self,
  67. rfp_steps,
  68. rfp_backbone,
  69. aspp_out_channels,
  70. aspp_dilations=(1, 3, 6, 1),
  71. init_cfg=None,
  72. **kwargs):
  73. assert init_cfg is None, 'To prevent abnormal initialization ' \
  74. 'behavior, init_cfg is not allowed to be set'
  75. super().__init__(init_cfg=init_cfg, **kwargs)
  76. self.rfp_steps = rfp_steps
  77. # Be careful! Pretrained weights cannot be loaded when use
  78. # nn.ModuleList
  79. self.rfp_modules = ModuleList()
  80. for rfp_idx in range(1, rfp_steps):
  81. rfp_module = build_backbone(rfp_backbone)
  82. self.rfp_modules.append(rfp_module)
  83. self.rfp_aspp = ASPP(self.out_channels, aspp_out_channels,
  84. aspp_dilations)
  85. self.rfp_weight = nn.Conv2d(
  86. self.out_channels,
  87. 1,
  88. kernel_size=1,
  89. stride=1,
  90. padding=0,
  91. bias=True)
  92. def init_weights(self):
  93. # Avoid using super().init_weights(), which may alter the default
  94. # initialization of the modules in self.rfp_modules that have missing
  95. # keys in the pretrained checkpoint.
  96. for convs in [self.lateral_convs, self.fpn_convs]:
  97. for m in convs.modules():
  98. if isinstance(m, nn.Conv2d):
  99. xavier_init(m, distribution='uniform')
  100. for rfp_idx in range(self.rfp_steps - 1):
  101. self.rfp_modules[rfp_idx].init_weights()
  102. constant_init(self.rfp_weight, 0)
  103. def forward(self, inputs):
  104. inputs = list(inputs)
  105. assert len(inputs) == len(self.in_channels) + 1 # +1 for input image
  106. img = inputs.pop(0)
  107. # FPN forward
  108. x = super().forward(tuple(inputs))
  109. for rfp_idx in range(self.rfp_steps - 1):
  110. rfp_feats = [x[0]] + list(
  111. self.rfp_aspp(x[i]) for i in range(1, len(x)))
  112. x_idx = self.rfp_modules[rfp_idx].rfp_forward(img, rfp_feats)
  113. # FPN forward
  114. x_idx = super().forward(x_idx)
  115. x_new = []
  116. for ft_idx in range(len(x_idx)):
  117. add_weight = torch.sigmoid(self.rfp_weight(x_idx[ft_idx]))
  118. x_new.append(add_weight * x_idx[ft_idx] +
  119. (1 - add_weight) * x[ft_idx])
  120. x = x_new
  121. return x

No Description

Contributors (3)