You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_forward.py 22 kB

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. """pytest tests/test_forward.py."""
  3. import copy
  4. from os.path import dirname, exists, join
  5. import numpy as np
  6. import pytest
  7. import torch
  8. def _get_config_directory():
  9. """Find the predefined detector config directory."""
  10. try:
  11. # Assume we are running in the source mmdetection repo
  12. repo_dpath = dirname(dirname(dirname(__file__)))
  13. except NameError:
  14. # For IPython development when this __file__ is not defined
  15. import mmdet
  16. repo_dpath = dirname(dirname(mmdet.__file__))
  17. config_dpath = join(repo_dpath, 'configs')
  18. if not exists(config_dpath):
  19. raise Exception('Cannot find config path')
  20. return config_dpath
  21. def _get_config_module(fname):
  22. """Load a configuration as a python module."""
  23. from mmcv import Config
  24. config_dpath = _get_config_directory()
  25. config_fpath = join(config_dpath, fname)
  26. config_mod = Config.fromfile(config_fpath)
  27. return config_mod
  28. def _get_detector_cfg(fname):
  29. """Grab configs necessary to create a detector.
  30. These are deep copied to allow for safe modification of parameters without
  31. influencing other tests.
  32. """
  33. config = _get_config_module(fname)
  34. model = copy.deepcopy(config.model)
  35. return model
  36. def _replace_r50_with_r18(model):
  37. """Replace ResNet50 with ResNet18 in config."""
  38. model = copy.deepcopy(model)
  39. if model.backbone.type == 'ResNet':
  40. model.backbone.depth = 18
  41. model.backbone.base_channels = 2
  42. model.neck.in_channels = [2, 4, 8, 16]
  43. return model
  44. def test_sparse_rcnn_forward():
  45. config_path = 'sparse_rcnn/sparse_rcnn_r50_fpn_1x_coco.py'
  46. model = _get_detector_cfg(config_path)
  47. model = _replace_r50_with_r18(model)
  48. model.backbone.init_cfg = None
  49. from mmdet.models import build_detector
  50. detector = build_detector(model)
  51. detector.init_weights()
  52. input_shape = (1, 3, 100, 100)
  53. mm_inputs = _demo_mm_inputs(input_shape, num_items=[5])
  54. imgs = mm_inputs.pop('imgs')
  55. img_metas = mm_inputs.pop('img_metas')
  56. # Test forward train with non-empty truth batch
  57. detector.train()
  58. gt_bboxes = mm_inputs['gt_bboxes']
  59. gt_bboxes = [item for item in gt_bboxes]
  60. gt_labels = mm_inputs['gt_labels']
  61. gt_labels = [item for item in gt_labels]
  62. losses = detector.forward(
  63. imgs,
  64. img_metas,
  65. gt_bboxes=gt_bboxes,
  66. gt_labels=gt_labels,
  67. return_loss=True)
  68. assert isinstance(losses, dict)
  69. loss, _ = detector._parse_losses(losses)
  70. assert float(loss.item()) > 0
  71. detector.forward_dummy(imgs)
  72. # Test forward train with an empty truth batch
  73. mm_inputs = _demo_mm_inputs(input_shape, num_items=[0])
  74. imgs = mm_inputs.pop('imgs')
  75. img_metas = mm_inputs.pop('img_metas')
  76. gt_bboxes = mm_inputs['gt_bboxes']
  77. gt_bboxes = [item for item in gt_bboxes]
  78. gt_labels = mm_inputs['gt_labels']
  79. gt_labels = [item for item in gt_labels]
  80. losses = detector.forward(
  81. imgs,
  82. img_metas,
  83. gt_bboxes=gt_bboxes,
  84. gt_labels=gt_labels,
  85. return_loss=True)
  86. assert isinstance(losses, dict)
  87. loss, _ = detector._parse_losses(losses)
  88. assert float(loss.item()) > 0
  89. # Test forward test
  90. detector.eval()
  91. with torch.no_grad():
  92. img_list = [g[None, :] for g in imgs]
  93. batch_results = []
  94. for one_img, one_meta in zip(img_list, img_metas):
  95. result = detector.forward([one_img], [[one_meta]],
  96. rescale=True,
  97. return_loss=False)
  98. batch_results.append(result)
  99. # test empty proposal in roi_head
  100. with torch.no_grad():
  101. # test no proposal in the whole batch
  102. detector.roi_head.simple_test([imgs[0][None, :]], torch.empty(
  103. (1, 0, 4)), torch.empty((1, 100, 4)), [img_metas[0]],
  104. torch.ones((1, 4)))
  105. def test_rpn_forward():
  106. model = _get_detector_cfg('rpn/rpn_r50_fpn_1x_coco.py')
  107. model = _replace_r50_with_r18(model)
  108. model.backbone.init_cfg = None
  109. from mmdet.models import build_detector
  110. detector = build_detector(model)
  111. input_shape = (1, 3, 100, 100)
  112. mm_inputs = _demo_mm_inputs(input_shape)
  113. imgs = mm_inputs.pop('imgs')
  114. img_metas = mm_inputs.pop('img_metas')
  115. # Test forward train
  116. gt_bboxes = mm_inputs['gt_bboxes']
  117. losses = detector.forward(
  118. imgs, img_metas, gt_bboxes=gt_bboxes, return_loss=True)
  119. assert isinstance(losses, dict)
  120. # Test forward test
  121. with torch.no_grad():
  122. img_list = [g[None, :] for g in imgs]
  123. batch_results = []
  124. for one_img, one_meta in zip(img_list, img_metas):
  125. result = detector.forward([one_img], [[one_meta]],
  126. return_loss=False)
  127. batch_results.append(result)
  128. @pytest.mark.parametrize(
  129. 'cfg_file',
  130. [
  131. 'reppoints/reppoints_moment_r50_fpn_1x_coco.py',
  132. 'retinanet/retinanet_r50_fpn_1x_coco.py',
  133. 'guided_anchoring/ga_retinanet_r50_fpn_1x_coco.py',
  134. 'ghm/retinanet_ghm_r50_fpn_1x_coco.py',
  135. 'fcos/fcos_center_r50_caffe_fpn_gn-head_1x_coco.py',
  136. 'foveabox/fovea_align_r50_fpn_gn-head_4x4_2x_coco.py',
  137. # 'free_anchor/retinanet_free_anchor_r50_fpn_1x_coco.py',
  138. # 'atss/atss_r50_fpn_1x_coco.py', # not ready for topk
  139. 'yolo/yolov3_mobilenetv2_320_300e_coco.py',
  140. 'yolox/yolox_tiny_8x8_300e_coco.py'
  141. ])
  142. def test_single_stage_forward_gpu(cfg_file):
  143. if not torch.cuda.is_available():
  144. import pytest
  145. pytest.skip('test requires GPU and torch+cuda')
  146. model = _get_detector_cfg(cfg_file)
  147. model = _replace_r50_with_r18(model)
  148. model.backbone.init_cfg = None
  149. from mmdet.models import build_detector
  150. detector = build_detector(model)
  151. input_shape = (2, 3, 128, 128)
  152. mm_inputs = _demo_mm_inputs(input_shape)
  153. imgs = mm_inputs.pop('imgs')
  154. img_metas = mm_inputs.pop('img_metas')
  155. detector = detector.cuda()
  156. imgs = imgs.cuda()
  157. # Test forward train
  158. gt_bboxes = [b.cuda() for b in mm_inputs['gt_bboxes']]
  159. gt_labels = [g.cuda() for g in mm_inputs['gt_labels']]
  160. losses = detector.forward(
  161. imgs,
  162. img_metas,
  163. gt_bboxes=gt_bboxes,
  164. gt_labels=gt_labels,
  165. return_loss=True)
  166. assert isinstance(losses, dict)
  167. # Test forward test
  168. detector.eval()
  169. with torch.no_grad():
  170. img_list = [g[None, :] for g in imgs]
  171. batch_results = []
  172. for one_img, one_meta in zip(img_list, img_metas):
  173. result = detector.forward([one_img], [[one_meta]],
  174. return_loss=False)
  175. batch_results.append(result)
  176. def test_faster_rcnn_ohem_forward():
  177. model = _get_detector_cfg(
  178. 'faster_rcnn/faster_rcnn_r50_fpn_ohem_1x_coco.py')
  179. model = _replace_r50_with_r18(model)
  180. model.backbone.init_cfg = None
  181. from mmdet.models import build_detector
  182. detector = build_detector(model)
  183. input_shape = (1, 3, 100, 100)
  184. # Test forward train with a non-empty truth batch
  185. mm_inputs = _demo_mm_inputs(input_shape, num_items=[10])
  186. imgs = mm_inputs.pop('imgs')
  187. img_metas = mm_inputs.pop('img_metas')
  188. gt_bboxes = mm_inputs['gt_bboxes']
  189. gt_labels = mm_inputs['gt_labels']
  190. losses = detector.forward(
  191. imgs,
  192. img_metas,
  193. gt_bboxes=gt_bboxes,
  194. gt_labels=gt_labels,
  195. return_loss=True)
  196. assert isinstance(losses, dict)
  197. loss, _ = detector._parse_losses(losses)
  198. assert float(loss.item()) > 0
  199. # Test forward train with an empty truth batch
  200. mm_inputs = _demo_mm_inputs(input_shape, num_items=[0])
  201. imgs = mm_inputs.pop('imgs')
  202. img_metas = mm_inputs.pop('img_metas')
  203. gt_bboxes = mm_inputs['gt_bboxes']
  204. gt_labels = mm_inputs['gt_labels']
  205. losses = detector.forward(
  206. imgs,
  207. img_metas,
  208. gt_bboxes=gt_bboxes,
  209. gt_labels=gt_labels,
  210. return_loss=True)
  211. assert isinstance(losses, dict)
  212. loss, _ = detector._parse_losses(losses)
  213. assert float(loss.item()) > 0
  214. # Test RoI forward train with an empty proposals
  215. feature = detector.extract_feat(imgs[0][None, :])
  216. losses = detector.roi_head.forward_train(
  217. feature,
  218. img_metas, [torch.empty((0, 5))],
  219. gt_bboxes=gt_bboxes,
  220. gt_labels=gt_labels)
  221. assert isinstance(losses, dict)
  222. @pytest.mark.parametrize(
  223. 'cfg_file',
  224. [
  225. # 'cascade_rcnn/cascade_mask_rcnn_r50_fpn_1x_coco.py',
  226. 'mask_rcnn/mask_rcnn_r50_fpn_1x_coco.py',
  227. # 'grid_rcnn/grid_rcnn_r50_fpn_gn-head_2x_coco.py',
  228. # 'ms_rcnn/ms_rcnn_r50_fpn_1x_coco.py',
  229. # 'htc/htc_r50_fpn_1x_coco.py',
  230. # 'panoptic_fpn/panoptic_fpn_r50_fpn_1x_coco.py',
  231. # 'scnet/scnet_r50_fpn_20e_coco.py',
  232. # 'seesaw_loss/mask_rcnn_r50_fpn_random_seesaw_loss_normed_mask_mstrain_2x_lvis_v1.py' # noqa: E501
  233. ])
  234. def test_two_stage_forward(cfg_file):
  235. models_with_semantic = [
  236. 'htc/htc_r50_fpn_1x_coco.py',
  237. 'panoptic_fpn/panoptic_fpn_r50_fpn_1x_coco.py',
  238. 'scnet/scnet_r50_fpn_20e_coco.py',
  239. ]
  240. if cfg_file in models_with_semantic:
  241. with_semantic = True
  242. else:
  243. with_semantic = False
  244. model = _get_detector_cfg(cfg_file)
  245. model = _replace_r50_with_r18(model)
  246. model.backbone.init_cfg = None
  247. # Save cost
  248. if cfg_file in [
  249. 'seesaw_loss/mask_rcnn_r50_fpn_random_seesaw_loss_normed_mask_mstrain_2x_lvis_v1.py' # noqa: E501
  250. ]:
  251. model.roi_head.bbox_head.num_classes = 80
  252. model.roi_head.bbox_head.loss_cls.num_classes = 80
  253. model.roi_head.mask_head.num_classes = 80
  254. model.test_cfg.rcnn.score_thr = 0.05
  255. model.test_cfg.rcnn.max_per_img = 100
  256. from mmdet.models import build_detector
  257. detector = build_detector(model)
  258. input_shape = (1, 3, 128, 128)
  259. # Test forward train with a non-empty truth batch
  260. mm_inputs = _demo_mm_inputs(
  261. input_shape, num_items=[10], with_semantic=with_semantic)
  262. imgs = mm_inputs.pop('imgs')
  263. img_metas = mm_inputs.pop('img_metas')
  264. losses = detector.forward(imgs, img_metas, return_loss=True, **mm_inputs)
  265. assert isinstance(losses, dict)
  266. loss, _ = detector._parse_losses(losses)
  267. loss.requires_grad_(True)
  268. assert float(loss.item()) > 0
  269. loss.backward()
  270. # Test forward train with an empty truth batch
  271. mm_inputs = _demo_mm_inputs(
  272. input_shape, num_items=[0], with_semantic=with_semantic)
  273. imgs = mm_inputs.pop('imgs')
  274. img_metas = mm_inputs.pop('img_metas')
  275. losses = detector.forward(imgs, img_metas, return_loss=True, **mm_inputs)
  276. assert isinstance(losses, dict)
  277. loss, _ = detector._parse_losses(losses)
  278. loss.requires_grad_(True)
  279. assert float(loss.item()) > 0
  280. loss.backward()
  281. # Test RoI forward train with an empty proposals
  282. if cfg_file in [
  283. 'panoptic_fpn/panoptic_fpn_r50_fpn_1x_coco.py' # noqa: E501
  284. ]:
  285. mm_inputs.pop('gt_semantic_seg')
  286. feature = detector.extract_feat(imgs[0][None, :])
  287. losses = detector.roi_head.forward_train(feature, img_metas,
  288. [torch.empty(
  289. (0, 5))], **mm_inputs)
  290. assert isinstance(losses, dict)
  291. # Test forward test
  292. with torch.no_grad():
  293. img_list = [g[None, :] for g in imgs]
  294. batch_results = []
  295. for one_img, one_meta in zip(img_list, img_metas):
  296. result = detector.forward([one_img], [[one_meta]],
  297. return_loss=False)
  298. batch_results.append(result)
  299. cascade_models = [
  300. 'cascade_rcnn/cascade_mask_rcnn_r50_fpn_1x_coco.py',
  301. 'htc/htc_r50_fpn_1x_coco.py',
  302. 'scnet/scnet_r50_fpn_20e_coco.py',
  303. ]
  304. # test empty proposal in roi_head
  305. with torch.no_grad():
  306. # test no proposal in the whole batch
  307. detector.simple_test(
  308. imgs[0][None, :], [img_metas[0]], proposals=[torch.empty((0, 4))])
  309. # test no proposal of aug
  310. features = detector.extract_feats([imgs[0][None, :]] * 2)
  311. detector.roi_head.aug_test(features, [torch.empty((0, 4))] * 2,
  312. [[img_metas[0]]] * 2)
  313. # test rcnn_test_cfg is None
  314. if cfg_file not in cascade_models:
  315. feature = detector.extract_feat(imgs[0][None, :])
  316. bboxes, scores = detector.roi_head.simple_test_bboxes(
  317. feature, [img_metas[0]], [torch.empty((0, 4))], None)
  318. assert all([bbox.shape == torch.Size((0, 4)) for bbox in bboxes])
  319. assert all([
  320. score.shape == torch.Size(
  321. (0, detector.roi_head.bbox_head.fc_cls.out_features))
  322. for score in scores
  323. ])
  324. # test no proposal in the some image
  325. x1y1 = torch.randint(1, 100, (10, 2)).float()
  326. # x2y2 must be greater than x1y1
  327. x2y2 = x1y1 + torch.randint(1, 100, (10, 2))
  328. detector.simple_test(
  329. imgs[0][None, :].repeat(2, 1, 1, 1), [img_metas[0]] * 2,
  330. proposals=[torch.empty((0, 4)),
  331. torch.cat([x1y1, x2y2], dim=-1)])
  332. # test no proposal of aug
  333. detector.roi_head.aug_test(
  334. features, [torch.cat([x1y1, x2y2], dim=-1),
  335. torch.empty((0, 4))], [[img_metas[0]]] * 2)
  336. # test rcnn_test_cfg is None
  337. if cfg_file not in cascade_models:
  338. feature = detector.extract_feat(imgs[0][None, :].repeat(
  339. 2, 1, 1, 1))
  340. bboxes, scores = detector.roi_head.simple_test_bboxes(
  341. feature, [img_metas[0]] * 2,
  342. [torch.empty((0, 4)),
  343. torch.cat([x1y1, x2y2], dim=-1)], None)
  344. assert bboxes[0].shape == torch.Size((0, 4))
  345. assert scores[0].shape == torch.Size(
  346. (0, detector.roi_head.bbox_head.fc_cls.out_features))
  347. @pytest.mark.parametrize(
  348. 'cfg_file', ['ghm/retinanet_ghm_r50_fpn_1x_coco.py', 'ssd/ssd300_coco.py'])
  349. def test_single_stage_forward_cpu(cfg_file):
  350. model = _get_detector_cfg(cfg_file)
  351. model = _replace_r50_with_r18(model)
  352. model.backbone.init_cfg = None
  353. from mmdet.models import build_detector
  354. detector = build_detector(model)
  355. input_shape = (1, 3, 300, 300)
  356. mm_inputs = _demo_mm_inputs(input_shape)
  357. imgs = mm_inputs.pop('imgs')
  358. img_metas = mm_inputs.pop('img_metas')
  359. # Test forward train
  360. gt_bboxes = mm_inputs['gt_bboxes']
  361. gt_labels = mm_inputs['gt_labels']
  362. losses = detector.forward(
  363. imgs,
  364. img_metas,
  365. gt_bboxes=gt_bboxes,
  366. gt_labels=gt_labels,
  367. return_loss=True)
  368. assert isinstance(losses, dict)
  369. # Test forward test
  370. detector.eval()
  371. with torch.no_grad():
  372. img_list = [g[None, :] for g in imgs]
  373. batch_results = []
  374. for one_img, one_meta in zip(img_list, img_metas):
  375. result = detector.forward([one_img], [[one_meta]],
  376. return_loss=False)
  377. batch_results.append(result)
  378. def _demo_mm_inputs(input_shape=(1, 3, 300, 300),
  379. num_items=None, num_classes=10,
  380. with_semantic=False): # yapf: disable
  381. """Create a superset of inputs needed to run test or train batches.
  382. Args:
  383. input_shape (tuple):
  384. input batch dimensions
  385. num_items (None | List[int]):
  386. specifies the number of boxes in each batch item
  387. num_classes (int):
  388. number of different labels a box might have
  389. """
  390. from mmdet.core import BitmapMasks
  391. (N, C, H, W) = input_shape
  392. rng = np.random.RandomState(0)
  393. imgs = rng.rand(*input_shape)
  394. img_metas = [{
  395. 'img_shape': (H, W, C),
  396. 'ori_shape': (H, W, C),
  397. 'pad_shape': (H, W, C),
  398. 'filename': '<demo>.png',
  399. 'scale_factor': np.array([1.1, 1.2, 1.1, 1.2]),
  400. 'flip': False,
  401. 'flip_direction': None,
  402. } for _ in range(N)]
  403. gt_bboxes = []
  404. gt_labels = []
  405. gt_masks = []
  406. for batch_idx in range(N):
  407. if num_items is None:
  408. num_boxes = rng.randint(1, 10)
  409. else:
  410. num_boxes = num_items[batch_idx]
  411. cx, cy, bw, bh = rng.rand(num_boxes, 4).T
  412. tl_x = ((cx * W) - (W * bw / 2)).clip(0, W)
  413. tl_y = ((cy * H) - (H * bh / 2)).clip(0, H)
  414. br_x = ((cx * W) + (W * bw / 2)).clip(0, W)
  415. br_y = ((cy * H) + (H * bh / 2)).clip(0, H)
  416. boxes = np.vstack([tl_x, tl_y, br_x, br_y]).T
  417. class_idxs = rng.randint(1, num_classes, size=num_boxes)
  418. gt_bboxes.append(torch.FloatTensor(boxes))
  419. gt_labels.append(torch.LongTensor(class_idxs))
  420. mask = np.random.randint(0, 2, (len(boxes), H, W), dtype=np.uint8)
  421. gt_masks.append(BitmapMasks(mask, H, W))
  422. mm_inputs = {
  423. 'imgs': torch.FloatTensor(imgs).requires_grad_(True),
  424. 'img_metas': img_metas,
  425. 'gt_bboxes': gt_bboxes,
  426. 'gt_labels': gt_labels,
  427. 'gt_bboxes_ignore': None,
  428. 'gt_masks': gt_masks,
  429. }
  430. if with_semantic:
  431. # assume gt_semantic_seg using scale 1/8 of the img
  432. gt_semantic_seg = np.random.randint(
  433. 0, num_classes, (1, 1, H // 8, W // 8), dtype=np.uint8)
  434. mm_inputs.update(
  435. {'gt_semantic_seg': torch.ByteTensor(gt_semantic_seg)})
  436. return mm_inputs
  437. def test_yolact_forward():
  438. model = _get_detector_cfg('yolact/yolact_r50_1x8_coco.py')
  439. model = _replace_r50_with_r18(model)
  440. model.backbone.init_cfg = None
  441. from mmdet.models import build_detector
  442. detector = build_detector(model)
  443. input_shape = (1, 3, 100, 100)
  444. mm_inputs = _demo_mm_inputs(input_shape)
  445. imgs = mm_inputs.pop('imgs')
  446. img_metas = mm_inputs.pop('img_metas')
  447. # Test forward train
  448. detector.train()
  449. gt_bboxes = mm_inputs['gt_bboxes']
  450. gt_labels = mm_inputs['gt_labels']
  451. gt_masks = mm_inputs['gt_masks']
  452. losses = detector.forward(
  453. imgs,
  454. img_metas,
  455. gt_bboxes=gt_bboxes,
  456. gt_labels=gt_labels,
  457. gt_masks=gt_masks,
  458. return_loss=True)
  459. assert isinstance(losses, dict)
  460. # Test forward dummy for get_flops
  461. detector.forward_dummy(imgs)
  462. # Test forward test
  463. detector.eval()
  464. with torch.no_grad():
  465. img_list = [g[None, :] for g in imgs]
  466. batch_results = []
  467. for one_img, one_meta in zip(img_list, img_metas):
  468. result = detector.forward([one_img], [[one_meta]],
  469. rescale=True,
  470. return_loss=False)
  471. batch_results.append(result)
  472. def test_detr_forward():
  473. model = _get_detector_cfg('detr/detr_r50_8x2_150e_coco.py')
  474. model.backbone.depth = 18
  475. model.bbox_head.in_channels = 512
  476. model.backbone.init_cfg = None
  477. from mmdet.models import build_detector
  478. detector = build_detector(model)
  479. input_shape = (1, 3, 100, 100)
  480. mm_inputs = _demo_mm_inputs(input_shape)
  481. imgs = mm_inputs.pop('imgs')
  482. img_metas = mm_inputs.pop('img_metas')
  483. # Test forward train with non-empty truth batch
  484. detector.train()
  485. gt_bboxes = mm_inputs['gt_bboxes']
  486. gt_labels = mm_inputs['gt_labels']
  487. losses = detector.forward(
  488. imgs,
  489. img_metas,
  490. gt_bboxes=gt_bboxes,
  491. gt_labels=gt_labels,
  492. return_loss=True)
  493. assert isinstance(losses, dict)
  494. loss, _ = detector._parse_losses(losses)
  495. assert float(loss.item()) > 0
  496. # Test forward train with an empty truth batch
  497. mm_inputs = _demo_mm_inputs(input_shape, num_items=[0])
  498. imgs = mm_inputs.pop('imgs')
  499. img_metas = mm_inputs.pop('img_metas')
  500. gt_bboxes = mm_inputs['gt_bboxes']
  501. gt_labels = mm_inputs['gt_labels']
  502. losses = detector.forward(
  503. imgs,
  504. img_metas,
  505. gt_bboxes=gt_bboxes,
  506. gt_labels=gt_labels,
  507. return_loss=True)
  508. assert isinstance(losses, dict)
  509. loss, _ = detector._parse_losses(losses)
  510. assert float(loss.item()) > 0
  511. # Test forward test
  512. detector.eval()
  513. with torch.no_grad():
  514. img_list = [g[None, :] for g in imgs]
  515. batch_results = []
  516. for one_img, one_meta in zip(img_list, img_metas):
  517. result = detector.forward([one_img], [[one_meta]],
  518. rescale=True,
  519. return_loss=False)
  520. batch_results.append(result)
  521. def test_inference_detector():
  522. from mmdet.apis import inference_detector
  523. from mmdet.models import build_detector
  524. from mmcv import ConfigDict
  525. # small RetinaNet
  526. num_class = 3
  527. model_dict = dict(
  528. type='RetinaNet',
  529. backbone=dict(
  530. type='ResNet',
  531. depth=18,
  532. num_stages=4,
  533. out_indices=(3, ),
  534. norm_cfg=dict(type='BN', requires_grad=False),
  535. norm_eval=True,
  536. style='pytorch'),
  537. neck=None,
  538. bbox_head=dict(
  539. type='RetinaHead',
  540. num_classes=num_class,
  541. in_channels=512,
  542. stacked_convs=1,
  543. feat_channels=256,
  544. anchor_generator=dict(
  545. type='AnchorGenerator',
  546. octave_base_scale=4,
  547. scales_per_octave=3,
  548. ratios=[0.5],
  549. strides=[32]),
  550. bbox_coder=dict(
  551. type='DeltaXYWHBBoxCoder',
  552. target_means=[.0, .0, .0, .0],
  553. target_stds=[1.0, 1.0, 1.0, 1.0]),
  554. ),
  555. test_cfg=dict(
  556. nms_pre=1000,
  557. min_bbox_size=0,
  558. score_thr=0.05,
  559. nms=dict(type='nms', iou_threshold=0.5),
  560. max_per_img=100))
  561. rng = np.random.RandomState(0)
  562. img1 = rng.rand(100, 100, 3)
  563. img2 = rng.rand(100, 100, 3)
  564. model = build_detector(ConfigDict(model_dict))
  565. config = _get_config_module('retinanet/retinanet_r50_fpn_1x_coco.py')
  566. model.cfg = config
  567. # test single image
  568. result = inference_detector(model, img1)
  569. assert len(result) == num_class
  570. # test multiple image
  571. result = inference_detector(model, [img1, img2])
  572. assert len(result) == 2 and len(result[0]) == num_class

No Description

Contributors (3)