You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_config.py 15 kB

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from os.path import dirname, exists, join
  3. from unittest.mock import Mock
  4. import pytest
  5. from mmdet.core import BitmapMasks, PolygonMasks
  6. from mmdet.datasets.builder import DATASETS
  7. from mmdet.datasets.utils import NumClassCheckHook
  8. def _get_config_directory():
  9. """Find the predefined detector config directory."""
  10. try:
  11. # Assume we are running in the source mmdetection repo
  12. repo_dpath = dirname(dirname(__file__))
  13. repo_dpath = join(repo_dpath, '..')
  14. except NameError:
  15. # For IPython development when this __file__ is not defined
  16. import mmdet
  17. repo_dpath = dirname(dirname(mmdet.__file__))
  18. config_dpath = join(repo_dpath, 'configs')
  19. if not exists(config_dpath):
  20. raise Exception('Cannot find config path')
  21. return config_dpath
  22. def _check_numclasscheckhook(detector, config_mod):
  23. dummy_runner = Mock()
  24. dummy_runner.model = detector
  25. def get_dataset_name_classes(dataset):
  26. # deal with `RepeatDataset`,`ConcatDataset`,`ClassBalancedDataset`..
  27. if isinstance(dataset, (list, tuple)):
  28. dataset = dataset[0]
  29. while ('dataset' in dataset):
  30. dataset = dataset['dataset']
  31. # ConcatDataset
  32. if isinstance(dataset, (list, tuple)):
  33. dataset = dataset[0]
  34. return dataset['type'], dataset.get('classes', None)
  35. compatible_check = NumClassCheckHook()
  36. dataset_name, CLASSES = get_dataset_name_classes(
  37. config_mod['data']['train'])
  38. if CLASSES is None:
  39. CLASSES = DATASETS.get(dataset_name).CLASSES
  40. dummy_runner.data_loader.dataset.CLASSES = CLASSES
  41. compatible_check.before_train_epoch(dummy_runner)
  42. dummy_runner.data_loader.dataset.CLASSES = None
  43. compatible_check.before_train_epoch(dummy_runner)
  44. dataset_name, CLASSES = get_dataset_name_classes(config_mod['data']['val'])
  45. if CLASSES is None:
  46. CLASSES = DATASETS.get(dataset_name).CLASSES
  47. dummy_runner.data_loader.dataset.CLASSES = CLASSES
  48. compatible_check.before_val_epoch(dummy_runner)
  49. dummy_runner.data_loader.dataset.CLASSES = None
  50. compatible_check.before_val_epoch(dummy_runner)
  51. def _check_roi_head(config, head):
  52. # check consistency between head_config and roi_head
  53. assert config['type'] == head.__class__.__name__
  54. # check roi_align
  55. bbox_roi_cfg = config.bbox_roi_extractor
  56. bbox_roi_extractor = head.bbox_roi_extractor
  57. _check_roi_extractor(bbox_roi_cfg, bbox_roi_extractor)
  58. # check bbox head infos
  59. bbox_cfg = config.bbox_head
  60. bbox_head = head.bbox_head
  61. _check_bbox_head(bbox_cfg, bbox_head)
  62. if head.with_mask:
  63. # check roi_align
  64. if config.mask_roi_extractor:
  65. mask_roi_cfg = config.mask_roi_extractor
  66. mask_roi_extractor = head.mask_roi_extractor
  67. _check_roi_extractor(mask_roi_cfg, mask_roi_extractor,
  68. bbox_roi_extractor)
  69. # check mask head infos
  70. mask_head = head.mask_head
  71. mask_cfg = config.mask_head
  72. _check_mask_head(mask_cfg, mask_head)
  73. # check arch specific settings, e.g., cascade/htc
  74. if config['type'] in ['CascadeRoIHead', 'HybridTaskCascadeRoIHead']:
  75. assert config.num_stages == len(head.bbox_head)
  76. assert config.num_stages == len(head.bbox_roi_extractor)
  77. if head.with_mask:
  78. assert config.num_stages == len(head.mask_head)
  79. assert config.num_stages == len(head.mask_roi_extractor)
  80. elif config['type'] in ['MaskScoringRoIHead']:
  81. assert (hasattr(head, 'mask_iou_head')
  82. and head.mask_iou_head is not None)
  83. mask_iou_cfg = config.mask_iou_head
  84. mask_iou_head = head.mask_iou_head
  85. assert (mask_iou_cfg.fc_out_channels ==
  86. mask_iou_head.fc_mask_iou.in_features)
  87. elif config['type'] in ['GridRoIHead']:
  88. grid_roi_cfg = config.grid_roi_extractor
  89. grid_roi_extractor = head.grid_roi_extractor
  90. _check_roi_extractor(grid_roi_cfg, grid_roi_extractor,
  91. bbox_roi_extractor)
  92. config.grid_head.grid_points = head.grid_head.grid_points
  93. def _check_roi_extractor(config, roi_extractor, prev_roi_extractor=None):
  94. import torch.nn as nn
  95. # Separate roi_extractor and prev_roi_extractor checks for flexibility
  96. if isinstance(roi_extractor, nn.ModuleList):
  97. roi_extractor = roi_extractor[0]
  98. if prev_roi_extractor and isinstance(prev_roi_extractor, nn.ModuleList):
  99. prev_roi_extractor = prev_roi_extractor[0]
  100. assert (len(config.featmap_strides) == len(roi_extractor.roi_layers))
  101. assert (config.out_channels == roi_extractor.out_channels)
  102. from torch.nn.modules.utils import _pair
  103. assert (_pair(config.roi_layer.output_size) ==
  104. roi_extractor.roi_layers[0].output_size)
  105. if 'use_torchvision' in config.roi_layer:
  106. assert (config.roi_layer.use_torchvision ==
  107. roi_extractor.roi_layers[0].use_torchvision)
  108. elif 'aligned' in config.roi_layer:
  109. assert (
  110. config.roi_layer.aligned == roi_extractor.roi_layers[0].aligned)
  111. if prev_roi_extractor:
  112. assert (roi_extractor.roi_layers[0].aligned ==
  113. prev_roi_extractor.roi_layers[0].aligned)
  114. assert (roi_extractor.roi_layers[0].use_torchvision ==
  115. prev_roi_extractor.roi_layers[0].use_torchvision)
  116. def _check_mask_head(mask_cfg, mask_head):
  117. import torch.nn as nn
  118. if isinstance(mask_cfg, list):
  119. for single_mask_cfg, single_mask_head in zip(mask_cfg, mask_head):
  120. _check_mask_head(single_mask_cfg, single_mask_head)
  121. elif isinstance(mask_head, nn.ModuleList):
  122. for single_mask_head in mask_head:
  123. _check_mask_head(mask_cfg, single_mask_head)
  124. else:
  125. assert mask_cfg['type'] == mask_head.__class__.__name__
  126. assert mask_cfg.in_channels == mask_head.in_channels
  127. class_agnostic = mask_cfg.get('class_agnostic', False)
  128. out_dim = (1 if class_agnostic else mask_cfg.num_classes)
  129. if hasattr(mask_head, 'conv_logits'):
  130. assert (mask_cfg.conv_out_channels ==
  131. mask_head.conv_logits.in_channels)
  132. assert mask_head.conv_logits.out_channels == out_dim
  133. else:
  134. assert mask_cfg.fc_out_channels == mask_head.fc_logits.in_features
  135. assert (mask_head.fc_logits.out_features == out_dim *
  136. mask_head.output_area)
  137. def _check_bbox_head(bbox_cfg, bbox_head):
  138. import torch.nn as nn
  139. if isinstance(bbox_cfg, list):
  140. for single_bbox_cfg, single_bbox_head in zip(bbox_cfg, bbox_head):
  141. _check_bbox_head(single_bbox_cfg, single_bbox_head)
  142. elif isinstance(bbox_head, nn.ModuleList):
  143. for single_bbox_head in bbox_head:
  144. _check_bbox_head(bbox_cfg, single_bbox_head)
  145. else:
  146. assert bbox_cfg['type'] == bbox_head.__class__.__name__
  147. if bbox_cfg['type'] == 'SABLHead':
  148. assert bbox_cfg.cls_in_channels == bbox_head.cls_in_channels
  149. assert bbox_cfg.reg_in_channels == bbox_head.reg_in_channels
  150. cls_out_channels = bbox_cfg.get('cls_out_channels', 1024)
  151. assert (cls_out_channels == bbox_head.fc_cls.in_features)
  152. assert (bbox_cfg.num_classes + 1 == bbox_head.fc_cls.out_features)
  153. elif bbox_cfg['type'] == 'DIIHead':
  154. assert bbox_cfg['num_ffn_fcs'] == bbox_head.ffn.num_fcs
  155. # 3 means FC and LN and Relu
  156. assert bbox_cfg['num_cls_fcs'] == len(bbox_head.cls_fcs) // 3
  157. assert bbox_cfg['num_reg_fcs'] == len(bbox_head.reg_fcs) // 3
  158. assert bbox_cfg['in_channels'] == bbox_head.in_channels
  159. assert bbox_cfg['in_channels'] == bbox_head.fc_cls.in_features
  160. assert bbox_cfg['in_channels'] == bbox_head.fc_reg.in_features
  161. assert bbox_cfg['in_channels'] == bbox_head.attention.embed_dims
  162. assert bbox_cfg[
  163. 'feedforward_channels'] == bbox_head.ffn.feedforward_channels
  164. else:
  165. assert bbox_cfg.in_channels == bbox_head.in_channels
  166. with_cls = bbox_cfg.get('with_cls', True)
  167. if with_cls:
  168. fc_out_channels = bbox_cfg.get('fc_out_channels', 2048)
  169. assert (fc_out_channels == bbox_head.fc_cls.in_features)
  170. if bbox_head.custom_cls_channels:
  171. assert (bbox_head.loss_cls.get_cls_channels(
  172. bbox_head.num_classes) == bbox_head.fc_cls.out_features
  173. )
  174. else:
  175. assert (bbox_cfg.num_classes +
  176. 1 == bbox_head.fc_cls.out_features)
  177. with_reg = bbox_cfg.get('with_reg', True)
  178. if with_reg:
  179. out_dim = (4 if bbox_cfg.reg_class_agnostic else 4 *
  180. bbox_cfg.num_classes)
  181. assert bbox_head.fc_reg.out_features == out_dim
  182. def _check_anchorhead(config, head):
  183. # check consistency between head_config and roi_head
  184. assert config['type'] == head.__class__.__name__
  185. assert config.in_channels == head.in_channels
  186. num_classes = (
  187. config.num_classes -
  188. 1 if config.loss_cls.get('use_sigmoid', False) else config.num_classes)
  189. if config['type'] == 'ATSSHead':
  190. assert (config.feat_channels == head.atss_cls.in_channels)
  191. assert (config.feat_channels == head.atss_reg.in_channels)
  192. assert (config.feat_channels == head.atss_centerness.in_channels)
  193. elif config['type'] == 'SABLRetinaHead':
  194. assert (config.feat_channels == head.retina_cls.in_channels)
  195. assert (config.feat_channels == head.retina_bbox_reg.in_channels)
  196. assert (config.feat_channels == head.retina_bbox_cls.in_channels)
  197. else:
  198. assert (config.in_channels == head.conv_cls.in_channels)
  199. assert (config.in_channels == head.conv_reg.in_channels)
  200. assert (head.conv_cls.out_channels == num_classes * head.num_anchors)
  201. assert head.fc_reg.out_channels == 4 * head.num_anchors
  202. # Only tests a representative subset of configurations
  203. # TODO: test pipelines using Albu, current Albu throw None given empty GT
  204. @pytest.mark.parametrize(
  205. 'config_rpath',
  206. [
  207. 'wider_face/ssd300_wider_face.py',
  208. 'pascal_voc/ssd300_voc0712.py',
  209. 'pascal_voc/ssd512_voc0712.py',
  210. # 'albu_example/mask_rcnn_r50_fpn_1x.py',
  211. 'foveabox/fovea_align_r50_fpn_gn-head_mstrain_640-800_4x4_2x_coco.py',
  212. 'mask_rcnn/mask_rcnn_r50_caffe_fpn_mstrain-poly_1x_coco.py',
  213. 'mask_rcnn/mask_rcnn_r50_caffe_fpn_mstrain_1x_coco.py',
  214. 'fp16/mask_rcnn_r50_fpn_fp16_1x_coco.py'
  215. ])
  216. def test_config_data_pipeline(config_rpath):
  217. """Test whether the data pipeline is valid and can process corner cases.
  218. CommandLine:
  219. xdoctest -m tests/test_runtime/
  220. test_config.py test_config_build_data_pipeline
  221. """
  222. from mmcv import Config
  223. from mmdet.datasets.pipelines import Compose
  224. import numpy as np
  225. config_dpath = _get_config_directory()
  226. print(f'Found config_dpath = {config_dpath}')
  227. def dummy_masks(h, w, num_obj=3, mode='bitmap'):
  228. assert mode in ('polygon', 'bitmap')
  229. if mode == 'bitmap':
  230. masks = np.random.randint(0, 2, (num_obj, h, w), dtype=np.uint8)
  231. masks = BitmapMasks(masks, h, w)
  232. else:
  233. masks = []
  234. for i in range(num_obj):
  235. masks.append([])
  236. masks[-1].append(
  237. np.random.uniform(0, min(h - 1, w - 1), (8 + 4 * i, )))
  238. masks[-1].append(
  239. np.random.uniform(0, min(h - 1, w - 1), (10 + 4 * i, )))
  240. masks = PolygonMasks(masks, h, w)
  241. return masks
  242. config_fpath = join(config_dpath, config_rpath)
  243. cfg = Config.fromfile(config_fpath)
  244. # remove loading pipeline
  245. loading_pipeline = cfg.train_pipeline.pop(0)
  246. loading_ann_pipeline = cfg.train_pipeline.pop(0)
  247. cfg.test_pipeline.pop(0)
  248. train_pipeline = Compose(cfg.train_pipeline)
  249. test_pipeline = Compose(cfg.test_pipeline)
  250. print(f'Building data pipeline, config_fpath = {config_fpath}')
  251. print(f'Test training data pipeline: \n{train_pipeline!r}')
  252. img = np.random.randint(0, 255, size=(888, 666, 3), dtype=np.uint8)
  253. if loading_pipeline.get('to_float32', False):
  254. img = img.astype(np.float32)
  255. mode = 'bitmap' if loading_ann_pipeline.get('poly2mask',
  256. True) else 'polygon'
  257. results = dict(
  258. filename='test_img.png',
  259. ori_filename='test_img.png',
  260. img=img,
  261. img_shape=img.shape,
  262. ori_shape=img.shape,
  263. gt_bboxes=np.array([[35.2, 11.7, 39.7, 15.7]], dtype=np.float32),
  264. gt_labels=np.array([1], dtype=np.int64),
  265. gt_masks=dummy_masks(img.shape[0], img.shape[1], mode=mode),
  266. )
  267. results['img_fields'] = ['img']
  268. results['bbox_fields'] = ['gt_bboxes']
  269. results['mask_fields'] = ['gt_masks']
  270. output_results = train_pipeline(results)
  271. assert output_results is not None
  272. print(f'Test testing data pipeline: \n{test_pipeline!r}')
  273. results = dict(
  274. filename='test_img.png',
  275. ori_filename='test_img.png',
  276. img=img,
  277. img_shape=img.shape,
  278. ori_shape=img.shape,
  279. gt_bboxes=np.array([[35.2, 11.7, 39.7, 15.7]], dtype=np.float32),
  280. gt_labels=np.array([1], dtype=np.int64),
  281. gt_masks=dummy_masks(img.shape[0], img.shape[1], mode=mode),
  282. )
  283. results['img_fields'] = ['img']
  284. results['bbox_fields'] = ['gt_bboxes']
  285. results['mask_fields'] = ['gt_masks']
  286. output_results = test_pipeline(results)
  287. assert output_results is not None
  288. # test empty GT
  289. print('Test empty GT with training data pipeline: '
  290. f'\n{train_pipeline!r}')
  291. results = dict(
  292. filename='test_img.png',
  293. ori_filename='test_img.png',
  294. img=img,
  295. img_shape=img.shape,
  296. ori_shape=img.shape,
  297. gt_bboxes=np.zeros((0, 4), dtype=np.float32),
  298. gt_labels=np.array([], dtype=np.int64),
  299. gt_masks=dummy_masks(img.shape[0], img.shape[1], num_obj=0, mode=mode),
  300. )
  301. results['img_fields'] = ['img']
  302. results['bbox_fields'] = ['gt_bboxes']
  303. results['mask_fields'] = ['gt_masks']
  304. output_results = train_pipeline(results)
  305. assert output_results is not None
  306. print(f'Test empty GT with testing data pipeline: \n{test_pipeline!r}')
  307. results = dict(
  308. filename='test_img.png',
  309. ori_filename='test_img.png',
  310. img=img,
  311. img_shape=img.shape,
  312. ori_shape=img.shape,
  313. gt_bboxes=np.zeros((0, 4), dtype=np.float32),
  314. gt_labels=np.array([], dtype=np.int64),
  315. gt_masks=dummy_masks(img.shape[0], img.shape[1], num_obj=0, mode=mode),
  316. )
  317. results['img_fields'] = ['img']
  318. results['bbox_fields'] = ['gt_bboxes']
  319. results['mask_fields'] = ['gt_masks']
  320. output_results = test_pipeline(results)
  321. assert output_results is not None

No Description

Contributors (3)