import unittest import _pickle from fastNLP import cache_results from fastNLP.io.embed_loader import EmbedLoader from fastNLP import DataSet from fastNLP import Instance import time import os @cache_results('test/demo1.pkl') def process_data_1(embed_file, cws_train): embed, vocab = EmbedLoader.load_without_vocab(embed_file) time.sleep(1) # 测试是否通过读取cache获得结果 with open(cws_train, 'r', encoding='utf-8') as f: d = DataSet() for line in f: line = line.strip() if len(line)>0: d.append(Instance(raw=line)) return embed, vocab, d class TestCache(unittest.TestCase): def test_cache_save(self): try: start_time = time.time() embed, vocab, d = process_data_1('test/data_for_tests/word2vec_test.txt', 'test/data_for_tests/cws_train') end_time = time.time() pre_time = end_time - start_time with open('test/demo1.pkl', 'rb') as f: _embed, _vocab, _d = _pickle.load(f) self.assertEqual(embed.shape, _embed.shape) for i in range(embed.shape[0]): self.assertListEqual(embed[i].tolist(), _embed[i].tolist()) start_time = time.time() embed, vocab, d = process_data_1('test/data_for_tests/word2vec_test.txt', 'test/data_for_tests/cws_train') end_time = time.time() read_time = end_time - start_time print("Read using {:.3f}, while prepare using:{:.3f}".format(read_time, pre_time)) self.assertGreater(pre_time-0.5, read_time) finally: os.remove('test/demo1.pkl') def test_cache_save_overwrite_path(self): try: start_time = time.time() embed, vocab, d = process_data_1('test/data_for_tests/word2vec_test.txt', 'test/data_for_tests/cws_train', cache_filepath='test/demo_overwrite.pkl') end_time = time.time() pre_time = end_time - start_time with open('test/demo_overwrite.pkl', 'rb') as f: _embed, _vocab, _d = _pickle.load(f) self.assertEqual(embed.shape, _embed.shape) for i in range(embed.shape[0]): self.assertListEqual(embed[i].tolist(), _embed[i].tolist()) start_time = time.time() embed, vocab, d = process_data_1('test/data_for_tests/word2vec_test.txt', 'test/data_for_tests/cws_train', cache_filepath='test/demo_overwrite.pkl') end_time = time.time() read_time = end_time - start_time print("Read using {:.3f}, while prepare using:{:.3f}".format(read_time, pre_time)) self.assertGreater(pre_time-0.5, read_time) finally: os.remove('test/demo_overwrite.pkl') def test_cache_refresh(self): try: start_time = time.time() embed, vocab, d = process_data_1('test/data_for_tests/word2vec_test.txt', 'test/data_for_tests/cws_train', refresh=True) end_time = time.time() pre_time = end_time - start_time with open('test/demo1.pkl', 'rb') as f: _embed, _vocab, _d = _pickle.load(f) self.assertEqual(embed.shape, _embed.shape) for i in range(embed.shape[0]): self.assertListEqual(embed[i].tolist(), _embed[i].tolist()) start_time = time.time() embed, vocab, d = process_data_1('test/data_for_tests/word2vec_test.txt', 'test/data_for_tests/cws_train', refresh=True) end_time = time.time() read_time = end_time - start_time print("Read using {:.3f}, while prepare using:{:.3f}".format(read_time, pre_time)) self.assertGreater(0.1, pre_time-read_time) finally: os.remove('test/demo1.pkl') def test_duplicate_keyword(self): with self.assertRaises(RuntimeError): @cache_results(None) def func_verbose(a, _verbose): pass func_verbose(0, 1) with self.assertRaises(RuntimeError): @cache_results(None) def func_cache(a, _cache_fp): pass func_cache(1, 2) with self.assertRaises(RuntimeError): @cache_results(None) def func_refresh(a, _refresh): pass func_refresh(1, 2) def test_create_cache_dir(self): @cache_results('test/demo1/demo.pkl') def cache(): return 1, 2 try: results = cache() print(results) finally: os.remove('test/demo1/demo.pkl') os.rmdir('test/demo1')