You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_robustness.py 15 kB

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import argparse
  3. import copy
  4. import os
  5. import os.path as osp
  6. import mmcv
  7. import torch
  8. from mmcv import DictAction
  9. from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
  10. from mmcv.runner import (get_dist_info, init_dist, load_checkpoint,
  11. wrap_fp16_model)
  12. from pycocotools.coco import COCO
  13. from pycocotools.cocoeval import COCOeval
  14. from tools.analysis_tools.robustness_eval import get_results
  15. from mmdet import datasets
  16. from mmdet.apis import multi_gpu_test, set_random_seed, single_gpu_test
  17. from mmdet.core import eval_map
  18. from mmdet.datasets import build_dataloader, build_dataset
  19. from mmdet.models import build_detector
  20. def coco_eval_with_return(result_files,
  21. result_types,
  22. coco,
  23. max_dets=(100, 300, 1000)):
  24. for res_type in result_types:
  25. assert res_type in ['proposal', 'bbox', 'segm', 'keypoints']
  26. if mmcv.is_str(coco):
  27. coco = COCO(coco)
  28. assert isinstance(coco, COCO)
  29. eval_results = {}
  30. for res_type in result_types:
  31. result_file = result_files[res_type]
  32. assert result_file.endswith('.json')
  33. coco_dets = coco.loadRes(result_file)
  34. img_ids = coco.getImgIds()
  35. iou_type = 'bbox' if res_type == 'proposal' else res_type
  36. cocoEval = COCOeval(coco, coco_dets, iou_type)
  37. cocoEval.params.imgIds = img_ids
  38. if res_type == 'proposal':
  39. cocoEval.params.useCats = 0
  40. cocoEval.params.maxDets = list(max_dets)
  41. cocoEval.evaluate()
  42. cocoEval.accumulate()
  43. cocoEval.summarize()
  44. if res_type == 'segm' or res_type == 'bbox':
  45. metric_names = [
  46. 'AP', 'AP50', 'AP75', 'APs', 'APm', 'APl', 'AR1', 'AR10',
  47. 'AR100', 'ARs', 'ARm', 'ARl'
  48. ]
  49. eval_results[res_type] = {
  50. metric_names[i]: cocoEval.stats[i]
  51. for i in range(len(metric_names))
  52. }
  53. else:
  54. eval_results[res_type] = cocoEval.stats
  55. return eval_results
  56. def voc_eval_with_return(result_file,
  57. dataset,
  58. iou_thr=0.5,
  59. logger='print',
  60. only_ap=True):
  61. det_results = mmcv.load(result_file)
  62. annotations = [dataset.get_ann_info(i) for i in range(len(dataset))]
  63. if hasattr(dataset, 'year') and dataset.year == 2007:
  64. dataset_name = 'voc07'
  65. else:
  66. dataset_name = dataset.CLASSES
  67. mean_ap, eval_results = eval_map(
  68. det_results,
  69. annotations,
  70. scale_ranges=None,
  71. iou_thr=iou_thr,
  72. dataset=dataset_name,
  73. logger=logger)
  74. if only_ap:
  75. eval_results = [{
  76. 'ap': eval_results[i]['ap']
  77. } for i in range(len(eval_results))]
  78. return mean_ap, eval_results
  79. def parse_args():
  80. parser = argparse.ArgumentParser(description='MMDet test detector')
  81. parser.add_argument('config', help='test config file path')
  82. parser.add_argument('checkpoint', help='checkpoint file')
  83. parser.add_argument('--out', help='output result file')
  84. parser.add_argument(
  85. '--corruptions',
  86. type=str,
  87. nargs='+',
  88. default='benchmark',
  89. choices=[
  90. 'all', 'benchmark', 'noise', 'blur', 'weather', 'digital',
  91. 'holdout', 'None', 'gaussian_noise', 'shot_noise', 'impulse_noise',
  92. 'defocus_blur', 'glass_blur', 'motion_blur', 'zoom_blur', 'snow',
  93. 'frost', 'fog', 'brightness', 'contrast', 'elastic_transform',
  94. 'pixelate', 'jpeg_compression', 'speckle_noise', 'gaussian_blur',
  95. 'spatter', 'saturate'
  96. ],
  97. help='corruptions')
  98. parser.add_argument(
  99. '--severities',
  100. type=int,
  101. nargs='+',
  102. default=[0, 1, 2, 3, 4, 5],
  103. help='corruption severity levels')
  104. parser.add_argument(
  105. '--eval',
  106. type=str,
  107. nargs='+',
  108. choices=['proposal', 'proposal_fast', 'bbox', 'segm', 'keypoints'],
  109. help='eval types')
  110. parser.add_argument(
  111. '--iou-thr',
  112. type=float,
  113. default=0.5,
  114. help='IoU threshold for pascal voc evaluation')
  115. parser.add_argument(
  116. '--summaries',
  117. type=bool,
  118. default=False,
  119. help='Print summaries for every corruption and severity')
  120. parser.add_argument(
  121. '--workers', type=int, default=32, help='workers per gpu')
  122. parser.add_argument('--show', action='store_true', help='show results')
  123. parser.add_argument(
  124. '--show-dir', help='directory where painted images will be saved')
  125. parser.add_argument(
  126. '--show-score-thr',
  127. type=float,
  128. default=0.3,
  129. help='score threshold (default: 0.3)')
  130. parser.add_argument('--tmpdir', help='tmp dir for writing some results')
  131. parser.add_argument('--seed', type=int, default=None, help='random seed')
  132. parser.add_argument(
  133. '--launcher',
  134. choices=['none', 'pytorch', 'slurm', 'mpi'],
  135. default='none',
  136. help='job launcher')
  137. parser.add_argument('--local_rank', type=int, default=0)
  138. parser.add_argument(
  139. '--final-prints',
  140. type=str,
  141. nargs='+',
  142. choices=['P', 'mPC', 'rPC'],
  143. default='mPC',
  144. help='corruption benchmark metric to print at the end')
  145. parser.add_argument(
  146. '--final-prints-aggregate',
  147. type=str,
  148. choices=['all', 'benchmark'],
  149. default='benchmark',
  150. help='aggregate all results or only those for benchmark corruptions')
  151. parser.add_argument(
  152. '--cfg-options',
  153. nargs='+',
  154. action=DictAction,
  155. help='override some settings in the used config, the key-value pair '
  156. 'in xxx=yyy format will be merged into config file. If the value to '
  157. 'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
  158. 'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
  159. 'Note that the quotation marks are necessary and that no white space '
  160. 'is allowed.')
  161. args = parser.parse_args()
  162. if 'LOCAL_RANK' not in os.environ:
  163. os.environ['LOCAL_RANK'] = str(args.local_rank)
  164. return args
  165. def main():
  166. args = parse_args()
  167. assert args.out or args.show or args.show_dir, \
  168. ('Please specify at least one operation (save or show the results) '
  169. 'with the argument "--out", "--show" or "show-dir"')
  170. if args.out is not None and not args.out.endswith(('.pkl', '.pickle')):
  171. raise ValueError('The output file must be a pkl file.')
  172. cfg = mmcv.Config.fromfile(args.config)
  173. if args.cfg_options is not None:
  174. cfg.merge_from_dict(args.cfg_options)
  175. # import modules from string list.
  176. if cfg.get('custom_imports', None):
  177. from mmcv.utils import import_modules_from_strings
  178. import_modules_from_strings(**cfg['custom_imports'])
  179. # set cudnn_benchmark
  180. if cfg.get('cudnn_benchmark', False):
  181. torch.backends.cudnn.benchmark = True
  182. cfg.model.pretrained = None
  183. cfg.data.test.test_mode = True
  184. if args.workers == 0:
  185. args.workers = cfg.data.workers_per_gpu
  186. # init distributed env first, since logger depends on the dist info.
  187. if args.launcher == 'none':
  188. distributed = False
  189. else:
  190. distributed = True
  191. init_dist(args.launcher, **cfg.dist_params)
  192. # set random seeds
  193. if args.seed is not None:
  194. set_random_seed(args.seed)
  195. if 'all' in args.corruptions:
  196. corruptions = [
  197. 'gaussian_noise', 'shot_noise', 'impulse_noise', 'defocus_blur',
  198. 'glass_blur', 'motion_blur', 'zoom_blur', 'snow', 'frost', 'fog',
  199. 'brightness', 'contrast', 'elastic_transform', 'pixelate',
  200. 'jpeg_compression', 'speckle_noise', 'gaussian_blur', 'spatter',
  201. 'saturate'
  202. ]
  203. elif 'benchmark' in args.corruptions:
  204. corruptions = [
  205. 'gaussian_noise', 'shot_noise', 'impulse_noise', 'defocus_blur',
  206. 'glass_blur', 'motion_blur', 'zoom_blur', 'snow', 'frost', 'fog',
  207. 'brightness', 'contrast', 'elastic_transform', 'pixelate',
  208. 'jpeg_compression'
  209. ]
  210. elif 'noise' in args.corruptions:
  211. corruptions = ['gaussian_noise', 'shot_noise', 'impulse_noise']
  212. elif 'blur' in args.corruptions:
  213. corruptions = [
  214. 'defocus_blur', 'glass_blur', 'motion_blur', 'zoom_blur'
  215. ]
  216. elif 'weather' in args.corruptions:
  217. corruptions = ['snow', 'frost', 'fog', 'brightness']
  218. elif 'digital' in args.corruptions:
  219. corruptions = [
  220. 'contrast', 'elastic_transform', 'pixelate', 'jpeg_compression'
  221. ]
  222. elif 'holdout' in args.corruptions:
  223. corruptions = ['speckle_noise', 'gaussian_blur', 'spatter', 'saturate']
  224. elif 'None' in args.corruptions:
  225. corruptions = ['None']
  226. args.severities = [0]
  227. else:
  228. corruptions = args.corruptions
  229. rank, _ = get_dist_info()
  230. aggregated_results = {}
  231. for corr_i, corruption in enumerate(corruptions):
  232. aggregated_results[corruption] = {}
  233. for sev_i, corruption_severity in enumerate(args.severities):
  234. # evaluate severity 0 (= no corruption) only once
  235. if corr_i > 0 and corruption_severity == 0:
  236. aggregated_results[corruption][0] = \
  237. aggregated_results[corruptions[0]][0]
  238. continue
  239. test_data_cfg = copy.deepcopy(cfg.data.test)
  240. # assign corruption and severity
  241. if corruption_severity > 0:
  242. corruption_trans = dict(
  243. type='Corrupt',
  244. corruption=corruption,
  245. severity=corruption_severity)
  246. # TODO: hard coded "1", we assume that the first step is
  247. # loading images, which needs to be fixed in the future
  248. test_data_cfg['pipeline'].insert(1, corruption_trans)
  249. # print info
  250. print(f'\nTesting {corruption} at severity {corruption_severity}')
  251. # build the dataloader
  252. # TODO: support multiple images per gpu
  253. # (only minor changes are needed)
  254. dataset = build_dataset(test_data_cfg)
  255. data_loader = build_dataloader(
  256. dataset,
  257. samples_per_gpu=1,
  258. workers_per_gpu=args.workers,
  259. dist=distributed,
  260. shuffle=False)
  261. # build the model and load checkpoint
  262. cfg.model.train_cfg = None
  263. model = build_detector(cfg.model, test_cfg=cfg.get('test_cfg'))
  264. fp16_cfg = cfg.get('fp16', None)
  265. if fp16_cfg is not None:
  266. wrap_fp16_model(model)
  267. checkpoint = load_checkpoint(
  268. model, args.checkpoint, map_location='cpu')
  269. # old versions did not save class info in checkpoints,
  270. # this walkaround is for backward compatibility
  271. if 'CLASSES' in checkpoint.get('meta', {}):
  272. model.CLASSES = checkpoint['meta']['CLASSES']
  273. else:
  274. model.CLASSES = dataset.CLASSES
  275. if not distributed:
  276. model = MMDataParallel(model, device_ids=[0])
  277. show_dir = args.show_dir
  278. if show_dir is not None:
  279. show_dir = osp.join(show_dir, corruption)
  280. show_dir = osp.join(show_dir, str(corruption_severity))
  281. if not osp.exists(show_dir):
  282. osp.makedirs(show_dir)
  283. outputs = single_gpu_test(model, data_loader, args.show,
  284. show_dir, args.show_score_thr)
  285. else:
  286. model = MMDistributedDataParallel(
  287. model.cuda(),
  288. device_ids=[torch.cuda.current_device()],
  289. broadcast_buffers=False)
  290. outputs = multi_gpu_test(model, data_loader, args.tmpdir)
  291. if args.out and rank == 0:
  292. eval_results_filename = (
  293. osp.splitext(args.out)[0] + '_results' +
  294. osp.splitext(args.out)[1])
  295. mmcv.dump(outputs, args.out)
  296. eval_types = args.eval
  297. if cfg.dataset_type == 'VOCDataset':
  298. if eval_types:
  299. for eval_type in eval_types:
  300. if eval_type == 'bbox':
  301. test_dataset = mmcv.runner.obj_from_dict(
  302. cfg.data.test, datasets)
  303. logger = 'print' if args.summaries else None
  304. mean_ap, eval_results = \
  305. voc_eval_with_return(
  306. args.out, test_dataset,
  307. args.iou_thr, logger)
  308. aggregated_results[corruption][
  309. corruption_severity] = eval_results
  310. else:
  311. print('\nOnly "bbox" evaluation \
  312. is supported for pascal voc')
  313. else:
  314. if eval_types:
  315. print(f'Starting evaluate {" and ".join(eval_types)}')
  316. if eval_types == ['proposal_fast']:
  317. result_file = args.out
  318. else:
  319. if not isinstance(outputs[0], dict):
  320. result_files = dataset.results2json(
  321. outputs, args.out)
  322. else:
  323. for name in outputs[0]:
  324. print(f'\nEvaluating {name}')
  325. outputs_ = [out[name] for out in outputs]
  326. result_file = args.out
  327. + f'.{name}'
  328. result_files = dataset.results2json(
  329. outputs_, result_file)
  330. eval_results = coco_eval_with_return(
  331. result_files, eval_types, dataset.coco)
  332. aggregated_results[corruption][
  333. corruption_severity] = eval_results
  334. else:
  335. print('\nNo task was selected for evaluation;'
  336. '\nUse --eval to select a task')
  337. # save results after each evaluation
  338. mmcv.dump(aggregated_results, eval_results_filename)
  339. if rank == 0:
  340. # print final results
  341. print('\nAggregated results:')
  342. prints = args.final_prints
  343. aggregate = args.final_prints_aggregate
  344. if cfg.dataset_type == 'VOCDataset':
  345. get_results(
  346. eval_results_filename,
  347. dataset='voc',
  348. prints=prints,
  349. aggregate=aggregate)
  350. else:
  351. get_results(
  352. eval_results_filename,
  353. dataset='coco',
  354. prints=prints,
  355. aggregate=aggregate)
  356. if __name__ == '__main__':
  357. main()

No Description

Contributors (3)