|
- # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
- import torch
-
- __all__ = ["subsample_labels"]
-
-
- def subsample_labels(labels, num_samples, positive_fraction, bg_label):
- """
- Return `num_samples` random samples from `labels`, with a fraction of
- positives no larger than `positive_fraction`.
-
- Args:
- labels (Tensor): (N, ) label vector with values:
- * -1: ignore
- * bg_label: background ("negative") class
- * otherwise: one or more foreground ("positive") classes
- num_samples (int): The total number of labels with value >= 0 to return.
- Values that are not sampled will be filled with -1 (ignore).
- positive_fraction (float): The number of subsampled labels with values > 0
- is `min(num_positives, int(positive_fraction * num_samples))`. The number
- of negatives sampled is `min(num_negatives, num_samples - num_positives_sampled)`.
- In order words, if there are not enough positives, the sample is filled with
- negatives. If there are also not enough negatives, then as many elements are
- sampled as is possible.
- bg_label (int): label index of background ("negative") class.
-
- Returns:
- pos_idx, neg_idx (Tensor):
- 1D indices. The total number of indices is `num_samples` if possible.
- The fraction of positive indices is `positive_fraction` if possible.
- """
- positive = torch.nonzero((labels != -1) & (labels != bg_label)).squeeze(1)
- negative = torch.nonzero(labels == bg_label).squeeze(1)
-
- num_pos = int(num_samples * positive_fraction)
- # protect against not enough positive examples
- num_pos = min(positive.numel(), num_pos)
- num_neg = num_samples - num_pos
- # protect against not enough negative examples
- num_neg = min(negative.numel(), num_neg)
-
- # randomly select positive and negative examples
- perm1 = torch.randperm(positive.numel(), device=positive.device)[:num_pos]
- perm2 = torch.randperm(negative.numel(), device=negative.device)[:num_neg]
-
- pos_idx = positive[perm1]
- neg_idx = negative[perm2]
- return pos_idx, neg_idx
|