You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

analyze_results.py 7.1 kB

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import argparse
  3. import os.path as osp
  4. import mmcv
  5. import numpy as np
  6. from mmcv import Config, DictAction
  7. from mmdet.core.evaluation import eval_map
  8. from mmdet.core.visualization import imshow_gt_det_bboxes
  9. from mmdet.datasets import build_dataset, get_loading_pipeline
  10. def bbox_map_eval(det_result, annotation):
  11. """Evaluate mAP of single image det result.
  12. Args:
  13. det_result (list[list]): [[cls1_det, cls2_det, ...], ...].
  14. The outer list indicates images, and the inner list indicates
  15. per-class detected bboxes.
  16. annotation (dict): Ground truth annotations where keys of
  17. annotations are:
  18. - bboxes: numpy array of shape (n, 4)
  19. - labels: numpy array of shape (n, )
  20. - bboxes_ignore (optional): numpy array of shape (k, 4)
  21. - labels_ignore (optional): numpy array of shape (k, )
  22. Returns:
  23. float: mAP
  24. """
  25. # use only bbox det result
  26. if isinstance(det_result, tuple):
  27. bbox_det_result = [det_result[0]]
  28. else:
  29. bbox_det_result = [det_result]
  30. # mAP
  31. iou_thrs = np.linspace(
  32. .5, 0.95, int(np.round((0.95 - .5) / .05)) + 1, endpoint=True)
  33. mean_aps = []
  34. for thr in iou_thrs:
  35. mean_ap, _ = eval_map(
  36. bbox_det_result, [annotation], iou_thr=thr, logger='silent')
  37. mean_aps.append(mean_ap)
  38. return sum(mean_aps) / len(mean_aps)
  39. class ResultVisualizer:
  40. """Display and save evaluation results.
  41. Args:
  42. show (bool): Whether to show the image. Default: True
  43. wait_time (float): Value of waitKey param. Default: 0.
  44. score_thr (float): Minimum score of bboxes to be shown.
  45. Default: 0
  46. """
  47. def __init__(self, show=False, wait_time=0, score_thr=0):
  48. self.show = show
  49. self.wait_time = wait_time
  50. self.score_thr = score_thr
  51. def _save_image_gts_results(self, dataset, results, mAPs, out_dir=None):
  52. mmcv.mkdir_or_exist(out_dir)
  53. for mAP_info in mAPs:
  54. index, mAP = mAP_info
  55. data_info = dataset.prepare_train_img(index)
  56. # calc save file path
  57. filename = data_info['filename']
  58. if data_info['img_prefix'] is not None:
  59. filename = osp.join(data_info['img_prefix'], filename)
  60. else:
  61. filename = data_info['filename']
  62. fname, name = osp.splitext(osp.basename(filename))
  63. save_filename = fname + '_' + str(round(mAP, 3)) + name
  64. out_file = osp.join(out_dir, save_filename)
  65. imshow_gt_det_bboxes(
  66. data_info['img'],
  67. data_info,
  68. results[index],
  69. dataset.CLASSES,
  70. show=self.show,
  71. score_thr=self.score_thr,
  72. wait_time=self.wait_time,
  73. out_file=out_file)
  74. def evaluate_and_show(self,
  75. dataset,
  76. results,
  77. topk=20,
  78. show_dir='work_dir',
  79. eval_fn=None):
  80. """Evaluate and show results.
  81. Args:
  82. dataset (Dataset): A PyTorch dataset.
  83. results (list): Det results from test results pkl file
  84. topk (int): Number of the highest topk and
  85. lowest topk after evaluation index sorting. Default: 20
  86. show_dir (str, optional): The filename to write the image.
  87. Default: 'work_dir'
  88. eval_fn (callable, optional): Eval function, Default: None
  89. """
  90. assert topk > 0
  91. if (topk * 2) > len(dataset):
  92. topk = len(dataset) // 2
  93. if eval_fn is None:
  94. eval_fn = bbox_map_eval
  95. else:
  96. assert callable(eval_fn)
  97. prog_bar = mmcv.ProgressBar(len(results))
  98. _mAPs = {}
  99. for i, (result, ) in enumerate(zip(results)):
  100. # self.dataset[i] should not call directly
  101. # because there is a risk of mismatch
  102. data_info = dataset.prepare_train_img(i)
  103. mAP = eval_fn(result, data_info['ann_info'])
  104. _mAPs[i] = mAP
  105. prog_bar.update()
  106. # descending select topk image
  107. _mAPs = list(sorted(_mAPs.items(), key=lambda kv: kv[1]))
  108. good_mAPs = _mAPs[-topk:]
  109. bad_mAPs = _mAPs[:topk]
  110. good_dir = osp.abspath(osp.join(show_dir, 'good'))
  111. bad_dir = osp.abspath(osp.join(show_dir, 'bad'))
  112. self._save_image_gts_results(dataset, results, good_mAPs, good_dir)
  113. self._save_image_gts_results(dataset, results, bad_mAPs, bad_dir)
  114. def parse_args():
  115. parser = argparse.ArgumentParser(
  116. description='MMDet eval image prediction result for each')
  117. parser.add_argument('config', help='test config file path')
  118. parser.add_argument(
  119. 'prediction_path', help='prediction path where test pkl result')
  120. parser.add_argument(
  121. 'show_dir', help='directory where painted images will be saved')
  122. parser.add_argument('--show', action='store_true', help='show results')
  123. parser.add_argument(
  124. '--wait-time',
  125. type=float,
  126. default=0,
  127. help='the interval of show (s), 0 is block')
  128. parser.add_argument(
  129. '--topk',
  130. default=20,
  131. type=int,
  132. help='saved Number of the highest topk '
  133. 'and lowest topk after index sorting')
  134. parser.add_argument(
  135. '--show-score-thr',
  136. type=float,
  137. default=0,
  138. help='score threshold (default: 0.)')
  139. parser.add_argument(
  140. '--cfg-options',
  141. nargs='+',
  142. action=DictAction,
  143. help='override some settings in the used config, the key-value pair '
  144. 'in xxx=yyy format will be merged into config file. If the value to '
  145. 'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
  146. 'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
  147. 'Note that the quotation marks are necessary and that no white space '
  148. 'is allowed.')
  149. args = parser.parse_args()
  150. return args
  151. def main():
  152. args = parse_args()
  153. mmcv.check_file_exist(args.prediction_path)
  154. cfg = Config.fromfile(args.config)
  155. if args.cfg_options is not None:
  156. cfg.merge_from_dict(args.cfg_options)
  157. cfg.data.test.test_mode = True
  158. # import modules from string list.
  159. if cfg.get('custom_imports', None):
  160. from mmcv.utils import import_modules_from_strings
  161. import_modules_from_strings(**cfg['custom_imports'])
  162. cfg.data.test.pop('samples_per_gpu', 0)
  163. cfg.data.test.pipeline = get_loading_pipeline(cfg.data.train.pipeline)
  164. dataset = build_dataset(cfg.data.test)
  165. outputs = mmcv.load(args.prediction_path)
  166. result_visualizer = ResultVisualizer(args.show, args.wait_time,
  167. args.show_score_thr)
  168. result_visualizer.evaluate_and_show(
  169. dataset, outputs, topk=args.topk, show_dir=args.show_dir)
  170. if __name__ == '__main__':
  171. main()

No Description

Contributors (1)