|
- # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
- import numpy as np
- from torch.utils.data.sampler import BatchSampler, Sampler
-
-
- class GroupedBatchSampler(BatchSampler):
- """
- Wraps another sampler to yield a mini-batch of indices.
- It enforces that the batch only contain elements from the same group.
- It also tries to provide mini-batches which follows an ordering which is
- as close as possible to the ordering from the original sampler.
- """
-
- def __init__(self, sampler, group_ids, batch_size):
- """
- Args:
- sampler (Sampler): Base sampler.
- group_ids (list[int]): If the sampler produces indices in range [0, N),
- `group_ids` must be a list of `N` ints which contains the group id of each sample.
- The group ids must be a set of integers in the range [0, num_groups).
- batch_size (int): Size of mini-batch.
- """
- if not isinstance(sampler, Sampler):
- raise ValueError(
- "sampler should be an instance of "
- "torch.utils.data.Sampler, but got sampler={}".format(sampler)
- )
- self.sampler = sampler
- self.group_ids = np.asarray(group_ids)
- assert self.group_ids.ndim == 1
- self.batch_size = batch_size
- groups = np.unique(self.group_ids).tolist()
-
- # buffer the indices of each group until batch size is reached
- self.buffer_per_group = {k: [] for k in groups}
-
- def __iter__(self):
- for idx in self.sampler:
- group_id = self.group_ids[idx]
- group_buffer = self.buffer_per_group[group_id]
- group_buffer.append(idx)
- if len(group_buffer) == self.batch_size:
- yield group_buffer[:] # yield a copy of the list
- del group_buffer[:]
-
- def __len__(self):
- raise NotImplementedError("len() of GroupedBatchSampler is not well-defined.")
|