You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

bfp.py 3.8 kB

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import torch.nn.functional as F
  3. from mmcv.cnn import ConvModule
  4. from mmcv.cnn.bricks import NonLocal2d
  5. from mmcv.runner import BaseModule
  6. from ..builder import NECKS
  7. @NECKS.register_module()
  8. class BFP(BaseModule):
  9. """BFP (Balanced Feature Pyramids)
  10. BFP takes multi-level features as inputs and gather them into a single one,
  11. then refine the gathered feature and scatter the refined results to
  12. multi-level features. This module is used in Libra R-CNN (CVPR 2019), see
  13. the paper `Libra R-CNN: Towards Balanced Learning for Object Detection
  14. <https://arxiv.org/abs/1904.02701>`_ for details.
  15. Args:
  16. in_channels (int): Number of input channels (feature maps of all levels
  17. should have the same channels).
  18. num_levels (int): Number of input feature levels.
  19. conv_cfg (dict): The config dict for convolution layers.
  20. norm_cfg (dict): The config dict for normalization layers.
  21. refine_level (int): Index of integration and refine level of BSF in
  22. multi-level features from bottom to top.
  23. refine_type (str): Type of the refine op, currently support
  24. [None, 'conv', 'non_local'].
  25. init_cfg (dict or list[dict], optional): Initialization config dict.
  26. """
  27. def __init__(self,
  28. in_channels,
  29. num_levels,
  30. refine_level=2,
  31. refine_type=None,
  32. conv_cfg=None,
  33. norm_cfg=None,
  34. init_cfg=dict(
  35. type='Xavier', layer='Conv2d', distribution='uniform')):
  36. super(BFP, self).__init__(init_cfg)
  37. assert refine_type in [None, 'conv', 'non_local']
  38. self.in_channels = in_channels
  39. self.num_levels = num_levels
  40. self.conv_cfg = conv_cfg
  41. self.norm_cfg = norm_cfg
  42. self.refine_level = refine_level
  43. self.refine_type = refine_type
  44. assert 0 <= self.refine_level < self.num_levels
  45. if self.refine_type == 'conv':
  46. self.refine = ConvModule(
  47. self.in_channels,
  48. self.in_channels,
  49. 3,
  50. padding=1,
  51. conv_cfg=self.conv_cfg,
  52. norm_cfg=self.norm_cfg)
  53. elif self.refine_type == 'non_local':
  54. self.refine = NonLocal2d(
  55. self.in_channels,
  56. reduction=1,
  57. use_scale=False,
  58. conv_cfg=self.conv_cfg,
  59. norm_cfg=self.norm_cfg)
  60. def forward(self, inputs):
  61. """Forward function."""
  62. assert len(inputs) == self.num_levels
  63. # step 1: gather multi-level features by resize and average
  64. feats = []
  65. gather_size = inputs[self.refine_level].size()[2:]
  66. for i in range(self.num_levels):
  67. if i < self.refine_level:
  68. gathered = F.adaptive_max_pool2d(
  69. inputs[i], output_size=gather_size)
  70. else:
  71. gathered = F.interpolate(
  72. inputs[i], size=gather_size, mode='nearest')
  73. feats.append(gathered)
  74. bsf = sum(feats) / len(feats)
  75. # step 2: refine gathered features
  76. if self.refine_type is not None:
  77. bsf = self.refine(bsf)
  78. # step 3: scatter refined features to multi-levels by a residual path
  79. outs = []
  80. for i in range(self.num_levels):
  81. out_size = inputs[i].size()[2:]
  82. if i < self.refine_level:
  83. residual = F.interpolate(bsf, size=out_size, mode='nearest')
  84. else:
  85. residual = F.adaptive_max_pool2d(bsf, output_size=out_size)
  86. outs.append(residual + inputs[i])
  87. return tuple(outs)

No Description

Contributors (1)