You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

group_sampler.py 5.4 kB

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import math
  3. import numpy as np
  4. import torch
  5. from mmcv.runner import get_dist_info
  6. from torch.utils.data import Sampler
  7. class GroupSampler(Sampler):
  8. def __init__(self, dataset, samples_per_gpu=1):
  9. assert hasattr(dataset, 'flag')
  10. self.dataset = dataset
  11. self.samples_per_gpu = samples_per_gpu
  12. self.flag = dataset.flag.astype(np.int64)
  13. self.group_sizes = np.bincount(self.flag)
  14. self.num_samples = 0
  15. for i, size in enumerate(self.group_sizes):
  16. self.num_samples += int(np.ceil(
  17. size / self.samples_per_gpu)) * self.samples_per_gpu
  18. def __iter__(self):
  19. indices = []
  20. for i, size in enumerate(self.group_sizes):
  21. if size == 0:
  22. continue
  23. indice = np.where(self.flag == i)[0]
  24. assert len(indice) == size
  25. np.random.shuffle(indice)
  26. num_extra = int(np.ceil(size / self.samples_per_gpu)
  27. ) * self.samples_per_gpu - len(indice)
  28. indice = np.concatenate(
  29. [indice, np.random.choice(indice, num_extra)])
  30. indices.append(indice)
  31. indices = np.concatenate(indices)
  32. indices = [
  33. indices[i * self.samples_per_gpu:(i + 1) * self.samples_per_gpu]
  34. for i in np.random.permutation(
  35. range(len(indices) // self.samples_per_gpu))
  36. ]
  37. indices = np.concatenate(indices)
  38. indices = indices.astype(np.int64).tolist()
  39. assert len(indices) == self.num_samples
  40. return iter(indices)
  41. def __len__(self):
  42. return self.num_samples
  43. class DistributedGroupSampler(Sampler):
  44. """Sampler that restricts data loading to a subset of the dataset.
  45. It is especially useful in conjunction with
  46. :class:`torch.nn.parallel.DistributedDataParallel`. In such case, each
  47. process can pass a DistributedSampler instance as a DataLoader sampler,
  48. and load a subset of the original dataset that is exclusive to it.
  49. .. note::
  50. Dataset is assumed to be of constant size.
  51. Arguments:
  52. dataset: Dataset used for sampling.
  53. num_replicas (optional): Number of processes participating in
  54. distributed training.
  55. rank (optional): Rank of the current process within num_replicas.
  56. seed (int, optional): random seed used to shuffle the sampler if
  57. ``shuffle=True``. This number should be identical across all
  58. processes in the distributed group. Default: 0.
  59. """
  60. def __init__(self,
  61. dataset,
  62. samples_per_gpu=1,
  63. num_replicas=None,
  64. rank=None,
  65. seed=0):
  66. _rank, _num_replicas = get_dist_info()
  67. if num_replicas is None:
  68. num_replicas = _num_replicas
  69. if rank is None:
  70. rank = _rank
  71. self.dataset = dataset
  72. self.samples_per_gpu = samples_per_gpu
  73. self.num_replicas = num_replicas
  74. self.rank = rank
  75. self.epoch = 0
  76. self.seed = seed if seed is not None else 0
  77. assert hasattr(self.dataset, 'flag')
  78. self.flag = self.dataset.flag
  79. self.group_sizes = np.bincount(self.flag)
  80. self.num_samples = 0
  81. for i, j in enumerate(self.group_sizes):
  82. self.num_samples += int(
  83. math.ceil(self.group_sizes[i] * 1.0 / self.samples_per_gpu /
  84. self.num_replicas)) * self.samples_per_gpu
  85. self.total_size = self.num_samples * self.num_replicas
  86. def __iter__(self):
  87. # deterministically shuffle based on epoch
  88. g = torch.Generator()
  89. g.manual_seed(self.epoch + self.seed)
  90. indices = []
  91. for i, size in enumerate(self.group_sizes):
  92. if size > 0:
  93. indice = np.where(self.flag == i)[0]
  94. assert len(indice) == size
  95. # add .numpy() to avoid bug when selecting indice in parrots.
  96. # TODO: check whether torch.randperm() can be replaced by
  97. # numpy.random.permutation().
  98. indice = indice[list(
  99. torch.randperm(int(size), generator=g).numpy())].tolist()
  100. extra = int(
  101. math.ceil(
  102. size * 1.0 / self.samples_per_gpu / self.num_replicas)
  103. ) * self.samples_per_gpu * self.num_replicas - len(indice)
  104. # pad indice
  105. tmp = indice.copy()
  106. for _ in range(extra // size):
  107. indice.extend(tmp)
  108. indice.extend(tmp[:extra % size])
  109. indices.extend(indice)
  110. assert len(indices) == self.total_size
  111. indices = [
  112. indices[j] for i in list(
  113. torch.randperm(
  114. len(indices) // self.samples_per_gpu, generator=g))
  115. for j in range(i * self.samples_per_gpu, (i + 1) *
  116. self.samples_per_gpu)
  117. ]
  118. # subsample
  119. offset = self.num_samples * self.rank
  120. indices = indices[offset:offset + self.num_samples]
  121. assert len(indices) == self.num_samples
  122. return iter(indices)
  123. def __len__(self):
  124. return self.num_samples
  125. def set_epoch(self, epoch):
  126. self.epoch = epoch

No Description

Contributors (3)