You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

pafpn.py 6.3 kB

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import torch.nn as nn
  3. import torch.nn.functional as F
  4. from mmcv.cnn import ConvModule
  5. from mmcv.runner import auto_fp16
  6. from ..builder import NECKS
  7. from .fpn import FPN
  8. @NECKS.register_module()
  9. class PAFPN(FPN):
  10. """Path Aggregation Network for Instance Segmentation.
  11. This is an implementation of the `PAFPN in Path Aggregation Network
  12. <https://arxiv.org/abs/1803.01534>`_.
  13. Args:
  14. in_channels (List[int]): Number of input channels per scale.
  15. out_channels (int): Number of output channels (used at each scale)
  16. num_outs (int): Number of output scales.
  17. start_level (int): Index of the start input backbone level used to
  18. build the feature pyramid. Default: 0.
  19. end_level (int): Index of the end input backbone level (exclusive) to
  20. build the feature pyramid. Default: -1, which means the last level.
  21. add_extra_convs (bool | str): If bool, it decides whether to add conv
  22. layers on top of the original feature maps. Default to False.
  23. If True, it is equivalent to `add_extra_convs='on_input'`.
  24. If str, it specifies the source feature map of the extra convs.
  25. Only the following options are allowed
  26. - 'on_input': Last feat map of neck inputs (i.e. backbone feature).
  27. - 'on_lateral': Last feature map after lateral convs.
  28. - 'on_output': The last output feature map after fpn convs.
  29. relu_before_extra_convs (bool): Whether to apply relu before the extra
  30. conv. Default: False.
  31. no_norm_on_lateral (bool): Whether to apply norm on lateral.
  32. Default: False.
  33. conv_cfg (dict): Config dict for convolution layer. Default: None.
  34. norm_cfg (dict): Config dict for normalization layer. Default: None.
  35. act_cfg (str): Config dict for activation layer in ConvModule.
  36. Default: None.
  37. init_cfg (dict or list[dict], optional): Initialization config dict.
  38. """
  39. def __init__(self,
  40. in_channels,
  41. out_channels,
  42. num_outs,
  43. start_level=0,
  44. end_level=-1,
  45. add_extra_convs=False,
  46. relu_before_extra_convs=False,
  47. no_norm_on_lateral=False,
  48. conv_cfg=None,
  49. norm_cfg=None,
  50. act_cfg=None,
  51. init_cfg=dict(
  52. type='Xavier', layer='Conv2d', distribution='uniform')):
  53. super(PAFPN, self).__init__(
  54. in_channels,
  55. out_channels,
  56. num_outs,
  57. start_level,
  58. end_level,
  59. add_extra_convs,
  60. relu_before_extra_convs,
  61. no_norm_on_lateral,
  62. conv_cfg,
  63. norm_cfg,
  64. act_cfg,
  65. init_cfg=init_cfg)
  66. # add extra bottom up pathway
  67. self.downsample_convs = nn.ModuleList()
  68. self.pafpn_convs = nn.ModuleList()
  69. for i in range(self.start_level + 1, self.backbone_end_level):
  70. d_conv = ConvModule(
  71. out_channels,
  72. out_channels,
  73. 3,
  74. stride=2,
  75. padding=1,
  76. conv_cfg=conv_cfg,
  77. norm_cfg=norm_cfg,
  78. act_cfg=act_cfg,
  79. inplace=False)
  80. pafpn_conv = ConvModule(
  81. out_channels,
  82. out_channels,
  83. 3,
  84. padding=1,
  85. conv_cfg=conv_cfg,
  86. norm_cfg=norm_cfg,
  87. act_cfg=act_cfg,
  88. inplace=False)
  89. self.downsample_convs.append(d_conv)
  90. self.pafpn_convs.append(pafpn_conv)
  91. @auto_fp16()
  92. def forward(self, inputs):
  93. """Forward function."""
  94. assert len(inputs) == len(self.in_channels)
  95. # build laterals
  96. laterals = [
  97. lateral_conv(inputs[i + self.start_level])
  98. for i, lateral_conv in enumerate(self.lateral_convs)
  99. ]
  100. # build top-down path
  101. used_backbone_levels = len(laterals)
  102. for i in range(used_backbone_levels - 1, 0, -1):
  103. prev_shape = laterals[i - 1].shape[2:]
  104. laterals[i - 1] += F.interpolate(
  105. laterals[i], size=prev_shape, mode='nearest')
  106. # build outputs
  107. # part 1: from original levels
  108. inter_outs = [
  109. self.fpn_convs[i](laterals[i]) for i in range(used_backbone_levels)
  110. ]
  111. # part 2: add bottom-up path
  112. for i in range(0, used_backbone_levels - 1):
  113. inter_outs[i + 1] += self.downsample_convs[i](inter_outs[i])
  114. outs = []
  115. outs.append(inter_outs[0])
  116. outs.extend([
  117. self.pafpn_convs[i - 1](inter_outs[i])
  118. for i in range(1, used_backbone_levels)
  119. ])
  120. # part 3: add extra levels
  121. if self.num_outs > len(outs):
  122. # use max pool to get more levels on top of outputs
  123. # (e.g., Faster R-CNN, Mask R-CNN)
  124. if not self.add_extra_convs:
  125. for i in range(self.num_outs - used_backbone_levels):
  126. outs.append(F.max_pool2d(outs[-1], 1, stride=2))
  127. # add conv layers on top of original feature maps (RetinaNet)
  128. else:
  129. if self.add_extra_convs == 'on_input':
  130. orig = inputs[self.backbone_end_level - 1]
  131. outs.append(self.fpn_convs[used_backbone_levels](orig))
  132. elif self.add_extra_convs == 'on_lateral':
  133. outs.append(self.fpn_convs[used_backbone_levels](
  134. laterals[-1]))
  135. elif self.add_extra_convs == 'on_output':
  136. outs.append(self.fpn_convs[used_backbone_levels](outs[-1]))
  137. else:
  138. raise NotImplementedError
  139. for i in range(used_backbone_levels + 1, self.num_outs):
  140. if self.relu_before_extra_convs:
  141. outs.append(self.fpn_convs[i](F.relu(outs[-1])))
  142. else:
  143. outs.append(self.fpn_convs[i](outs[-1]))
  144. return tuple(outs)

No Description

Contributors (2)