You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_batch.py 2.0 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162
  1. import unittest
  2. import torch
  3. from fastNLP.core.batch import Batch
  4. from fastNLP.core.dataset import DataSet, create_dataset_from_lists
  5. from fastNLP.core.field import TextField, LabelField
  6. from fastNLP.core.instance import Instance
  7. raw_texts = ["i am a cat",
  8. "this is a test of new batch",
  9. "ha ha",
  10. "I am a good boy .",
  11. "This is the most beautiful girl ."
  12. ]
  13. texts = [text.strip().split() for text in raw_texts]
  14. labels = [0, 1, 0, 0, 1]
  15. # prepare vocabulary
  16. vocab = {}
  17. for text in texts:
  18. for tokens in text:
  19. if tokens not in vocab:
  20. vocab[tokens] = len(vocab)
  21. class TestCase1(unittest.TestCase):
  22. def test(self):
  23. data = DataSet()
  24. for text, label in zip(texts, labels):
  25. x = TextField(text, is_target=False)
  26. y = LabelField(label, is_target=True)
  27. ins = Instance(text=x, label=y)
  28. data.append(ins)
  29. # use vocabulary to index data
  30. data.index_field("text", vocab)
  31. # define naive sampler for batch class
  32. class SeqSampler:
  33. def __call__(self, dataset):
  34. return list(range(len(dataset)))
  35. # use batch to iterate dataset
  36. data_iterator = Batch(data, 2, SeqSampler(), False)
  37. for batch_x, batch_y in data_iterator:
  38. self.assertEqual(len(batch_x), 2)
  39. self.assertTrue(isinstance(batch_x, dict))
  40. self.assertTrue(isinstance(batch_x["text"], torch.LongTensor))
  41. self.assertTrue(isinstance(batch_y, dict))
  42. self.assertTrue(isinstance(batch_y["label"], torch.LongTensor))
  43. class TestCase2(unittest.TestCase):
  44. def test(self):
  45. data = DataSet()
  46. for text in texts:
  47. x = TextField(text, is_target=False)
  48. ins = Instance(text=x)
  49. data.append(ins)
  50. data_set = create_dataset_from_lists(texts, vocab, has_target=False)
  51. self.assertTrue(type(data) == type(data_set))