You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

metric.py 6.8 kB

4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184
  1. """Modified from Tong Xiao's open-reid (https://github.com/Cysu/open-reid)
  2. reid/evaluation_metrics/ranking.py. Modifications:
  3. 1) Only accepts numpy data input, no torch is involved.
  4. 1) Here results of each query can be returned.
  5. 2) In the single-gallery-shot evaluation case, the time of repeats is changed
  6. from 10 to 100.
  7. """
  8. from __future__ import absolute_import
  9. from collections import defaultdict
  10. import numpy as np
  11. from sklearn.metrics import average_precision_score
  12. def _unique_sample(ids_dict, num):
  13. mask = np.zeros(num, dtype=np.bool)
  14. for _, indices in ids_dict.items():
  15. i = np.random.choice(indices)
  16. mask[i] = True
  17. return mask
  18. def cmc(
  19. distmat,
  20. query_ids=None,
  21. gallery_ids=None,
  22. query_cams=None,
  23. gallery_cams=None,
  24. topk=100,
  25. separate_camera_set=False,
  26. single_gallery_shot=False,
  27. first_match_break=False,
  28. average=True):
  29. """
  30. Args:
  31. distmat: numpy array with shape [num_query, num_gallery], the
  32. pairwise distance between query and gallery samples
  33. query_ids: numpy array with shape [num_query]
  34. gallery_ids: numpy array with shape [num_gallery]
  35. query_cams: numpy array with shape [num_query]
  36. gallery_cams: numpy array with shape [num_gallery]
  37. average: whether to average the results across queries
  38. Returns:
  39. If `average` is `False`:
  40. ret: numpy array with shape [num_query, topk]
  41. is_valid_query: numpy array with shape [num_query], containing 0's and
  42. 1's, whether each query is valid or not
  43. If `average` is `True`:
  44. numpy array with shape [topk]
  45. """
  46. # Ensure numpy array
  47. assert isinstance(distmat, np.ndarray)
  48. assert isinstance(query_ids, np.ndarray)
  49. assert isinstance(gallery_ids, np.ndarray)
  50. # assert isinstance(query_cams, np.ndarray)
  51. # assert isinstance(gallery_cams, np.ndarray)
  52. first_match_break = True
  53. m, _ = distmat.shape
  54. # Sort and find correct matches
  55. indices = np.argsort(distmat, axis=1)
  56. matches = (gallery_ids[indices] == query_ids[:, np.newaxis])
  57. # Compute CMC for each query
  58. ret = np.zeros([m, topk])
  59. is_valid_query = np.zeros(m)
  60. num_valid_queries = 0
  61. for i in range(m):
  62. valid = (gallery_ids[indices[i]] != query_ids[i]) | (gallery_ids[indices[i]] == query_ids[i])
  63. if separate_camera_set:
  64. # Filter out samples from same camera
  65. valid = (gallery_cams[indices[i]] != query_cams[i])
  66. if not np.any(matches[i, valid]): continue
  67. is_valid_query[i] = 1
  68. if single_gallery_shot:
  69. repeat = 100
  70. gids = gallery_ids[indices[i][valid]]
  71. inds = np.where(valid)[0]
  72. ids_dict = defaultdict(list)
  73. for j, x in zip(inds, gids):
  74. ids_dict[x].append(j)
  75. else:
  76. repeat = 1
  77. for _ in range(repeat):
  78. if single_gallery_shot:
  79. # Randomly choose one instance for each id
  80. sampled = (valid & _unique_sample(ids_dict, len(valid)))
  81. index = np.nonzero(matches[i, sampled])[0]
  82. else:
  83. index = np.nonzero(matches[i, valid])[0]
  84. delta = 1. / (len(index) * repeat)
  85. for j, k in enumerate(index):
  86. if k - j >= topk: break
  87. if first_match_break:
  88. ret[i, k - j] += 1
  89. break
  90. ret[i, k - j] += delta
  91. num_valid_queries += 1
  92. if num_valid_queries == 0:
  93. raise RuntimeError("No valid query")
  94. ret = ret.cumsum(axis=1)
  95. if average:
  96. return np.sum(ret, axis=0) / num_valid_queries, indices
  97. return ret, is_valid_query, indices
  98. def mean_ap(
  99. distmat,
  100. query_ids=None,
  101. gallery_ids=None,
  102. query_cams=None,
  103. gallery_cams=None,
  104. average=True):
  105. """
  106. Args:
  107. distmat: numpy array with shape [num_query, num_gallery], the
  108. pairwise distance between query and gallery samples
  109. query_ids: numpy array with shape [num_query]
  110. gallery_ids: numpy array with shape [num_gallery]
  111. query_cams: numpy array with shape [num_query]
  112. gallery_cams: numpy array with shape [num_gallery]
  113. average: whether to average the results across queries
  114. Returns:
  115. If `average` is `False`:
  116. ret: numpy array with shape [num_query]
  117. is_valid_query: numpy array with shape [num_query], containing 0's and
  118. 1's, whether each query is valid or not
  119. If `average` is `True`:
  120. a scalar
  121. """
  122. # -------------------------------------------------------------------------
  123. # The behavior of method `sklearn.average_precision` has changed since version
  124. # 0.19.
  125. # Version 0.18.1 has same results as Matlab evaluation code by Zhun Zhong
  126. # (https://github.com/zhunzhong07/person-re-ranking/
  127. # blob/master/evaluation/utils/evaluation.m) and by Liang Zheng
  128. # (http://www.liangzheng.org/Project/project_reid.html).
  129. # My current awkward solution is sticking to this older version.
  130. # if cur_version != required_version:
  131. # print('User Warning: Version {} is required for package scikit-learn, '
  132. # 'your current version is {}. '
  133. # 'As a result, the mAP score may not be totally correct. '
  134. # 'You can try `pip uninstall scikit-learn` '
  135. # 'and then `pip install scikit-learn=={}`'.format(
  136. # required_version, cur_version, required_version))
  137. # -------------------------------------------------------------------------
  138. # Ensure numpy array
  139. assert isinstance(distmat, np.ndarray)
  140. assert isinstance(query_ids, np.ndarray)
  141. assert isinstance(gallery_ids, np.ndarray)
  142. # assert isinstance(query_cams, np.ndarray)
  143. # assert isinstance(gallery_cams, np.ndarray)
  144. m, _ = distmat.shape
  145. # Sort and find correct matches
  146. indices = np.argsort(distmat, axis=1)
  147. # print("indices:", indices)
  148. matches = (gallery_ids[indices] == query_ids[:, np.newaxis])
  149. # Compute AP for each query
  150. aps = np.zeros(m)
  151. is_valid_query = np.zeros(m)
  152. for i in range(m):
  153. # Filter out the same id and same camera
  154. valid = (gallery_ids[indices[i]] != query_ids[i]) | (gallery_ids[indices[i]] == query_ids[i])
  155. y_true = matches[i, valid]
  156. y_score = -distmat[i][indices[i]][valid]
  157. if not np.any(y_true): continue
  158. is_valid_query[i] = 1
  159. aps[i] = average_precision_score(y_true, y_score)
  160. if average:
  161. return float(np.sum(aps)) / np.sum(is_valid_query)
  162. return aps, is_valid_query