You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

sampling.py 2.2 kB

3 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748
  1. # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
  2. import torch
  3. __all__ = ["subsample_labels"]
  4. def subsample_labels(labels, num_samples, positive_fraction, bg_label):
  5. """
  6. Return `num_samples` random samples from `labels`, with a fraction of
  7. positives no larger than `positive_fraction`.
  8. Args:
  9. labels (Tensor): (N, ) label vector with values:
  10. * -1: ignore
  11. * bg_label: background ("negative") class
  12. * otherwise: one or more foreground ("positive") classes
  13. num_samples (int): The total number of labels with value >= 0 to return.
  14. Values that are not sampled will be filled with -1 (ignore).
  15. positive_fraction (float): The number of subsampled labels with values > 0
  16. is `min(num_positives, int(positive_fraction * num_samples))`. The number
  17. of negatives sampled is `min(num_negatives, num_samples - num_positives_sampled)`.
  18. In order words, if there are not enough positives, the sample is filled with
  19. negatives. If there are also not enough negatives, then as many elements are
  20. sampled as is possible.
  21. bg_label (int): label index of background ("negative") class.
  22. Returns:
  23. pos_idx, neg_idx (Tensor):
  24. 1D indices. The total number of indices is `num_samples` if possible.
  25. The fraction of positive indices is `positive_fraction` if possible.
  26. """
  27. positive = torch.nonzero((labels != -1) & (labels != bg_label)).squeeze(1)
  28. negative = torch.nonzero(labels == bg_label).squeeze(1)
  29. num_pos = int(num_samples * positive_fraction)
  30. # protect against not enough positive examples
  31. num_pos = min(positive.numel(), num_pos)
  32. num_neg = num_samples - num_pos
  33. # protect against not enough negative examples
  34. num_neg = min(negative.numel(), num_neg)
  35. # randomly select positive and negative examples
  36. perm1 = torch.randperm(positive.numel(), device=positive.device)[:num_pos]
  37. perm2 = torch.randperm(negative.numel(), device=negative.device)[:num_neg]
  38. pos_idx = positive[perm1]
  39. neg_idx = negative[perm2]
  40. return pos_idx, neg_idx

No Description