You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_processor.py 2.3 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455
  1. import random
  2. import unittest
  3. from fastNLP import Vocabulary
  4. from fastNLP.api.processor import FullSpaceToHalfSpaceProcessor, PreAppendProcessor, SliceProcessor, Num2TagProcessor, \
  5. IndexerProcessor, VocabProcessor, SeqLenProcessor
  6. from fastNLP.core.dataset import DataSet
  7. class TestProcessor(unittest.TestCase):
  8. def test_FullSpaceToHalfSpaceProcessor(self):
  9. ds = DataSet({"word": ["00, u1, u), (u2, u2"]})
  10. proc = FullSpaceToHalfSpaceProcessor("word")
  11. ds = proc(ds)
  12. self.assertEqual(ds.field_arrays["word"].content, ["00, u1, u), (u2, u2"])
  13. def test_PreAppendProcessor(self):
  14. ds = DataSet({"word": [["1234", "3456"], ["8789", "3464"]]})
  15. proc = PreAppendProcessor(data="abc", field_name="word")
  16. ds = proc(ds)
  17. self.assertEqual(ds.field_arrays["word"].content, [["abc", "1234", "3456"], ["abc", "8789", "3464"]])
  18. def test_SliceProcessor(self):
  19. ds = DataSet({"xx": [[random.randint(0, 10) for _ in range(30)]] * 40})
  20. proc = SliceProcessor(10, 20, 2, "xx", new_added_field_name="yy")
  21. ds = proc(ds)
  22. self.assertEqual(len(ds.field_arrays["yy"].content[0]), 5)
  23. def test_Num2TagProcessor(self):
  24. ds = DataSet({"num": [["99.9982", "2134.0"], ["0.002", "234"]]})
  25. proc = Num2TagProcessor("<num>", "num")
  26. ds = proc(ds)
  27. for data in ds.field_arrays["num"].content:
  28. for d in data:
  29. self.assertEqual(d, "<num>")
  30. def test_VocabProcessor_and_IndexerProcessor(self):
  31. ds = DataSet({"xx": [[str(random.randint(0, 10)) for _ in range(30)]] * 40})
  32. vocab_proc = VocabProcessor("xx")
  33. vocab_proc(ds)
  34. vocab = vocab_proc.vocab
  35. self.assertTrue(isinstance(vocab, Vocabulary))
  36. self.assertTrue(len(vocab) > 5)
  37. proc = IndexerProcessor(vocab, "xx", "yy")
  38. ds = proc(ds)
  39. for data in ds.field_arrays["yy"].content[0]:
  40. self.assertTrue(isinstance(data, int))
  41. def test_SeqLenProcessor(self):
  42. ds = DataSet({"xx": [[str(random.randint(0, 10)) for _ in range(30)]] * 10})
  43. proc = SeqLenProcessor("xx", "len")
  44. ds = proc(ds)
  45. for data in ds.field_arrays["len"].content:
  46. self.assertEqual(data, 30)