You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_common.py 12 kB

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import copy
  3. import logging
  4. import os
  5. import os.path as osp
  6. import tempfile
  7. from unittest.mock import MagicMock, patch
  8. import mmcv
  9. import numpy as np
  10. import pytest
  11. import torch
  12. import torch.nn as nn
  13. from mmcv.runner import EpochBasedRunner
  14. from torch.utils.data import DataLoader
  15. from mmdet.core.evaluation import DistEvalHook, EvalHook
  16. from mmdet.datasets import DATASETS, CocoDataset, CustomDataset, build_dataset
  17. def _create_dummy_coco_json(json_name):
  18. image = {
  19. 'id': 0,
  20. 'width': 640,
  21. 'height': 640,
  22. 'file_name': 'fake_name.jpg',
  23. }
  24. annotation_1 = {
  25. 'id': 1,
  26. 'image_id': 0,
  27. 'category_id': 0,
  28. 'area': 400,
  29. 'bbox': [50, 60, 20, 20],
  30. 'iscrowd': 0,
  31. }
  32. annotation_2 = {
  33. 'id': 2,
  34. 'image_id': 0,
  35. 'category_id': 0,
  36. 'area': 900,
  37. 'bbox': [100, 120, 30, 30],
  38. 'iscrowd': 0,
  39. }
  40. annotation_3 = {
  41. 'id': 3,
  42. 'image_id': 0,
  43. 'category_id': 0,
  44. 'area': 1600,
  45. 'bbox': [150, 160, 40, 40],
  46. 'iscrowd': 0,
  47. }
  48. annotation_4 = {
  49. 'id': 4,
  50. 'image_id': 0,
  51. 'category_id': 0,
  52. 'area': 10000,
  53. 'bbox': [250, 260, 100, 100],
  54. 'iscrowd': 0,
  55. }
  56. categories = [{
  57. 'id': 0,
  58. 'name': 'car',
  59. 'supercategory': 'car',
  60. }]
  61. fake_json = {
  62. 'images': [image],
  63. 'annotations':
  64. [annotation_1, annotation_2, annotation_3, annotation_4],
  65. 'categories': categories
  66. }
  67. mmcv.dump(fake_json, json_name)
  68. def _create_dummy_custom_pkl(pkl_name):
  69. fake_pkl = [{
  70. 'filename': 'fake_name.jpg',
  71. 'width': 640,
  72. 'height': 640,
  73. 'ann': {
  74. 'bboxes':
  75. np.array([[50, 60, 70, 80], [100, 120, 130, 150],
  76. [150, 160, 190, 200], [250, 260, 350, 360]]),
  77. 'labels':
  78. np.array([0, 0, 0, 0])
  79. }
  80. }]
  81. mmcv.dump(fake_pkl, pkl_name)
  82. def _create_dummy_results():
  83. boxes = [
  84. np.array([[50, 60, 70, 80, 1.0], [100, 120, 130, 150, 0.98],
  85. [150, 160, 190, 200, 0.96], [250, 260, 350, 360, 0.95]])
  86. ]
  87. return [boxes]
  88. @pytest.mark.parametrize('config_path',
  89. ['./configs/_base_/datasets/voc0712.py'])
  90. def test_dataset_init(config_path):
  91. use_symlink = False
  92. if not os.path.exists('./data'):
  93. os.symlink('./tests/data', './data')
  94. use_symlink = True
  95. data_config = mmcv.Config.fromfile(config_path)
  96. if 'data' not in data_config:
  97. return
  98. stage_names = ['train', 'val', 'test']
  99. for stage_name in stage_names:
  100. dataset_config = copy.deepcopy(data_config.data.get(stage_name))
  101. dataset = build_dataset(dataset_config)
  102. dataset[0]
  103. if use_symlink:
  104. os.unlink('./data')
  105. def test_dataset_evaluation():
  106. tmp_dir = tempfile.TemporaryDirectory()
  107. # create dummy data
  108. fake_json_file = osp.join(tmp_dir.name, 'fake_data.json')
  109. _create_dummy_coco_json(fake_json_file)
  110. # test single coco dataset evaluation
  111. coco_dataset = CocoDataset(
  112. ann_file=fake_json_file, classes=('car', ), pipeline=[])
  113. fake_results = _create_dummy_results()
  114. eval_results = coco_dataset.evaluate(fake_results, classwise=True)
  115. assert eval_results['bbox_mAP'] == 1
  116. assert eval_results['bbox_mAP_50'] == 1
  117. assert eval_results['bbox_mAP_75'] == 1
  118. # test concat dataset evaluation
  119. fake_concat_results = _create_dummy_results() + _create_dummy_results()
  120. # build concat dataset through two config dict
  121. coco_cfg = dict(
  122. type='CocoDataset',
  123. ann_file=fake_json_file,
  124. classes=('car', ),
  125. pipeline=[])
  126. concat_cfgs = [coco_cfg, coco_cfg]
  127. concat_dataset = build_dataset(concat_cfgs)
  128. eval_results = concat_dataset.evaluate(fake_concat_results)
  129. assert eval_results['0_bbox_mAP'] == 1
  130. assert eval_results['0_bbox_mAP_50'] == 1
  131. assert eval_results['0_bbox_mAP_75'] == 1
  132. assert eval_results['1_bbox_mAP'] == 1
  133. assert eval_results['1_bbox_mAP_50'] == 1
  134. assert eval_results['1_bbox_mAP_75'] == 1
  135. # build concat dataset through concatenated ann_file
  136. coco_cfg = dict(
  137. type='CocoDataset',
  138. ann_file=[fake_json_file, fake_json_file],
  139. classes=('car', ),
  140. pipeline=[])
  141. concat_dataset = build_dataset(coco_cfg)
  142. eval_results = concat_dataset.evaluate(fake_concat_results)
  143. assert eval_results['0_bbox_mAP'] == 1
  144. assert eval_results['0_bbox_mAP_50'] == 1
  145. assert eval_results['0_bbox_mAP_75'] == 1
  146. assert eval_results['1_bbox_mAP'] == 1
  147. assert eval_results['1_bbox_mAP_50'] == 1
  148. assert eval_results['1_bbox_mAP_75'] == 1
  149. # create dummy data
  150. fake_pkl_file = osp.join(tmp_dir.name, 'fake_data.pkl')
  151. _create_dummy_custom_pkl(fake_pkl_file)
  152. # test single custom dataset evaluation
  153. custom_dataset = CustomDataset(
  154. ann_file=fake_pkl_file, classes=('car', ), pipeline=[])
  155. fake_results = _create_dummy_results()
  156. eval_results = custom_dataset.evaluate(fake_results)
  157. assert eval_results['mAP'] == 1
  158. # test concat dataset evaluation
  159. fake_concat_results = _create_dummy_results() + _create_dummy_results()
  160. # build concat dataset through two config dict
  161. custom_cfg = dict(
  162. type='CustomDataset',
  163. ann_file=fake_pkl_file,
  164. classes=('car', ),
  165. pipeline=[])
  166. concat_cfgs = [custom_cfg, custom_cfg]
  167. concat_dataset = build_dataset(concat_cfgs)
  168. eval_results = concat_dataset.evaluate(fake_concat_results)
  169. assert eval_results['0_mAP'] == 1
  170. assert eval_results['1_mAP'] == 1
  171. # build concat dataset through concatenated ann_file
  172. concat_cfg = dict(
  173. type='CustomDataset',
  174. ann_file=[fake_pkl_file, fake_pkl_file],
  175. classes=('car', ),
  176. pipeline=[])
  177. concat_dataset = build_dataset(concat_cfg)
  178. eval_results = concat_dataset.evaluate(fake_concat_results)
  179. assert eval_results['0_mAP'] == 1
  180. assert eval_results['1_mAP'] == 1
  181. # build concat dataset through explicit type
  182. concat_cfg = dict(
  183. type='ConcatDataset',
  184. datasets=[custom_cfg, custom_cfg],
  185. separate_eval=False)
  186. concat_dataset = build_dataset(concat_cfg)
  187. eval_results = concat_dataset.evaluate(fake_concat_results, metric='mAP')
  188. assert eval_results['mAP'] == 1
  189. assert len(concat_dataset.datasets[0].data_infos) == \
  190. len(concat_dataset.datasets[1].data_infos)
  191. assert len(concat_dataset.datasets[0].data_infos) == 1
  192. tmp_dir.cleanup()
  193. @patch('mmdet.apis.single_gpu_test', MagicMock)
  194. @patch('mmdet.apis.multi_gpu_test', MagicMock)
  195. @pytest.mark.parametrize('EvalHookParam', (EvalHook, DistEvalHook))
  196. def test_evaluation_hook(EvalHookParam):
  197. # create dummy data
  198. dataloader = DataLoader(torch.ones((5, 2)))
  199. # 0.1. dataloader is not a DataLoader object
  200. with pytest.raises(TypeError):
  201. EvalHookParam(dataloader=MagicMock(), interval=-1)
  202. # 0.2. negative interval
  203. with pytest.raises(ValueError):
  204. EvalHookParam(dataloader, interval=-1)
  205. # 1. start=None, interval=1: perform evaluation after each epoch.
  206. runner = _build_demo_runner()
  207. evalhook = EvalHookParam(dataloader, interval=1)
  208. evalhook.evaluate = MagicMock()
  209. runner.register_hook(evalhook)
  210. runner.run([dataloader], [('train', 1)], 2)
  211. assert evalhook.evaluate.call_count == 2 # after epoch 1 & 2
  212. # 2. start=1, interval=1: perform evaluation after each epoch.
  213. runner = _build_demo_runner()
  214. evalhook = EvalHookParam(dataloader, start=1, interval=1)
  215. evalhook.evaluate = MagicMock()
  216. runner.register_hook(evalhook)
  217. runner.run([dataloader], [('train', 1)], 2)
  218. assert evalhook.evaluate.call_count == 2 # after epoch 1 & 2
  219. # 3. start=None, interval=2: perform evaluation after epoch 2, 4, 6, etc
  220. runner = _build_demo_runner()
  221. evalhook = EvalHookParam(dataloader, interval=2)
  222. evalhook.evaluate = MagicMock()
  223. runner.register_hook(evalhook)
  224. runner.run([dataloader], [('train', 1)], 2)
  225. assert evalhook.evaluate.call_count == 1 # after epoch 2
  226. # 4. start=1, interval=2: perform evaluation after epoch 1, 3, 5, etc
  227. runner = _build_demo_runner()
  228. evalhook = EvalHookParam(dataloader, start=1, interval=2)
  229. evalhook.evaluate = MagicMock()
  230. runner.register_hook(evalhook)
  231. runner.run([dataloader], [('train', 1)], 3)
  232. assert evalhook.evaluate.call_count == 2 # after epoch 1 & 3
  233. # 5. start=0/negative, interval=1: perform evaluation after each epoch and
  234. # before epoch 1.
  235. runner = _build_demo_runner()
  236. evalhook = EvalHookParam(dataloader, start=0)
  237. evalhook.evaluate = MagicMock()
  238. runner.register_hook(evalhook)
  239. runner.run([dataloader], [('train', 1)], 2)
  240. assert evalhook.evaluate.call_count == 3 # before epoch1 and after e1 & e2
  241. # the evaluation start epoch cannot be less than 0
  242. runner = _build_demo_runner()
  243. with pytest.raises(ValueError):
  244. EvalHookParam(dataloader, start=-2)
  245. evalhook = EvalHookParam(dataloader, start=0)
  246. evalhook.evaluate = MagicMock()
  247. runner.register_hook(evalhook)
  248. runner.run([dataloader], [('train', 1)], 2)
  249. assert evalhook.evaluate.call_count == 3 # before epoch1 and after e1 & e2
  250. # 6. resuming from epoch i, start = x (x<=i), interval =1: perform
  251. # evaluation after each epoch and before the first epoch.
  252. runner = _build_demo_runner()
  253. evalhook = EvalHookParam(dataloader, start=1)
  254. evalhook.evaluate = MagicMock()
  255. runner.register_hook(evalhook)
  256. runner._epoch = 2
  257. runner.run([dataloader], [('train', 1)], 3)
  258. assert evalhook.evaluate.call_count == 2 # before & after epoch 3
  259. # 7. resuming from epoch i, start = i+1/None, interval =1: perform
  260. # evaluation after each epoch.
  261. runner = _build_demo_runner()
  262. evalhook = EvalHookParam(dataloader, start=2)
  263. evalhook.evaluate = MagicMock()
  264. runner.register_hook(evalhook)
  265. runner._epoch = 1
  266. runner.run([dataloader], [('train', 1)], 3)
  267. assert evalhook.evaluate.call_count == 2 # after epoch 2 & 3
  268. def _build_demo_runner():
  269. class Model(nn.Module):
  270. def __init__(self):
  271. super().__init__()
  272. self.linear = nn.Linear(2, 1)
  273. def forward(self, x):
  274. return self.linear(x)
  275. def train_step(self, x, optimizer, **kwargs):
  276. return dict(loss=self(x))
  277. def val_step(self, x, optimizer, **kwargs):
  278. return dict(loss=self(x))
  279. model = Model()
  280. tmp_dir = tempfile.mkdtemp()
  281. runner = EpochBasedRunner(
  282. model=model, work_dir=tmp_dir, logger=logging.getLogger())
  283. return runner
  284. @pytest.mark.parametrize('classes, expected_length', [(['bus'], 2),
  285. (['car'], 1),
  286. (['bus', 'car'], 2)])
  287. def test_allow_empty_images(classes, expected_length):
  288. dataset_class = DATASETS.get('CocoDataset')
  289. # Filter empty images
  290. filtered_dataset = dataset_class(
  291. ann_file='tests/data/coco_sample.json',
  292. img_prefix='tests/data',
  293. pipeline=[],
  294. classes=classes,
  295. filter_empty_gt=True)
  296. # Get all
  297. full_dataset = dataset_class(
  298. ann_file='tests/data/coco_sample.json',
  299. img_prefix='tests/data',
  300. pipeline=[],
  301. classes=classes,
  302. filter_empty_gt=False)
  303. assert len(filtered_dataset) == expected_length
  304. assert len(filtered_dataset.img_ids) == expected_length
  305. assert len(full_dataset) == 3
  306. assert len(full_dataset.img_ids) == 3
  307. assert filtered_dataset.CLASSES == classes
  308. assert full_dataset.CLASSES == classes

No Description

Contributors (3)