You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

recall.py 6.8 kB

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from collections.abc import Sequence
  3. import numpy as np
  4. from mmcv.utils import print_log
  5. from terminaltables import AsciiTable
  6. from .bbox_overlaps import bbox_overlaps
  7. def _recalls(all_ious, proposal_nums, thrs):
  8. img_num = all_ious.shape[0]
  9. total_gt_num = sum([ious.shape[0] for ious in all_ious])
  10. _ious = np.zeros((proposal_nums.size, total_gt_num), dtype=np.float32)
  11. for k, proposal_num in enumerate(proposal_nums):
  12. tmp_ious = np.zeros(0)
  13. for i in range(img_num):
  14. ious = all_ious[i][:, :proposal_num].copy()
  15. gt_ious = np.zeros((ious.shape[0]))
  16. if ious.size == 0:
  17. tmp_ious = np.hstack((tmp_ious, gt_ious))
  18. continue
  19. for j in range(ious.shape[0]):
  20. gt_max_overlaps = ious.argmax(axis=1)
  21. max_ious = ious[np.arange(0, ious.shape[0]), gt_max_overlaps]
  22. gt_idx = max_ious.argmax()
  23. gt_ious[j] = max_ious[gt_idx]
  24. box_idx = gt_max_overlaps[gt_idx]
  25. ious[gt_idx, :] = -1
  26. ious[:, box_idx] = -1
  27. tmp_ious = np.hstack((tmp_ious, gt_ious))
  28. _ious[k, :] = tmp_ious
  29. _ious = np.fliplr(np.sort(_ious, axis=1))
  30. recalls = np.zeros((proposal_nums.size, thrs.size))
  31. for i, thr in enumerate(thrs):
  32. recalls[:, i] = (_ious >= thr).sum(axis=1) / float(total_gt_num)
  33. return recalls
  34. def set_recall_param(proposal_nums, iou_thrs):
  35. """Check proposal_nums and iou_thrs and set correct format."""
  36. if isinstance(proposal_nums, Sequence):
  37. _proposal_nums = np.array(proposal_nums)
  38. elif isinstance(proposal_nums, int):
  39. _proposal_nums = np.array([proposal_nums])
  40. else:
  41. _proposal_nums = proposal_nums
  42. if iou_thrs is None:
  43. _iou_thrs = np.array([0.5])
  44. elif isinstance(iou_thrs, Sequence):
  45. _iou_thrs = np.array(iou_thrs)
  46. elif isinstance(iou_thrs, float):
  47. _iou_thrs = np.array([iou_thrs])
  48. else:
  49. _iou_thrs = iou_thrs
  50. return _proposal_nums, _iou_thrs
  51. def eval_recalls(gts,
  52. proposals,
  53. proposal_nums=None,
  54. iou_thrs=0.5,
  55. logger=None,
  56. use_legacy_coordinate=False):
  57. """Calculate recalls.
  58. Args:
  59. gts (list[ndarray]): a list of arrays of shape (n, 4)
  60. proposals (list[ndarray]): a list of arrays of shape (k, 4) or (k, 5)
  61. proposal_nums (int | Sequence[int]): Top N proposals to be evaluated.
  62. iou_thrs (float | Sequence[float]): IoU thresholds. Default: 0.5.
  63. logger (logging.Logger | str | None): The way to print the recall
  64. summary. See `mmcv.utils.print_log()` for details. Default: None.
  65. use_legacy_coordinate (bool): Whether use coordinate system
  66. in mmdet v1.x. "1" was added to both height and width
  67. which means w, h should be
  68. computed as 'x2 - x1 + 1` and 'y2 - y1 + 1'. Default: False.
  69. Returns:
  70. ndarray: recalls of different ious and proposal nums
  71. """
  72. img_num = len(gts)
  73. assert img_num == len(proposals)
  74. proposal_nums, iou_thrs = set_recall_param(proposal_nums, iou_thrs)
  75. all_ious = []
  76. for i in range(img_num):
  77. if proposals[i].ndim == 2 and proposals[i].shape[1] == 5:
  78. scores = proposals[i][:, 4]
  79. sort_idx = np.argsort(scores)[::-1]
  80. img_proposal = proposals[i][sort_idx, :]
  81. else:
  82. img_proposal = proposals[i]
  83. prop_num = min(img_proposal.shape[0], proposal_nums[-1])
  84. if gts[i] is None or gts[i].shape[0] == 0:
  85. ious = np.zeros((0, img_proposal.shape[0]), dtype=np.float32)
  86. else:
  87. ious = bbox_overlaps(
  88. gts[i],
  89. img_proposal[:prop_num, :4],
  90. use_legacy_coordinate=use_legacy_coordinate)
  91. all_ious.append(ious)
  92. all_ious = np.array(all_ious)
  93. recalls = _recalls(all_ious, proposal_nums, iou_thrs)
  94. print_recall_summary(recalls, proposal_nums, iou_thrs, logger=logger)
  95. return recalls
  96. def print_recall_summary(recalls,
  97. proposal_nums,
  98. iou_thrs,
  99. row_idxs=None,
  100. col_idxs=None,
  101. logger=None):
  102. """Print recalls in a table.
  103. Args:
  104. recalls (ndarray): calculated from `bbox_recalls`
  105. proposal_nums (ndarray or list): top N proposals
  106. iou_thrs (ndarray or list): iou thresholds
  107. row_idxs (ndarray): which rows(proposal nums) to print
  108. col_idxs (ndarray): which cols(iou thresholds) to print
  109. logger (logging.Logger | str | None): The way to print the recall
  110. summary. See `mmcv.utils.print_log()` for details. Default: None.
  111. """
  112. proposal_nums = np.array(proposal_nums, dtype=np.int32)
  113. iou_thrs = np.array(iou_thrs)
  114. if row_idxs is None:
  115. row_idxs = np.arange(proposal_nums.size)
  116. if col_idxs is None:
  117. col_idxs = np.arange(iou_thrs.size)
  118. row_header = [''] + iou_thrs[col_idxs].tolist()
  119. table_data = [row_header]
  120. for i, num in enumerate(proposal_nums[row_idxs]):
  121. row = [f'{val:.3f}' for val in recalls[row_idxs[i], col_idxs].tolist()]
  122. row.insert(0, num)
  123. table_data.append(row)
  124. table = AsciiTable(table_data)
  125. print_log('\n' + table.table, logger=logger)
  126. def plot_num_recall(recalls, proposal_nums):
  127. """Plot Proposal_num-Recalls curve.
  128. Args:
  129. recalls(ndarray or list): shape (k,)
  130. proposal_nums(ndarray or list): same shape as `recalls`
  131. """
  132. if isinstance(proposal_nums, np.ndarray):
  133. _proposal_nums = proposal_nums.tolist()
  134. else:
  135. _proposal_nums = proposal_nums
  136. if isinstance(recalls, np.ndarray):
  137. _recalls = recalls.tolist()
  138. else:
  139. _recalls = recalls
  140. import matplotlib.pyplot as plt
  141. f = plt.figure()
  142. plt.plot([0] + _proposal_nums, [0] + _recalls)
  143. plt.xlabel('Proposal num')
  144. plt.ylabel('Recall')
  145. plt.axis([0, proposal_nums.max(), 0, 1])
  146. f.show()
  147. def plot_iou_recall(recalls, iou_thrs):
  148. """Plot IoU-Recalls curve.
  149. Args:
  150. recalls(ndarray or list): shape (k,)
  151. iou_thrs(ndarray or list): same shape as `recalls`
  152. """
  153. if isinstance(iou_thrs, np.ndarray):
  154. _iou_thrs = iou_thrs.tolist()
  155. else:
  156. _iou_thrs = iou_thrs
  157. if isinstance(recalls, np.ndarray):
  158. _recalls = recalls.tolist()
  159. else:
  160. _recalls = recalls
  161. import matplotlib.pyplot as plt
  162. f = plt.figure()
  163. plt.plot(_iou_thrs + [1.0], _recalls + [0.])
  164. plt.xlabel('IoU')
  165. plt.ylabel('Recall')
  166. plt.axis([iou_thrs.min(), 1, 0, 1])
  167. f.show()

No Description

Contributors (1)