You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_sampler.py 684 B

3 years ago
1234567891011121314151617181920212223
  1. # -*- coding: utf-8 -*-
  2. import unittest
  3. from torch.utils.data.sampler import SequentialSampler
  4. from detectron2.data.samplers import GroupedBatchSampler
  5. class TestGroupedBatchSampler(unittest.TestCase):
  6. def test_missing_group_id(self):
  7. sampler = SequentialSampler(list(range(100)))
  8. group_ids = [1] * 100
  9. s = GroupedBatchSampler(sampler, group_ids, 2)
  10. for k in s:
  11. self.assertEqual(len(k), 2)
  12. def test_groups(self):
  13. sampler = SequentialSampler(list(range(100)))
  14. group_ids = [1, 0] * 50
  15. s = GroupedBatchSampler(sampler, group_ids, 2)
  16. for k in s:
  17. self.assertTrue((k[0] + k[1]) % 2 == 0)

No Description