You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_utils.py 4.9 kB

6 years ago
6 years ago
6 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115
  1. import unittest
  2. import _pickle
  3. from fastNLP import cache_results
  4. from fastNLP.io.embed_loader import EmbedLoader
  5. from fastNLP import DataSet
  6. from fastNLP import Instance
  7. import time
  8. import os
  9. @cache_results('test/demo1.pkl')
  10. def process_data_1(embed_file, cws_train):
  11. embed, vocab = EmbedLoader.load_without_vocab(embed_file)
  12. time.sleep(1) # 测试是否通过读取cache获得结果
  13. with open(cws_train, 'r', encoding='utf-8') as f:
  14. d = DataSet()
  15. for line in f:
  16. line = line.strip()
  17. if len(line)>0:
  18. d.append(Instance(raw=line))
  19. return embed, vocab, d
  20. class TestCache(unittest.TestCase):
  21. def test_cache_save(self):
  22. try:
  23. start_time = time.time()
  24. embed, vocab, d = process_data_1('test/data_for_tests/word2vec_test.txt', 'test/data_for_tests/cws_train')
  25. end_time = time.time()
  26. pre_time = end_time - start_time
  27. with open('test/demo1.pkl', 'rb') as f:
  28. _embed, _vocab, _d = _pickle.load(f)
  29. self.assertEqual(embed.shape, _embed.shape)
  30. for i in range(embed.shape[0]):
  31. self.assertListEqual(embed[i].tolist(), _embed[i].tolist())
  32. start_time = time.time()
  33. embed, vocab, d = process_data_1('test/data_for_tests/word2vec_test.txt', 'test/data_for_tests/cws_train')
  34. end_time = time.time()
  35. read_time = end_time - start_time
  36. print("Read using {:.3f}, while prepare using:{:.3f}".format(read_time, pre_time))
  37. self.assertGreater(pre_time-0.5, read_time)
  38. finally:
  39. os.remove('test/demo1.pkl')
  40. def test_cache_save_overwrite_path(self):
  41. try:
  42. start_time = time.time()
  43. embed, vocab, d = process_data_1('test/data_for_tests/word2vec_test.txt', 'test/data_for_tests/cws_train',
  44. cache_filepath='test/demo_overwrite.pkl')
  45. end_time = time.time()
  46. pre_time = end_time - start_time
  47. with open('test/demo_overwrite.pkl', 'rb') as f:
  48. _embed, _vocab, _d = _pickle.load(f)
  49. self.assertEqual(embed.shape, _embed.shape)
  50. for i in range(embed.shape[0]):
  51. self.assertListEqual(embed[i].tolist(), _embed[i].tolist())
  52. start_time = time.time()
  53. embed, vocab, d = process_data_1('test/data_for_tests/word2vec_test.txt', 'test/data_for_tests/cws_train',
  54. cache_filepath='test/demo_overwrite.pkl')
  55. end_time = time.time()
  56. read_time = end_time - start_time
  57. print("Read using {:.3f}, while prepare using:{:.3f}".format(read_time, pre_time))
  58. self.assertGreater(pre_time-0.5, read_time)
  59. finally:
  60. os.remove('test/demo_overwrite.pkl')
  61. def test_cache_refresh(self):
  62. try:
  63. start_time = time.time()
  64. embed, vocab, d = process_data_1('test/data_for_tests/word2vec_test.txt', 'test/data_for_tests/cws_train',
  65. refresh=True)
  66. end_time = time.time()
  67. pre_time = end_time - start_time
  68. with open('test/demo1.pkl', 'rb') as f:
  69. _embed, _vocab, _d = _pickle.load(f)
  70. self.assertEqual(embed.shape, _embed.shape)
  71. for i in range(embed.shape[0]):
  72. self.assertListEqual(embed[i].tolist(), _embed[i].tolist())
  73. start_time = time.time()
  74. embed, vocab, d = process_data_1('test/data_for_tests/word2vec_test.txt', 'test/data_for_tests/cws_train',
  75. refresh=True)
  76. end_time = time.time()
  77. read_time = end_time - start_time
  78. print("Read using {:.3f}, while prepare using:{:.3f}".format(read_time, pre_time))
  79. self.assertGreater(0.1, pre_time-read_time)
  80. finally:
  81. os.remove('test/demo1.pkl')
  82. def test_duplicate_keyword(self):
  83. with self.assertRaises(RuntimeError):
  84. @cache_results(None)
  85. def func_verbose(a, _verbose):
  86. pass
  87. func_verbose(0, 1)
  88. with self.assertRaises(RuntimeError):
  89. @cache_results(None)
  90. def func_cache(a, _cache_fp):
  91. pass
  92. func_cache(1, 2)
  93. with self.assertRaises(RuntimeError):
  94. @cache_results(None)
  95. def func_refresh(a, _refresh):
  96. pass
  97. func_refresh(1, 2)
  98. def test_create_cache_dir(self):
  99. @cache_results('test/demo1/demo.pkl')
  100. def cache():
  101. return 1, 2
  102. try:
  103. results = cache()
  104. print(results)
  105. finally:
  106. os.remove('test/demo1/demo.pkl')
  107. os.rmdir('test/demo1')