You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

csp_darknet.py 10 kB

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import math
  3. import torch
  4. import torch.nn as nn
  5. from mmcv.cnn import ConvModule, DepthwiseSeparableConvModule
  6. from mmcv.runner import BaseModule
  7. from torch.nn.modules.batchnorm import _BatchNorm
  8. from ..builder import BACKBONES
  9. from ..utils import CSPLayer
  10. class Focus(nn.Module):
  11. """Focus width and height information into channel space.
  12. Args:
  13. in_channels (int): The input channels of this Module.
  14. out_channels (int): The output channels of this Module.
  15. kernel_size (int): The kernel size of the convolution. Default: 1
  16. stride (int): The stride of the convolution. Default: 1
  17. conv_cfg (dict): Config dict for convolution layer. Default: None,
  18. which means using conv2d.
  19. norm_cfg (dict): Config dict for normalization layer.
  20. Default: dict(type='BN', momentum=0.03, eps=0.001).
  21. act_cfg (dict): Config dict for activation layer.
  22. Default: dict(type='Swish').
  23. """
  24. def __init__(self,
  25. in_channels,
  26. out_channels,
  27. kernel_size=1,
  28. stride=1,
  29. conv_cfg=None,
  30. norm_cfg=dict(type='BN', momentum=0.03, eps=0.001),
  31. act_cfg=dict(type='Swish')):
  32. super().__init__()
  33. self.conv = ConvModule(
  34. in_channels * 4,
  35. out_channels,
  36. kernel_size,
  37. stride,
  38. padding=(kernel_size - 1) // 2,
  39. conv_cfg=conv_cfg,
  40. norm_cfg=norm_cfg,
  41. act_cfg=act_cfg)
  42. def forward(self, x):
  43. # shape of x (b,c,w,h) -> y(b,4c,w/2,h/2)
  44. patch_top_left = x[..., ::2, ::2]
  45. patch_top_right = x[..., ::2, 1::2]
  46. patch_bot_left = x[..., 1::2, ::2]
  47. patch_bot_right = x[..., 1::2, 1::2]
  48. x = torch.cat(
  49. (
  50. patch_top_left,
  51. patch_bot_left,
  52. patch_top_right,
  53. patch_bot_right,
  54. ),
  55. dim=1,
  56. )
  57. return self.conv(x)
  58. class SPPBottleneck(BaseModule):
  59. """Spatial pyramid pooling layer used in YOLOv3-SPP.
  60. Args:
  61. in_channels (int): The input channels of this Module.
  62. out_channels (int): The output channels of this Module.
  63. kernel_sizes (tuple[int]): Sequential of kernel sizes of pooling
  64. layers. Default: (5, 9, 13).
  65. conv_cfg (dict): Config dict for convolution layer. Default: None,
  66. which means using conv2d.
  67. norm_cfg (dict): Config dict for normalization layer.
  68. Default: dict(type='BN').
  69. act_cfg (dict): Config dict for activation layer.
  70. Default: dict(type='Swish').
  71. init_cfg (dict or list[dict], optional): Initialization config dict.
  72. Default: None.
  73. """
  74. def __init__(self,
  75. in_channels,
  76. out_channels,
  77. kernel_sizes=(5, 9, 13),
  78. conv_cfg=None,
  79. norm_cfg=dict(type='BN', momentum=0.03, eps=0.001),
  80. act_cfg=dict(type='Swish'),
  81. init_cfg=None):
  82. super().__init__(init_cfg)
  83. mid_channels = in_channels // 2
  84. self.conv1 = ConvModule(
  85. in_channels,
  86. mid_channels,
  87. 1,
  88. stride=1,
  89. conv_cfg=conv_cfg,
  90. norm_cfg=norm_cfg,
  91. act_cfg=act_cfg)
  92. self.poolings = nn.ModuleList([
  93. nn.MaxPool2d(kernel_size=ks, stride=1, padding=ks // 2)
  94. for ks in kernel_sizes
  95. ])
  96. conv2_channels = mid_channels * (len(kernel_sizes) + 1)
  97. self.conv2 = ConvModule(
  98. conv2_channels,
  99. out_channels,
  100. 1,
  101. conv_cfg=conv_cfg,
  102. norm_cfg=norm_cfg,
  103. act_cfg=act_cfg)
  104. def forward(self, x):
  105. x = self.conv1(x)
  106. x = torch.cat([x] + [pooling(x) for pooling in self.poolings], dim=1)
  107. x = self.conv2(x)
  108. return x
  109. @BACKBONES.register_module()
  110. class CSPDarknet(BaseModule):
  111. """CSP-Darknet backbone used in YOLOv5 and YOLOX.
  112. Args:
  113. arch (str): Architecture of CSP-Darknet, from {P5, P6}.
  114. Default: P5.
  115. deepen_factor (float): Depth multiplier, multiply number of
  116. channels in each layer by this amount. Default: 1.0.
  117. widen_factor (float): Width multiplier, multiply number of
  118. blocks in CSP layer by this amount. Default: 1.0.
  119. out_indices (Sequence[int]): Output from which stages.
  120. Default: (2, 3, 4).
  121. frozen_stages (int): Stages to be frozen (stop grad and set eval
  122. mode). -1 means not freezing any parameters. Default: -1.
  123. use_depthwise (bool): Whether to use depthwise separable convolution.
  124. Default: False.
  125. arch_ovewrite(list): Overwrite default arch settings. Default: None.
  126. spp_kernal_sizes: (tuple[int]): Sequential of kernel sizes of SPP
  127. layers. Default: (5, 9, 13).
  128. conv_cfg (dict): Config dict for convolution layer. Default: None.
  129. norm_cfg (dict): Dictionary to construct and config norm layer.
  130. Default: dict(type='BN', requires_grad=True).
  131. act_cfg (dict): Config dict for activation layer.
  132. Default: dict(type='LeakyReLU', negative_slope=0.1).
  133. norm_eval (bool): Whether to set norm layers to eval mode, namely,
  134. freeze running stats (mean and var). Note: Effect on Batch Norm
  135. and its variants only.
  136. init_cfg (dict or list[dict], optional): Initialization config dict.
  137. Default: None.
  138. Example:
  139. >>> from mmdet.models import CSPDarknet
  140. >>> import torch
  141. >>> self = CSPDarknet(depth=53)
  142. >>> self.eval()
  143. >>> inputs = torch.rand(1, 3, 416, 416)
  144. >>> level_outputs = self.forward(inputs)
  145. >>> for level_out in level_outputs:
  146. ... print(tuple(level_out.shape))
  147. ...
  148. (1, 256, 52, 52)
  149. (1, 512, 26, 26)
  150. (1, 1024, 13, 13)
  151. """
  152. # From left to right:
  153. # in_channels, out_channels, num_blocks, add_identity, use_spp
  154. arch_settings = {
  155. 'P5': [[64, 128, 3, True, False], [128, 256, 9, True, False],
  156. [256, 512, 9, True, False], [512, 1024, 3, False, True]],
  157. 'P6': [[64, 128, 3, True, False], [128, 256, 9, True, False],
  158. [256, 512, 9, True, False], [512, 768, 3, True, False],
  159. [768, 1024, 3, False, True]]
  160. }
  161. def __init__(self,
  162. arch='P5',
  163. deepen_factor=1.0,
  164. widen_factor=1.0,
  165. out_indices=(2, 3, 4),
  166. frozen_stages=-1,
  167. use_depthwise=False,
  168. arch_ovewrite=None,
  169. spp_kernal_sizes=(5, 9, 13),
  170. conv_cfg=None,
  171. norm_cfg=dict(type='BN', momentum=0.03, eps=0.001),
  172. act_cfg=dict(type='Swish'),
  173. norm_eval=False,
  174. init_cfg=dict(
  175. type='Kaiming',
  176. layer='Conv2d',
  177. a=math.sqrt(5),
  178. distribution='uniform',
  179. mode='fan_in',
  180. nonlinearity='leaky_relu')):
  181. super().__init__(init_cfg)
  182. arch_setting = self.arch_settings[arch]
  183. if arch_ovewrite:
  184. arch_setting = arch_ovewrite
  185. assert set(out_indices).issubset(
  186. i for i in range(len(arch_setting) + 1))
  187. if frozen_stages not in range(-1, len(arch_setting) + 1):
  188. raise ValueError('frozen_stages must be in range(-1, '
  189. 'len(arch_setting) + 1). But received '
  190. f'{frozen_stages}')
  191. self.out_indices = out_indices
  192. self.frozen_stages = frozen_stages
  193. self.use_depthwise = use_depthwise
  194. self.norm_eval = norm_eval
  195. conv = DepthwiseSeparableConvModule if use_depthwise else ConvModule
  196. self.stem = Focus(
  197. 3,
  198. int(arch_setting[0][0] * widen_factor),
  199. kernel_size=3,
  200. conv_cfg=conv_cfg,
  201. norm_cfg=norm_cfg,
  202. act_cfg=act_cfg)
  203. self.layers = ['stem']
  204. for i, (in_channels, out_channels, num_blocks, add_identity,
  205. use_spp) in enumerate(arch_setting):
  206. in_channels = int(in_channels * widen_factor)
  207. out_channels = int(out_channels * widen_factor)
  208. num_blocks = max(round(num_blocks * deepen_factor), 1)
  209. stage = []
  210. conv_layer = conv(
  211. in_channels,
  212. out_channels,
  213. 3,
  214. stride=2,
  215. padding=1,
  216. conv_cfg=conv_cfg,
  217. norm_cfg=norm_cfg,
  218. act_cfg=act_cfg)
  219. stage.append(conv_layer)
  220. if use_spp:
  221. spp = SPPBottleneck(
  222. out_channels,
  223. out_channels,
  224. kernel_sizes=spp_kernal_sizes,
  225. conv_cfg=conv_cfg,
  226. norm_cfg=norm_cfg,
  227. act_cfg=act_cfg)
  228. stage.append(spp)
  229. csp_layer = CSPLayer(
  230. out_channels,
  231. out_channels,
  232. num_blocks=num_blocks,
  233. add_identity=add_identity,
  234. use_depthwise=use_depthwise,
  235. conv_cfg=conv_cfg,
  236. norm_cfg=norm_cfg,
  237. act_cfg=act_cfg)
  238. stage.append(csp_layer)
  239. self.add_module(f'stage{i + 1}', nn.Sequential(*stage))
  240. self.layers.append(f'stage{i + 1}')
  241. def _freeze_stages(self):
  242. if self.frozen_stages >= 0:
  243. for i in range(self.frozen_stages + 1):
  244. m = getattr(self, self.layers[i])
  245. m.eval()
  246. for param in m.parameters():
  247. param.requires_grad = False
  248. def train(self, mode=True):
  249. super(CSPDarknet, self).train(mode)
  250. self._freeze_stages()
  251. if mode and self.norm_eval:
  252. for m in self.modules():
  253. if isinstance(m, _BatchNorm):
  254. m.eval()
  255. def forward(self, x):
  256. outs = []
  257. for i, layer_name in enumerate(self.layers):
  258. layer = getattr(self, layer_name)
  259. x = layer(x)
  260. if i in self.out_indices:
  261. outs.append(x)
  262. return tuple(outs)

No Description

Contributors (3)