You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_field.py 2.3 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869
  1. import os
  2. import sys
  3. sys.path.append(os.path.join(os.path.dirname(__file__), '../..'))
  4. import unittest
  5. import torch
  6. from fastNLP.data.field import TextField, LabelField
  7. from fastNLP.data.instance import Instance
  8. from fastNLP.data.dataset import DataSet
  9. from fastNLP.data.batch import Batch
  10. class TestField(unittest.TestCase):
  11. def check_batched_data_equal(self, data1, data2):
  12. self.assertEqual(len(data1), len(data2))
  13. for i in range(len(data1)):
  14. self.assertTrue(data1[i].keys(), data2[i].keys())
  15. for i in range(len(data1)):
  16. for t1, t2 in zip(data1[i].values(), data2[i].values()):
  17. self.assertTrue(torch.equal(t1, t2))
  18. def test_batchiter(self):
  19. texts = [
  20. "i am a cat",
  21. "this is a test of new batch",
  22. "haha"
  23. ]
  24. labels = [0, 1, 0]
  25. # prepare vocabulary
  26. vocab = {}
  27. for text in texts:
  28. for tokens in text.split():
  29. if tokens not in vocab:
  30. vocab[tokens] = len(vocab)
  31. # prepare input dataset
  32. data = DataSet()
  33. for text, label in zip(texts, labels):
  34. x = TextField(text.split(), False)
  35. y = LabelField(label, is_target=True)
  36. ins = Instance(text=x, label=y)
  37. data.append(ins)
  38. # use vocabulary to index data
  39. data.index_field("text", vocab)
  40. # define naive sampler for batch class
  41. class SeqSampler:
  42. def __call__(self, dataset):
  43. return list(range(len(dataset)))
  44. # use bacth to iterate dataset
  45. batcher = Batch(data, SeqSampler(), 2)
  46. TRUE_X = [{'text': torch.tensor([[0, 1, 2, 3, 0, 0, 0], [4, 5, 2, 6, 7, 8, 9]])}, {'text': torch.tensor([[10]])}]
  47. TRUE_Y = [{'label': torch.tensor([[0], [1]])}, {'label': torch.tensor([[0]])}]
  48. for epoch in range(3):
  49. test_x, test_y = [], []
  50. for batch_x, batch_y in batcher:
  51. test_x.append(batch_x)
  52. test_y.append(batch_y)
  53. self.check_batched_data_equal(TRUE_X, test_x)
  54. self.check_batched_data_equal(TRUE_Y, test_y)
  55. if __name__ == "__main__":
  56. unittest.main()