You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

mean_ap.py 21 kB

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from multiprocessing import Pool
  3. import mmcv
  4. import numpy as np
  5. from mmcv.utils import print_log
  6. from terminaltables import AsciiTable
  7. from .bbox_overlaps import bbox_overlaps
  8. from .class_names import get_classes
  9. def average_precision(recalls, precisions, mode='area'):
  10. """Calculate average precision (for single or multiple scales).
  11. Args:
  12. recalls (ndarray): shape (num_scales, num_dets) or (num_dets, )
  13. precisions (ndarray): shape (num_scales, num_dets) or (num_dets, )
  14. mode (str): 'area' or '11points', 'area' means calculating the area
  15. under precision-recall curve, '11points' means calculating
  16. the average precision of recalls at [0, 0.1, ..., 1]
  17. Returns:
  18. float or ndarray: calculated average precision
  19. """
  20. no_scale = False
  21. if recalls.ndim == 1:
  22. no_scale = True
  23. recalls = recalls[np.newaxis, :]
  24. precisions = precisions[np.newaxis, :]
  25. assert recalls.shape == precisions.shape and recalls.ndim == 2
  26. num_scales = recalls.shape[0]
  27. ap = np.zeros(num_scales, dtype=np.float32)
  28. if mode == 'area':
  29. zeros = np.zeros((num_scales, 1), dtype=recalls.dtype)
  30. ones = np.ones((num_scales, 1), dtype=recalls.dtype)
  31. mrec = np.hstack((zeros, recalls, ones))
  32. mpre = np.hstack((zeros, precisions, zeros))
  33. for i in range(mpre.shape[1] - 1, 0, -1):
  34. mpre[:, i - 1] = np.maximum(mpre[:, i - 1], mpre[:, i])
  35. for i in range(num_scales):
  36. ind = np.where(mrec[i, 1:] != mrec[i, :-1])[0]
  37. ap[i] = np.sum(
  38. (mrec[i, ind + 1] - mrec[i, ind]) * mpre[i, ind + 1])
  39. elif mode == '11points':
  40. for i in range(num_scales):
  41. for thr in np.arange(0, 1 + 1e-3, 0.1):
  42. precs = precisions[i, recalls[i, :] >= thr]
  43. prec = precs.max() if precs.size > 0 else 0
  44. ap[i] += prec
  45. ap /= 11
  46. else:
  47. raise ValueError(
  48. 'Unrecognized mode, only "area" and "11points" are supported')
  49. if no_scale:
  50. ap = ap[0]
  51. return ap
  52. def tpfp_imagenet(det_bboxes,
  53. gt_bboxes,
  54. gt_bboxes_ignore=None,
  55. default_iou_thr=0.5,
  56. area_ranges=None,
  57. use_legacy_coordinate=False):
  58. """Check if detected bboxes are true positive or false positive.
  59. Args:
  60. det_bbox (ndarray): Detected bboxes of this image, of shape (m, 5).
  61. gt_bboxes (ndarray): GT bboxes of this image, of shape (n, 4).
  62. gt_bboxes_ignore (ndarray): Ignored gt bboxes of this image,
  63. of shape (k, 4). Default: None
  64. default_iou_thr (float): IoU threshold to be considered as matched for
  65. medium and large bboxes (small ones have special rules).
  66. Default: 0.5.
  67. area_ranges (list[tuple] | None): Range of bbox areas to be evaluated,
  68. in the format [(min1, max1), (min2, max2), ...]. Default: None.
  69. use_legacy_coordinate (bool): Whether to use coordinate system in
  70. mmdet v1.x. which means width, height should be
  71. calculated as 'x2 - x1 + 1` and 'y2 - y1 + 1' respectively.
  72. Default: False.
  73. Returns:
  74. tuple[np.ndarray]: (tp, fp) whose elements are 0 and 1. The shape of
  75. each array is (num_scales, m).
  76. """
  77. if not use_legacy_coordinate:
  78. extra_length = 0.
  79. else:
  80. extra_length = 1.
  81. # an indicator of ignored gts
  82. gt_ignore_inds = np.concatenate(
  83. (np.zeros(gt_bboxes.shape[0], dtype=np.bool),
  84. np.ones(gt_bboxes_ignore.shape[0], dtype=np.bool)))
  85. # stack gt_bboxes and gt_bboxes_ignore for convenience
  86. gt_bboxes = np.vstack((gt_bboxes, gt_bboxes_ignore))
  87. num_dets = det_bboxes.shape[0]
  88. num_gts = gt_bboxes.shape[0]
  89. if area_ranges is None:
  90. area_ranges = [(None, None)]
  91. num_scales = len(area_ranges)
  92. # tp and fp are of shape (num_scales, num_gts), each row is tp or fp
  93. # of a certain scale.
  94. tp = np.zeros((num_scales, num_dets), dtype=np.float32)
  95. fp = np.zeros((num_scales, num_dets), dtype=np.float32)
  96. if gt_bboxes.shape[0] == 0:
  97. if area_ranges == [(None, None)]:
  98. fp[...] = 1
  99. else:
  100. det_areas = (
  101. det_bboxes[:, 2] - det_bboxes[:, 0] + extra_length) * (
  102. det_bboxes[:, 3] - det_bboxes[:, 1] + extra_length)
  103. for i, (min_area, max_area) in enumerate(area_ranges):
  104. fp[i, (det_areas >= min_area) & (det_areas < max_area)] = 1
  105. return tp, fp
  106. ious = bbox_overlaps(
  107. det_bboxes, gt_bboxes - 1, use_legacy_coordinate=use_legacy_coordinate)
  108. gt_w = gt_bboxes[:, 2] - gt_bboxes[:, 0] + extra_length
  109. gt_h = gt_bboxes[:, 3] - gt_bboxes[:, 1] + extra_length
  110. iou_thrs = np.minimum((gt_w * gt_h) / ((gt_w + 10.0) * (gt_h + 10.0)),
  111. default_iou_thr)
  112. # sort all detections by scores in descending order
  113. sort_inds = np.argsort(-det_bboxes[:, -1])
  114. for k, (min_area, max_area) in enumerate(area_ranges):
  115. gt_covered = np.zeros(num_gts, dtype=bool)
  116. # if no area range is specified, gt_area_ignore is all False
  117. if min_area is None:
  118. gt_area_ignore = np.zeros_like(gt_ignore_inds, dtype=bool)
  119. else:
  120. gt_areas = gt_w * gt_h
  121. gt_area_ignore = (gt_areas < min_area) | (gt_areas >= max_area)
  122. for i in sort_inds:
  123. max_iou = -1
  124. matched_gt = -1
  125. # find best overlapped available gt
  126. for j in range(num_gts):
  127. # different from PASCAL VOC: allow finding other gts if the
  128. # best overlapped ones are already matched by other det bboxes
  129. if gt_covered[j]:
  130. continue
  131. elif ious[i, j] >= iou_thrs[j] and ious[i, j] > max_iou:
  132. max_iou = ious[i, j]
  133. matched_gt = j
  134. # there are 4 cases for a det bbox:
  135. # 1. it matches a gt, tp = 1, fp = 0
  136. # 2. it matches an ignored gt, tp = 0, fp = 0
  137. # 3. it matches no gt and within area range, tp = 0, fp = 1
  138. # 4. it matches no gt but is beyond area range, tp = 0, fp = 0
  139. if matched_gt >= 0:
  140. gt_covered[matched_gt] = 1
  141. if not (gt_ignore_inds[matched_gt]
  142. or gt_area_ignore[matched_gt]):
  143. tp[k, i] = 1
  144. elif min_area is None:
  145. fp[k, i] = 1
  146. else:
  147. bbox = det_bboxes[i, :4]
  148. area = (bbox[2] - bbox[0] + extra_length) * (
  149. bbox[3] - bbox[1] + extra_length)
  150. if area >= min_area and area < max_area:
  151. fp[k, i] = 1
  152. return tp, fp
  153. def tpfp_default(det_bboxes,
  154. gt_bboxes,
  155. gt_bboxes_ignore=None,
  156. iou_thr=0.5,
  157. area_ranges=None,
  158. use_legacy_coordinate=False):
  159. """Check if detected bboxes are true positive or false positive.
  160. Args:
  161. det_bbox (ndarray): Detected bboxes of this image, of shape (m, 5).
  162. gt_bboxes (ndarray): GT bboxes of this image, of shape (n, 4).
  163. gt_bboxes_ignore (ndarray): Ignored gt bboxes of this image,
  164. of shape (k, 4). Default: None
  165. iou_thr (float): IoU threshold to be considered as matched.
  166. Default: 0.5.
  167. area_ranges (list[tuple] | None): Range of bbox areas to be
  168. evaluated, in the format [(min1, max1), (min2, max2), ...].
  169. Default: None.
  170. use_legacy_coordinate (bool): Whether to use coordinate system in
  171. mmdet v1.x. which means width, height should be
  172. calculated as 'x2 - x1 + 1` and 'y2 - y1 + 1' respectively.
  173. Default: False.
  174. Returns:
  175. tuple[np.ndarray]: (tp, fp) whose elements are 0 and 1. The shape of
  176. each array is (num_scales, m).
  177. """
  178. if not use_legacy_coordinate:
  179. extra_length = 0.
  180. else:
  181. extra_length = 1.
  182. # an indicator of ignored gts
  183. gt_ignore_inds = np.concatenate(
  184. (np.zeros(gt_bboxes.shape[0], dtype=np.bool),
  185. np.ones(gt_bboxes_ignore.shape[0], dtype=np.bool)))
  186. # stack gt_bboxes and gt_bboxes_ignore for convenience
  187. gt_bboxes = np.vstack((gt_bboxes, gt_bboxes_ignore))
  188. num_dets = det_bboxes.shape[0]
  189. num_gts = gt_bboxes.shape[0]
  190. if area_ranges is None:
  191. area_ranges = [(None, None)]
  192. num_scales = len(area_ranges)
  193. # tp and fp are of shape (num_scales, num_gts), each row is tp or fp of
  194. # a certain scale
  195. tp = np.zeros((num_scales, num_dets), dtype=np.float32)
  196. fp = np.zeros((num_scales, num_dets), dtype=np.float32)
  197. # if there is no gt bboxes in this image, then all det bboxes
  198. # within area range are false positives
  199. if gt_bboxes.shape[0] == 0:
  200. if area_ranges == [(None, None)]:
  201. fp[...] = 1
  202. else:
  203. det_areas = (
  204. det_bboxes[:, 2] - det_bboxes[:, 0] + extra_length) * (
  205. det_bboxes[:, 3] - det_bboxes[:, 1] + extra_length)
  206. for i, (min_area, max_area) in enumerate(area_ranges):
  207. fp[i, (det_areas >= min_area) & (det_areas < max_area)] = 1
  208. return tp, fp
  209. ious = bbox_overlaps(
  210. det_bboxes, gt_bboxes, use_legacy_coordinate=use_legacy_coordinate)
  211. # for each det, the max iou with all gts
  212. ious_max = ious.max(axis=1)
  213. # for each det, which gt overlaps most with it
  214. ious_argmax = ious.argmax(axis=1)
  215. # sort all dets in descending order by scores
  216. sort_inds = np.argsort(-det_bboxes[:, -1])
  217. for k, (min_area, max_area) in enumerate(area_ranges):
  218. gt_covered = np.zeros(num_gts, dtype=bool)
  219. # if no area range is specified, gt_area_ignore is all False
  220. if min_area is None:
  221. gt_area_ignore = np.zeros_like(gt_ignore_inds, dtype=bool)
  222. else:
  223. gt_areas = (gt_bboxes[:, 2] - gt_bboxes[:, 0] + extra_length) * (
  224. gt_bboxes[:, 3] - gt_bboxes[:, 1] + extra_length)
  225. gt_area_ignore = (gt_areas < min_area) | (gt_areas >= max_area)
  226. for i in sort_inds:
  227. if ious_max[i] >= iou_thr:
  228. matched_gt = ious_argmax[i]
  229. if not (gt_ignore_inds[matched_gt]
  230. or gt_area_ignore[matched_gt]):
  231. if not gt_covered[matched_gt]:
  232. gt_covered[matched_gt] = True
  233. tp[k, i] = 1
  234. else:
  235. fp[k, i] = 1
  236. # otherwise ignore this detected bbox, tp = 0, fp = 0
  237. elif min_area is None:
  238. fp[k, i] = 1
  239. else:
  240. bbox = det_bboxes[i, :4]
  241. area = (bbox[2] - bbox[0] + extra_length) * (
  242. bbox[3] - bbox[1] + extra_length)
  243. if area >= min_area and area < max_area:
  244. fp[k, i] = 1
  245. return tp, fp
  246. def get_cls_results(det_results, annotations, class_id):
  247. """Get det results and gt information of a certain class.
  248. Args:
  249. det_results (list[list]): Same as `eval_map()`.
  250. annotations (list[dict]): Same as `eval_map()`.
  251. class_id (int): ID of a specific class.
  252. Returns:
  253. tuple[list[np.ndarray]]: detected bboxes, gt bboxes, ignored gt bboxes
  254. """
  255. cls_dets = [img_res[class_id] for img_res in det_results]
  256. cls_gts = []
  257. cls_gts_ignore = []
  258. for ann in annotations:
  259. gt_inds = ann['labels'] == class_id
  260. cls_gts.append(ann['bboxes'][gt_inds, :])
  261. if ann.get('labels_ignore', None) is not None:
  262. ignore_inds = ann['labels_ignore'] == class_id
  263. cls_gts_ignore.append(ann['bboxes_ignore'][ignore_inds, :])
  264. else:
  265. cls_gts_ignore.append(np.empty((0, 4), dtype=np.float32))
  266. return cls_dets, cls_gts, cls_gts_ignore
  267. def eval_map(det_results,
  268. annotations,
  269. scale_ranges=None,
  270. iou_thr=0.5,
  271. dataset=None,
  272. logger=None,
  273. tpfp_fn=None,
  274. nproc=4,
  275. use_legacy_coordinate=False):
  276. """Evaluate mAP of a dataset.
  277. Args:
  278. det_results (list[list]): [[cls1_det, cls2_det, ...], ...].
  279. The outer list indicates images, and the inner list indicates
  280. per-class detected bboxes.
  281. annotations (list[dict]): Ground truth annotations where each item of
  282. the list indicates an image. Keys of annotations are:
  283. - `bboxes`: numpy array of shape (n, 4)
  284. - `labels`: numpy array of shape (n, )
  285. - `bboxes_ignore` (optional): numpy array of shape (k, 4)
  286. - `labels_ignore` (optional): numpy array of shape (k, )
  287. scale_ranges (list[tuple] | None): Range of scales to be evaluated,
  288. in the format [(min1, max1), (min2, max2), ...]. A range of
  289. (32, 64) means the area range between (32**2, 64**2).
  290. Default: None.
  291. iou_thr (float): IoU threshold to be considered as matched.
  292. Default: 0.5.
  293. dataset (list[str] | str | None): Dataset name or dataset classes,
  294. there are minor differences in metrics for different datasets, e.g.
  295. "voc07", "imagenet_det", etc. Default: None.
  296. logger (logging.Logger | str | None): The way to print the mAP
  297. summary. See `mmcv.utils.print_log()` for details. Default: None.
  298. tpfp_fn (callable | None): The function used to determine true/
  299. false positives. If None, :func:`tpfp_default` is used as default
  300. unless dataset is 'det' or 'vid' (:func:`tpfp_imagenet` in this
  301. case). If it is given as a function, then this function is used
  302. to evaluate tp & fp. Default None.
  303. nproc (int): Processes used for computing TP and FP.
  304. Default: 4.
  305. use_legacy_coordinate (bool): Whether to use coordinate system in
  306. mmdet v1.x. which means width, height should be
  307. calculated as 'x2 - x1 + 1` and 'y2 - y1 + 1' respectively.
  308. Default: False.
  309. Returns:
  310. tuple: (mAP, [dict, dict, ...])
  311. """
  312. assert len(det_results) == len(annotations)
  313. if not use_legacy_coordinate:
  314. extra_length = 0.
  315. else:
  316. extra_length = 1.
  317. num_imgs = len(det_results)
  318. num_scales = len(scale_ranges) if scale_ranges is not None else 1
  319. num_classes = len(det_results[0]) # positive class num
  320. area_ranges = ([(rg[0]**2, rg[1]**2) for rg in scale_ranges]
  321. if scale_ranges is not None else None)
  322. pool = Pool(nproc)
  323. eval_results = []
  324. for i in range(num_classes):
  325. # get gt and det bboxes of this class
  326. cls_dets, cls_gts, cls_gts_ignore = get_cls_results(
  327. det_results, annotations, i)
  328. # choose proper function according to datasets to compute tp and fp
  329. if tpfp_fn is None:
  330. if dataset in ['det', 'vid']:
  331. tpfp_fn = tpfp_imagenet
  332. else:
  333. tpfp_fn = tpfp_default
  334. if not callable(tpfp_fn):
  335. raise ValueError(
  336. f'tpfp_fn has to be a function or None, but got {tpfp_fn}')
  337. # compute tp and fp for each image with multiple processes
  338. tpfp = pool.starmap(
  339. tpfp_fn,
  340. zip(cls_dets, cls_gts, cls_gts_ignore,
  341. [iou_thr for _ in range(num_imgs)],
  342. [area_ranges for _ in range(num_imgs)],
  343. [use_legacy_coordinate for _ in range(num_imgs)]))
  344. tp, fp = tuple(zip(*tpfp))
  345. # calculate gt number of each scale
  346. # ignored gts or gts beyond the specific scale are not counted
  347. num_gts = np.zeros(num_scales, dtype=int)
  348. for j, bbox in enumerate(cls_gts):
  349. if area_ranges is None:
  350. num_gts[0] += bbox.shape[0]
  351. else:
  352. gt_areas = (bbox[:, 2] - bbox[:, 0] + extra_length) * (
  353. bbox[:, 3] - bbox[:, 1] + extra_length)
  354. for k, (min_area, max_area) in enumerate(area_ranges):
  355. num_gts[k] += np.sum((gt_areas >= min_area)
  356. & (gt_areas < max_area))
  357. # sort all det bboxes by score, also sort tp and fp
  358. cls_dets = np.vstack(cls_dets)
  359. num_dets = cls_dets.shape[0]
  360. sort_inds = np.argsort(-cls_dets[:, -1])
  361. tp = np.hstack(tp)[:, sort_inds]
  362. fp = np.hstack(fp)[:, sort_inds]
  363. # calculate recall and precision with tp and fp
  364. tp = np.cumsum(tp, axis=1)
  365. fp = np.cumsum(fp, axis=1)
  366. eps = np.finfo(np.float32).eps
  367. recalls = tp / np.maximum(num_gts[:, np.newaxis], eps)
  368. precisions = tp / np.maximum((tp + fp), eps)
  369. # calculate AP
  370. if scale_ranges is None:
  371. recalls = recalls[0, :]
  372. precisions = precisions[0, :]
  373. num_gts = num_gts.item()
  374. mode = 'area' if dataset != 'voc07' else '11points'
  375. ap = average_precision(recalls, precisions, mode)
  376. eval_results.append({
  377. 'num_gts': num_gts,
  378. 'num_dets': num_dets,
  379. 'recall': recalls,
  380. 'precision': precisions,
  381. 'ap': ap
  382. })
  383. pool.close()
  384. if scale_ranges is not None:
  385. # shape (num_classes, num_scales)
  386. all_ap = np.vstack([cls_result['ap'] for cls_result in eval_results])
  387. all_num_gts = np.vstack(
  388. [cls_result['num_gts'] for cls_result in eval_results])
  389. mean_ap = []
  390. for i in range(num_scales):
  391. if np.any(all_num_gts[:, i] > 0):
  392. mean_ap.append(all_ap[all_num_gts[:, i] > 0, i].mean())
  393. else:
  394. mean_ap.append(0.0)
  395. else:
  396. aps = []
  397. for cls_result in eval_results:
  398. if cls_result['num_gts'] > 0:
  399. aps.append(cls_result['ap'])
  400. mean_ap = np.array(aps).mean().item() if aps else 0.0
  401. print_map_summary(
  402. mean_ap, eval_results, dataset, area_ranges, logger=logger)
  403. return mean_ap, eval_results
  404. def print_map_summary(mean_ap,
  405. results,
  406. dataset=None,
  407. scale_ranges=None,
  408. logger=None):
  409. """Print mAP and results of each class.
  410. A table will be printed to show the gts/dets/recall/AP of each class and
  411. the mAP.
  412. Args:
  413. mean_ap (float): Calculated from `eval_map()`.
  414. results (list[dict]): Calculated from `eval_map()`.
  415. dataset (list[str] | str | None): Dataset name or dataset classes.
  416. scale_ranges (list[tuple] | None): Range of scales to be evaluated.
  417. logger (logging.Logger | str | None): The way to print the mAP
  418. summary. See `mmcv.utils.print_log()` for details. Default: None.
  419. """
  420. if logger == 'silent':
  421. return
  422. if isinstance(results[0]['ap'], np.ndarray):
  423. num_scales = len(results[0]['ap'])
  424. else:
  425. num_scales = 1
  426. if scale_ranges is not None:
  427. assert len(scale_ranges) == num_scales
  428. num_classes = len(results)
  429. recalls = np.zeros((num_scales, num_classes), dtype=np.float32)
  430. aps = np.zeros((num_scales, num_classes), dtype=np.float32)
  431. num_gts = np.zeros((num_scales, num_classes), dtype=int)
  432. for i, cls_result in enumerate(results):
  433. if cls_result['recall'].size > 0:
  434. recalls[:, i] = np.array(cls_result['recall'], ndmin=2)[:, -1]
  435. aps[:, i] = cls_result['ap']
  436. num_gts[:, i] = cls_result['num_gts']
  437. if dataset is None:
  438. label_names = [str(i) for i in range(num_classes)]
  439. elif mmcv.is_str(dataset):
  440. label_names = get_classes(dataset)
  441. else:
  442. label_names = dataset
  443. if not isinstance(mean_ap, list):
  444. mean_ap = [mean_ap]
  445. header = ['class', 'gts', 'dets', 'recall', 'ap']
  446. for i in range(num_scales):
  447. if scale_ranges is not None:
  448. print_log(f'Scale range {scale_ranges[i]}', logger=logger)
  449. table_data = [header]
  450. for j in range(num_classes):
  451. row_data = [
  452. label_names[j], num_gts[i, j], results[j]['num_dets'],
  453. f'{recalls[i, j]:.3f}', f'{aps[i, j]:.3f}'
  454. ]
  455. table_data.append(row_data)
  456. table_data.append(['mAP', '', '', '', f'{mean_ap[i]:.3f}'])
  457. table = AsciiTable(table_data)
  458. table.inner_footing_row_border = True
  459. print_log('\n' + table.table, logger=logger)

No Description

Contributors (1)