You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

dataset_wrappers.py 17 kB

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import bisect
  3. import collections
  4. import copy
  5. import math
  6. from collections import defaultdict
  7. import numpy as np
  8. from mmcv.utils import build_from_cfg, print_log
  9. from torch.utils.data.dataset import ConcatDataset as _ConcatDataset
  10. from .builder import DATASETS, PIPELINES
  11. from .coco import CocoDataset
  12. @DATASETS.register_module()
  13. class ConcatDataset(_ConcatDataset):
  14. """A wrapper of concatenated dataset.
  15. Same as :obj:`torch.utils.data.dataset.ConcatDataset`, but
  16. concat the group flag for image aspect ratio.
  17. Args:
  18. datasets (list[:obj:`Dataset`]): A list of datasets.
  19. separate_eval (bool): Whether to evaluate the results
  20. separately if it is used as validation dataset.
  21. Defaults to True.
  22. """
  23. def __init__(self, datasets, separate_eval=True):
  24. super(ConcatDataset, self).__init__(datasets)
  25. self.CLASSES = datasets[0].CLASSES
  26. self.separate_eval = separate_eval
  27. if not separate_eval:
  28. if any([isinstance(ds, CocoDataset) for ds in datasets]):
  29. raise NotImplementedError(
  30. 'Evaluating concatenated CocoDataset as a whole is not'
  31. ' supported! Please set "separate_eval=True"')
  32. elif len(set([type(ds) for ds in datasets])) != 1:
  33. raise NotImplementedError(
  34. 'All the datasets should have same types')
  35. if hasattr(datasets[0], 'flag'):
  36. flags = []
  37. for i in range(0, len(datasets)):
  38. flags.append(datasets[i].flag)
  39. self.flag = np.concatenate(flags)
  40. def get_cat_ids(self, idx):
  41. """Get category ids of concatenated dataset by index.
  42. Args:
  43. idx (int): Index of data.
  44. Returns:
  45. list[int]: All categories in the image of specified index.
  46. """
  47. if idx < 0:
  48. if -idx > len(self):
  49. raise ValueError(
  50. 'absolute value of index should not exceed dataset length')
  51. idx = len(self) + idx
  52. dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
  53. if dataset_idx == 0:
  54. sample_idx = idx
  55. else:
  56. sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
  57. return self.datasets[dataset_idx].get_cat_ids(sample_idx)
  58. def evaluate(self, results, logger=None, **kwargs):
  59. """Evaluate the results.
  60. Args:
  61. results (list[list | tuple]): Testing results of the dataset.
  62. logger (logging.Logger | str | None): Logger used for printing
  63. related information during evaluation. Default: None.
  64. Returns:
  65. dict[str: float]: AP results of the total dataset or each separate
  66. dataset if `self.separate_eval=True`.
  67. """
  68. assert len(results) == self.cumulative_sizes[-1], \
  69. ('Dataset and results have different sizes: '
  70. f'{self.cumulative_sizes[-1]} v.s. {len(results)}')
  71. # Check whether all the datasets support evaluation
  72. for dataset in self.datasets:
  73. assert hasattr(dataset, 'evaluate'), \
  74. f'{type(dataset)} does not implement evaluate function'
  75. if self.separate_eval:
  76. dataset_idx = -1
  77. total_eval_results = dict()
  78. for size, dataset in zip(self.cumulative_sizes, self.datasets):
  79. start_idx = 0 if dataset_idx == -1 else \
  80. self.cumulative_sizes[dataset_idx]
  81. end_idx = self.cumulative_sizes[dataset_idx + 1]
  82. results_per_dataset = results[start_idx:end_idx]
  83. print_log(
  84. f'\nEvaluateing {dataset.ann_file} with '
  85. f'{len(results_per_dataset)} images now',
  86. logger=logger)
  87. eval_results_per_dataset = dataset.evaluate(
  88. results_per_dataset, logger=logger, **kwargs)
  89. dataset_idx += 1
  90. for k, v in eval_results_per_dataset.items():
  91. total_eval_results.update({f'{dataset_idx}_{k}': v})
  92. return total_eval_results
  93. elif any([isinstance(ds, CocoDataset) for ds in self.datasets]):
  94. raise NotImplementedError(
  95. 'Evaluating concatenated CocoDataset as a whole is not'
  96. ' supported! Please set "separate_eval=True"')
  97. elif len(set([type(ds) for ds in self.datasets])) != 1:
  98. raise NotImplementedError(
  99. 'All the datasets should have same types')
  100. else:
  101. original_data_infos = self.datasets[0].data_infos
  102. self.datasets[0].data_infos = sum(
  103. [dataset.data_infos for dataset in self.datasets], [])
  104. eval_results = self.datasets[0].evaluate(
  105. results, logger=logger, **kwargs)
  106. self.datasets[0].data_infos = original_data_infos
  107. return eval_results
  108. @DATASETS.register_module()
  109. class RepeatDataset:
  110. """A wrapper of repeated dataset.
  111. The length of repeated dataset will be `times` larger than the original
  112. dataset. This is useful when the data loading time is long but the dataset
  113. is small. Using RepeatDataset can reduce the data loading time between
  114. epochs.
  115. Args:
  116. dataset (:obj:`Dataset`): The dataset to be repeated.
  117. times (int): Repeat times.
  118. """
  119. def __init__(self, dataset, times):
  120. self.dataset = dataset
  121. self.times = times
  122. self.CLASSES = dataset.CLASSES
  123. if hasattr(self.dataset, 'flag'):
  124. self.flag = np.tile(self.dataset.flag, times)
  125. self._ori_len = len(self.dataset)
  126. def __getitem__(self, idx):
  127. return self.dataset[idx % self._ori_len]
  128. def get_cat_ids(self, idx):
  129. """Get category ids of repeat dataset by index.
  130. Args:
  131. idx (int): Index of data.
  132. Returns:
  133. list[int]: All categories in the image of specified index.
  134. """
  135. return self.dataset.get_cat_ids(idx % self._ori_len)
  136. def __len__(self):
  137. """Length after repetition."""
  138. return self.times * self._ori_len
  139. # Modified from https://github.com/facebookresearch/detectron2/blob/41d475b75a230221e21d9cac5d69655e3415e3a4/detectron2/data/samplers/distributed_sampler.py#L57 # noqa
  140. @DATASETS.register_module()
  141. class ClassBalancedDataset:
  142. """A wrapper of repeated dataset with repeat factor.
  143. Suitable for training on class imbalanced datasets like LVIS. Following
  144. the sampling strategy in the `paper <https://arxiv.org/abs/1908.03195>`_,
  145. in each epoch, an image may appear multiple times based on its
  146. "repeat factor".
  147. The repeat factor for an image is a function of the frequency the rarest
  148. category labeled in that image. The "frequency of category c" in [0, 1]
  149. is defined by the fraction of images in the training set (without repeats)
  150. in which category c appears.
  151. The dataset needs to instantiate :func:`self.get_cat_ids` to support
  152. ClassBalancedDataset.
  153. The repeat factor is computed as followed.
  154. 1. For each category c, compute the fraction # of images
  155. that contain it: :math:`f(c)`
  156. 2. For each category c, compute the category-level repeat factor:
  157. :math:`r(c) = max(1, sqrt(t/f(c)))`
  158. 3. For each image I, compute the image-level repeat factor:
  159. :math:`r(I) = max_{c in I} r(c)`
  160. Args:
  161. dataset (:obj:`CustomDataset`): The dataset to be repeated.
  162. oversample_thr (float): frequency threshold below which data is
  163. repeated. For categories with ``f_c >= oversample_thr``, there is
  164. no oversampling. For categories with ``f_c < oversample_thr``, the
  165. degree of oversampling following the square-root inverse frequency
  166. heuristic above.
  167. filter_empty_gt (bool, optional): If set true, images without bounding
  168. boxes will not be oversampled. Otherwise, they will be categorized
  169. as the pure background class and involved into the oversampling.
  170. Default: True.
  171. """
  172. def __init__(self, dataset, oversample_thr, filter_empty_gt=True):
  173. self.dataset = dataset
  174. self.oversample_thr = oversample_thr
  175. self.filter_empty_gt = filter_empty_gt
  176. self.CLASSES = dataset.CLASSES
  177. repeat_factors = self._get_repeat_factors(dataset, oversample_thr)
  178. repeat_indices = []
  179. for dataset_idx, repeat_factor in enumerate(repeat_factors):
  180. repeat_indices.extend([dataset_idx] * math.ceil(repeat_factor))
  181. self.repeat_indices = repeat_indices
  182. flags = []
  183. if hasattr(self.dataset, 'flag'):
  184. for flag, repeat_factor in zip(self.dataset.flag, repeat_factors):
  185. flags.extend([flag] * int(math.ceil(repeat_factor)))
  186. assert len(flags) == len(repeat_indices)
  187. self.flag = np.asarray(flags, dtype=np.uint8)
  188. def _get_repeat_factors(self, dataset, repeat_thr):
  189. """Get repeat factor for each images in the dataset.
  190. Args:
  191. dataset (:obj:`CustomDataset`): The dataset
  192. repeat_thr (float): The threshold of frequency. If an image
  193. contains the categories whose frequency below the threshold,
  194. it would be repeated.
  195. Returns:
  196. list[float]: The repeat factors for each images in the dataset.
  197. """
  198. # 1. For each category c, compute the fraction # of images
  199. # that contain it: f(c)
  200. category_freq = defaultdict(int)
  201. num_images = len(dataset)
  202. for idx in range(num_images):
  203. cat_ids = set(self.dataset.get_cat_ids(idx))
  204. if len(cat_ids) == 0 and not self.filter_empty_gt:
  205. cat_ids = set([len(self.CLASSES)])
  206. for cat_id in cat_ids:
  207. category_freq[cat_id] += 1
  208. for k, v in category_freq.items():
  209. category_freq[k] = v / num_images
  210. # 2. For each category c, compute the category-level repeat factor:
  211. # r(c) = max(1, sqrt(t/f(c)))
  212. category_repeat = {
  213. cat_id: max(1.0, math.sqrt(repeat_thr / cat_freq))
  214. for cat_id, cat_freq in category_freq.items()
  215. }
  216. # 3. For each image I, compute the image-level repeat factor:
  217. # r(I) = max_{c in I} r(c)
  218. repeat_factors = []
  219. for idx in range(num_images):
  220. cat_ids = set(self.dataset.get_cat_ids(idx))
  221. if len(cat_ids) == 0 and not self.filter_empty_gt:
  222. cat_ids = set([len(self.CLASSES)])
  223. repeat_factor = 1
  224. if len(cat_ids) > 0:
  225. repeat_factor = max(
  226. {category_repeat[cat_id]
  227. for cat_id in cat_ids})
  228. repeat_factors.append(repeat_factor)
  229. return repeat_factors
  230. def __getitem__(self, idx):
  231. ori_index = self.repeat_indices[idx]
  232. return self.dataset[ori_index]
  233. def __len__(self):
  234. """Length after repetition."""
  235. return len(self.repeat_indices)
  236. @DATASETS.register_module()
  237. class AD_ClassBalancedDataset:
  238. def __init__(self, dataset, oversample_thr=1.0, filter_empty_gt=False):
  239. self.dataset = dataset
  240. self.oversample_thr = oversample_thr
  241. self.filter_empty_gt = filter_empty_gt
  242. self.CLASSES = dataset.CLASSES
  243. repeat_factors = self._get_repeat_factors(dataset, oversample_thr)
  244. repeat_indices = []
  245. for dataset_idx, repeat_factor in enumerate(repeat_factors):
  246. repeat_indices.extend([dataset_idx] * math.ceil(repeat_factor))
  247. self.repeat_indices = repeat_indices
  248. flags = []
  249. if hasattr(self.dataset, 'flag'):
  250. for flag, repeat_factor in zip(self.dataset.flag, repeat_factors):
  251. flags.extend([flag] * int(math.ceil(repeat_factor)))
  252. assert len(flags) == len(repeat_indices)
  253. self.flag = np.asarray(flags, dtype=np.uint8)
  254. def _get_repeat_factors(self, dataset, repeat_thr):
  255. num_images = len(dataset)
  256. num_ok = 0
  257. num_ng = 0
  258. repeat_factors = []
  259. for idx in range(num_images):
  260. cat_ids = set(self.dataset.get_cat_ids(idx))
  261. if len(cat_ids) == 0:
  262. num_ok += 1
  263. else:
  264. num_ng += 1
  265. for idx in range(num_images):
  266. cat_ids = set(self.dataset.get_cat_ids(idx))
  267. if len(cat_ids) == 0:
  268. repeat_factor = 1
  269. else:
  270. repeat_factor = max(1.0, num_ok/num_ng*repeat_thr)
  271. repeat_factors.append(repeat_factor)
  272. return repeat_factors
  273. def __getitem__(self, idx):
  274. ori_index = self.repeat_indices[idx]
  275. return self.dataset[ori_index]
  276. def __len__(self):
  277. """Length after repetition."""
  278. return len(self.repeat_indices)
  279. @DATASETS.register_module()
  280. class MultiImageMixDataset:
  281. """A wrapper of multiple images mixed dataset.
  282. Suitable for training on multiple images mixed data augmentation like
  283. mosaic and mixup. For the augmentation pipeline of mixed image data,
  284. the `get_indexes` method needs to be provided to obtain the image
  285. indexes, and you can set `skip_flags` to change the pipeline running
  286. process. At the same time, we provide the `dynamic_scale` parameter
  287. to dynamically change the output image size.
  288. Args:
  289. dataset (:obj:`CustomDataset`): The dataset to be mixed.
  290. pipeline (Sequence[dict]): Sequence of transform object or
  291. config dict to be composed.
  292. dynamic_scale (tuple[int], optional): The image scale can be changed
  293. dynamically. Default to None.
  294. skip_type_keys (list[str], optional): Sequence of type string to
  295. be skip pipeline. Default to None.
  296. """
  297. def __init__(self,
  298. dataset,
  299. pipeline,
  300. dynamic_scale=None,
  301. skip_type_keys=None):
  302. assert isinstance(pipeline, collections.abc.Sequence)
  303. if skip_type_keys is not None:
  304. assert all([
  305. isinstance(skip_type_key, str)
  306. for skip_type_key in skip_type_keys
  307. ])
  308. self._skip_type_keys = skip_type_keys
  309. self.pipeline = []
  310. self.pipeline_types = []
  311. for transform in pipeline:
  312. if isinstance(transform, dict):
  313. self.pipeline_types.append(transform['type'])
  314. transform = build_from_cfg(transform, PIPELINES)
  315. self.pipeline.append(transform)
  316. else:
  317. raise TypeError('pipeline must be a dict')
  318. self.dataset = dataset
  319. self.CLASSES = dataset.CLASSES
  320. if hasattr(self.dataset, 'flag'):
  321. self.flag = dataset.flag
  322. self.num_samples = len(dataset)
  323. if dynamic_scale is not None:
  324. assert isinstance(dynamic_scale, tuple)
  325. self._dynamic_scale = dynamic_scale
  326. def __len__(self):
  327. return self.num_samples
  328. def __getitem__(self, idx):
  329. results = copy.deepcopy(self.dataset[idx])
  330. for (transform, transform_type) in zip(self.pipeline,
  331. self.pipeline_types):
  332. if self._skip_type_keys is not None and \
  333. transform_type in self._skip_type_keys:
  334. continue
  335. if hasattr(transform, 'get_indexes'):
  336. indexes = transform.get_indexes(self.dataset)
  337. if not isinstance(indexes, collections.abc.Sequence):
  338. indexes = [indexes]
  339. mix_results = [
  340. copy.deepcopy(self.dataset[index]) for index in indexes
  341. ]
  342. results['mix_results'] = mix_results
  343. if self._dynamic_scale is not None:
  344. # Used for subsequent pipeline to automatically change
  345. # the output image size. E.g MixUp, Resize.
  346. results['scale'] = self._dynamic_scale
  347. results = transform(results)
  348. if 'mix_results' in results:
  349. results.pop('mix_results')
  350. return results
  351. def update_skip_type_keys(self, skip_type_keys):
  352. """Update skip_type_keys. It is called by an external hook.
  353. Args:
  354. skip_type_keys (list[str], optional): Sequence of type
  355. string to be skip pipeline.
  356. """
  357. assert all([
  358. isinstance(skip_type_key, str) for skip_type_key in skip_type_keys
  359. ])
  360. self._skip_type_keys = skip_type_keys
  361. def update_dynamic_scale(self, dynamic_scale):
  362. """Update dynamic_scale. It is called by an external hook.
  363. Args:
  364. dynamic_scale (tuple[int]): The image scale can be
  365. changed dynamically.
  366. """
  367. assert isinstance(dynamic_scale, tuple)
  368. self._dynamic_scale = dynamic_scale

No Description

Contributors (1)