You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_cache_results.py 8.5 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296
  1. import os
  2. import pytest
  3. import subprocess
  4. from io import StringIO
  5. import sys
  6. sys.path.append(os.path.join(os.path.dirname(__file__), '../../..'))
  7. from fastNLP.core.utils.cache_results import cache_results
  8. from fastNLP.core import rank_zero_rm
  9. def get_subprocess_results(cmd):
  10. output = subprocess.check_output(cmd, shell=True)
  11. return output.decode('utf8')
  12. class Capturing(list):
  13. # 用来捕获当前环境中的stdout和stderr,会将其中stderr的输出拼接在stdout的输出后面
  14. def __enter__(self):
  15. self._stdout = sys.stdout
  16. self._stderr = sys.stderr
  17. sys.stdout = self._stringio = StringIO()
  18. sys.stderr = self._stringioerr = StringIO()
  19. return self
  20. def __exit__(self, *args):
  21. self.append(self._stringio.getvalue() + self._stringioerr.getvalue())
  22. del self._stringio, self._stringioerr # free up some memory
  23. sys.stdout = self._stdout
  24. sys.stderr = self._stderr
  25. class TestCacheResults:
  26. def test_cache_save(self):
  27. cache_fp = 'demo.pkl'
  28. try:
  29. @cache_results(cache_fp)
  30. def demo():
  31. print("¥")
  32. return 1
  33. res = demo()
  34. with Capturing() as output:
  35. res = demo()
  36. assert '¥' not in output[0]
  37. finally:
  38. rank_zero_rm(cache_fp)
  39. def test_cache_save_refresh(self):
  40. cache_fp = 'demo.pkl'
  41. try:
  42. @cache_results(cache_fp, _refresh=True)
  43. def demo():
  44. print("¥")
  45. return 1
  46. res = demo()
  47. with Capturing() as output:
  48. res = demo()
  49. assert '¥' in output[0]
  50. finally:
  51. rank_zero_rm(cache_fp)
  52. def test_cache_no_func_change(self):
  53. cache_fp = os.path.abspath('demo.pkl')
  54. try:
  55. @cache_results(cache_fp)
  56. def demo():
  57. print('¥')
  58. return 1
  59. with Capturing() as output:
  60. res = demo()
  61. assert '¥' in output[0]
  62. @cache_results(cache_fp)
  63. def demo():
  64. print('¥')
  65. return 1
  66. with Capturing() as output:
  67. res = demo()
  68. assert '¥' not in output[0]
  69. finally:
  70. rank_zero_rm('demo.pkl')
  71. def test_cache_func_change(self, capsys):
  72. cache_fp = 'demo.pkl'
  73. try:
  74. @cache_results(cache_fp)
  75. def demo():
  76. print('¥')
  77. return 1
  78. with Capturing() as output:
  79. res = demo()
  80. assert '¥' in output[0]
  81. @cache_results(cache_fp)
  82. def demo():
  83. print('¥¥')
  84. return 1
  85. with Capturing() as output:
  86. res = demo()
  87. assert 'different' in output[0]
  88. assert '¥' not in output[0]
  89. # 关闭check_hash应该不warning的
  90. with Capturing() as output:
  91. res = demo(_check_hash=0)
  92. assert 'different' not in output[0]
  93. assert '¥' not in output[0]
  94. finally:
  95. rank_zero_rm('demo.pkl')
  96. def test_cache_check_hash(self):
  97. cache_fp = 'demo.pkl'
  98. try:
  99. @cache_results(cache_fp, _check_hash=False)
  100. def demo():
  101. print('¥')
  102. return 1
  103. with Capturing() as output:
  104. res = demo(_check_hash=0)
  105. assert '¥' in output[0]
  106. @cache_results(cache_fp, _check_hash=False)
  107. def demo():
  108. print('¥¥')
  109. return 1
  110. # 默认不会check
  111. with Capturing() as output:
  112. res = demo()
  113. assert 'different' not in output[0]
  114. assert '¥' not in output[0]
  115. # check也可以
  116. with Capturing() as output:
  117. res = demo(_check_hash=True)
  118. assert 'different' in output[0]
  119. assert '¥' not in output[0]
  120. finally:
  121. rank_zero_rm('demo.pkl')
  122. # 外部 function 改变也会 导致改变
  123. def test_refer_fun_change(self):
  124. cache_fp = 'demo.pkl'
  125. test_type = 'func_refer_fun_change'
  126. try:
  127. cmd = f'python {__file__} --cache_fp {cache_fp} --test_type {test_type} --turn 0'
  128. res = get_subprocess_results(cmd)
  129. assert "¥" in res
  130. # 引用的function没有变化
  131. cmd = f'python {__file__} --cache_fp {cache_fp} --test_type {test_type} --turn 0'
  132. res = get_subprocess_results(cmd)
  133. assert "¥" not in res
  134. assert 'Read' in res
  135. assert 'different' not in res
  136. # 引用的function有变化
  137. cmd = f'python {__file__} --cache_fp {cache_fp} --test_type {test_type} --turn 1'
  138. res = get_subprocess_results(cmd)
  139. assert "¥" not in res
  140. assert 'different' in res
  141. finally:
  142. rank_zero_rm(cache_fp)
  143. # 外部 method 改变也会 导致改变
  144. def test_refer_class_method_change(self):
  145. cache_fp = 'demo.pkl'
  146. test_type = 'refer_class_method_change'
  147. try:
  148. cmd = f'python {__file__} --cache_fp {cache_fp} --test_type {test_type} --turn 0'
  149. res = get_subprocess_results(cmd)
  150. assert "¥" in res
  151. # 引用的class没有变化
  152. cmd = f'python {__file__} --cache_fp {cache_fp} --test_type {test_type} --turn 0'
  153. res = get_subprocess_results(cmd)
  154. assert 'Read' in res
  155. assert 'different' not in res
  156. assert "¥" not in res
  157. cmd = f'python {__file__} --cache_fp {cache_fp} --test_type {test_type} --turn 1'
  158. res = get_subprocess_results(cmd)
  159. assert 'different' in res
  160. assert "¥" not in res
  161. finally:
  162. rank_zero_rm(cache_fp)
  163. def test_duplicate_keyword(self):
  164. with pytest.raises(RuntimeError):
  165. @cache_results(None)
  166. def func_verbose(a, _verbose):
  167. pass
  168. func_verbose(0, 1)
  169. with pytest.raises(RuntimeError):
  170. @cache_results(None)
  171. def func_cache(a, _cache_fp):
  172. pass
  173. func_cache(1, 2)
  174. with pytest.raises(RuntimeError):
  175. @cache_results(None)
  176. def func_refresh(a, _refresh):
  177. pass
  178. func_refresh(1, 2)
  179. with pytest.raises(RuntimeError):
  180. @cache_results(None)
  181. def func_refresh(a, _check_hash):
  182. pass
  183. func_refresh(1, 2)
  184. def test_create_cache_dir(self):
  185. @cache_results('demo/demo.pkl')
  186. def cache():
  187. return 1, 2
  188. try:
  189. results = cache()
  190. assert (1, 2) == results
  191. finally:
  192. rank_zero_rm('demo/')
  193. def test_result_none_error(self):
  194. @cache_results('demo.pkl')
  195. def cache():
  196. pass
  197. try:
  198. with pytest.raises(RuntimeError):
  199. results = cache()
  200. finally:
  201. rank_zero_rm('demo.pkl')
  202. if __name__ == '__main__':
  203. import argparse
  204. parser = argparse.ArgumentParser()
  205. parser.add_argument('--test_type', type=str, default='refer_class_method_change')
  206. parser.add_argument('--turn', type=int, default=1)
  207. parser.add_argument('--cache_fp', type=str, default='demo.pkl')
  208. args = parser.parse_args()
  209. test_type = args.test_type
  210. cache_fp = args.cache_fp
  211. turn = args.turn
  212. if test_type == 'func_refer_fun_change':
  213. if turn == 0:
  214. def demo():
  215. b = 1
  216. return b
  217. else:
  218. def demo():
  219. b = 2
  220. return b
  221. @cache_results(cache_fp)
  222. def demo_refer_other_func():
  223. b = demo()
  224. print("¥")
  225. return b
  226. res = demo_refer_other_func()
  227. if test_type == 'refer_class_method_change':
  228. print(f"Turn:{turn}")
  229. if turn == 0:
  230. from helper_for_cache_results_1 import Demo
  231. else:
  232. from helper_for_cache_results_2 import Demo
  233. demo = Demo()
  234. # import pdb
  235. # pdb.set_trace()
  236. @cache_results(cache_fp)
  237. def demo_func():
  238. print("¥")
  239. b = demo.demo()
  240. return b
  241. res = demo_func()