import random import unittest import torch from fastNLP import DataSet from fastNLP import SequentialSampler, RandomSampler, BucketSampler from fastNLP.core.sampler import k_means_1d, k_means_bucketing, simple_sort_bucketing class TestSampler(unittest.TestCase): def test_sequential_sampler(self): sampler = SequentialSampler() data = [1, 3, 5, 7, 9, 2, 4, 6, 8, 10] for idx, i in enumerate(sampler(data)): assert idx == i def test_random_sampler(self): sampler = RandomSampler() data = [1, 3, 5, 7, 9, 2, 4, 6, 8, 10] ans = [data[i] for i in sampler(data)] assert len(ans) == len(data) for d in ans: assert d in data def test_k_means(self): centroids, assign = k_means_1d([21, 3, 25, 7, 9, 22, 4, 6, 28, 10], 2, max_iter=5) centroids, assign = list(centroids), list(assign) assert len(centroids) == 2 assert len(assign) == 10 def test_k_means_bucketing(self): res = k_means_bucketing([21, 3, 25, 7, 9, 22, 4, 6, 28, 10], [None, None]) assert len(res) == 2 def test_simple_sort_bucketing(self): _ = simple_sort_bucketing([21, 3, 25, 7, 9, 22, 4, 6, 28, 10]) assert len(_) == 10 def test_BucketSampler(self): sampler = BucketSampler(num_buckets=3, batch_size=16, seq_len_field_name="seq_len") data_set = DataSet({"x": [[0] * random.randint(1, 10)] * 10, "y": [[5, 6]] * 10}) data_set.apply(lambda ins: len(ins["x"]), new_field_name="seq_len") indices = sampler(data_set) self.assertEqual(len(indices), 10) # 跑通即可,不验证效果