You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_loss_compatibility.py 6.4 kB

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. """pytest tests/test_loss_compatibility.py."""
  3. import copy
  4. from os.path import dirname, exists, join
  5. import numpy as np
  6. import pytest
  7. import torch
  8. def _get_config_directory():
  9. """Find the predefined detector config directory."""
  10. try:
  11. # Assume we are running in the source mmdetection repo
  12. repo_dpath = dirname(dirname(dirname(__file__)))
  13. except NameError:
  14. # For IPython development when this __file__ is not defined
  15. import mmdet
  16. repo_dpath = dirname(dirname(mmdet.__file__))
  17. config_dpath = join(repo_dpath, 'configs')
  18. if not exists(config_dpath):
  19. raise Exception('Cannot find config path')
  20. return config_dpath
  21. def _get_config_module(fname):
  22. """Load a configuration as a python module."""
  23. from mmcv import Config
  24. config_dpath = _get_config_directory()
  25. config_fpath = join(config_dpath, fname)
  26. config_mod = Config.fromfile(config_fpath)
  27. return config_mod
  28. def _get_detector_cfg(fname):
  29. """Grab configs necessary to create a detector.
  30. These are deep copied to allow for safe modification of parameters without
  31. influencing other tests.
  32. """
  33. config = _get_config_module(fname)
  34. model = copy.deepcopy(config.model)
  35. return model
  36. @pytest.mark.parametrize('loss_bbox', [
  37. dict(type='L1Loss', loss_weight=1.0),
  38. dict(type='GHMR', mu=0.02, bins=10, momentum=0.7, loss_weight=10.0),
  39. dict(type='IoULoss', loss_weight=1.0),
  40. dict(type='BoundedIoULoss', loss_weight=1.0),
  41. dict(type='GIoULoss', loss_weight=1.0),
  42. dict(type='DIoULoss', loss_weight=1.0),
  43. dict(type='CIoULoss', loss_weight=1.0),
  44. dict(type='MSELoss', loss_weight=1.0),
  45. dict(type='SmoothL1Loss', loss_weight=1.0),
  46. dict(type='BalancedL1Loss', loss_weight=1.0)
  47. ])
  48. def test_bbox_loss_compatibility(loss_bbox):
  49. """Test loss_bbox compatibility.
  50. Using Faster R-CNN as a sample, modifying the loss function in the config
  51. file to verify the compatibility of Loss APIS
  52. """
  53. # Faster R-CNN config dict
  54. config_path = '_base_/models/faster_rcnn_r50_fpn.py'
  55. cfg_model = _get_detector_cfg(config_path)
  56. input_shape = (1, 3, 256, 256)
  57. mm_inputs = _demo_mm_inputs(input_shape, num_items=[10])
  58. imgs = mm_inputs.pop('imgs')
  59. img_metas = mm_inputs.pop('img_metas')
  60. if 'IoULoss' in loss_bbox['type']:
  61. cfg_model.roi_head.bbox_head.reg_decoded_bbox = True
  62. cfg_model.roi_head.bbox_head.loss_bbox = loss_bbox
  63. from mmdet.models import build_detector
  64. detector = build_detector(cfg_model)
  65. loss = detector.forward(imgs, img_metas, return_loss=True, **mm_inputs)
  66. assert isinstance(loss, dict)
  67. loss, _ = detector._parse_losses(loss)
  68. assert float(loss.item()) > 0
  69. @pytest.mark.parametrize('loss_cls', [
  70. dict(type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),
  71. dict(
  72. type='FocalLoss',
  73. use_sigmoid=True,
  74. gamma=2.0,
  75. alpha=0.25,
  76. loss_weight=1.0),
  77. dict(
  78. type='GHMC', bins=30, momentum=0.75, use_sigmoid=True, loss_weight=1.0)
  79. ])
  80. def test_cls_loss_compatibility(loss_cls):
  81. """Test loss_cls compatibility.
  82. Using Faster R-CNN as a sample, modifying the loss function in the config
  83. file to verify the compatibility of Loss APIS
  84. """
  85. # Faster R-CNN config dict
  86. config_path = '_base_/models/faster_rcnn_r50_fpn.py'
  87. cfg_model = _get_detector_cfg(config_path)
  88. input_shape = (1, 3, 256, 256)
  89. mm_inputs = _demo_mm_inputs(input_shape, num_items=[10])
  90. imgs = mm_inputs.pop('imgs')
  91. img_metas = mm_inputs.pop('img_metas')
  92. # verify class loss function compatibility
  93. # for loss_cls in loss_clses:
  94. cfg_model.roi_head.bbox_head.loss_cls = loss_cls
  95. from mmdet.models import build_detector
  96. detector = build_detector(cfg_model)
  97. loss = detector.forward(imgs, img_metas, return_loss=True, **mm_inputs)
  98. assert isinstance(loss, dict)
  99. loss, _ = detector._parse_losses(loss)
  100. assert float(loss.item()) > 0
  101. def _demo_mm_inputs(input_shape=(1, 3, 300, 300),
  102. num_items=None, num_classes=10,
  103. with_semantic=False): # yapf: disable
  104. """Create a superset of inputs needed to run test or train batches.
  105. Args:
  106. input_shape (tuple):
  107. input batch dimensions
  108. num_items (None | List[int]):
  109. specifies the number of boxes in each batch item
  110. num_classes (int):
  111. number of different labels a box might have
  112. """
  113. from mmdet.core import BitmapMasks
  114. (N, C, H, W) = input_shape
  115. rng = np.random.RandomState(0)
  116. imgs = rng.rand(*input_shape)
  117. img_metas = [{
  118. 'img_shape': (H, W, C),
  119. 'ori_shape': (H, W, C),
  120. 'pad_shape': (H, W, C),
  121. 'filename': '<demo>.png',
  122. 'scale_factor': np.array([1.1, 1.2, 1.1, 1.2]),
  123. 'flip': False,
  124. 'flip_direction': None,
  125. } for _ in range(N)]
  126. gt_bboxes = []
  127. gt_labels = []
  128. gt_masks = []
  129. for batch_idx in range(N):
  130. if num_items is None:
  131. num_boxes = rng.randint(1, 10)
  132. else:
  133. num_boxes = num_items[batch_idx]
  134. cx, cy, bw, bh = rng.rand(num_boxes, 4).T
  135. tl_x = ((cx * W) - (W * bw / 2)).clip(0, W)
  136. tl_y = ((cy * H) - (H * bh / 2)).clip(0, H)
  137. br_x = ((cx * W) + (W * bw / 2)).clip(0, W)
  138. br_y = ((cy * H) + (H * bh / 2)).clip(0, H)
  139. boxes = np.vstack([tl_x, tl_y, br_x, br_y]).T
  140. class_idxs = rng.randint(1, num_classes, size=num_boxes)
  141. gt_bboxes.append(torch.FloatTensor(boxes))
  142. gt_labels.append(torch.LongTensor(class_idxs))
  143. mask = np.random.randint(0, 2, (len(boxes), H, W), dtype=np.uint8)
  144. gt_masks.append(BitmapMasks(mask, H, W))
  145. mm_inputs = {
  146. 'imgs': torch.FloatTensor(imgs).requires_grad_(True),
  147. 'img_metas': img_metas,
  148. 'gt_bboxes': gt_bboxes,
  149. 'gt_labels': gt_labels,
  150. 'gt_bboxes_ignore': None,
  151. 'gt_masks': gt_masks,
  152. }
  153. if with_semantic:
  154. # assume gt_semantic_seg using scale 1/8 of the img
  155. gt_semantic_seg = np.random.randint(
  156. 0, num_classes, (1, 1, H // 8, W // 8), dtype=np.uint8)
  157. mm_inputs.update(
  158. {'gt_semantic_seg': torch.ByteTensor(gt_semantic_seg)})
  159. return mm_inputs

No Description

Contributors (3)