import os import unittest from fastNLP.core.dataset import DataSet from fastNLP.core.preprocess import SeqLabelPreprocess data = [ [['Tom', 'and', 'Jerry', '.'], ['n', '&', 'n', '.']], [['Hello', 'world', '!'], ['a', 'n', '.']], [['Tom', 'and', 'Jerry', '.'], ['n', '&', 'n', '.']], [['Hello', 'world', '!'], ['a', 'n', '.']], [['Tom', 'and', 'Jerry', '.'], ['n', '&', 'n', '.']], [['Hello', 'world', '!'], ['a', 'n', '.']], [['Tom', 'and', 'Jerry', '.'], ['n', '&', 'n', '.']], [['Hello', 'world', '!'], ['a', 'n', '.']], [['Tom', 'and', 'Jerry', '.'], ['n', '&', 'n', '.']], [['Hello', 'world', '!'], ['a', 'n', '.']], ] class TestCase1(unittest.TestCase): def test(self): if os.path.exists("./save"): for root, dirs, files in os.walk("./save", topdown=False): for name in files: os.remove(os.path.join(root, name)) for name in dirs: os.rmdir(os.path.join(root, name)) result = SeqLabelPreprocess().run(train_dev_data=data, train_dev_split=0.4, pickle_path="./save") self.assertEqual(len(result), 2) self.assertEqual(type(result[0]), DataSet) self.assertEqual(type(result[1]), DataSet) os.system("rm -rf save") print("pickle path deleted") class TestCase2(unittest.TestCase): def test(self): if os.path.exists("./save"): for root, dirs, files in os.walk("./save", topdown=False): for name in files: os.remove(os.path.join(root, name)) for name in dirs: os.rmdir(os.path.join(root, name)) result = SeqLabelPreprocess().run(test_data=data, train_dev_data=data, pickle_path="./save", train_dev_split=0.4, cross_val=False) self.assertEqual(len(result), 3) self.assertEqual(type(result[0]), DataSet) self.assertEqual(type(result[1]), DataSet) self.assertEqual(type(result[2]), DataSet) os.system("rm -rf save") print("pickle path deleted") class TestCase3(unittest.TestCase): def test(self): num_folds = 2 result = SeqLabelPreprocess().run(test_data=None, train_dev_data=data, pickle_path="./save", train_dev_split=0.4, cross_val=True, n_fold=num_folds) self.assertEqual(len(result), 2) self.assertEqual(len(result[0]), num_folds) self.assertEqual(len(result[1]), num_folds) for data_set in result[0] + result[1]: self.assertEqual(type(data_set), DataSet) os.system("rm -rf save") print("pickle path deleted")