You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

infinite_sampler.py 6.2 kB

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171
  1. import itertools
  2. import numpy as np
  3. import torch
  4. from mmcv.runner import get_dist_info
  5. from torch.utils.data.sampler import Sampler
  6. class InfiniteGroupBatchSampler(Sampler):
  7. """Similar to `BatchSampler` warping a `GroupSampler. It is designed for
  8. iteration-based runners like `IterBasedRunner` and yields a mini-batch
  9. indices each time, all indices in a batch should be in the same group.
  10. The implementation logic is referred to
  11. https://github.com/facebookresearch/detectron2/blob/main/detectron2/data/samplers/grouped_batch_sampler.py
  12. Args:
  13. dataset (object): The dataset.
  14. batch_size (int): When model is :obj:`DistributedDataParallel`,
  15. it is the number of training samples on each GPU.
  16. When model is :obj:`DataParallel`, it is
  17. `num_gpus * samples_per_gpu`.
  18. Default : 1.
  19. world_size (int, optional): Number of processes participating in
  20. distributed training. Default: None.
  21. rank (int, optional): Rank of current process. Default: None.
  22. seed (int): Random seed. Default: 0.
  23. shuffle (bool): Whether shuffle the indices of a dummy `epoch`, it
  24. should be noted that `shuffle` can not guarantee that you can
  25. generate sequential indices because it need to ensure
  26. that all indices in a batch is in a group. Default: True.
  27. """ # noqa: W605
  28. def __init__(self,
  29. dataset,
  30. batch_size=1,
  31. world_size=None,
  32. rank=None,
  33. seed=0,
  34. shuffle=True):
  35. _rank, _world_size = get_dist_info()
  36. if world_size is None:
  37. world_size = _world_size
  38. if rank is None:
  39. rank = _rank
  40. self.rank = rank
  41. self.world_size = world_size
  42. self.dataset = dataset
  43. self.batch_size = batch_size
  44. self.seed = seed if seed is not None else 0
  45. self.shuffle = shuffle
  46. assert hasattr(self.dataset, 'flag')
  47. self.flag = self.dataset.flag
  48. self.group_sizes = np.bincount(self.flag)
  49. # buffer used to save indices of each group
  50. self.buffer_per_group = {k: [] for k in range(len(self.group_sizes))}
  51. self.size = len(dataset)
  52. self.indices = self._indices_of_rank()
  53. def _infinite_indices(self):
  54. """Infinitely yield a sequence of indices."""
  55. g = torch.Generator()
  56. g.manual_seed(self.seed)
  57. while True:
  58. if self.shuffle:
  59. yield from torch.randperm(self.size, generator=g).tolist()
  60. else:
  61. yield from torch.arange(self.size).tolist()
  62. def _indices_of_rank(self):
  63. """Slice the infinite indices by rank."""
  64. yield from itertools.islice(self._infinite_indices(), self.rank, None,
  65. self.world_size)
  66. def __iter__(self):
  67. # once batch size is reached, yield the indices
  68. for idx in self.indices:
  69. flag = self.flag[idx]
  70. group_buffer = self.buffer_per_group[flag]
  71. group_buffer.append(idx)
  72. if len(group_buffer) == self.batch_size:
  73. yield group_buffer[:]
  74. del group_buffer[:]
  75. def __len__(self):
  76. """Length of base dataset."""
  77. return self.size
  78. def set_epoch(self, epoch):
  79. """Not supported in `IterationBased` runner."""
  80. raise NotImplementedError
  81. class InfiniteBatchSampler(Sampler):
  82. """Similar to `BatchSampler` warping a `DistributedSampler. It is designed
  83. iteration-based runners like `IterBasedRunner` and yields a mini-batch
  84. indices each time.
  85. The implementation logic is referred to
  86. https://github.com/facebookresearch/detectron2/blob/main/detectron2/data/samplers/grouped_batch_sampler.py
  87. Args:
  88. dataset (object): The dataset.
  89. batch_size (int): When model is :obj:`DistributedDataParallel`,
  90. it is the number of training samples on each GPU,
  91. When model is :obj:`DataParallel`, it is
  92. `num_gpus * samples_per_gpu`.
  93. Default : 1.
  94. world_size (int, optional): Number of processes participating in
  95. distributed training. Default: None.
  96. rank (int, optional): Rank of current process. Default: None.
  97. seed (int): Random seed. Default: 0.
  98. shuffle (bool): Whether shuffle the dataset or not. Default: True.
  99. """ # noqa: W605
  100. def __init__(self,
  101. dataset,
  102. batch_size=1,
  103. world_size=None,
  104. rank=None,
  105. seed=0,
  106. shuffle=True):
  107. _rank, _world_size = get_dist_info()
  108. if world_size is None:
  109. world_size = _world_size
  110. if rank is None:
  111. rank = _rank
  112. self.rank = rank
  113. self.world_size = world_size
  114. self.dataset = dataset
  115. self.batch_size = batch_size
  116. self.seed = seed if seed is not None else 0
  117. self.shuffle = shuffle
  118. self.size = len(dataset)
  119. self.indices = self._indices_of_rank()
  120. def _infinite_indices(self):
  121. """Infinitely yield a sequence of indices."""
  122. g = torch.Generator()
  123. g.manual_seed(self.seed)
  124. while True:
  125. if self.shuffle:
  126. yield from torch.randperm(self.size, generator=g).tolist()
  127. else:
  128. yield from torch.arange(self.size).tolist()
  129. def _indices_of_rank(self):
  130. """Slice the infinite indices by rank."""
  131. yield from itertools.islice(self._infinite_indices(), self.rank, None,
  132. self.world_size)
  133. def __iter__(self):
  134. # once batch size is reached, yield the indices
  135. batch_buffer = []
  136. for idx in self.indices:
  137. batch_buffer.append(idx)
  138. if len(batch_buffer) == self.batch_size:
  139. yield batch_buffer
  140. batch_buffer = []
  141. def __len__(self):
  142. """Length of base dataset."""
  143. return self.size
  144. def set_epoch(self, epoch):
  145. """Not supported in `IterationBased` runner."""
  146. raise NotImplementedError

No Description

Contributors (1)