You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_sampler.py 1.1 kB

12345678910111213141516171819202122232425262728293031
  1. import unittest
  2. import random
  3. from fastNLP.core.samplers import SequentialSampler, RandomSampler, BucketSampler
  4. from fastNLP.core.dataset import DataSet
  5. from array import array
  6. import torch
  7. from fastNLP.core.samplers.sampler import ReproduceBatchSampler
  8. from fastNLP.core.drivers.torch_driver.utils import replace_batch_sampler
  9. from tests.helpers.datasets.torch_data import TorchNormalDataset
  10. class SamplerTest(unittest.TestCase):
  11. def test_sequentialsampler(self):
  12. ds = DataSet({'x': [1, 2, 3, 4] * 10})
  13. sqspl = SequentialSampler(ds)
  14. for idx, inst in enumerate(sqspl):
  15. self.assertEqual(idx, inst)
  16. def test_randomsampler(self):
  17. ds = DataSet({'x': [1, 2, 3, 4] * 10})
  18. rdspl = RandomSampler(ds)
  19. ans = [ds[i] for i in rdspl]
  20. self.assertEqual(len(ans), len(ds))
  21. def test_bucketsampler(self):
  22. data_set = DataSet({"x": [[0] * random.randint(1, 10)] * 10, "y": [[5, 6]] * 10})
  23. sampler = BucketSampler(data_set, num_buckets=3, batch_size=16, seq_len_field_name="seq_len")