You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

metric.py 7.2 kB

4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194
  1. """Modified from Tong Xiao's open-reid (https://github.com/Cysu/open-reid)
  2. reid/evaluation_metrics/ranking.py. Modifications:
  3. 1) Only accepts numpy data input, no torch is involved.
  4. 1) Here results of each query can be returned.
  5. 2) In the single-gallery-shot evaluation case, the time of repeats is changed
  6. from 10 to 100.
  7. """
  8. from __future__ import absolute_import
  9. from collections import defaultdict
  10. import numpy as np
  11. from sklearn.metrics import average_precision_score
  12. def _unique_sample(ids_dict, num):
  13. mask = np.zeros(num, dtype=np.bool)
  14. for _, indices in ids_dict.items():
  15. i = np.random.choice(indices)
  16. mask[i] = True
  17. return mask
  18. def cmc(
  19. distmat,
  20. query_ids=None,
  21. gallery_ids=None,
  22. query_cams=None,
  23. gallery_cams=None,
  24. topk=100,
  25. separate_camera_set=False,
  26. single_gallery_shot=False,
  27. first_match_break=False,
  28. average=True):
  29. """
  30. Args:
  31. distmat: numpy array with shape [num_query, num_gallery], the
  32. pairwise distance between query and gallery samples
  33. query_ids: numpy array with shape [num_query]
  34. gallery_ids: numpy array with shape [num_gallery]
  35. query_cams: numpy array with shape [num_query]
  36. gallery_cams: numpy array with shape [num_gallery]
  37. average: whether to average the results across queries
  38. Returns:
  39. If `average` is `False`:
  40. ret: numpy array with shape [num_query, topk]
  41. is_valid_query: numpy array with shape [num_query], containing 0's and
  42. 1's, whether each query is valid or not
  43. If `average` is `True`:
  44. numpy array with shape [topk]
  45. """
  46. # Ensure numpy array
  47. assert isinstance(distmat, np.ndarray)
  48. assert isinstance(query_ids, np.ndarray)
  49. assert isinstance(gallery_ids, np.ndarray)
  50. # assert isinstance(query_cams, np.ndarray)
  51. # assert isinstance(gallery_cams, np.ndarray)
  52. # separate_camera_set=False
  53. first_match_break = True
  54. m, _ = distmat.shape
  55. # Sort and find correct matches
  56. indices = np.argsort(distmat, axis=1)
  57. #print(indices)
  58. matches = (gallery_ids[indices] == query_ids[:, np.newaxis])
  59. # Compute CMC for each query
  60. ret = np.zeros([m, topk])
  61. is_valid_query = np.zeros(m)
  62. num_valid_queries = 0
  63. for i in range(m):
  64. valid = (gallery_ids[indices[i]] != query_ids[i]) | (gallery_ids[indices[i]] == query_ids[i])
  65. if separate_camera_set:
  66. # Filter out samples from same camera
  67. valid = (gallery_cams[indices[i]] != query_cams[i])
  68. if not np.any(matches[i, valid]): continue
  69. is_valid_query[i] = 1
  70. if single_gallery_shot:
  71. repeat = 100
  72. gids = gallery_ids[indices[i][valid]]
  73. inds = np.where(valid)[0]
  74. ids_dict = defaultdict(list)
  75. for j, x in zip(inds, gids):
  76. ids_dict[x].append(j)
  77. else:
  78. repeat = 1
  79. for _ in range(repeat):
  80. if single_gallery_shot:
  81. # Randomly choose one instance for each id
  82. sampled = (valid & _unique_sample(ids_dict, len(valid)))
  83. index = np.nonzero(matches[i, sampled])[0]
  84. else:
  85. index = np.nonzero(matches[i, valid])[0]
  86. delta = 1. / (len(index) * repeat)
  87. for j, k in enumerate(index):
  88. if k - j >= topk: break
  89. if first_match_break:
  90. ret[i, k - j] += 1
  91. break
  92. ret[i, k - j] += delta
  93. num_valid_queries += 1
  94. if num_valid_queries == 0:
  95. raise RuntimeError("No valid query")
  96. ret = ret.cumsum(axis=1)
  97. if average:
  98. return np.sum(ret, axis=0) / num_valid_queries, indices
  99. return ret, is_valid_query, indices
  100. def mean_ap(
  101. distmat,
  102. query_ids=None,
  103. gallery_ids=None,
  104. query_cams=None,
  105. gallery_cams=None,
  106. average=True):
  107. """
  108. Args:
  109. distmat: numpy array with shape [num_query, num_gallery], the
  110. pairwise distance between query and gallery samples
  111. query_ids: numpy array with shape [num_query]
  112. gallery_ids: numpy array with shape [num_gallery]
  113. query_cams: numpy array with shape [num_query]
  114. gallery_cams: numpy array with shape [num_gallery]
  115. average: whether to average the results across queries
  116. Returns:
  117. If `average` is `False`:
  118. ret: numpy array with shape [num_query]
  119. is_valid_query: numpy array with shape [num_query], containing 0's and
  120. 1's, whether each query is valid or not
  121. If `average` is `True`:
  122. a scalar
  123. """
  124. # -------------------------------------------------------------------------
  125. # The behavior of method `sklearn.average_precision` has changed since version
  126. # 0.19.
  127. # Version 0.18.1 has same results as Matlab evaluation code by Zhun Zhong
  128. # (https://github.com/zhunzhong07/person-re-ranking/
  129. # blob/master/evaluation/utils/evaluation.m) and by Liang Zheng
  130. # (http://www.liangzheng.org/Project/project_reid.html).
  131. # My current awkward solution is sticking to this older version.
  132. # if cur_version != required_version:
  133. # print('User Warning: Version {} is required for package scikit-learn, '
  134. # 'your current version is {}. '
  135. # 'As a result, the mAP score may not be totally correct. '
  136. # 'You can try `pip uninstall scikit-learn` '
  137. # 'and then `pip install scikit-learn=={}`'.format(
  138. # required_version, cur_version, required_version))
  139. # -------------------------------------------------------------------------
  140. # Ensure numpy array
  141. assert isinstance(distmat, np.ndarray)
  142. assert isinstance(query_ids, np.ndarray)
  143. assert isinstance(gallery_ids, np.ndarray)
  144. # assert isinstance(query_cams, np.ndarray)
  145. # assert isinstance(gallery_cams, np.ndarray)
  146. m, _ = distmat.shape
  147. # Sort and find correct matches
  148. indices = np.argsort(distmat, axis=1)
  149. # print("indices:", indices)
  150. matches = (gallery_ids[indices] == query_ids[:, np.newaxis])
  151. # Compute AP for each query
  152. aps = np.zeros(m)
  153. is_valid_query = np.zeros(m)
  154. for i in range(m):
  155. # Filter out the same id and same camera
  156. # valid = ((gallery_ids[indices[i]] != query_ids[i]) |
  157. # (gallery_cams[indices[i]] != query_cams[i]))
  158. valid = (gallery_ids[indices[i]] != query_ids[i]) | (gallery_ids[indices[i]] == query_ids[i])
  159. # valid = indices[i] != i
  160. # valid = (gallery_cams[indices[i]] != query_cams[i])
  161. y_true = matches[i, valid]
  162. y_score = -distmat[i][indices[i]][valid]
  163. # y_true=y_true[0:100]
  164. # y_score=y_score[0:100]
  165. if not np.any(y_true): continue
  166. is_valid_query[i] = 1
  167. aps[i] = average_precision_score(y_true, y_score)
  168. # if not aps:
  169. # raise RuntimeError("No valid query")
  170. if average:
  171. return float(np.sum(aps)) / np.sum(is_valid_query)
  172. return aps, is_valid_query