You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_necks.py 13 kB

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import pytest
  3. import torch
  4. from torch.nn.modules.batchnorm import _BatchNorm
  5. from mmdet.models.necks import (FPN, YOLOXPAFPN, ChannelMapper, CTResNetNeck,
  6. DilatedEncoder, SSDNeck, YOLOV3Neck)
  7. def test_fpn():
  8. """Tests fpn."""
  9. s = 64
  10. in_channels = [8, 16, 32, 64]
  11. feat_sizes = [s // 2**i for i in range(4)] # [64, 32, 16, 8]
  12. out_channels = 8
  13. # `num_outs` is not equal to len(in_channels) - start_level
  14. with pytest.raises(AssertionError):
  15. FPN(in_channels=in_channels,
  16. out_channels=out_channels,
  17. start_level=1,
  18. num_outs=2)
  19. # `end_level` is larger than len(in_channels) - 1
  20. with pytest.raises(AssertionError):
  21. FPN(in_channels=in_channels,
  22. out_channels=out_channels,
  23. start_level=1,
  24. end_level=4,
  25. num_outs=2)
  26. # `num_outs` is not equal to end_level - start_level
  27. with pytest.raises(AssertionError):
  28. FPN(in_channels=in_channels,
  29. out_channels=out_channels,
  30. start_level=1,
  31. end_level=3,
  32. num_outs=1)
  33. # Invalid `add_extra_convs` option
  34. with pytest.raises(AssertionError):
  35. FPN(in_channels=in_channels,
  36. out_channels=out_channels,
  37. start_level=1,
  38. add_extra_convs='on_xxx',
  39. num_outs=5)
  40. fpn_model = FPN(
  41. in_channels=in_channels,
  42. out_channels=out_channels,
  43. start_level=1,
  44. add_extra_convs=True,
  45. num_outs=5)
  46. # FPN expects a multiple levels of features per image
  47. feats = [
  48. torch.rand(1, in_channels[i], feat_sizes[i], feat_sizes[i])
  49. for i in range(len(in_channels))
  50. ]
  51. outs = fpn_model(feats)
  52. assert fpn_model.add_extra_convs == 'on_input'
  53. assert len(outs) == fpn_model.num_outs
  54. for i in range(fpn_model.num_outs):
  55. outs[i].shape[1] == out_channels
  56. outs[i].shape[2] == outs[i].shape[3] == s // (2**i)
  57. # Tests for fpn with no extra convs (pooling is used instead)
  58. fpn_model = FPN(
  59. in_channels=in_channels,
  60. out_channels=out_channels,
  61. start_level=1,
  62. add_extra_convs=False,
  63. num_outs=5)
  64. outs = fpn_model(feats)
  65. assert len(outs) == fpn_model.num_outs
  66. assert not fpn_model.add_extra_convs
  67. for i in range(fpn_model.num_outs):
  68. outs[i].shape[1] == out_channels
  69. outs[i].shape[2] == outs[i].shape[3] == s // (2**i)
  70. # Tests for fpn with lateral bns
  71. fpn_model = FPN(
  72. in_channels=in_channels,
  73. out_channels=out_channels,
  74. start_level=1,
  75. add_extra_convs=True,
  76. no_norm_on_lateral=False,
  77. norm_cfg=dict(type='BN', requires_grad=True),
  78. num_outs=5)
  79. outs = fpn_model(feats)
  80. assert len(outs) == fpn_model.num_outs
  81. assert fpn_model.add_extra_convs == 'on_input'
  82. for i in range(fpn_model.num_outs):
  83. outs[i].shape[1] == out_channels
  84. outs[i].shape[2] == outs[i].shape[3] == s // (2**i)
  85. bn_exist = False
  86. for m in fpn_model.modules():
  87. if isinstance(m, _BatchNorm):
  88. bn_exist = True
  89. assert bn_exist
  90. # Bilinear upsample
  91. fpn_model = FPN(
  92. in_channels=in_channels,
  93. out_channels=out_channels,
  94. start_level=1,
  95. add_extra_convs=True,
  96. upsample_cfg=dict(mode='bilinear', align_corners=True),
  97. num_outs=5)
  98. fpn_model(feats)
  99. outs = fpn_model(feats)
  100. assert len(outs) == fpn_model.num_outs
  101. assert fpn_model.add_extra_convs == 'on_input'
  102. for i in range(fpn_model.num_outs):
  103. outs[i].shape[1] == out_channels
  104. outs[i].shape[2] == outs[i].shape[3] == s // (2**i)
  105. # Scale factor instead of fixed upsample size upsample
  106. fpn_model = FPN(
  107. in_channels=in_channels,
  108. out_channels=out_channels,
  109. start_level=1,
  110. add_extra_convs=True,
  111. upsample_cfg=dict(scale_factor=2),
  112. num_outs=5)
  113. outs = fpn_model(feats)
  114. assert len(outs) == fpn_model.num_outs
  115. for i in range(fpn_model.num_outs):
  116. outs[i].shape[1] == out_channels
  117. outs[i].shape[2] == outs[i].shape[3] == s // (2**i)
  118. # Extra convs source is 'inputs'
  119. fpn_model = FPN(
  120. in_channels=in_channels,
  121. out_channels=out_channels,
  122. add_extra_convs='on_input',
  123. start_level=1,
  124. num_outs=5)
  125. assert fpn_model.add_extra_convs == 'on_input'
  126. outs = fpn_model(feats)
  127. assert len(outs) == fpn_model.num_outs
  128. for i in range(fpn_model.num_outs):
  129. outs[i].shape[1] == out_channels
  130. outs[i].shape[2] == outs[i].shape[3] == s // (2**i)
  131. # Extra convs source is 'laterals'
  132. fpn_model = FPN(
  133. in_channels=in_channels,
  134. out_channels=out_channels,
  135. add_extra_convs='on_lateral',
  136. start_level=1,
  137. num_outs=5)
  138. assert fpn_model.add_extra_convs == 'on_lateral'
  139. outs = fpn_model(feats)
  140. assert len(outs) == fpn_model.num_outs
  141. for i in range(fpn_model.num_outs):
  142. outs[i].shape[1] == out_channels
  143. outs[i].shape[2] == outs[i].shape[3] == s // (2**i)
  144. # Extra convs source is 'outputs'
  145. fpn_model = FPN(
  146. in_channels=in_channels,
  147. out_channels=out_channels,
  148. add_extra_convs='on_output',
  149. start_level=1,
  150. num_outs=5)
  151. assert fpn_model.add_extra_convs == 'on_output'
  152. outs = fpn_model(feats)
  153. assert len(outs) == fpn_model.num_outs
  154. for i in range(fpn_model.num_outs):
  155. outs[i].shape[1] == out_channels
  156. outs[i].shape[2] == outs[i].shape[3] == s // (2**i)
  157. def test_channel_mapper():
  158. """Tests ChannelMapper."""
  159. s = 64
  160. in_channels = [8, 16, 32, 64]
  161. feat_sizes = [s // 2**i for i in range(4)] # [64, 32, 16, 8]
  162. out_channels = 8
  163. kernel_size = 3
  164. feats = [
  165. torch.rand(1, in_channels[i], feat_sizes[i], feat_sizes[i])
  166. for i in range(len(in_channels))
  167. ]
  168. # in_channels must be a list
  169. with pytest.raises(AssertionError):
  170. channel_mapper = ChannelMapper(
  171. in_channels=10, out_channels=out_channels, kernel_size=kernel_size)
  172. # the length of channel_mapper's inputs must be equal to the length of
  173. # in_channels
  174. with pytest.raises(AssertionError):
  175. channel_mapper = ChannelMapper(
  176. in_channels=in_channels[:-1],
  177. out_channels=out_channels,
  178. kernel_size=kernel_size)
  179. channel_mapper(feats)
  180. channel_mapper = ChannelMapper(
  181. in_channels=in_channels,
  182. out_channels=out_channels,
  183. kernel_size=kernel_size)
  184. outs = channel_mapper(feats)
  185. assert len(outs) == len(feats)
  186. for i in range(len(feats)):
  187. outs[i].shape[1] == out_channels
  188. outs[i].shape[2] == outs[i].shape[3] == s // (2**i)
  189. def test_dilated_encoder():
  190. in_channels = 16
  191. out_channels = 32
  192. out_shape = 34
  193. dilated_encoder = DilatedEncoder(in_channels, out_channels, 16, 2)
  194. feat = [torch.rand(1, in_channels, 34, 34)]
  195. out_feat = dilated_encoder(feat)[0]
  196. assert out_feat.shape == (1, out_channels, out_shape, out_shape)
  197. def test_ct_resnet_neck():
  198. # num_filters/num_kernels must be a list
  199. with pytest.raises(TypeError):
  200. CTResNetNeck(
  201. in_channel=10, num_deconv_filters=10, num_deconv_kernels=4)
  202. # num_filters/num_kernels must be same length
  203. with pytest.raises(AssertionError):
  204. CTResNetNeck(
  205. in_channel=10,
  206. num_deconv_filters=(10, 10),
  207. num_deconv_kernels=(4, ))
  208. in_channels = 16
  209. num_filters = (8, 8)
  210. num_kernels = (4, 4)
  211. feat = torch.rand(1, 16, 4, 4)
  212. ct_resnet_neck = CTResNetNeck(
  213. in_channel=in_channels,
  214. num_deconv_filters=num_filters,
  215. num_deconv_kernels=num_kernels,
  216. use_dcn=False)
  217. # feat must be list or tuple
  218. with pytest.raises(AssertionError):
  219. ct_resnet_neck(feat)
  220. out_feat = ct_resnet_neck([feat])[0]
  221. assert out_feat.shape == (1, num_filters[-1], 16, 16)
  222. if torch.cuda.is_available():
  223. # test dcn
  224. ct_resnet_neck = CTResNetNeck(
  225. in_channel=in_channels,
  226. num_deconv_filters=num_filters,
  227. num_deconv_kernels=num_kernels)
  228. ct_resnet_neck = ct_resnet_neck.cuda()
  229. feat = feat.cuda()
  230. out_feat = ct_resnet_neck([feat])[0]
  231. assert out_feat.shape == (1, num_filters[-1], 16, 16)
  232. def test_yolov3_neck():
  233. # num_scales, in_channels, out_channels must be same length
  234. with pytest.raises(AssertionError):
  235. YOLOV3Neck(num_scales=3, in_channels=[16, 8, 4], out_channels=[8, 4])
  236. # len(feats) must equal to num_scales
  237. with pytest.raises(AssertionError):
  238. neck = YOLOV3Neck(
  239. num_scales=3, in_channels=[16, 8, 4], out_channels=[8, 4, 2])
  240. feats = (torch.rand(1, 4, 16, 16), torch.rand(1, 8, 16, 16))
  241. neck(feats)
  242. # test normal channels
  243. s = 32
  244. in_channels = [16, 8, 4]
  245. out_channels = [8, 4, 2]
  246. feat_sizes = [s // 2**i for i in range(len(in_channels) - 1, -1, -1)]
  247. feats = [
  248. torch.rand(1, in_channels[i], feat_sizes[i], feat_sizes[i])
  249. for i in range(len(in_channels) - 1, -1, -1)
  250. ]
  251. neck = YOLOV3Neck(
  252. num_scales=3, in_channels=in_channels, out_channels=out_channels)
  253. outs = neck(feats)
  254. assert len(outs) == len(feats)
  255. for i in range(len(outs)):
  256. assert outs[i].shape == \
  257. (1, out_channels[i], feat_sizes[i], feat_sizes[i])
  258. # test more flexible setting
  259. s = 32
  260. in_channels = [32, 8, 16]
  261. out_channels = [19, 21, 5]
  262. feat_sizes = [s // 2**i for i in range(len(in_channels) - 1, -1, -1)]
  263. feats = [
  264. torch.rand(1, in_channels[i], feat_sizes[i], feat_sizes[i])
  265. for i in range(len(in_channels) - 1, -1, -1)
  266. ]
  267. neck = YOLOV3Neck(
  268. num_scales=3, in_channels=in_channels, out_channels=out_channels)
  269. outs = neck(feats)
  270. assert len(outs) == len(feats)
  271. for i in range(len(outs)):
  272. assert outs[i].shape == \
  273. (1, out_channels[i], feat_sizes[i], feat_sizes[i])
  274. def test_ssd_neck():
  275. # level_strides/level_paddings must be same length
  276. with pytest.raises(AssertionError):
  277. SSDNeck(
  278. in_channels=[8, 16],
  279. out_channels=[8, 16, 32],
  280. level_strides=[2],
  281. level_paddings=[2, 1])
  282. # length of out_channels must larger than in_channels
  283. with pytest.raises(AssertionError):
  284. SSDNeck(
  285. in_channels=[8, 16],
  286. out_channels=[8],
  287. level_strides=[2],
  288. level_paddings=[2])
  289. # len(out_channels) - len(in_channels) must equal to len(level_strides)
  290. with pytest.raises(AssertionError):
  291. SSDNeck(
  292. in_channels=[8, 16],
  293. out_channels=[4, 16, 64],
  294. level_strides=[2, 2],
  295. level_paddings=[2, 2])
  296. # in_channels must be same with out_channels[:len(in_channels)]
  297. with pytest.raises(AssertionError):
  298. SSDNeck(
  299. in_channels=[8, 16],
  300. out_channels=[4, 16, 64],
  301. level_strides=[2],
  302. level_paddings=[2])
  303. ssd_neck = SSDNeck(
  304. in_channels=[4],
  305. out_channels=[4, 8, 16],
  306. level_strides=[2, 1],
  307. level_paddings=[1, 0])
  308. feats = (torch.rand(1, 4, 16, 16), )
  309. outs = ssd_neck(feats)
  310. assert outs[0].shape == (1, 4, 16, 16)
  311. assert outs[1].shape == (1, 8, 8, 8)
  312. assert outs[2].shape == (1, 16, 6, 6)
  313. # test SSD-Lite Neck
  314. ssd_neck = SSDNeck(
  315. in_channels=[4, 8],
  316. out_channels=[4, 8, 16],
  317. level_strides=[1],
  318. level_paddings=[1],
  319. l2_norm_scale=None,
  320. use_depthwise=True,
  321. norm_cfg=dict(type='BN'),
  322. act_cfg=dict(type='ReLU6'))
  323. assert not hasattr(ssd_neck, 'l2_norm')
  324. from mmcv.cnn.bricks import DepthwiseSeparableConvModule
  325. assert isinstance(ssd_neck.extra_layers[0][-1],
  326. DepthwiseSeparableConvModule)
  327. feats = (torch.rand(1, 4, 8, 8), torch.rand(1, 8, 8, 8))
  328. outs = ssd_neck(feats)
  329. assert outs[0].shape == (1, 4, 8, 8)
  330. assert outs[1].shape == (1, 8, 8, 8)
  331. assert outs[2].shape == (1, 16, 8, 8)
  332. def test_yolox_pafpn():
  333. s = 64
  334. in_channels = [8, 16, 32, 64]
  335. feat_sizes = [s // 2**i for i in range(4)] # [64, 32, 16, 8]
  336. out_channels = 24
  337. feats = [
  338. torch.rand(1, in_channels[i], feat_sizes[i], feat_sizes[i])
  339. for i in range(len(in_channels))
  340. ]
  341. neck = YOLOXPAFPN(in_channels=in_channels, out_channels=out_channels)
  342. outs = neck(feats)
  343. assert len(outs) == len(feats)
  344. for i in range(len(feats)):
  345. assert outs[i].shape[1] == out_channels
  346. assert outs[i].shape[2] == outs[i].shape[3] == s // (2**i)
  347. # test depth-wise
  348. neck = YOLOXPAFPN(
  349. in_channels=in_channels, out_channels=out_channels, use_depthwise=True)
  350. from mmcv.cnn.bricks import DepthwiseSeparableConvModule
  351. assert isinstance(neck.downsamples[0], DepthwiseSeparableConvModule)
  352. outs = neck(feats)
  353. assert len(outs) == len(feats)
  354. for i in range(len(feats)):
  355. assert outs[i].shape[1] == out_channels
  356. assert outs[i].shape[2] == outs[i].shape[3] == s // (2**i)

No Description

Contributors (2)