You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

distributed_sampler.py 1.4 kB

2 years ago
12345678910111213141516171819202122232425262728293031323334353637383940
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import math
  3. import torch
  4. from torch.utils.data import DistributedSampler as _DistributedSampler
  5. class DistributedSampler(_DistributedSampler):
  6. def __init__(self,
  7. dataset,
  8. num_replicas=None,
  9. rank=None,
  10. shuffle=True,
  11. seed=0):
  12. super().__init__(
  13. dataset, num_replicas=num_replicas, rank=rank, shuffle=shuffle)
  14. # for the compatibility from PyTorch 1.3+
  15. self.seed = seed if seed is not None else 0
  16. def __iter__(self):
  17. # deterministically shuffle based on epoch
  18. if self.shuffle:
  19. g = torch.Generator()
  20. g.manual_seed(self.epoch + self.seed)
  21. indices = torch.randperm(len(self.dataset), generator=g).tolist()
  22. else:
  23. indices = torch.arange(len(self.dataset)).tolist()
  24. # add extra samples to make it evenly divisible
  25. # in case that indices is shorter than half of total_size
  26. indices = (indices *
  27. math.ceil(self.total_size / len(indices)))[:self.total_size]
  28. assert len(indices) == self.total_size
  29. # subsample
  30. indices = indices[self.rank:self.total_size:self.num_replicas]
  31. assert len(indices) == self.num_samples
  32. return iter(indices)

No Description

Contributors (3)