You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_field.py 1.8 kB

7 years ago
7 years ago
7 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142
  1. import unittest
  2. from fastNLP.core.field import CharTextField, LabelField, SeqLabelField
  3. class TestField(unittest.TestCase):
  4. def test_char_field(self):
  5. text = "PhD applicants must submit a Research Plan and a resume " \
  6. "specify your class ranking written in English and a list of research" \
  7. " publications if any".split()
  8. max_word_len = max([len(w) for w in text])
  9. field = CharTextField(text, max_word_len, is_target=False)
  10. all_char = set()
  11. for word in text:
  12. all_char.update([ch for ch in word])
  13. char_vocab = {ch: idx + 1 for idx, ch in enumerate(all_char)}
  14. self.assertEqual(field.index(char_vocab),
  15. [[char_vocab[ch] for ch in word] + [0] * (max_word_len - len(word)) for word in text])
  16. self.assertEqual(field.get_length(), len(text))
  17. self.assertEqual(field.contents(), text)
  18. tensor = field.to_tensor(50)
  19. self.assertEqual(tuple(tensor.shape), (50, max_word_len))
  20. def test_label_field(self):
  21. label = LabelField("A", is_target=True)
  22. self.assertEqual(label.get_length(), 1)
  23. self.assertEqual(label.index({"A": 10}), 10)
  24. label = LabelField(30, is_target=True)
  25. self.assertEqual(label.get_length(), 1)
  26. tensor = label.to_tensor(0)
  27. self.assertEqual(tensor.shape, ())
  28. self.assertEqual(int(tensor), 30)
  29. def test_seq_label_field(self):
  30. seq = ["a", "b", "c", "d", "a", "c", "a", "b"]
  31. field = SeqLabelField(seq)
  32. vocab = {"a": 10, "b": 20, "c": 30, "d": 40}
  33. self.assertEqual(field.index(vocab), [vocab[x] for x in seq])
  34. tensor = field.to_tensor(10)
  35. self.assertEqual(tuple(tensor.shape), (10,))