You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

nas_fpn.py 6.6 kB

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import torch.nn as nn
  3. from mmcv.cnn import ConvModule
  4. from mmcv.ops.merge_cells import GlobalPoolingCell, SumCell
  5. from mmcv.runner import BaseModule, ModuleList
  6. from ..builder import NECKS
  7. @NECKS.register_module()
  8. class NASFPN(BaseModule):
  9. """NAS-FPN.
  10. Implementation of `NAS-FPN: Learning Scalable Feature Pyramid Architecture
  11. for Object Detection <https://arxiv.org/abs/1904.07392>`_
  12. Args:
  13. in_channels (List[int]): Number of input channels per scale.
  14. out_channels (int): Number of output channels (used at each scale)
  15. num_outs (int): Number of output scales.
  16. stack_times (int): The number of times the pyramid architecture will
  17. be stacked.
  18. start_level (int): Index of the start input backbone level used to
  19. build the feature pyramid. Default: 0.
  20. end_level (int): Index of the end input backbone level (exclusive) to
  21. build the feature pyramid. Default: -1, which means the last level.
  22. add_extra_convs (bool): It decides whether to add conv
  23. layers on top of the original feature maps. Default to False.
  24. If True, its actual mode is specified by `extra_convs_on_inputs`.
  25. init_cfg (dict or list[dict], optional): Initialization config dict.
  26. """
  27. def __init__(self,
  28. in_channels,
  29. out_channels,
  30. num_outs,
  31. stack_times,
  32. start_level=0,
  33. end_level=-1,
  34. add_extra_convs=False,
  35. norm_cfg=None,
  36. init_cfg=dict(type='Caffe2Xavier', layer='Conv2d')):
  37. super(NASFPN, self).__init__(init_cfg)
  38. assert isinstance(in_channels, list)
  39. self.in_channels = in_channels
  40. self.out_channels = out_channels
  41. self.num_ins = len(in_channels) # num of input feature levels
  42. self.num_outs = num_outs # num of output feature levels
  43. self.stack_times = stack_times
  44. self.norm_cfg = norm_cfg
  45. if end_level == -1:
  46. self.backbone_end_level = self.num_ins
  47. assert num_outs >= self.num_ins - start_level
  48. else:
  49. # if end_level < inputs, no extra level is allowed
  50. self.backbone_end_level = end_level
  51. assert end_level <= len(in_channels)
  52. assert num_outs == end_level - start_level
  53. self.start_level = start_level
  54. self.end_level = end_level
  55. self.add_extra_convs = add_extra_convs
  56. # add lateral connections
  57. self.lateral_convs = nn.ModuleList()
  58. for i in range(self.start_level, self.backbone_end_level):
  59. l_conv = ConvModule(
  60. in_channels[i],
  61. out_channels,
  62. 1,
  63. norm_cfg=norm_cfg,
  64. act_cfg=None)
  65. self.lateral_convs.append(l_conv)
  66. # add extra downsample layers (stride-2 pooling or conv)
  67. extra_levels = num_outs - self.backbone_end_level + self.start_level
  68. self.extra_downsamples = nn.ModuleList()
  69. for i in range(extra_levels):
  70. extra_conv = ConvModule(
  71. out_channels, out_channels, 1, norm_cfg=norm_cfg, act_cfg=None)
  72. self.extra_downsamples.append(
  73. nn.Sequential(extra_conv, nn.MaxPool2d(2, 2)))
  74. # add NAS FPN connections
  75. self.fpn_stages = ModuleList()
  76. for _ in range(self.stack_times):
  77. stage = nn.ModuleDict()
  78. # gp(p6, p4) -> p4_1
  79. stage['gp_64_4'] = GlobalPoolingCell(
  80. in_channels=out_channels,
  81. out_channels=out_channels,
  82. out_norm_cfg=norm_cfg)
  83. # sum(p4_1, p4) -> p4_2
  84. stage['sum_44_4'] = SumCell(
  85. in_channels=out_channels,
  86. out_channels=out_channels,
  87. out_norm_cfg=norm_cfg)
  88. # sum(p4_2, p3) -> p3_out
  89. stage['sum_43_3'] = SumCell(
  90. in_channels=out_channels,
  91. out_channels=out_channels,
  92. out_norm_cfg=norm_cfg)
  93. # sum(p3_out, p4_2) -> p4_out
  94. stage['sum_34_4'] = SumCell(
  95. in_channels=out_channels,
  96. out_channels=out_channels,
  97. out_norm_cfg=norm_cfg)
  98. # sum(p5, gp(p4_out, p3_out)) -> p5_out
  99. stage['gp_43_5'] = GlobalPoolingCell(with_out_conv=False)
  100. stage['sum_55_5'] = SumCell(
  101. in_channels=out_channels,
  102. out_channels=out_channels,
  103. out_norm_cfg=norm_cfg)
  104. # sum(p7, gp(p5_out, p4_2)) -> p7_out
  105. stage['gp_54_7'] = GlobalPoolingCell(with_out_conv=False)
  106. stage['sum_77_7'] = SumCell(
  107. in_channels=out_channels,
  108. out_channels=out_channels,
  109. out_norm_cfg=norm_cfg)
  110. # gp(p7_out, p5_out) -> p6_out
  111. stage['gp_75_6'] = GlobalPoolingCell(
  112. in_channels=out_channels,
  113. out_channels=out_channels,
  114. out_norm_cfg=norm_cfg)
  115. self.fpn_stages.append(stage)
  116. def forward(self, inputs):
  117. """Forward function."""
  118. # build P3-P5
  119. feats = [
  120. lateral_conv(inputs[i + self.start_level])
  121. for i, lateral_conv in enumerate(self.lateral_convs)
  122. ]
  123. # build P6-P7 on top of P5
  124. for downsample in self.extra_downsamples:
  125. feats.append(downsample(feats[-1]))
  126. p3, p4, p5, p6, p7 = feats
  127. for stage in self.fpn_stages:
  128. # gp(p6, p4) -> p4_1
  129. p4_1 = stage['gp_64_4'](p6, p4, out_size=p4.shape[-2:])
  130. # sum(p4_1, p4) -> p4_2
  131. p4_2 = stage['sum_44_4'](p4_1, p4, out_size=p4.shape[-2:])
  132. # sum(p4_2, p3) -> p3_out
  133. p3 = stage['sum_43_3'](p4_2, p3, out_size=p3.shape[-2:])
  134. # sum(p3_out, p4_2) -> p4_out
  135. p4 = stage['sum_34_4'](p3, p4_2, out_size=p4.shape[-2:])
  136. # sum(p5, gp(p4_out, p3_out)) -> p5_out
  137. p5_tmp = stage['gp_43_5'](p4, p3, out_size=p5.shape[-2:])
  138. p5 = stage['sum_55_5'](p5, p5_tmp, out_size=p5.shape[-2:])
  139. # sum(p7, gp(p5_out, p4_2)) -> p7_out
  140. p7_tmp = stage['gp_54_7'](p5, p4_2, out_size=p7.shape[-2:])
  141. p7 = stage['sum_77_7'](p7, p7_tmp, out_size=p7.shape[-2:])
  142. # gp(p7_out, p5_out) -> p6_out
  143. p6 = stage['gp_75_6'](p7, p5, out_size=p6.shape[-2:])
  144. return p3, p4, p5, p6, p7

No Description

Contributors (2)