You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_mmtrack.py 7.6 kB

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import copy
  3. from collections import defaultdict
  4. import numpy as np
  5. import pytest
  6. import torch
  7. from mmcv import Config
  8. @pytest.mark.parametrize(
  9. 'cfg_file',
  10. ['./tests/data/configs_mmtrack/selsa_faster_rcnn_r101_dc5_1x.py'])
  11. def test_vid_fgfa_style_forward(cfg_file):
  12. config = Config.fromfile(cfg_file)
  13. model = copy.deepcopy(config.model)
  14. model.pretrains = None
  15. model.detector.pretrained = None
  16. from mmtrack.models import build_model
  17. detector = build_model(model)
  18. # Test forward train with a non-empty truth batch
  19. input_shape = (1, 3, 256, 256)
  20. mm_inputs = _demo_mm_inputs(input_shape, num_items=[10])
  21. imgs = mm_inputs.pop('imgs')
  22. img_metas = mm_inputs.pop('img_metas')
  23. img_metas[0]['is_video_data'] = True
  24. gt_bboxes = mm_inputs['gt_bboxes']
  25. gt_labels = mm_inputs['gt_labels']
  26. gt_masks = mm_inputs['gt_masks']
  27. ref_input_shape = (2, 3, 256, 256)
  28. ref_mm_inputs = _demo_mm_inputs(ref_input_shape, num_items=[9, 11])
  29. ref_img = ref_mm_inputs.pop('imgs')[None]
  30. ref_img_metas = ref_mm_inputs.pop('img_metas')
  31. ref_img_metas[0]['is_video_data'] = True
  32. ref_img_metas[1]['is_video_data'] = True
  33. ref_gt_bboxes = ref_mm_inputs['gt_bboxes']
  34. ref_gt_labels = ref_mm_inputs['gt_labels']
  35. ref_gt_masks = ref_mm_inputs['gt_masks']
  36. losses = detector.forward(
  37. img=imgs,
  38. img_metas=img_metas,
  39. gt_bboxes=gt_bboxes,
  40. gt_labels=gt_labels,
  41. ref_img=ref_img,
  42. ref_img_metas=[ref_img_metas],
  43. ref_gt_bboxes=ref_gt_bboxes,
  44. ref_gt_labels=ref_gt_labels,
  45. gt_masks=gt_masks,
  46. ref_gt_masks=ref_gt_masks,
  47. return_loss=True)
  48. assert isinstance(losses, dict)
  49. loss, _ = detector._parse_losses(losses)
  50. loss.requires_grad_(True)
  51. assert float(loss.item()) > 0
  52. loss.backward()
  53. # Test forward train with an empty truth batch
  54. mm_inputs = _demo_mm_inputs(input_shape, num_items=[0])
  55. imgs = mm_inputs.pop('imgs')
  56. img_metas = mm_inputs.pop('img_metas')
  57. img_metas[0]['is_video_data'] = True
  58. gt_bboxes = mm_inputs['gt_bboxes']
  59. gt_labels = mm_inputs['gt_labels']
  60. gt_masks = mm_inputs['gt_masks']
  61. ref_mm_inputs = _demo_mm_inputs(ref_input_shape, num_items=[0, 0])
  62. ref_imgs = ref_mm_inputs.pop('imgs')[None]
  63. ref_img_metas = ref_mm_inputs.pop('img_metas')
  64. ref_img_metas[0]['is_video_data'] = True
  65. ref_img_metas[1]['is_video_data'] = True
  66. ref_gt_bboxes = ref_mm_inputs['gt_bboxes']
  67. ref_gt_labels = ref_mm_inputs['gt_labels']
  68. ref_gt_masks = ref_mm_inputs['gt_masks']
  69. losses = detector.forward(
  70. img=imgs,
  71. img_metas=img_metas,
  72. gt_bboxes=gt_bboxes,
  73. gt_labels=gt_labels,
  74. ref_img=ref_imgs,
  75. ref_img_metas=[ref_img_metas],
  76. ref_gt_bboxes=ref_gt_bboxes,
  77. ref_gt_labels=ref_gt_labels,
  78. gt_masks=gt_masks,
  79. ref_gt_masks=ref_gt_masks,
  80. return_loss=True)
  81. assert isinstance(losses, dict)
  82. loss, _ = detector._parse_losses(losses)
  83. loss.requires_grad_(True)
  84. assert float(loss.item()) > 0
  85. loss.backward()
  86. # Test forward test with frame_stride=1 and frame_range=[-1,0]
  87. with torch.no_grad():
  88. imgs = torch.cat([imgs, imgs.clone()], dim=0)
  89. img_list = [g[None, :] for g in imgs]
  90. img_metas.extend(copy.deepcopy(img_metas))
  91. for i in range(len(img_metas)):
  92. img_metas[i]['frame_id'] = i
  93. img_metas[i]['num_left_ref_imgs'] = 1
  94. img_metas[i]['frame_stride'] = 1
  95. ref_imgs = [ref_imgs.clone(), imgs[[0]][None].clone()]
  96. ref_img_metas = [
  97. copy.deepcopy(ref_img_metas),
  98. copy.deepcopy([img_metas[0]])
  99. ]
  100. results = defaultdict(list)
  101. for one_img, one_meta, ref_img, ref_img_meta in zip(
  102. img_list, img_metas, ref_imgs, ref_img_metas):
  103. result = detector.forward([one_img], [[one_meta]],
  104. ref_img=[ref_img],
  105. ref_img_metas=[[ref_img_meta]],
  106. return_loss=False)
  107. for k, v in result.items():
  108. results[k].append(v)
  109. @pytest.mark.parametrize('cfg_file', [
  110. './tests/data/configs_mmtrack/tracktor_faster-rcnn_r50_fpn_4e.py',
  111. ])
  112. def test_tracktor_forward(cfg_file):
  113. config = Config.fromfile(cfg_file)
  114. model = copy.deepcopy(config.model)
  115. model.pretrains = None
  116. model.detector.pretrained = None
  117. from mmtrack.models import build_model
  118. mot = build_model(model)
  119. mot.eval()
  120. input_shape = (1, 3, 256, 256)
  121. mm_inputs = _demo_mm_inputs(input_shape, num_items=[10], with_track=True)
  122. imgs = mm_inputs.pop('imgs')
  123. img_metas = mm_inputs.pop('img_metas')
  124. with torch.no_grad():
  125. imgs = torch.cat([imgs, imgs.clone()], dim=0)
  126. img_list = [g[None, :] for g in imgs]
  127. img2_metas = copy.deepcopy(img_metas)
  128. img2_metas[0]['frame_id'] = 1
  129. img_metas.extend(img2_metas)
  130. results = defaultdict(list)
  131. for one_img, one_meta in zip(img_list, img_metas):
  132. result = mot.forward([one_img], [[one_meta]], return_loss=False)
  133. for k, v in result.items():
  134. results[k].append(v)
  135. def _demo_mm_inputs(
  136. input_shape=(1, 3, 300, 300),
  137. num_items=None,
  138. num_classes=10,
  139. with_track=False):
  140. """Create a superset of inputs needed to run test or train batches.
  141. Args:
  142. input_shape (tuple):
  143. input batch dimensions
  144. num_items (None | List[int]):
  145. specifies the number of boxes in each batch item
  146. num_classes (int):
  147. number of different labels a box might have
  148. """
  149. from mmdet.core import BitmapMasks
  150. (N, C, H, W) = input_shape
  151. rng = np.random.RandomState(0)
  152. imgs = rng.rand(*input_shape)
  153. img_metas = [{
  154. 'img_shape': (H, W, C),
  155. 'ori_shape': (H, W, C),
  156. 'pad_shape': (H, W, C),
  157. 'filename': '<demo>.png',
  158. 'scale_factor': 1.0,
  159. 'flip': False,
  160. 'frame_id': 0,
  161. 'img_norm_cfg': {
  162. 'mean': (128.0, 128.0, 128.0),
  163. 'std': (10.0, 10.0, 10.0)
  164. }
  165. } for i in range(N)]
  166. gt_bboxes = []
  167. gt_labels = []
  168. gt_masks = []
  169. gt_match_indices = []
  170. for batch_idx in range(N):
  171. if num_items is None:
  172. num_boxes = rng.randint(1, 10)
  173. else:
  174. num_boxes = num_items[batch_idx]
  175. cx, cy, bw, bh = rng.rand(num_boxes, 4).T
  176. tl_x = ((cx * W) - (W * bw / 2)).clip(0, W)
  177. tl_y = ((cy * H) - (H * bh / 2)).clip(0, H)
  178. br_x = ((cx * W) + (W * bw / 2)).clip(0, W)
  179. br_y = ((cy * H) + (H * bh / 2)).clip(0, H)
  180. boxes = np.vstack([tl_x, tl_y, br_x, br_y]).T
  181. class_idxs = rng.randint(1, num_classes, size=num_boxes)
  182. gt_bboxes.append(torch.FloatTensor(boxes))
  183. gt_labels.append(torch.LongTensor(class_idxs))
  184. if with_track:
  185. gt_match_indices.append(torch.arange(boxes.shape[0]))
  186. mask = np.random.randint(0, 2, (len(boxes), H, W), dtype=np.uint8)
  187. gt_masks.append(BitmapMasks(mask, H, W))
  188. mm_inputs = {
  189. 'imgs': torch.FloatTensor(imgs).requires_grad_(True),
  190. 'img_metas': img_metas,
  191. 'gt_bboxes': gt_bboxes,
  192. 'gt_labels': gt_labels,
  193. 'gt_bboxes_ignore': None,
  194. 'gt_masks': gt_masks,
  195. }
  196. if with_track:
  197. mm_inputs['gt_match_indices'] = gt_match_indices
  198. return mm_inputs

No Description

Contributors (3)