You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

util.py 5.0 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """metrics utils"""
  16. import numpy as np
  17. from config import ConfigYOLOV3ResNet18
  18. def calc_iou(bbox_pred, bbox_ground):
  19. """Calculate iou of predicted bbox and ground truth."""
  20. x1 = bbox_pred[0]
  21. y1 = bbox_pred[1]
  22. width1 = bbox_pred[2] - bbox_pred[0]
  23. height1 = bbox_pred[3] - bbox_pred[1]
  24. x2 = bbox_ground[0]
  25. y2 = bbox_ground[1]
  26. width2 = bbox_ground[2] - bbox_ground[0]
  27. height2 = bbox_ground[3] - bbox_ground[1]
  28. endx = max(x1 + width1, x2 + width2)
  29. startx = min(x1, x2)
  30. width = width1 + width2 - (endx - startx)
  31. endy = max(y1 + height1, y2 + height2)
  32. starty = min(y1, y2)
  33. height = height1 + height2 - (endy - starty)
  34. if width <= 0 or height <= 0:
  35. iou = 0
  36. else:
  37. area = width * height
  38. area1 = width1 * height1
  39. area2 = width2 * height2
  40. iou = area * 1. / (area1 + area2 - area)
  41. return iou
  42. def apply_nms(all_boxes, all_scores, thres, max_boxes):
  43. """Apply NMS to bboxes."""
  44. x1 = all_boxes[:, 0]
  45. y1 = all_boxes[:, 1]
  46. x2 = all_boxes[:, 2]
  47. y2 = all_boxes[:, 3]
  48. areas = (x2 - x1 + 1) * (y2 - y1 + 1)
  49. order = all_scores.argsort()[::-1]
  50. keep = []
  51. while order.size > 0:
  52. i = order[0]
  53. keep.append(i)
  54. if len(keep) >= max_boxes:
  55. break
  56. xx1 = np.maximum(x1[i], x1[order[1:]])
  57. yy1 = np.maximum(y1[i], y1[order[1:]])
  58. xx2 = np.minimum(x2[i], x2[order[1:]])
  59. yy2 = np.minimum(y2[i], y2[order[1:]])
  60. w = np.maximum(0.0, xx2 - xx1 + 1)
  61. h = np.maximum(0.0, yy2 - yy1 + 1)
  62. inter = w * h
  63. ovr = inter / (areas[i] + areas[order[1:]] - inter)
  64. inds = np.where(ovr <= thres)[0]
  65. order = order[inds + 1]
  66. return keep
  67. def metrics(pred_data):
  68. """Calculate precision and recall of predicted bboxes."""
  69. config = ConfigYOLOV3ResNet18()
  70. num_classes = config.num_classes
  71. count_corrects = [1e-6 for _ in range(num_classes)]
  72. count_grounds = [1e-6 for _ in range(num_classes)]
  73. count_preds = [1e-6 for _ in range(num_classes)]
  74. for i, sample in enumerate(pred_data):
  75. gt_anno = sample["annotation"]
  76. box_scores = sample['box_scores']
  77. boxes = sample['boxes']
  78. mask = box_scores >= config.obj_threshold
  79. boxes_ = []
  80. scores_ = []
  81. classes_ = []
  82. max_boxes = config.nms_max_num
  83. for c in range(num_classes):
  84. class_boxes = np.reshape(boxes, [-1, 4])[np.reshape(mask[:, c], [-1])]
  85. class_box_scores = np.reshape(box_scores[:, c], [-1])[np.reshape(mask[:, c], [-1])]
  86. nms_index = apply_nms(class_boxes, class_box_scores, config.nms_threshold, max_boxes)
  87. class_boxes = class_boxes[nms_index]
  88. class_box_scores = class_box_scores[nms_index]
  89. classes = np.ones_like(class_box_scores, 'int32') * c
  90. boxes_.append(class_boxes)
  91. scores_.append(class_box_scores)
  92. classes_.append(classes)
  93. boxes = np.concatenate(boxes_, axis=0)
  94. classes = np.concatenate(classes_, axis=0)
  95. # metric
  96. count_correct = [1e-6 for _ in range(num_classes)]
  97. count_ground = [1e-6 for _ in range(num_classes)]
  98. count_pred = [1e-6 for _ in range(num_classes)]
  99. for anno in gt_anno:
  100. count_ground[anno[4]] += 1
  101. for box_index, box in enumerate(boxes):
  102. bbox_pred = [box[1], box[0], box[3], box[2]]
  103. count_pred[classes[box_index]] += 1
  104. for anno in gt_anno:
  105. class_ground = anno[4]
  106. if classes[box_index] == class_ground:
  107. iou = calc_iou(bbox_pred, anno)
  108. if iou >= 0.5:
  109. count_correct[class_ground] += 1
  110. break
  111. count_corrects = [count_corrects[i] + count_correct[i] for i in range(num_classes)]
  112. count_preds = [count_preds[i] + count_pred[i] for i in range(num_classes)]
  113. count_grounds = [count_grounds[i] + count_ground[i] for i in range(num_classes)]
  114. precision = np.array([count_corrects[ix] / count_preds[ix] for ix in range(num_classes)])
  115. recall = np.array([count_corrects[ix] / count_grounds[ix] for ix in range(num_classes)])
  116. return precision, recall