You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_preprocess.py 2.8 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172
  1. import os
  2. import unittest
  3. from fastNLP.core.dataset import DataSet
  4. from fastNLP.core.preprocess import SeqLabelPreprocess
  5. data = [
  6. [['Tom', 'and', 'Jerry', '.'], ['n', '&', 'n', '.']],
  7. [['Hello', 'world', '!'], ['a', 'n', '.']],
  8. [['Tom', 'and', 'Jerry', '.'], ['n', '&', 'n', '.']],
  9. [['Hello', 'world', '!'], ['a', 'n', '.']],
  10. [['Tom', 'and', 'Jerry', '.'], ['n', '&', 'n', '.']],
  11. [['Hello', 'world', '!'], ['a', 'n', '.']],
  12. [['Tom', 'and', 'Jerry', '.'], ['n', '&', 'n', '.']],
  13. [['Hello', 'world', '!'], ['a', 'n', '.']],
  14. [['Tom', 'and', 'Jerry', '.'], ['n', '&', 'n', '.']],
  15. [['Hello', 'world', '!'], ['a', 'n', '.']],
  16. ]
  17. class TestCase1(unittest.TestCase):
  18. def test(self):
  19. if os.path.exists("./save"):
  20. for root, dirs, files in os.walk("./save", topdown=False):
  21. for name in files:
  22. os.remove(os.path.join(root, name))
  23. for name in dirs:
  24. os.rmdir(os.path.join(root, name))
  25. result = SeqLabelPreprocess().run(train_dev_data=data, train_dev_split=0.4,
  26. pickle_path="./save")
  27. self.assertEqual(len(result), 2)
  28. self.assertEqual(type(result[0]), DataSet)
  29. self.assertEqual(type(result[1]), DataSet)
  30. os.system("rm -rf save")
  31. print("pickle path deleted")
  32. class TestCase2(unittest.TestCase):
  33. def test(self):
  34. if os.path.exists("./save"):
  35. for root, dirs, files in os.walk("./save", topdown=False):
  36. for name in files:
  37. os.remove(os.path.join(root, name))
  38. for name in dirs:
  39. os.rmdir(os.path.join(root, name))
  40. result = SeqLabelPreprocess().run(test_data=data, train_dev_data=data,
  41. pickle_path="./save", train_dev_split=0.4,
  42. cross_val=False)
  43. self.assertEqual(len(result), 3)
  44. self.assertEqual(type(result[0]), DataSet)
  45. self.assertEqual(type(result[1]), DataSet)
  46. self.assertEqual(type(result[2]), DataSet)
  47. os.system("rm -rf save")
  48. print("pickle path deleted")
  49. class TestCase3(unittest.TestCase):
  50. def test(self):
  51. num_folds = 2
  52. result = SeqLabelPreprocess().run(test_data=None, train_dev_data=data,
  53. pickle_path="./save", train_dev_split=0.4,
  54. cross_val=True, n_fold=num_folds)
  55. self.assertEqual(len(result), 2)
  56. self.assertEqual(len(result[0]), num_folds)
  57. self.assertEqual(len(result[1]), num_folds)
  58. for data_set in result[0] + result[1]:
  59. self.assertEqual(type(data_set), DataSet)
  60. os.system("rm -rf save")
  61. print("pickle path deleted")