You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

coco.py 27 kB

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import itertools
  3. import logging
  4. import os.path as osp
  5. import tempfile
  6. import warnings
  7. from collections import OrderedDict
  8. import mmcv
  9. import numpy as np
  10. from mmcv.utils import print_log
  11. from terminaltables import AsciiTable
  12. from mmdet.core import eval_recalls
  13. from .api_wrappers import COCO, COCOeval
  14. from .builder import DATASETS
  15. from .custom import CustomDataset
  16. from sklearn.metrics import roc_auc_score, classification_report, accuracy_score
  17. from sklearn.metrics import roc_curve
  18. from sklearn.metrics import precision_recall_curve, confusion_matrix
  19. @DATASETS.register_module()
  20. class CocoDataset(CustomDataset):
  21. CLASSES = ('person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
  22. 'train', 'truck', 'boat', 'traffic light', 'fire hydrant',
  23. 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog',
  24. 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe',
  25. 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee',
  26. 'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat',
  27. 'baseball glove', 'skateboard', 'surfboard', 'tennis racket',
  28. 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl',
  29. 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot',
  30. 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch',
  31. 'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop',
  32. 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave',
  33. 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock',
  34. 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush')
  35. def load_annotations(self, ann_file):
  36. """Load annotation from COCO style annotation file.
  37. Args:
  38. ann_file (str): Path of annotation file.
  39. Returns:
  40. list[dict]: Annotation info from COCO api.
  41. """
  42. self.coco = COCO(ann_file)
  43. # The order of returned `cat_ids` will not
  44. # change with the order of the CLASSES
  45. self.cat_ids = self.coco.get_cat_ids(cat_names=self.CLASSES)
  46. self.cat2label = {cat_id: i for i, cat_id in enumerate(self.cat_ids)}
  47. self.img_ids = self.coco.get_img_ids()
  48. data_infos = []
  49. total_ann_ids = []
  50. for i in self.img_ids:
  51. info = self.coco.load_imgs([i])[0]
  52. info['filename'] = info['file_name']
  53. data_infos.append(info)
  54. ann_ids = self.coco.get_ann_ids(img_ids=[i])
  55. total_ann_ids.extend(ann_ids)
  56. assert len(set(total_ann_ids)) == len(
  57. total_ann_ids), f"Annotation ids in '{ann_file}' are not unique!"
  58. return data_infos
  59. def get_ann_info(self, idx):
  60. """Get COCO annotation by index.
  61. Args:
  62. idx (int): Index of data.
  63. Returns:
  64. dict: Annotation info of specified index.
  65. """
  66. img_id = self.data_infos[idx]['id']
  67. ann_ids = self.coco.get_ann_ids(img_ids=[img_id])
  68. ann_info = self.coco.load_anns(ann_ids)
  69. return self._parse_ann_info(self.data_infos[idx], ann_info)
  70. def get_cat_ids(self, idx):
  71. """Get COCO category ids by index.
  72. Args:
  73. idx (int): Index of data.
  74. Returns:
  75. list[int]: All categories in the image of specified index.
  76. """
  77. img_id = self.data_infos[idx]['id']
  78. ann_ids = self.coco.get_ann_ids(img_ids=[img_id])
  79. ann_info = self.coco.load_anns(ann_ids)
  80. return [ann['category_id'] for ann in ann_info]
  81. def _filter_imgs(self, min_size=32):
  82. """Filter images too small or without ground truths."""
  83. valid_inds = []
  84. # obtain images that contain annotation
  85. ids_with_ann = set(_['image_id'] for _ in self.coco.anns.values())
  86. # obtain images that contain annotations of the required categories
  87. ids_in_cat = set()
  88. for i, class_id in enumerate(self.cat_ids):
  89. ids_in_cat |= set(self.coco.cat_img_map[class_id])
  90. # merge the image id sets of the two conditions and use the merged set
  91. # to filter out images if self.filter_empty_gt=True
  92. ids_in_cat &= ids_with_ann
  93. valid_img_ids = []
  94. for i, img_info in enumerate(self.data_infos):
  95. img_id = self.img_ids[i]
  96. if self.filter_empty_gt and img_id not in ids_in_cat:
  97. continue
  98. if min(img_info['width'], img_info['height']) >= min_size:
  99. valid_inds.append(i)
  100. valid_img_ids.append(img_id)
  101. self.img_ids = valid_img_ids
  102. return valid_inds
  103. def _parse_ann_info(self, img_info, ann_info):
  104. """Parse bbox and mask annotation.
  105. Args:
  106. ann_info (list[dict]): Annotation info of an image.
  107. with_mask (bool): Whether to parse mask annotations.
  108. Returns:
  109. dict: A dict containing the following keys: bboxes, bboxes_ignore,\
  110. labels, masks, seg_map. "masks" are raw annotations and not \
  111. decoded into binary masks.
  112. """
  113. gt_bboxes = []
  114. gt_labels = []
  115. gt_bboxes_ignore = []
  116. gt_masks_ann = []
  117. for i, ann in enumerate(ann_info):
  118. if ann.get('ignore', False):
  119. continue
  120. x1, y1, w, h = ann['bbox']
  121. inter_w = max(0, min(x1 + w, img_info['width']) - max(x1, 0))
  122. inter_h = max(0, min(y1 + h, img_info['height']) - max(y1, 0))
  123. if inter_w * inter_h == 0:
  124. continue
  125. if ann['area'] <= 0 or w < 1 or h < 1:
  126. continue
  127. if ann['category_id'] not in self.cat_ids:
  128. continue
  129. bbox = [x1, y1, x1 + w, y1 + h]
  130. if ann.get('iscrowd', False):
  131. gt_bboxes_ignore.append(bbox)
  132. else:
  133. gt_bboxes.append(bbox)
  134. gt_labels.append(self.cat2label[ann['category_id']])
  135. gt_masks_ann.append(ann.get('segmentation', None))
  136. if gt_bboxes:
  137. gt_bboxes = np.array(gt_bboxes, dtype=np.float32)
  138. gt_labels = np.array(gt_labels, dtype=np.int64)
  139. else:
  140. gt_bboxes = np.zeros((0, 4), dtype=np.float32)
  141. gt_labels = np.array([], dtype=np.int64)
  142. if gt_bboxes_ignore:
  143. gt_bboxes_ignore = np.array(gt_bboxes_ignore, dtype=np.float32)
  144. else:
  145. gt_bboxes_ignore = np.zeros((0, 4), dtype=np.float32)
  146. seg_map = img_info['filename'].replace('jpg', 'png')
  147. ann = dict(
  148. bboxes=gt_bboxes,
  149. labels=gt_labels,
  150. bboxes_ignore=gt_bboxes_ignore,
  151. masks=gt_masks_ann,
  152. seg_map=seg_map)
  153. return ann
  154. def xyxy2xywh(self, bbox):
  155. """Convert ``xyxy`` style bounding boxes to ``xywh`` style for COCO
  156. evaluation.
  157. Args:
  158. bbox (numpy.ndarray): The bounding boxes, shape (4, ), in
  159. ``xyxy`` order.
  160. Returns:
  161. list[float]: The converted bounding boxes, in ``xywh`` order.
  162. """
  163. _bbox = bbox.tolist()
  164. return [
  165. _bbox[0],
  166. _bbox[1],
  167. _bbox[2] - _bbox[0],
  168. _bbox[3] - _bbox[1],
  169. ]
  170. def _proposal2json(self, results):
  171. """Convert proposal results to COCO json style."""
  172. json_results = []
  173. for idx in range(len(self)):
  174. img_id = self.img_ids[idx]
  175. bboxes = results[idx]
  176. for i in range(bboxes.shape[0]):
  177. data = dict()
  178. data['image_id'] = img_id
  179. data['bbox'] = self.xyxy2xywh(bboxes[i])
  180. data['score'] = float(bboxes[i][4])
  181. data['category_id'] = 1
  182. json_results.append(data)
  183. return json_results
  184. def _det2json(self, results):
  185. """Convert detection results to COCO json style."""
  186. json_results = []
  187. for idx in range(len(self)):
  188. img_id = self.img_ids[idx]
  189. result = results[idx]
  190. for label in range(len(result)):
  191. bboxes = result[label]
  192. for i in range(bboxes.shape[0]):
  193. data = dict()
  194. data['image_id'] = img_id
  195. data['bbox'] = self.xyxy2xywh(bboxes[i])
  196. data['score'] = float(bboxes[i][4])
  197. data['category_id'] = self.cat_ids[label]
  198. json_results.append(data)
  199. return json_results
  200. def _segm2json(self, results):
  201. """Convert instance segmentation results to COCO json style."""
  202. bbox_json_results = []
  203. segm_json_results = []
  204. for idx in range(len(self)):
  205. img_id = self.img_ids[idx]
  206. det, seg = results[idx]
  207. for label in range(len(det)):
  208. # bbox results
  209. bboxes = det[label]
  210. for i in range(bboxes.shape[0]):
  211. data = dict()
  212. data['image_id'] = img_id
  213. data['bbox'] = self.xyxy2xywh(bboxes[i])
  214. data['score'] = float(bboxes[i][4])
  215. data['category_id'] = self.cat_ids[label]
  216. bbox_json_results.append(data)
  217. # segm results
  218. # some detectors use different scores for bbox and mask
  219. if isinstance(seg, tuple):
  220. segms = seg[0][label]
  221. mask_score = seg[1][label]
  222. else:
  223. segms = seg[label]
  224. mask_score = [bbox[4] for bbox in bboxes]
  225. for i in range(bboxes.shape[0]):
  226. data = dict()
  227. data['image_id'] = img_id
  228. data['bbox'] = self.xyxy2xywh(bboxes[i])
  229. data['score'] = float(mask_score[i])
  230. data['category_id'] = self.cat_ids[label]
  231. if isinstance(segms[i]['counts'], bytes):
  232. segms[i]['counts'] = segms[i]['counts'].decode()
  233. data['segmentation'] = segms[i]
  234. segm_json_results.append(data)
  235. return bbox_json_results, segm_json_results
  236. def results2json(self, results, outfile_prefix):
  237. """Dump the detection results to a COCO style json file.
  238. There are 3 types of results: proposals, bbox predictions, mask
  239. predictions, and they have different data types. This method will
  240. automatically recognize the type, and dump them to json files.
  241. Args:
  242. results (list[list | tuple | ndarray]): Testing results of the
  243. dataset.
  244. outfile_prefix (str): The filename prefix of the json files. If the
  245. prefix is "somepath/xxx", the json files will be named
  246. "somepath/xxx.bbox.json", "somepath/xxx.segm.json",
  247. "somepath/xxx.proposal.json".
  248. Returns:
  249. dict[str: str]: Possible keys are "bbox", "segm", "proposal", and \
  250. values are corresponding filenames.
  251. """
  252. result_files = dict()
  253. if isinstance(results[0], list):
  254. json_results = self._det2json(results)
  255. result_files['bbox'] = f'{outfile_prefix}.bbox.json'
  256. result_files['proposal'] = f'{outfile_prefix}.bbox.json'
  257. mmcv.dump(json_results, result_files['bbox'])
  258. elif isinstance(results[0], tuple):
  259. json_results = self._segm2json(results)
  260. result_files['bbox'] = f'{outfile_prefix}.bbox.json'
  261. result_files['proposal'] = f'{outfile_prefix}.bbox.json'
  262. result_files['segm'] = f'{outfile_prefix}.segm.json'
  263. mmcv.dump(json_results[0], result_files['bbox'])
  264. mmcv.dump(json_results[1], result_files['segm'])
  265. elif isinstance(results[0], np.ndarray):
  266. json_results = self._proposal2json(results)
  267. result_files['proposal'] = f'{outfile_prefix}.proposal.json'
  268. mmcv.dump(json_results, result_files['proposal'])
  269. else:
  270. raise TypeError('invalid type of results')
  271. return result_files
  272. def fast_eval_recall(self, results, proposal_nums, iou_thrs, logger=None):
  273. gt_bboxes = []
  274. for i in range(len(self.img_ids)):
  275. ann_ids = self.coco.get_ann_ids(img_ids=self.img_ids[i])
  276. ann_info = self.coco.load_anns(ann_ids)
  277. if len(ann_info) == 0:
  278. gt_bboxes.append(np.zeros((0, 4)))
  279. continue
  280. bboxes = []
  281. for ann in ann_info:
  282. if ann.get('ignore', False) or ann['iscrowd']:
  283. continue
  284. x1, y1, w, h = ann['bbox']
  285. bboxes.append([x1, y1, x1 + w, y1 + h])
  286. bboxes = np.array(bboxes, dtype=np.float32)
  287. if bboxes.shape[0] == 0:
  288. bboxes = np.zeros((0, 4))
  289. gt_bboxes.append(bboxes)
  290. recalls = eval_recalls(
  291. gt_bboxes, results, proposal_nums, iou_thrs, logger=logger)
  292. ar = recalls.mean(axis=1)
  293. return ar
  294. def format_results(self, results, jsonfile_prefix=None, **kwargs):
  295. """Format the results to json (standard format for COCO evaluation).
  296. Args:
  297. results (list[tuple | numpy.ndarray]): Testing results of the
  298. dataset.
  299. jsonfile_prefix (str | None): The prefix of json files. It includes
  300. the file path and the prefix of filename, e.g., "a/b/prefix".
  301. If not specified, a temp file will be created. Default: None.
  302. Returns:
  303. tuple: (result_files, tmp_dir), result_files is a dict containing \
  304. the json filepaths, tmp_dir is the temporal directory created \
  305. for saving json files when jsonfile_prefix is not specified.
  306. """
  307. assert isinstance(results, list), 'results must be a list'
  308. assert len(results) == len(self), (
  309. 'The length of results is not equal to the dataset len: {} != {}'.
  310. format(len(results), len(self)))
  311. if jsonfile_prefix is None:
  312. tmp_dir = tempfile.TemporaryDirectory()
  313. jsonfile_prefix = osp.join(tmp_dir.name, 'results')
  314. else:
  315. tmp_dir = None
  316. result_files = self.results2json(results, jsonfile_prefix)
  317. return result_files, tmp_dir
  318. def evaluate(self,
  319. results,
  320. metric='bbox',
  321. logger=None,
  322. jsonfile_prefix=None,
  323. classwise=False,
  324. proposal_nums=(100, 300, 1000),
  325. iou_thrs=None,
  326. metric_items=None):
  327. """Evaluation in COCO protocol.
  328. Args:
  329. results (list[list | tuple]): Testing results of the dataset.
  330. metric (str | list[str]): Metrics to be evaluated. Options are
  331. 'bbox', 'segm', 'proposal', 'proposal_fast'.
  332. logger (logging.Logger | str | None): Logger used for printing
  333. related information during evaluation. Default: None.
  334. jsonfile_prefix (str | None): The prefix of json files. It includes
  335. the file path and the prefix of filename, e.g., "a/b/prefix".
  336. If not specified, a temp file will be created. Default: None.
  337. classwise (bool): Whether to evaluating the AP for each class.
  338. proposal_nums (Sequence[int]): Proposal number used for evaluating
  339. recalls, such as recall@100, recall@1000.
  340. Default: (100, 300, 1000).
  341. iou_thrs (Sequence[float], optional): IoU threshold used for
  342. evaluating recalls/mAPs. If set to a list, the average of all
  343. IoUs will also be computed. If not specified, [0.50, 0.55,
  344. 0.60, 0.65, 0.70, 0.75, 0.80, 0.85, 0.90, 0.95] will be used.
  345. Default: None.
  346. metric_items (list[str] | str, optional): Metric items that will
  347. be returned. If not specified, ``['AR@100', 'AR@300',
  348. 'AR@1000', 'AR_s@1000', 'AR_m@1000', 'AR_l@1000' ]`` will be
  349. used when ``metric=='proposal'``, ``['mAP', 'mAP_50', 'mAP_75',
  350. 'mAP_s', 'mAP_m', 'mAP_l']`` will be used when
  351. ``metric=='bbox' or metric=='segm'``.
  352. Returns:
  353. dict[str, float]: COCO style evaluation metric.
  354. """
  355. metrics = metric if isinstance(metric, list) else [metric]
  356. allowed_metrics = ['bbox', 'segm', 'proposal', 'proposal_fast']
  357. for metric in metrics:
  358. if metric not in allowed_metrics:
  359. raise KeyError(f'metric {metric} is not supported')
  360. if iou_thrs is None:
  361. iou_thrs = np.linspace(
  362. .5, 0.95, int(np.round((0.95 - .5) / .05)) + 1, endpoint=True)
  363. if metric_items is not None:
  364. if not isinstance(metric_items, list):
  365. metric_items = [metric_items]
  366. result_files, tmp_dir = self.format_results(results, jsonfile_prefix)
  367. eval_results = OrderedDict()
  368. cocoGt = self.coco
  369. for metric in metrics:
  370. msg = f'Evaluating {metric}...'
  371. if logger is None:
  372. msg = '\n' + msg
  373. print_log(msg, logger=logger)
  374. if metric == 'proposal_fast':
  375. ar = self.fast_eval_recall(
  376. results, proposal_nums, iou_thrs, logger='silent')
  377. log_msg = []
  378. for i, num in enumerate(proposal_nums):
  379. eval_results[f'AR@{num}'] = ar[i]
  380. log_msg.append(f'\nAR@{num}\t{ar[i]:.4f}')
  381. log_msg = ''.join(log_msg)
  382. print_log(log_msg, logger=logger)
  383. continue
  384. iou_type = 'bbox' if metric == 'proposal' else metric
  385. if metric not in result_files:
  386. raise KeyError(f'{metric} is not in results')
  387. try:
  388. predictions = mmcv.load(result_files[metric])
  389. if iou_type == 'segm':
  390. # Refer to https://github.com/cocodataset/cocoapi/blob/master/PythonAPI/pycocotools/coco.py#L331 # noqa
  391. # When evaluating mask AP, if the results contain bbox,
  392. # cocoapi will use the box area instead of the mask area
  393. # for calculating the instance area. Though the overall AP
  394. # is not affected, this leads to different
  395. # small/medium/large mask AP results.
  396. for x in predictions:
  397. x.pop('bbox')
  398. warnings.simplefilter('once')
  399. warnings.warn(
  400. 'The key "bbox" is deleted for more accurate mask AP '
  401. 'of small/medium/large instances since v2.12.0. This '
  402. 'does not change the overall mAP calculation.',
  403. UserWarning)
  404. cocoDt = cocoGt.loadRes(predictions)
  405. except IndexError:
  406. print_log(
  407. 'The testing results of the whole dataset is empty.',
  408. logger=logger,
  409. level=logging.ERROR)
  410. break
  411. #AD_eval
  412. num_dataset = len(cocoGt.imgs)
  413. AD_label = get_AD_classification(cocoGt, num_dataset)
  414. AD_predict = get_AD_classification(cocoDt, num_dataset)
  415. #cr = classification_report(AD_label, AD_predict)
  416. cm = confusion_matrix(AD_label, AD_predict)
  417. AD_predict_score = get_AD_score(cocoDt, num_dataset)
  418. eval_AUC = roc_auc_score(AD_label, AD_predict_score)
  419. ok_recall = cm[0][0]/(cm[0][0]+cm[0][1]+0.00000001)
  420. ng_recall = cm[1][1]/(cm[1][0]+cm[1][1]+0.00000001)
  421. eval_results['ok_recall'] = ok_recall
  422. eval_results['ng_recall'] = ng_recall
  423. eval_results['AUC'] = eval_AUC
  424. #print(classification_report(AD_label, AD_predict))
  425. #print(confusion_matrix(AD_label, AD_predict))
  426. cocoEval = COCOeval(cocoGt, cocoDt, iou_type)
  427. cocoEval.params.catIds = self.cat_ids
  428. cocoEval.params.imgIds = self.img_ids
  429. cocoEval.params.maxDets = list(proposal_nums)
  430. cocoEval.params.iouThrs = iou_thrs
  431. # mapping of cocoEval.stats
  432. coco_metric_names = {
  433. 'mAP': 0,
  434. 'mAP_50': 1,
  435. 'mAP_75': 2,
  436. 'mAP_s': 3,
  437. 'mAP_m': 4,
  438. 'mAP_l': 5,
  439. 'AR@100': 6,
  440. 'AR@300': 7,
  441. 'AR@1000': 8,
  442. 'AR_s@1000': 9,
  443. 'AR_m@1000': 10,
  444. 'AR_l@1000': 11
  445. }
  446. if metric_items is not None:
  447. for metric_item in metric_items:
  448. if metric_item not in coco_metric_names:
  449. raise KeyError(
  450. f'metric item {metric_item} is not supported')
  451. if metric == 'proposal':
  452. cocoEval.params.useCats = 0
  453. cocoEval.evaluate()
  454. cocoEval.accumulate()
  455. cocoEval.summarize()
  456. if metric_items is None:
  457. metric_items = [
  458. 'AR@100', 'AR@300', 'AR@1000', 'AR_s@1000',
  459. 'AR_m@1000', 'AR_l@1000'
  460. ]
  461. for item in metric_items:
  462. val = float(
  463. f'{cocoEval.stats[coco_metric_names[item]]:.3f}')
  464. eval_results[item] = val
  465. else:
  466. cocoEval.evaluate()
  467. cocoEval.accumulate()
  468. cocoEval.summarize()
  469. if classwise: # Compute per-category AP
  470. # Compute per-category AP
  471. # from https://github.com/facebookresearch/detectron2/
  472. precisions = cocoEval.eval['precision']
  473. # precision: (iou, recall, cls, area range, max dets)
  474. assert len(self.cat_ids) == precisions.shape[2]
  475. results_per_category = []
  476. for idx, catId in enumerate(self.cat_ids):
  477. # area range index 0: all area ranges
  478. # max dets index -1: typically 100 per image
  479. nm = self.coco.loadCats(catId)[0]
  480. precision = precisions[:, :, idx, 0, -1]
  481. precision = precision[precision > -1]
  482. if precision.size:
  483. ap = np.mean(precision)
  484. else:
  485. ap = float('nan')
  486. results_per_category.append(
  487. (f'{nm["name"]}', f'{float(ap):0.3f}'))
  488. num_columns = min(6, len(results_per_category) * 2)
  489. results_flatten = list(
  490. itertools.chain(*results_per_category))
  491. headers = ['category', 'AP'] * (num_columns // 2)
  492. results_2d = itertools.zip_longest(*[
  493. results_flatten[i::num_columns]
  494. for i in range(num_columns)
  495. ])
  496. table_data = [headers]
  497. table_data += [result for result in results_2d]
  498. table = AsciiTable(table_data)
  499. print_log('\n' + table.table, logger=logger)
  500. if classwise: # Compute per-category AP
  501. # Compute per-category AP
  502. # from https://github.com/facebookresearch/detectron2/
  503. precisions = cocoEval.eval['recall']
  504. # precision: (iou, recall, cls, area range, max dets)
  505. assert len(self.cat_ids) == precisions.shape[1]
  506. results_per_category = []
  507. for idx, catId in enumerate(self.cat_ids):
  508. # area range index 0: all area ranges
  509. # max dets index -1: typically 100 per image
  510. nm = self.coco.loadCats(catId)[0]
  511. precision = precisions[:, idx, 0, -1]
  512. precision = precision[precision > -1]
  513. if precision.size:
  514. ap = np.mean(precision)
  515. else:
  516. ap = float('nan')
  517. results_per_category.append(
  518. (f'{nm["name"]}', f'{float(ap):0.3f}'))
  519. num_columns = min(6, len(results_per_category) * 2)
  520. results_flatten = list(
  521. itertools.chain(*results_per_category))
  522. headers = ['category', 'AR'] * (num_columns // 2)
  523. results_2d = itertools.zip_longest(*[
  524. results_flatten[i::num_columns]
  525. for i in range(num_columns)
  526. ])
  527. table_data = [headers]
  528. table_data += [result for result in results_2d]
  529. table = AsciiTable(table_data)
  530. print_log('\n' + table.table, logger=logger)
  531. if metric_items is None:
  532. metric_items = [
  533. 'mAP', 'mAP_50', 'mAP_75', 'mAP_s', 'mAP_m', 'mAP_l'
  534. ]
  535. for metric_item in metric_items:
  536. key = f'{metric}_{metric_item}'
  537. val = float(
  538. f'{cocoEval.stats[coco_metric_names[metric_item]]:.3f}'
  539. )
  540. eval_results[key] = val
  541. ap = cocoEval.stats[:6]
  542. eval_results[f'{metric}_mAP_copypaste'] = (
  543. f'{ap[0]:.3f} {ap[1]:.3f} {ap[2]:.3f} {ap[3]:.3f} '
  544. f'{ap[4]:.3f} {ap[5]:.3f}')
  545. if tmp_dir is not None:
  546. tmp_dir.cleanup()
  547. return eval_results
  548. def get_AD_classification(cocoGt, num_dataset):
  549. res = np.zeros(num_dataset+1)
  550. for ann in cocoGt.imgToAnns.values():
  551. if len(ann)>0:
  552. key = ann[0]["image_id"]
  553. res[int(key)] = 1
  554. return res
  555. def get_AD_score(cocoGt, num_dataset):
  556. res = np.zeros(num_dataset+1)
  557. for ann in cocoGt.imgToAnns.values():
  558. for ann_box in ann:
  559. key = ann_box["image_id"]
  560. if res[int(key)] < ann_box["score"]:
  561. res[int(key)] = ann_box["score"]
  562. return res

No Description

Contributors (3)