You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

custom.py 14 kB

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import os.path as osp
  3. import warnings
  4. from collections import OrderedDict
  5. import mmcv
  6. import numpy as np
  7. from mmcv.utils import print_log
  8. from terminaltables import AsciiTable
  9. from torch.utils.data import Dataset
  10. from mmdet.core import eval_map, eval_recalls
  11. from .builder import DATASETS
  12. from .pipelines import Compose
  13. @DATASETS.register_module()
  14. class CustomDataset(Dataset):
  15. """Custom dataset for detection.
  16. The annotation format is shown as follows. The `ann` field is optional for
  17. testing.
  18. .. code-block:: none
  19. [
  20. {
  21. 'filename': 'a.jpg',
  22. 'width': 1280,
  23. 'height': 720,
  24. 'ann': {
  25. 'bboxes': <np.ndarray> (n, 4) in (x1, y1, x2, y2) order.
  26. 'labels': <np.ndarray> (n, ),
  27. 'bboxes_ignore': <np.ndarray> (k, 4), (optional field)
  28. 'labels_ignore': <np.ndarray> (k, 4) (optional field)
  29. }
  30. },
  31. ...
  32. ]
  33. Args:
  34. ann_file (str): Annotation file path.
  35. pipeline (list[dict]): Processing pipeline.
  36. classes (str | Sequence[str], optional): Specify classes to load.
  37. If is None, ``cls.CLASSES`` will be used. Default: None.
  38. data_root (str, optional): Data root for ``ann_file``,
  39. ``img_prefix``, ``seg_prefix``, ``proposal_file`` if specified.
  40. test_mode (bool, optional): If set True, annotation will not be loaded.
  41. filter_empty_gt (bool, optional): If set true, images without bounding
  42. boxes of the dataset's classes will be filtered out. This option
  43. only works when `test_mode=False`, i.e., we never filter images
  44. during tests.
  45. """
  46. CLASSES = None
  47. def __init__(self,
  48. ann_file,
  49. pipeline,
  50. classes=None,
  51. data_root=None,
  52. img_prefix='',
  53. seg_prefix=None,
  54. proposal_file=None,
  55. test_mode=False,
  56. filter_empty_gt=True):
  57. self.ann_file = ann_file
  58. self.data_root = data_root
  59. self.img_prefix = img_prefix
  60. self.seg_prefix = seg_prefix
  61. self.proposal_file = proposal_file
  62. self.test_mode = test_mode
  63. self.filter_empty_gt = filter_empty_gt
  64. self.CLASSES = self.get_classes(classes)
  65. # join paths if data_root is specified
  66. if self.data_root is not None:
  67. if not osp.isabs(self.ann_file):
  68. self.ann_file = osp.join(self.data_root, self.ann_file)
  69. if not (self.img_prefix is None or osp.isabs(self.img_prefix)):
  70. self.img_prefix = osp.join(self.data_root, self.img_prefix)
  71. if not (self.seg_prefix is None or osp.isabs(self.seg_prefix)):
  72. self.seg_prefix = osp.join(self.data_root, self.seg_prefix)
  73. if not (self.proposal_file is None
  74. or osp.isabs(self.proposal_file)):
  75. self.proposal_file = osp.join(self.data_root,
  76. self.proposal_file)
  77. # load annotations (and proposals)
  78. self.data_infos = self.load_annotations(self.ann_file)
  79. if self.proposal_file is not None:
  80. self.proposals = self.load_proposals(self.proposal_file)
  81. else:
  82. self.proposals = None
  83. # filter images too small and containing no annotations
  84. if not test_mode:
  85. '''valid_inds = self._filter_imgs()
  86. self.data_infos = [self.data_infos[i] for i in valid_inds]
  87. if self.proposals is not None:
  88. self.proposals = [self.proposals[i] for i in valid_inds]'''
  89. # set group flag for the sampler
  90. self._set_group_flag()
  91. # processing pipeline
  92. self.pipeline = Compose(pipeline)
  93. def __len__(self):
  94. """Total number of samples of data."""
  95. return len(self.data_infos)
  96. def load_annotations(self, ann_file):
  97. """Load annotation from annotation file."""
  98. return mmcv.load(ann_file)
  99. def load_proposals(self, proposal_file):
  100. """Load proposal from proposal file."""
  101. return mmcv.load(proposal_file)
  102. def get_ann_info(self, idx):
  103. """Get annotation by index.
  104. Args:
  105. idx (int): Index of data.
  106. Returns:
  107. dict: Annotation info of specified index.
  108. """
  109. return self.data_infos[idx]['ann']
  110. def get_cat_ids(self, idx):
  111. """Get category ids by index.
  112. Args:
  113. idx (int): Index of data.
  114. Returns:
  115. list[int]: All categories in the image of specified index.
  116. """
  117. return self.data_infos[idx]['ann']['labels'].astype(np.int).tolist()
  118. def pre_pipeline(self, results):
  119. """Prepare results dict for pipeline."""
  120. results['img_prefix'] = self.img_prefix
  121. results['seg_prefix'] = self.seg_prefix
  122. results['proposal_file'] = self.proposal_file
  123. results['bbox_fields'] = []
  124. results['mask_fields'] = []
  125. results['seg_fields'] = []
  126. def _filter_imgs(self, min_size=32):
  127. """Filter images too small."""
  128. if self.filter_empty_gt:
  129. warnings.warn(
  130. 'CustomDataset does not support filtering empty gt images.')
  131. valid_inds = []
  132. for i, img_info in enumerate(self.data_infos):
  133. if min(img_info['width'], img_info['height']) >= min_size:
  134. valid_inds.append(i)
  135. return valid_inds
  136. def _set_group_flag(self):
  137. """Set flag according to image aspect ratio.
  138. Images with aspect ratio greater than 1 will be set as group 1,
  139. otherwise group 0.
  140. """
  141. self.flag = np.zeros(len(self), dtype=np.uint8)
  142. for i in range(len(self)):
  143. img_info = self.data_infos[i]
  144. if img_info['width'] / img_info['height'] > 1:
  145. self.flag[i] = 1
  146. def _rand_another(self, idx):
  147. """Get another random index from the same group as the given index."""
  148. pool = np.where(self.flag == self.flag[idx])[0]
  149. return np.random.choice(pool)
  150. def __getitem__(self, idx):
  151. """Get training/test data after pipeline.
  152. Args:
  153. idx (int): Index of data.
  154. Returns:
  155. dict: Training/test data (with annotation if `test_mode` is set \
  156. True).
  157. """
  158. if self.test_mode:
  159. return self.prepare_test_img(idx)
  160. while True:
  161. data = self.prepare_train_img(idx)
  162. if data is None:
  163. idx = self._rand_another(idx)
  164. continue
  165. return data
  166. def prepare_train_img(self, idx):
  167. """Get training data and annotations after pipeline.
  168. Args:
  169. idx (int): Index of data.
  170. Returns:
  171. dict: Training data and annotation after pipeline with new keys \
  172. introduced by pipeline.
  173. """
  174. img_info = self.data_infos[idx]
  175. ann_info = self.get_ann_info(idx)
  176. results = dict(img_info=img_info, ann_info=ann_info)
  177. if self.proposals is not None:
  178. results['proposals'] = self.proposals[idx]
  179. self.pre_pipeline(results)
  180. return self.pipeline(results)
  181. def prepare_test_img(self, idx):
  182. """Get testing data after pipeline.
  183. Args:
  184. idx (int): Index of data.
  185. Returns:
  186. dict: Testing data after pipeline with new keys introduced by \
  187. pipeline.
  188. """
  189. img_info = self.data_infos[idx]
  190. results = dict(img_info=img_info)
  191. if self.proposals is not None:
  192. results['proposals'] = self.proposals[idx]
  193. self.pre_pipeline(results)
  194. return self.pipeline(results)
  195. @classmethod
  196. def get_classes(cls, classes=None):
  197. """Get class names of current dataset.
  198. Args:
  199. classes (Sequence[str] | str | None): If classes is None, use
  200. default CLASSES defined by builtin dataset. If classes is a
  201. string, take it as a file name. The file contains the name of
  202. classes where each line contains one class name. If classes is
  203. a tuple or list, override the CLASSES defined by the dataset.
  204. Returns:
  205. tuple[str] or list[str]: Names of categories of the dataset.
  206. """
  207. if classes is None:
  208. return cls.CLASSES
  209. if isinstance(classes, str):
  210. # take it as a file path
  211. try:
  212. class_names = mmcv.list_from_file(classes)
  213. except:
  214. class_names = [classes]
  215. elif isinstance(classes, (tuple, list)):
  216. class_names = classes
  217. else:
  218. raise ValueError(f'Unsupported type {type(classes)} of classes.')
  219. return class_names
  220. def format_results(self, results, **kwargs):
  221. """Place holder to format result to dataset specific output."""
  222. def evaluate(self,
  223. results,
  224. metric='mAP',
  225. logger=None,
  226. proposal_nums=(100, 300, 1000),
  227. iou_thr=0.5,
  228. scale_ranges=None):
  229. """Evaluate the dataset.
  230. Args:
  231. results (list): Testing results of the dataset.
  232. metric (str | list[str]): Metrics to be evaluated.
  233. logger (logging.Logger | None | str): Logger used for printing
  234. related information during evaluation. Default: None.
  235. proposal_nums (Sequence[int]): Proposal number used for evaluating
  236. recalls, such as recall@100, recall@1000.
  237. Default: (100, 300, 1000).
  238. iou_thr (float | list[float]): IoU threshold. Default: 0.5.
  239. scale_ranges (list[tuple] | None): Scale ranges for evaluating mAP.
  240. Default: None.
  241. """
  242. if not isinstance(metric, str):
  243. assert len(metric) == 1
  244. metric = metric[0]
  245. allowed_metrics = ['mAP', 'recall']
  246. if metric not in allowed_metrics:
  247. raise KeyError(f'metric {metric} is not supported')
  248. annotations = [self.get_ann_info(i) for i in range(len(self))]
  249. eval_results = OrderedDict()
  250. iou_thrs = [iou_thr] if isinstance(iou_thr, float) else iou_thr
  251. if metric == 'mAP':
  252. assert isinstance(iou_thrs, list)
  253. mean_aps = []
  254. for iou_thr in iou_thrs:
  255. print_log(f'\n{"-" * 15}iou_thr: {iou_thr}{"-" * 15}')
  256. mean_ap, _ = eval_map(
  257. results,
  258. annotations,
  259. scale_ranges=scale_ranges,
  260. iou_thr=iou_thr,
  261. dataset=self.CLASSES,
  262. logger=logger)
  263. mean_aps.append(mean_ap)
  264. eval_results[f'AP{int(iou_thr * 100):02d}'] = round(mean_ap, 3)
  265. eval_results['mAP'] = sum(mean_aps) / len(mean_aps)
  266. elif metric == 'recall':
  267. gt_bboxes = [ann['bboxes'] for ann in annotations]
  268. recalls = eval_recalls(
  269. gt_bboxes, results, proposal_nums, iou_thr, logger=logger)
  270. for i, num in enumerate(proposal_nums):
  271. for j, iou in enumerate(iou_thrs):
  272. eval_results[f'recall@{num}@{iou}'] = recalls[i, j]
  273. if recalls.shape[1] > 1:
  274. ar = recalls.mean(axis=1)
  275. for i, num in enumerate(proposal_nums):
  276. eval_results[f'AR@{num}'] = ar[i]
  277. return eval_results
  278. def __repr__(self):
  279. """Print the number of instance number."""
  280. dataset_type = 'Test' if self.test_mode else 'Train'
  281. result = (f'\n{self.__class__.__name__} {dataset_type} dataset '
  282. f'with number of images {len(self)}, '
  283. f'and instance counts: \n')
  284. if self.CLASSES is None:
  285. result += 'Category names are not provided. \n'
  286. return result
  287. instance_count = np.zeros(len(self.CLASSES) + 1).astype(int)
  288. # count the instance number in each image
  289. for idx in range(len(self)):
  290. label = self.get_ann_info(idx)['labels']
  291. unique, counts = np.unique(label, return_counts=True)
  292. if len(unique) > 0:
  293. # add the occurrence number to each class
  294. instance_count[unique] += counts
  295. else:
  296. # background is the last index
  297. instance_count[-1] += 1
  298. # create a table with category count
  299. table_data = [['category', 'count'] * 5]
  300. row_data = []
  301. for cls, count in enumerate(instance_count):
  302. if cls < len(self.CLASSES):
  303. row_data += [f'{cls} [{self.CLASSES[cls]}]', f'{count}']
  304. else:
  305. # add the background number
  306. row_data += ['-1 background', f'{count}']
  307. if len(row_data) == 10:
  308. table_data.append(row_data)
  309. row_data = []
  310. if len(row_data) >= 2:
  311. if row_data[-1] == '0':
  312. row_data = row_data[:-2]
  313. if len(row_data) >= 2:
  314. table_data.append([])
  315. table_data.append(row_data)
  316. table = AsciiTable(table_data)
  317. result += table.table
  318. return result

No Description

Contributors (2)