You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

builder.py 7.1 kB

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import copy
  3. import platform
  4. import random
  5. from functools import partial
  6. import numpy as np
  7. from mmcv.parallel import collate
  8. from mmcv.runner import get_dist_info
  9. from mmcv.utils import Registry, build_from_cfg
  10. from torch.utils.data import DataLoader
  11. from .samplers import (DistributedGroupSampler, DistributedSampler,
  12. GroupSampler, InfiniteBatchSampler,
  13. InfiniteGroupBatchSampler)
  14. if platform.system() != 'Windows':
  15. # https://github.com/pytorch/pytorch/issues/973
  16. import resource
  17. rlimit = resource.getrlimit(resource.RLIMIT_NOFILE)
  18. base_soft_limit = rlimit[0]
  19. hard_limit = rlimit[1]
  20. soft_limit = min(max(4096, base_soft_limit), hard_limit)
  21. resource.setrlimit(resource.RLIMIT_NOFILE, (soft_limit, hard_limit))
  22. DATASETS = Registry('dataset')
  23. PIPELINES = Registry('pipeline')
  24. def _concat_dataset(cfg, default_args=None):
  25. from .dataset_wrappers import ConcatDataset
  26. ann_files = cfg['ann_file']
  27. img_prefixes = cfg.get('img_prefix', None)
  28. seg_prefixes = cfg.get('seg_prefix', None)
  29. proposal_files = cfg.get('proposal_file', None)
  30. separate_eval = cfg.get('separate_eval', True)
  31. datasets = []
  32. num_dset = len(ann_files)
  33. for i in range(num_dset):
  34. data_cfg = copy.deepcopy(cfg)
  35. # pop 'separate_eval' since it is not a valid key for common datasets.
  36. if 'separate_eval' in data_cfg:
  37. data_cfg.pop('separate_eval')
  38. data_cfg['ann_file'] = ann_files[i]
  39. if isinstance(img_prefixes, (list, tuple)):
  40. data_cfg['img_prefix'] = img_prefixes[i]
  41. if isinstance(seg_prefixes, (list, tuple)):
  42. data_cfg['seg_prefix'] = seg_prefixes[i]
  43. if isinstance(proposal_files, (list, tuple)):
  44. data_cfg['proposal_file'] = proposal_files[i]
  45. datasets.append(build_dataset(data_cfg, default_args))
  46. return ConcatDataset(datasets, separate_eval)
  47. def build_dataset(cfg, default_args=None):
  48. from .dataset_wrappers import (ConcatDataset, RepeatDataset,
  49. ClassBalancedDataset, MultiImageMixDataset, AD_ClassBalancedDataset)
  50. if isinstance(cfg, (list, tuple)):
  51. dataset = ConcatDataset([build_dataset(c, default_args) for c in cfg])
  52. elif cfg['type'] == 'ConcatDataset':
  53. dataset = ConcatDataset(
  54. [build_dataset(c, default_args) for c in cfg['datasets']],
  55. cfg.get('separate_eval', True))
  56. elif cfg['type'] == 'RepeatDataset':
  57. dataset = RepeatDataset(
  58. build_dataset(cfg['dataset'], default_args), cfg['times'])
  59. elif cfg['type'] == 'ClassBalancedDataset':
  60. dataset = ClassBalancedDataset(
  61. build_dataset(cfg['dataset'], default_args), cfg['oversample_thr'])
  62. elif cfg['type'] == 'AD_ClassBalancedDataset':
  63. dataset = AD_ClassBalancedDataset(
  64. build_dataset(cfg['dataset'], default_args), cfg['oversample_thr'])
  65. elif cfg['type'] == 'MultiImageMixDataset':
  66. cp_cfg = copy.deepcopy(cfg)
  67. cp_cfg['dataset'] = build_dataset(cp_cfg['dataset'])
  68. cp_cfg.pop('type')
  69. dataset = MultiImageMixDataset(**cp_cfg)
  70. elif isinstance(cfg.get('ann_file'), (list, tuple)):
  71. dataset = _concat_dataset(cfg, default_args)
  72. else:
  73. dataset = build_from_cfg(cfg, DATASETS, default_args)
  74. return dataset
  75. def build_dataloader(dataset,
  76. samples_per_gpu,
  77. workers_per_gpu,
  78. num_gpus=1,
  79. dist=True,
  80. shuffle=True,
  81. seed=None,
  82. runner_type='EpochBasedRunner',
  83. **kwargs):
  84. """Build PyTorch DataLoader.
  85. In distributed training, each GPU/process has a dataloader.
  86. In non-distributed training, there is only one dataloader for all GPUs.
  87. Args:
  88. dataset (Dataset): A PyTorch dataset.
  89. samples_per_gpu (int): Number of training samples on each GPU, i.e.,
  90. batch size of each GPU.
  91. workers_per_gpu (int): How many subprocesses to use for data loading
  92. for each GPU.
  93. num_gpus (int): Number of GPUs. Only used in non-distributed training.
  94. dist (bool): Distributed training/test or not. Default: True.
  95. shuffle (bool): Whether to shuffle the data at every epoch.
  96. Default: True.
  97. runner_type (str): Type of runner. Default: `EpochBasedRunner`
  98. kwargs: any keyword argument to be used to initialize DataLoader
  99. Returns:
  100. DataLoader: A PyTorch dataloader.
  101. """
  102. rank, world_size = get_dist_info()
  103. print(dataset)
  104. if dist:
  105. # When model is :obj:`DistributedDataParallel`,
  106. # `batch_size` of :obj:`dataloader` is the
  107. # number of training samples on each GPU.
  108. batch_size = samples_per_gpu
  109. num_workers = workers_per_gpu
  110. else:
  111. # When model is obj:`DataParallel`
  112. # the batch size is samples on all the GPUS
  113. batch_size = num_gpus * samples_per_gpu
  114. num_workers = num_gpus * workers_per_gpu
  115. if runner_type == 'IterBasedRunner':
  116. # this is a batch sampler, which can yield
  117. # a mini-batch indices each time.
  118. # it can be used in both `DataParallel` and
  119. # `DistributedDataParallel`
  120. if shuffle:
  121. batch_sampler = InfiniteGroupBatchSampler(
  122. dataset, batch_size, world_size, rank, seed=seed)
  123. else:
  124. batch_sampler = InfiniteBatchSampler(
  125. dataset,
  126. batch_size,
  127. world_size,
  128. rank,
  129. seed=seed,
  130. shuffle=False)
  131. batch_size = 1
  132. sampler = None
  133. else:
  134. if dist:
  135. # DistributedGroupSampler will definitely shuffle the data to
  136. # satisfy that images on each GPU are in the same group
  137. if shuffle:
  138. sampler = DistributedGroupSampler(
  139. dataset, samples_per_gpu, world_size, rank, seed=seed)
  140. else:
  141. sampler = DistributedSampler(
  142. dataset, world_size, rank, shuffle=False, seed=seed)
  143. else:
  144. sampler = GroupSampler(dataset,
  145. samples_per_gpu) if shuffle else None
  146. batch_sampler = None
  147. init_fn = partial(
  148. worker_init_fn, num_workers=num_workers, rank=rank,
  149. seed=seed) if seed is not None else None
  150. data_loader = DataLoader(
  151. dataset,
  152. batch_size=batch_size,
  153. sampler=sampler,
  154. num_workers=num_workers,
  155. batch_sampler=batch_sampler,
  156. collate_fn=partial(collate, samples_per_gpu=samples_per_gpu),
  157. pin_memory=False,
  158. worker_init_fn=init_fn,
  159. **kwargs)
  160. return data_loader
  161. def worker_init_fn(worker_id, num_workers, rank, seed):
  162. # The seed of each worker equals to
  163. # num_worker * rank + worker_id + user_seed
  164. worker_seed = num_workers * rank + worker_id + seed
  165. np.random.seed(worker_seed)
  166. random.seed(worker_seed)

No Description

Contributors (3)