You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_batch.py 1.2 kB

123456789101112131415161718192021222324252627282930313233
  1. import unittest
  2. import numpy as np
  3. from fastNLP.core.batch import Batch
  4. from fastNLP.core.dataset import DataSet
  5. from fastNLP.core.dataset import construct_dataset
  6. from fastNLP.core.sampler import SequentialSampler
  7. class TestCase1(unittest.TestCase):
  8. def test_simple(self):
  9. dataset = construct_dataset(
  10. [["FastNLP", "is", "the", "most", "beautiful", "tool", "in", "the", "world"] for _ in range(40)])
  11. dataset.set_target()
  12. batch = Batch(dataset, batch_size=4, sampler=SequentialSampler(), as_numpy=True)
  13. cnt = 0
  14. for _, _ in batch:
  15. cnt += 1
  16. self.assertEqual(cnt, 10)
  17. def test_dataset_batching(self):
  18. ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40})
  19. ds.set_input(x=True)
  20. ds.set_target(y=True)
  21. iter = Batch(ds, batch_size=4, sampler=SequentialSampler(), as_numpy=True)
  22. for x, y in iter:
  23. self.assertTrue(isinstance(x["x"], np.ndarray) and isinstance(y["y"], np.ndarray))
  24. self.assertEqual(len(x["x"]), 4)
  25. self.assertEqual(len(y["y"]), 4)
  26. self.assertListEqual(list(x["x"][-1]), [1, 2, 3, 4])
  27. self.assertListEqual(list(y["y"][-1]), [5, 6])