import torch from fastNLP.core.sampler import convert_to_torch_tensor, SequentialSampler, RandomSampler def test_convert_to_torch_tensor(): data = [[1, 2, 3, 4, 5], [5, 4, 3, 2, 1], [1, 3, 4, 5, 2]] ans = convert_to_torch_tensor(data, False) assert isinstance(ans, torch.Tensor) assert tuple(ans.shape) == (3, 5) def test_sequential_sampler(): sampler = SequentialSampler() data = [1, 3, 5, 7, 9, 2, 4, 6, 8, 10] for idx, i in enumerate(sampler(data)): assert idx == i def test_random_sampler(): sampler = RandomSampler() data = [1, 3, 5, 7, 9, 2, 4, 6, 8, 10] ans = [data[i] for i in sampler(data)] assert len(ans) == len(data) for d in ans: assert d in data if __name__ == "__main__": test_sequential_sampler()