You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_utils.py 10 kB

6 years ago
6 years ago
6 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252
  1. import unittest
  2. import _pickle
  3. from fastNLP import cache_results
  4. from fastNLP.io import EmbedLoader
  5. from fastNLP import DataSet
  6. from fastNLP import Instance
  7. import time
  8. import os
  9. import torch
  10. from torch import nn
  11. from fastNLP.core.utils import _move_model_to_device, _get_model_device
  12. import numpy as np
  13. from fastNLP.core.utils import seq_len_to_mask
  14. class Model(nn.Module):
  15. def __init__(self):
  16. super().__init__()
  17. self.param = nn.Parameter(torch.zeros(0))
  18. class TestMoveModelDeivce(unittest.TestCase):
  19. def test_case1(self):
  20. # 测试str
  21. model = Model()
  22. model = _move_model_to_device(model, 'cpu')
  23. assert model.param.device == torch.device('cpu')
  24. # 测试不存在的device报错
  25. with self.assertRaises(Exception):
  26. _move_model_to_device(model, 'cpuu')
  27. # 测试gpu
  28. if torch.cuda.is_available():
  29. model = _move_model_to_device(model, 'cuda')
  30. assert model.param.is_cuda
  31. model = _move_model_to_device(model, 'cuda:0')
  32. assert model.param.device == torch.device('cuda:0')
  33. with self.assertRaises(Exception):
  34. _move_model_to_device(model, 'cuda:1000')
  35. # 测试None
  36. model = _move_model_to_device(model, None)
  37. def test_case2(self):
  38. # 测试使用int初始化
  39. model = Model()
  40. if torch.cuda.is_available():
  41. model = _move_model_to_device(model, 0)
  42. assert model.param.device == torch.device('cuda:0')
  43. assert model.param.device == torch.device('cuda:0'), "The model should be in "
  44. with self.assertRaises(Exception):
  45. _move_model_to_device(model, 100)
  46. with self.assertRaises(Exception):
  47. _move_model_to_device(model, -1)
  48. def test_case3(self):
  49. # 测试None
  50. model = Model()
  51. device = _get_model_device(model)
  52. model = _move_model_to_device(model, None)
  53. assert device == _get_model_device(model), "The device should not change."
  54. if torch.cuda.is_available():
  55. model.cuda()
  56. device = _get_model_device(model)
  57. model = _move_model_to_device(model, None)
  58. assert device == _get_model_device(model), "The device should not change."
  59. model = nn.DataParallel(model, device_ids=[0])
  60. _move_model_to_device(model, None)
  61. with self.assertRaises(Exception):
  62. _move_model_to_device(model, 'cpu')
  63. def test_case4(self):
  64. # 测试传入list的内容
  65. model = Model()
  66. device = ['cpu']
  67. with self.assertRaises(Exception):
  68. _move_model_to_device(model, device)
  69. if torch.cuda.is_available():
  70. device = [0]
  71. _model = _move_model_to_device(model, device)
  72. assert isinstance(_model, nn.DataParallel)
  73. device = [torch.device('cuda:0'), torch.device('cuda:0')]
  74. with self.assertRaises(Exception):
  75. _model = _move_model_to_device(model, device)
  76. if torch.cuda.device_count() > 1:
  77. device = [0, 1]
  78. _model = _move_model_to_device(model, device)
  79. assert isinstance(_model, nn.DataParallel)
  80. device = ['cuda', 'cuda:1']
  81. with self.assertRaises(Exception):
  82. _move_model_to_device(model, device)
  83. def test_case5(self):
  84. if not torch.cuda.is_available():
  85. return
  86. # torch.device()
  87. device = torch.device('cpu')
  88. model = Model()
  89. _move_model_to_device(model, device)
  90. device = torch.device('cuda')
  91. model = _move_model_to_device(model, device)
  92. assert model.param.device == torch.device('cuda:0')
  93. with self.assertRaises(Exception):
  94. _move_model_to_device(model, torch.device('cuda:100'))
  95. @cache_results('test/demo1.pkl')
  96. def process_data_1(embed_file, cws_train):
  97. embed, vocab = EmbedLoader.load_without_vocab(embed_file)
  98. time.sleep(1) # 测试是否通过读取cache获得结果
  99. with open(cws_train, 'r', encoding='utf-8') as f:
  100. d = DataSet()
  101. for line in f:
  102. line = line.strip()
  103. if len(line) > 0:
  104. d.append(Instance(raw=line))
  105. return embed, vocab, d
  106. class TestCache(unittest.TestCase):
  107. def test_cache_save(self):
  108. try:
  109. start_time = time.time()
  110. embed, vocab, d = process_data_1('test/data_for_tests/word2vec_test.txt', 'test/data_for_tests/cws_train')
  111. end_time = time.time()
  112. pre_time = end_time - start_time
  113. with open('test/demo1.pkl', 'rb') as f:
  114. _embed, _vocab, _d = _pickle.load(f)
  115. self.assertEqual(embed.shape, _embed.shape)
  116. for i in range(embed.shape[0]):
  117. self.assertListEqual(embed[i].tolist(), _embed[i].tolist())
  118. start_time = time.time()
  119. embed, vocab, d = process_data_1('test/data_for_tests/word2vec_test.txt', 'test/data_for_tests/cws_train')
  120. end_time = time.time()
  121. read_time = end_time - start_time
  122. print("Read using {:.3f}, while prepare using:{:.3f}".format(read_time, pre_time))
  123. self.assertGreater(pre_time - 0.5, read_time)
  124. finally:
  125. os.remove('test/demo1.pkl')
  126. def test_cache_save_overwrite_path(self):
  127. try:
  128. start_time = time.time()
  129. embed, vocab, d = process_data_1('test/data_for_tests/word2vec_test.txt', 'test/data_for_tests/cws_train',
  130. cache_filepath='test/demo_overwrite.pkl')
  131. end_time = time.time()
  132. pre_time = end_time - start_time
  133. with open('test/demo_overwrite.pkl', 'rb') as f:
  134. _embed, _vocab, _d = _pickle.load(f)
  135. self.assertEqual(embed.shape, _embed.shape)
  136. for i in range(embed.shape[0]):
  137. self.assertListEqual(embed[i].tolist(), _embed[i].tolist())
  138. start_time = time.time()
  139. embed, vocab, d = process_data_1('test/data_for_tests/word2vec_test.txt', 'test/data_for_tests/cws_train',
  140. cache_filepath='test/demo_overwrite.pkl')
  141. end_time = time.time()
  142. read_time = end_time - start_time
  143. print("Read using {:.3f}, while prepare using:{:.3f}".format(read_time, pre_time))
  144. self.assertGreater(pre_time - 0.5, read_time)
  145. finally:
  146. os.remove('test/demo_overwrite.pkl')
  147. def test_cache_refresh(self):
  148. try:
  149. start_time = time.time()
  150. embed, vocab, d = process_data_1('test/data_for_tests/word2vec_test.txt', 'test/data_for_tests/cws_train',
  151. refresh=True)
  152. end_time = time.time()
  153. pre_time = end_time - start_time
  154. with open('test/demo1.pkl', 'rb') as f:
  155. _embed, _vocab, _d = _pickle.load(f)
  156. self.assertEqual(embed.shape, _embed.shape)
  157. for i in range(embed.shape[0]):
  158. self.assertListEqual(embed[i].tolist(), _embed[i].tolist())
  159. start_time = time.time()
  160. embed, vocab, d = process_data_1('test/data_for_tests/word2vec_test.txt', 'test/data_for_tests/cws_train',
  161. refresh=True)
  162. end_time = time.time()
  163. read_time = end_time - start_time
  164. print("Read using {:.3f}, while prepare using:{:.3f}".format(read_time, pre_time))
  165. self.assertGreater(0.1, pre_time - read_time)
  166. finally:
  167. os.remove('test/demo1.pkl')
  168. def test_duplicate_keyword(self):
  169. with self.assertRaises(RuntimeError):
  170. @cache_results(None)
  171. def func_verbose(a, _verbose):
  172. pass
  173. func_verbose(0, 1)
  174. with self.assertRaises(RuntimeError):
  175. @cache_results(None)
  176. def func_cache(a, _cache_fp):
  177. pass
  178. func_cache(1, 2)
  179. with self.assertRaises(RuntimeError):
  180. @cache_results(None)
  181. def func_refresh(a, _refresh):
  182. pass
  183. func_refresh(1, 2)
  184. def test_create_cache_dir(self):
  185. @cache_results('test/demo1/demo.pkl')
  186. def cache():
  187. return 1, 2
  188. try:
  189. results = cache()
  190. print(results)
  191. finally:
  192. os.remove('test/demo1/demo.pkl')
  193. os.rmdir('test/demo1')
  194. class TestSeqLenToMask(unittest.TestCase):
  195. def evaluate_mask_seq_len(self, seq_len, mask):
  196. max_len = int(max(seq_len))
  197. for i in range(len(seq_len)):
  198. length = seq_len[i]
  199. mask_i = mask[i]
  200. for j in range(max_len):
  201. self.assertEqual(mask_i[j], j<length)
  202. def test_numpy_seq_len(self):
  203. # 测试能否转换numpy类型的seq_len
  204. # 1. 随机测试
  205. seq_len = np.random.randint(1, 10, size=(10, ))
  206. mask = seq_len_to_mask(seq_len)
  207. max_len = seq_len.max()
  208. self.assertEqual(max_len, mask.shape[1])
  209. self.evaluate_mask_seq_len(seq_len, mask)
  210. # 2. 异常检测
  211. seq_len = np.random.randint(10, size=(10, 1))
  212. with self.assertRaises(AssertionError):
  213. mask = seq_len_to_mask(seq_len)
  214. def test_pytorch_seq_len(self):
  215. # 1. 随机测试
  216. seq_len = torch.randint(1, 10, size=(10, ))
  217. max_len = seq_len.max()
  218. mask = seq_len_to_mask(seq_len)
  219. self.assertEqual(max_len, mask.shape[1])
  220. self.evaluate_mask_seq_len(seq_len.tolist(), mask)
  221. # 2. 异常检测
  222. seq_len = torch.randn(3, 4)
  223. with self.assertRaises(AssertionError):
  224. mask = seq_len_to_mask(seq_len)