You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

yolox_pafpn.py 5.6 kB

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import math
  3. import torch
  4. import torch.nn as nn
  5. from mmcv.cnn import ConvModule, DepthwiseSeparableConvModule
  6. from mmcv.runner import BaseModule
  7. from ..builder import NECKS
  8. from ..utils import CSPLayer
  9. @NECKS.register_module()
  10. class YOLOXPAFPN(BaseModule):
  11. """Path Aggregation Network used in YOLOX.
  12. Args:
  13. in_channels (List[int]): Number of input channels per scale.
  14. out_channels (int): Number of output channels (used at each scale)
  15. num_csp_blocks (int): Number of bottlenecks in CSPLayer. Default: 3
  16. use_depthwise (bool): Whether to depthwise separable convolution in
  17. blocks. Default: False
  18. upsample_cfg (dict): Config dict for interpolate layer.
  19. Default: `dict(scale_factor=2, mode='nearest')`
  20. conv_cfg (dict, optional): Config dict for convolution layer.
  21. Default: None, which means using conv2d.
  22. norm_cfg (dict): Config dict for normalization layer.
  23. Default: dict(type='BN')
  24. act_cfg (dict): Config dict for activation layer.
  25. Default: dict(type='Swish')
  26. init_cfg (dict or list[dict], optional): Initialization config dict.
  27. Default: None.
  28. """
  29. def __init__(self,
  30. in_channels,
  31. out_channels,
  32. num_csp_blocks=3,
  33. use_depthwise=False,
  34. upsample_cfg=dict(scale_factor=2, mode='nearest'),
  35. conv_cfg=None,
  36. norm_cfg=dict(type='BN', momentum=0.03, eps=0.001),
  37. act_cfg=dict(type='Swish'),
  38. init_cfg=dict(
  39. type='Kaiming',
  40. layer='Conv2d',
  41. a=math.sqrt(5),
  42. distribution='uniform',
  43. mode='fan_in',
  44. nonlinearity='leaky_relu')):
  45. super(YOLOXPAFPN, self).__init__(init_cfg)
  46. self.in_channels = in_channels
  47. self.out_channels = out_channels
  48. conv = DepthwiseSeparableConvModule if use_depthwise else ConvModule
  49. # build top-down blocks
  50. self.upsample = nn.Upsample(**upsample_cfg)
  51. self.reduce_layers = nn.ModuleList()
  52. self.top_down_blocks = nn.ModuleList()
  53. for idx in range(len(in_channels) - 1, 0, -1):
  54. self.reduce_layers.append(
  55. ConvModule(
  56. in_channels[idx],
  57. in_channels[idx - 1],
  58. 1,
  59. conv_cfg=conv_cfg,
  60. norm_cfg=norm_cfg,
  61. act_cfg=act_cfg))
  62. self.top_down_blocks.append(
  63. CSPLayer(
  64. in_channels[idx - 1] * 2,
  65. in_channels[idx - 1],
  66. num_blocks=num_csp_blocks,
  67. add_identity=False,
  68. use_depthwise=use_depthwise,
  69. conv_cfg=conv_cfg,
  70. norm_cfg=norm_cfg,
  71. act_cfg=act_cfg))
  72. # build bottom-up blocks
  73. self.downsamples = nn.ModuleList()
  74. self.bottom_up_blocks = nn.ModuleList()
  75. for idx in range(len(in_channels) - 1):
  76. self.downsamples.append(
  77. conv(
  78. in_channels[idx],
  79. in_channels[idx],
  80. 3,
  81. stride=2,
  82. padding=1,
  83. conv_cfg=conv_cfg,
  84. norm_cfg=norm_cfg,
  85. act_cfg=act_cfg))
  86. self.bottom_up_blocks.append(
  87. CSPLayer(
  88. in_channels[idx] * 2,
  89. in_channels[idx + 1],
  90. num_blocks=num_csp_blocks,
  91. add_identity=False,
  92. use_depthwise=use_depthwise,
  93. conv_cfg=conv_cfg,
  94. norm_cfg=norm_cfg,
  95. act_cfg=act_cfg))
  96. self.out_convs = nn.ModuleList()
  97. for i in range(len(in_channels)):
  98. self.out_convs.append(
  99. ConvModule(
  100. in_channels[i],
  101. out_channels,
  102. 1,
  103. conv_cfg=conv_cfg,
  104. norm_cfg=norm_cfg,
  105. act_cfg=act_cfg))
  106. def forward(self, inputs):
  107. """
  108. Args:
  109. inputs (tuple[Tensor]): input features.
  110. Returns:
  111. tuple[Tensor]: YOLOXPAFPN features.
  112. """
  113. assert len(inputs) == len(self.in_channels)
  114. # top-down path
  115. inner_outs = [inputs[-1]]
  116. for idx in range(len(self.in_channels) - 1, 0, -1):
  117. feat_heigh = inner_outs[0]
  118. feat_low = inputs[idx - 1]
  119. feat_heigh = self.reduce_layers[len(self.in_channels) - 1 - idx](
  120. feat_heigh)
  121. inner_outs[0] = feat_heigh
  122. upsample_feat = self.upsample(feat_heigh)
  123. inner_out = self.top_down_blocks[len(self.in_channels) - 1 - idx](
  124. torch.cat([upsample_feat, feat_low], 1))
  125. inner_outs.insert(0, inner_out)
  126. # bottom-up path
  127. outs = [inner_outs[0]]
  128. for idx in range(len(self.in_channels) - 1):
  129. feat_low = outs[-1]
  130. feat_height = inner_outs[idx + 1]
  131. downsample_feat = self.downsamples[idx](feat_low)
  132. out = self.bottom_up_blocks[idx](
  133. torch.cat([downsample_feat, feat_height], 1))
  134. outs.append(out)
  135. # out convs
  136. for idx, conv in enumerate(self.out_convs):
  137. outs[idx] = conv(outs[idx])
  138. return tuple(outs)

No Description

Contributors (2)