You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_dataset.py 2.6 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354
  1. import unittest
  2. from fastNLP.io.dataset_loader import convert_seq2seq_dataset, convert_seq_dataset
  3. class TestDataSet(unittest.TestCase):
  4. labeled_data_list = [
  5. [["a", "b", "e", "d"], ["1", "2", "3", "4"]],
  6. [["a", "b", "e", "d"], ["1", "2", "3", "4"]],
  7. [["a", "b", "e", "d"], ["1", "2", "3", "4"]],
  8. ]
  9. unlabeled_data_list = [
  10. ["a", "b", "e", "d"],
  11. ["a", "b", "e", "d"],
  12. ["a", "b", "e", "d"]
  13. ]
  14. word_vocab = {"a": 0, "b": 1, "e": 2, "d": 3}
  15. label_vocab = {"1": 1, "2": 2, "3": 3, "4": 4}
  16. def test_case_1(self):
  17. data_set = convert_seq2seq_dataset(self.labeled_data_list)
  18. data_set.index_field("word_seq", self.word_vocab)
  19. data_set.index_field("label_seq", self.label_vocab)
  20. self.assertEqual(len(data_set), len(self.labeled_data_list))
  21. self.assertTrue(len(data_set) > 0)
  22. self.assertTrue(hasattr(data_set[0], "fields"))
  23. self.assertTrue("word_seq" in data_set[0].fields)
  24. self.assertTrue(hasattr(data_set[0].fields["word_seq"], "text"))
  25. self.assertTrue(hasattr(data_set[0].fields["word_seq"], "_index"))
  26. self.assertEqual(data_set[0].fields["word_seq"].text, self.labeled_data_list[0][0])
  27. self.assertEqual(data_set[0].fields["word_seq"]._index,
  28. [self.word_vocab[c] for c in self.labeled_data_list[0][0]])
  29. self.assertTrue("label_seq" in data_set[0].fields)
  30. self.assertTrue(hasattr(data_set[0].fields["label_seq"], "text"))
  31. self.assertTrue(hasattr(data_set[0].fields["label_seq"], "_index"))
  32. self.assertEqual(data_set[0].fields["label_seq"].text, self.labeled_data_list[0][1])
  33. self.assertEqual(data_set[0].fields["label_seq"]._index,
  34. [self.label_vocab[c] for c in self.labeled_data_list[0][1]])
  35. def test_case_2(self):
  36. data_set = convert_seq_dataset(self.unlabeled_data_list)
  37. data_set.index_field("word_seq", self.word_vocab)
  38. self.assertEqual(len(data_set), len(self.unlabeled_data_list))
  39. self.assertTrue(len(data_set) > 0)
  40. self.assertTrue(hasattr(data_set[0], "fields"))
  41. self.assertTrue("word_seq" in data_set[0].fields)
  42. self.assertTrue(hasattr(data_set[0].fields["word_seq"], "text"))
  43. self.assertTrue(hasattr(data_set[0].fields["word_seq"], "_index"))
  44. self.assertEqual(data_set[0].fields["word_seq"].text, self.unlabeled_data_list[0])
  45. self.assertEqual(data_set[0].fields["word_seq"]._index,
  46. [self.word_vocab[c] for c in self.unlabeled_data_list[0]])