You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import torch.nn as nn
  3. import torch.nn.functional as F
  4. from mmcv.cnn import ConvModule
  5. from mmcv.runner import BaseModule, auto_fp16
  6. from ..builder import NECKS
  7. @NECKS.register_module()
  8. class FPN(BaseModule):
  9. r"""Feature Pyramid Network.
  10. This is an implementation of paper `Feature Pyramid Networks for Object
  11. Detection <https://arxiv.org/abs/1612.03144>`_.
  12. Args:
  13. in_channels (List[int]): Number of input channels per scale.
  14. out_channels (int): Number of output channels (used at each scale)
  15. num_outs (int): Number of output scales.
  16. start_level (int): Index of the start input backbone level used to
  17. build the feature pyramid. Default: 0.
  18. end_level (int): Index of the end input backbone level (exclusive) to
  19. build the feature pyramid. Default: -1, which means the last level.
  20. add_extra_convs (bool | str): If bool, it decides whether to add conv
  21. layers on top of the original feature maps. Default to False.
  22. If True, it is equivalent to `add_extra_convs='on_input'`.
  23. If str, it specifies the source feature map of the extra convs.
  24. Only the following options are allowed
  25. - 'on_input': Last feat map of neck inputs (i.e. backbone feature).
  26. - 'on_lateral': Last feature map after lateral convs.
  27. - 'on_output': The last output feature map after fpn convs.
  28. relu_before_extra_convs (bool): Whether to apply relu before the extra
  29. conv. Default: False.
  30. no_norm_on_lateral (bool): Whether to apply norm on lateral.
  31. Default: False.
  32. conv_cfg (dict): Config dict for convolution layer. Default: None.
  33. norm_cfg (dict): Config dict for normalization layer. Default: None.
  34. act_cfg (str): Config dict for activation layer in ConvModule.
  35. Default: None.
  36. upsample_cfg (dict): Config dict for interpolate layer.
  37. Default: `dict(mode='nearest')`
  38. init_cfg (dict or list[dict], optional): Initialization config dict.
  39. Example:
  40. >>> import torch
  41. >>> in_channels = [2, 3, 5, 7]
  42. >>> scales = [340, 170, 84, 43]
  43. >>> inputs = [torch.rand(1, c, s, s)
  44. ... for c, s in zip(in_channels, scales)]
  45. >>> self = FPN(in_channels, 11, len(in_channels)).eval()
  46. >>> outputs = self.forward(inputs)
  47. >>> for i in range(len(outputs)):
  48. ... print(f'outputs[{i}].shape = {outputs[i].shape}')
  49. outputs[0].shape = torch.Size([1, 11, 340, 340])
  50. outputs[1].shape = torch.Size([1, 11, 170, 170])
  51. outputs[2].shape = torch.Size([1, 11, 84, 84])
  52. outputs[3].shape = torch.Size([1, 11, 43, 43])
  53. """
  54. def __init__(self,
  55. in_channels,
  56. out_channels,
  57. num_outs,
  58. start_level=0,
  59. end_level=-1,
  60. add_extra_convs=False,
  61. relu_before_extra_convs=False,
  62. no_norm_on_lateral=False,
  63. conv_cfg=None,
  64. norm_cfg=None,
  65. act_cfg=None,
  66. upsample_cfg=dict(mode='nearest'),
  67. init_cfg=dict(
  68. type='Xavier', layer='Conv2d', distribution='uniform')):
  69. super(FPN, self).__init__(init_cfg)
  70. assert isinstance(in_channels, list)
  71. self.in_channels = in_channels
  72. self.out_channels = out_channels
  73. self.num_ins = len(in_channels)
  74. self.num_outs = num_outs
  75. self.relu_before_extra_convs = relu_before_extra_convs
  76. self.no_norm_on_lateral = no_norm_on_lateral
  77. self.fp16_enabled = False
  78. self.upsample_cfg = upsample_cfg.copy()
  79. if end_level == -1:
  80. self.backbone_end_level = self.num_ins
  81. assert num_outs >= self.num_ins - start_level
  82. else:
  83. # if end_level < inputs, no extra level is allowed
  84. self.backbone_end_level = end_level
  85. assert end_level <= len(in_channels)
  86. assert num_outs == end_level - start_level
  87. self.start_level = start_level
  88. self.end_level = end_level
  89. self.add_extra_convs = add_extra_convs
  90. assert isinstance(add_extra_convs, (str, bool))
  91. if isinstance(add_extra_convs, str):
  92. # Extra_convs_source choices: 'on_input', 'on_lateral', 'on_output'
  93. assert add_extra_convs in ('on_input', 'on_lateral', 'on_output')
  94. elif add_extra_convs: # True
  95. self.add_extra_convs = 'on_input'
  96. self.lateral_convs = nn.ModuleList()
  97. self.fpn_convs = nn.ModuleList()
  98. for i in range(self.start_level, self.backbone_end_level):
  99. l_conv = ConvModule(
  100. in_channels[i],
  101. out_channels,
  102. 1,
  103. conv_cfg=conv_cfg,
  104. norm_cfg=norm_cfg if not self.no_norm_on_lateral else None,
  105. act_cfg=act_cfg,
  106. inplace=False)
  107. fpn_conv = ConvModule(
  108. out_channels,
  109. out_channels,
  110. 3,
  111. padding=1,
  112. conv_cfg=conv_cfg,
  113. norm_cfg=norm_cfg,
  114. act_cfg=act_cfg,
  115. inplace=False)
  116. self.lateral_convs.append(l_conv)
  117. self.fpn_convs.append(fpn_conv)
  118. # add extra conv layers (e.g., RetinaNet)
  119. extra_levels = num_outs - self.backbone_end_level + self.start_level
  120. if self.add_extra_convs and extra_levels >= 1:
  121. for i in range(extra_levels):
  122. if i == 0 and self.add_extra_convs == 'on_input':
  123. in_channels = self.in_channels[self.backbone_end_level - 1]
  124. else:
  125. in_channels = out_channels
  126. extra_fpn_conv = ConvModule(
  127. in_channels,
  128. out_channels,
  129. 3,
  130. stride=2,
  131. padding=1,
  132. conv_cfg=conv_cfg,
  133. norm_cfg=norm_cfg,
  134. act_cfg=act_cfg,
  135. inplace=False)
  136. self.fpn_convs.append(extra_fpn_conv)
  137. @auto_fp16()
  138. def forward(self, inputs):
  139. """Forward function."""
  140. assert len(inputs) == len(self.in_channels)
  141. # build laterals
  142. laterals = [
  143. lateral_conv(inputs[i + self.start_level])
  144. for i, lateral_conv in enumerate(self.lateral_convs)
  145. ]
  146. # build top-down path
  147. used_backbone_levels = len(laterals)
  148. for i in range(used_backbone_levels - 1, 0, -1):
  149. # In some cases, fixing `scale factor` (e.g. 2) is preferred, but
  150. # it cannot co-exist with `size` in `F.interpolate`.
  151. if 'scale_factor' in self.upsample_cfg:
  152. laterals[i - 1] += F.interpolate(laterals[i],
  153. **self.upsample_cfg)
  154. else:
  155. prev_shape = laterals[i - 1].shape[2:]
  156. laterals[i - 1] += F.interpolate(
  157. laterals[i], size=prev_shape, **self.upsample_cfg)
  158. # build outputs
  159. # part 1: from original levels
  160. outs = [
  161. self.fpn_convs[i](laterals[i]) for i in range(used_backbone_levels)
  162. ]
  163. # part 2: add extra levels
  164. if self.num_outs > len(outs):
  165. # use max pool to get more levels on top of outputs
  166. # (e.g., Faster R-CNN, Mask R-CNN)
  167. if not self.add_extra_convs:
  168. for i in range(self.num_outs - used_backbone_levels):
  169. outs.append(F.max_pool2d(outs[-1], 1, stride=2))
  170. # add conv layers on top of original feature maps (RetinaNet)
  171. else:
  172. if self.add_extra_convs == 'on_input':
  173. extra_source = inputs[self.backbone_end_level - 1]
  174. elif self.add_extra_convs == 'on_lateral':
  175. extra_source = laterals[-1]
  176. elif self.add_extra_convs == 'on_output':
  177. extra_source = outs[-1]
  178. else:
  179. raise NotImplementedError
  180. outs.append(self.fpn_convs[used_backbone_levels](extra_source))
  181. for i in range(used_backbone_levels + 1, self.num_outs):
  182. if self.relu_before_extra_convs:
  183. outs.append(self.fpn_convs[i](F.relu(outs[-1])))
  184. else:
  185. outs.append(self.fpn_convs[i](outs[-1]))
  186. return tuple(outs)

No Description

Contributors (3)