# -*- coding: utf-8 -*- import unittest from torch.utils.data.sampler import SequentialSampler from detectron2.data.samplers import GroupedBatchSampler class TestGroupedBatchSampler(unittest.TestCase): def test_missing_group_id(self): sampler = SequentialSampler(list(range(100))) group_ids = [1] * 100 s = GroupedBatchSampler(sampler, group_ids, 2) for k in s: self.assertEqual(len(k), 2) def test_groups(self): sampler = SequentialSampler(list(range(100))) group_ids = [1, 0] * 50 s = GroupedBatchSampler(sampler, group_ids, 2) for k in s: self.assertTrue((k[0] + k[1]) % 2 == 0)