You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

util.py 7.2 kB

5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """metrics utils"""
  16. import numpy as np
  17. from config import ConfigSSD
  18. from dataset import ssd_bboxes_decode
  19. def calc_iou(bbox_pred, bbox_ground):
  20. """Calculate iou of predicted bbox and ground truth."""
  21. bbox_pred = np.expand_dims(bbox_pred, axis=0)
  22. pred_w = bbox_pred[:, 2] - bbox_pred[:, 0]
  23. pred_h = bbox_pred[:, 3] - bbox_pred[:, 1]
  24. pred_area = pred_w * pred_h
  25. gt_w = bbox_ground[:, 2] - bbox_ground[:, 0]
  26. gt_h = bbox_ground[:, 3] - bbox_ground[:, 1]
  27. gt_area = gt_w * gt_h
  28. iw = np.minimum(bbox_pred[:, 2], bbox_ground[:, 2]) - np.maximum(bbox_pred[:, 0], bbox_ground[:, 0])
  29. ih = np.minimum(bbox_pred[:, 3], bbox_ground[:, 3]) - np.maximum(bbox_pred[:, 1], bbox_ground[:, 1])
  30. iw = np.maximum(iw, 0)
  31. ih = np.maximum(ih, 0)
  32. intersection_area = iw * ih
  33. union_area = pred_area + gt_area - intersection_area
  34. union_area = np.maximum(union_area, np.finfo(float).eps)
  35. iou = intersection_area * 1. / union_area
  36. return iou
  37. def apply_nms(all_boxes, all_scores, thres, max_boxes):
  38. """Apply NMS to bboxes."""
  39. x1 = all_boxes[:, 0]
  40. y1 = all_boxes[:, 1]
  41. x2 = all_boxes[:, 2]
  42. y2 = all_boxes[:, 3]
  43. areas = (x2 - x1 + 1) * (y2 - y1 + 1)
  44. order = all_scores.argsort()[::-1]
  45. keep = []
  46. while order.size > 0:
  47. i = order[0]
  48. keep.append(i)
  49. if len(keep) >= max_boxes:
  50. break
  51. xx1 = np.maximum(x1[i], x1[order[1:]])
  52. yy1 = np.maximum(y1[i], y1[order[1:]])
  53. xx2 = np.minimum(x2[i], x2[order[1:]])
  54. yy2 = np.minimum(y2[i], y2[order[1:]])
  55. w = np.maximum(0.0, xx2 - xx1 + 1)
  56. h = np.maximum(0.0, yy2 - yy1 + 1)
  57. inter = w * h
  58. ovr = inter / (areas[i] + areas[order[1:]] - inter)
  59. inds = np.where(ovr <= thres)[0]
  60. order = order[inds + 1]
  61. return keep
  62. def calc_ap(recall, precision):
  63. """Calculate AP."""
  64. correct_recall = np.concatenate(([0.], recall, [1.]))
  65. correct_precision = np.concatenate(([0.], precision, [0.]))
  66. for i in range(correct_recall.size - 1, 0, -1):
  67. correct_precision[i - 1] = np.maximum(correct_precision[i - 1], correct_precision[i])
  68. i = np.where(correct_recall[1:] != correct_recall[:-1])[0]
  69. ap = np.sum((correct_recall[i + 1] - correct_recall[i]) * correct_precision[i + 1])
  70. return ap
  71. def metrics(pred_data):
  72. """Calculate mAP of predicted bboxes."""
  73. config = ConfigSSD()
  74. num_classes = config.NUM_CLASSES
  75. all_detections = [None for i in range(num_classes)]
  76. all_pred_scores = [None for i in range(num_classes)]
  77. all_annotations = [None for i in range(num_classes)]
  78. average_precisions = {}
  79. num = [0 for i in range(num_classes)]
  80. accurate_num = [0 for i in range(num_classes)]
  81. for sample in pred_data:
  82. pred_boxes = sample['boxes']
  83. boxes_scores = sample['box_scores']
  84. annotation = sample['annotation']
  85. annotation = np.squeeze(annotation, axis=0)
  86. pred_labels = np.argmax(boxes_scores, axis=-1)
  87. index = np.nonzero(pred_labels)
  88. pred_boxes = ssd_bboxes_decode(pred_boxes, index)
  89. pred_boxes = pred_boxes.clip(0, 1)
  90. boxes_scores = np.max(boxes_scores, axis=-1)
  91. boxes_scores = boxes_scores[index]
  92. pred_labels = pred_labels[index]
  93. top_k = 50
  94. for c in range(1, num_classes):
  95. if len(pred_labels) >= 1:
  96. class_box_scores = boxes_scores[pred_labels == c]
  97. class_boxes = pred_boxes[pred_labels == c]
  98. nms_index = apply_nms(class_boxes, class_box_scores, config.MATCH_THRESHOLD, top_k)
  99. class_boxes = class_boxes[nms_index]
  100. class_box_scores = class_box_scores[nms_index]
  101. cmask = class_box_scores > 0.5
  102. class_boxes = class_boxes[cmask]
  103. class_box_scores = class_box_scores[cmask]
  104. all_detections[c] = class_boxes
  105. all_pred_scores[c] = class_box_scores
  106. for c in range(1, num_classes):
  107. if len(annotation) >= 1:
  108. all_annotations[c] = annotation[annotation[:, 4] == c, :4]
  109. for c in range(1, num_classes):
  110. false_positives = np.zeros((0,))
  111. true_positives = np.zeros((0,))
  112. scores = np.zeros((0,))
  113. num_annotations = 0.0
  114. annotations = all_annotations[c]
  115. num_annotations += annotations.shape[0]
  116. detections = all_detections[c]
  117. pred_scores = all_pred_scores[c]
  118. for index, detection in enumerate(detections):
  119. scores = np.append(scores, pred_scores[index])
  120. if len(annotations) >= 1:
  121. IoUs = calc_iou(detection, annotations)
  122. assigned_anno = np.argmax(IoUs)
  123. max_overlap = IoUs[assigned_anno]
  124. if max_overlap >= 0.5:
  125. false_positives = np.append(false_positives, 0)
  126. true_positives = np.append(true_positives, 1)
  127. else:
  128. false_positives = np.append(false_positives, 1)
  129. true_positives = np.append(true_positives, 0)
  130. else:
  131. false_positives = np.append(false_positives, 1)
  132. true_positives = np.append(true_positives, 0)
  133. if num_annotations == 0:
  134. if c not in average_precisions.keys():
  135. average_precisions[c] = 0
  136. continue
  137. accurate_num[c] = 1
  138. indices = np.argsort(-scores)
  139. false_positives = false_positives[indices]
  140. true_positives = true_positives[indices]
  141. false_positives = np.cumsum(false_positives)
  142. true_positives = np.cumsum(true_positives)
  143. recall = true_positives * 1. / num_annotations
  144. precision = true_positives * 1. / np.maximum(true_positives + false_positives, np.finfo(np.float64).eps)
  145. average_precision = calc_ap(recall, precision)
  146. if c not in average_precisions.keys():
  147. average_precisions[c] = average_precision
  148. else:
  149. average_precisions[c] += average_precision
  150. num[c] += 1
  151. count = 0
  152. for key in average_precisions:
  153. if num[key] != 0:
  154. count += (average_precisions[key] / num[key])
  155. mAP = count * 1. / accurate_num.count(1)
  156. return mAP