You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

d3m_eval.py 8.3 kB

first commit Former-commit-id: 08bc23ba02cffbce3cf63962390a65459a132e48 [formerly 0795edd4834b9b7dc66db8d10d4cbaf42bbf82cb] [formerly b5010b42541add7e2ea2578bf2da537efc457757 [formerly a7ca09c2c34c4fc8b3d8e01fcfa08eeeb2cae99d]] [formerly 615058473a2177ca5b89e9edbb797f4c2a59c7e5 [formerly 743d8dfc6843c4c205051a8ab309fbb2116c895e] [formerly bb0ea98b1e14154ef464e2f7a16738705894e54b [formerly 960a69da74b81ef8093820e003f2d6c59a34974c]]] [formerly 2fa3be52c1b44665bc81a7cc7d4cea4bbf0d91d5 [formerly 2054589f0898627e0a17132fd9d4cc78efc91867] [formerly 3b53730e8a895e803dfdd6ca72bc05e17a4164c1 [formerly 8a2fa8ab7baf6686d21af1f322df46fd58c60e69]] [formerly 87d1e3a07a19d03c7d7c94d93ab4fa9f58dada7c [formerly f331916385a5afac1234854ee8d7f160f34b668f] [formerly 69fb3c78a483343f5071da4f7e2891b83a49dd18 [formerly 386086f05aa9487f65bce2ee54438acbdce57650]]]] Former-commit-id: a00aed8c934a6460c4d9ac902b9a74a3d6864697 [formerly 26fdeca29c2f07916d837883983ca2982056c78e] [formerly 0e3170d41a2f99ecf5c918183d361d4399d793bf [formerly 3c12ad4c88ac5192e0f5606ac0d88dd5bf8602dc]] [formerly d5894f84f2fd2e77a6913efdc5ae388cf1be0495 [formerly ad3e7bc670ff92c992730d29c9d3aa1598d844e8] [formerly 69fb3c78a483343f5071da4f7e2891b83a49dd18]] Former-commit-id: 3c19c9fae64f6106415fbc948a4dc613b9ee12f8 [formerly 467ddc0549c74bb007e8f01773bb6dc9103b417d] [formerly 5fa518345d958e2760e443b366883295de6d991c [formerly 3530e130b9fdb7280f638dbc2e785d2165ba82aa]] Former-commit-id: 9f5d473d42a435ec0d60149939d09be1acc25d92 [formerly be0b25c4ec2cde052a041baf0e11f774a158105d] Former-commit-id: 9eca71cb73ba9edccd70ac06a3b636b8d4093b04
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246
  1. import numpy as np
  2. def group_gt_boxes_by_image_name(gt_boxes):
  3. gt_dict = {}
  4. for box in gt_boxes:
  5. #x = box.split()
  6. #image_name = x[0]
  7. #bbox = [float(z) for z in x[1:]]
  8. image_name = box[0]
  9. bbox = box[1:]
  10. #print(image_name, bbox)
  11. if image_name not in gt_dict.keys():
  12. gt_dict[image_name] = []
  13. gt_dict[image_name].append({'bbox': bbox})
  14. return gt_dict
  15. def voc_ap(rec, prec, use_07_metric=False):
  16. """ ap = voc_ap(rec, prec, [use_07_metric])
  17. Compute VOC AP given precision and recall.
  18. If use_07_metric is true, uses the
  19. VOC 07 11 point method (default:False).
  20. """
  21. if use_07_metric:
  22. # 11 point metric
  23. ap = 0.
  24. for t in np.arange(0., 1.1, 0.1):
  25. if np.sum(rec >= t) == 0:
  26. p = 0
  27. else:
  28. p = np.max(prec[rec >= t])
  29. ap = ap + p / 11.
  30. else:
  31. # correct AP calculation
  32. # first append sentinel values at the end
  33. mrec = np.concatenate(([0.], rec, [1.]))
  34. mpre = np.concatenate(([0.], prec, [0.]))
  35. # compute the precision envelope
  36. for i in range(mpre.size - 1, 0, -1):
  37. mpre[i - 1] = np.maximum(mpre[i - 1], mpre[i])
  38. # to calculate area under PR curve, look for points
  39. # where X axis (recall) changes value
  40. i = np.where(mrec[1:] != mrec[:-1])[0]
  41. # and sum (\Delta recall) * prec
  42. ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1])
  43. return ap
  44. def objectDetectionAP(dets,
  45. gts,
  46. ovthresh=0.5,
  47. use_07_metric=False):
  48. """
  49. This function takes a list of ground truth boxes and a list of detected bounding boxes
  50. for a given class and computes the average precision of the detections with respect to
  51. the ground truth boxes.
  52. Parameters:
  53. -----------
  54. dets: list
  55. List of bounding box detections. Each box is represented as a list
  56. with format:
  57. Case 1 (confidence provided):
  58. ['image_name', 'x_min', 'y_min', 'x_max', 'y_max', 'confidence']
  59. Case 2 (confidence not provided):
  60. ['image_name', 'x_min', 'y_min', 'x_max', 'y_max']
  61. gts: list
  62. List of ground truth boxes. Each box is represented as a list with the
  63. following format: [image_name, x_min, y_min, x_max, y_max].
  64. [ovthresh]: float
  65. Overlap threshold (default = 0.5)
  66. [use_07_metric]: boolean
  67. Whether to use VOC07's 11 point AP computation (default False)
  68. Returns:
  69. --------
  70. rec: 1d array-like
  71. Array where each element (rec[i]) is the recall when considering i+1 detections
  72. prec: 1d array-like
  73. Array where each element (rec[i]) is the precision when considering i+1 detections
  74. ap: float
  75. Average precision between detected boxes and the ground truth boxes.
  76. (it is also the area under the precision-recall curve).
  77. Example:
  78. With confidence scores:
  79. >> predictions_list = [['img_00285.png',330,463,387,505,0.0739],
  80. ['img_00285.png',420,433,451,498,0.0910],
  81. ['img_00285.png',328,465,403,540,0.1008],
  82. ['img_00285.png',480,477,508,522,0.1012],
  83. ['img_00285.png',357,460,417,537,0.1058],
  84. ['img_00285.png',356,456,391,521,0.0843],
  85. ['img_00225.png',345,460,415,547,0.0539],
  86. ['img_00225.png',381,362,455,513,0.0542],
  87. ['img_00225.png',382,366,416,422,0.0559],
  88. ['img_00225.png',730,463,763,583,0.0588]]
  89. >> ground_truth_list = [['img_00285.png',480,457,515,529],
  90. ['img_00285.png',480,457,515,529],
  91. ['img_00225.png',522,540,576,660],
  92. ['img_00225.png',739,460,768,545]]
  93. >> rec, prec, ap = objectDetectionAP(predictions_list, ground_truth_list)
  94. >> print(ap)
  95. 0.125
  96. Without confidence scores:
  97. >> predictions_list = [['img_00285.png',330,463,387,505],
  98. ['img_00285.png',420,433,451,498],
  99. ['img_00285.png',328,465,403,540],
  100. ['img_00285.png',480,477,508,522],
  101. ['img_00285.png',357,460,417,537],
  102. ['img_00285.png',356,456,391,521],
  103. ['img_00225.png',345,460,415,547],
  104. ['img_00225.png',381,362,455,513],
  105. ['img_00225.png',382,366,416,422],
  106. ['img_00225.png',730,463,763,583]]
  107. >> ground_truth_list = [['img_00285.png',480,457,515,529],
  108. ['img_00285.png',480,457,515,529],
  109. ['img_00225.png',522,540,576,660],
  110. ['img_00225.png',739,460,768,545]]
  111. >> rec, prec, ap = objectDetectionAP(predictions_list, ground_truth_list)
  112. >> print(ap)
  113. 0.0625
  114. """
  115. # Load ground truth
  116. gt_dict = group_gt_boxes_by_image_name(gts)
  117. # extract gt objects for this class
  118. recs = {}
  119. npos = 0
  120. imagenames = sorted(gt_dict.keys())
  121. for imagename in imagenames:
  122. R = [obj for obj in gt_dict[imagename]]
  123. bbox = np.array([x['bbox'] for x in R])
  124. det = [False] * len(R)
  125. npos = npos + len(R)
  126. recs[imagename] = {'bbox': bbox,
  127. 'det': det}
  128. # Load detections
  129. det_length = len(dets[0])
  130. # Check that all boxes are the same size
  131. for det in dets:
  132. assert len(det) == det_length, 'Not all boxes have the same dimensions.'
  133. image_ids = [x[0] for x in dets]
  134. BB = np.array([[float(z) for z in x[1:5]] for x in dets])
  135. if det_length == 6:
  136. print('confidence scores are present')
  137. confidence = np.array([float(x[-1]) for x in dets])
  138. # sort by confidence
  139. sorted_ind = np.argsort(-confidence)
  140. sorted_scores = np.sort(-confidence)
  141. else:
  142. print('confidence scores are not present')
  143. num_dets = len(dets)
  144. sorted_ind = np.arange(num_dets)
  145. sorted_scores = np.ones(num_dets)
  146. BB = BB[sorted_ind, :]
  147. image_ids = [image_ids[x] for x in sorted_ind]
  148. # print('sorted_ind: ', sorted_ind)
  149. # print('sorted_scores: ', sorted_scores)
  150. # print('BB: ', BB)
  151. # print('image_ids: ', image_ids)
  152. # go down dets and mark TPs and FPs
  153. nd = len(image_ids)
  154. tp = np.zeros(nd)
  155. fp = np.zeros(nd)
  156. for d in range(nd):
  157. R = recs[image_ids[d]]
  158. bb = BB[d, :].astype(float)
  159. ovmax = -np.inf
  160. BBGT = R['bbox'].astype(float)
  161. # print('det %d: ' % d)
  162. # print('bb: ', bb)
  163. if BBGT.size > 0:
  164. # compute overlaps
  165. # intersection
  166. ixmin = np.maximum(BBGT[:, 0], bb[0])
  167. iymin = np.maximum(BBGT[:, 1], bb[1])
  168. ixmax = np.minimum(BBGT[:, 2], bb[2])
  169. iymax = np.minimum(BBGT[:, 3], bb[3])
  170. iw = np.maximum(ixmax - ixmin + 1., 0.)
  171. ih = np.maximum(iymax - iymin + 1., 0.)
  172. inters = iw * ih
  173. # union
  174. uni = ((bb[2] - bb[0] + 1.) * (bb[3] - bb[1] + 1.) +
  175. (BBGT[:, 2] - BBGT[:, 0] + 1.) *
  176. (BBGT[:, 3] - BBGT[:, 1] + 1.) - inters)
  177. overlaps = inters / uni
  178. ovmax = np.max(overlaps)
  179. jmax = np.argmax(overlaps)
  180. # print('overlaps: ', overlaps)
  181. if ovmax > ovthresh:
  182. if not R['det'][jmax]:
  183. # print('Box matched!')
  184. tp[d] = 1.
  185. R['det'][jmax] = 1
  186. else:
  187. # print('Box was already taken!')
  188. fp[d] = 1.
  189. else:
  190. # print('No match with sufficient overlap!')
  191. fp[d] = 1.
  192. # print('tp: ', tp)
  193. # print('fp: ', fp)
  194. # compute precision recall
  195. fp = np.cumsum(fp)
  196. tp = np.cumsum(tp)
  197. rec = tp / float(npos)
  198. # avoid divide by zero in case the first detection matches a difficult
  199. # ground truth
  200. prec = tp / np.maximum(tp + fp, np.finfo(np.float64).eps)
  201. ap = voc_ap(rec, prec, use_07_metric)
  202. return rec, prec, ap

全栈的自动化机器学习系统,主要针对多变量时间序列数据的异常检测。TODS提供了详尽的用于构建基于机器学习的异常检测系统的模块,它们包括:数据处理(data processing),时间序列处理( time series processing),特征分析(feature analysis),检测算法(detection algorithms),和强化模块( reinforcement module)。这些模块所提供的功能包括常见的数据预处理、时间序列数据的平滑或变换,从时域或频域中抽取特征、多种多样的检测算