You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_cache_results.py 9.1 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304
  1. import time
  2. import os
  3. import pytest
  4. from subprocess import Popen, PIPE
  5. from io import StringIO
  6. import sys
  7. from fastNLP.core.utils.cache_results import cache_results
  8. from tests.helpers.common.utils import check_time_elapse
  9. from fastNLP.core import synchronize_safe_rm
  10. def get_subprocess_results(cmd):
  11. pipe = Popen(cmd, shell=True, stdout=PIPE, stderr=PIPE)
  12. output, err = pipe.communicate()
  13. if output:
  14. output = output.decode('utf8')
  15. else:
  16. output = ''
  17. if err:
  18. err = err.decode('utf8')
  19. else:
  20. err = ''
  21. res = output + err
  22. return res
  23. class Capturing(list):
  24. # 用来捕获当前环境中的stdout和stderr,会将其中stderr的输出拼接在stdout的输出后面
  25. def __enter__(self):
  26. self._stdout = sys.stdout
  27. self._stderr = sys.stderr
  28. sys.stdout = self._stringio = StringIO()
  29. sys.stderr = self._stringioerr = StringIO()
  30. return self
  31. def __exit__(self, *args):
  32. self.append(self._stringio.getvalue() + self._stringioerr.getvalue())
  33. del self._stringio, self._stringioerr # free up some memory
  34. sys.stdout = self._stdout
  35. sys.stderr = self._stderr
  36. class TestCacheResults:
  37. def test_cache_save(self):
  38. cache_fp = 'demo.pkl'
  39. try:
  40. @cache_results(cache_fp)
  41. def demo():
  42. time.sleep(1)
  43. return 1
  44. res = demo()
  45. with check_time_elapse(1, op='lt'):
  46. res = demo()
  47. finally:
  48. synchronize_safe_rm(cache_fp)
  49. def test_cache_save_refresh(self):
  50. cache_fp = 'demo.pkl'
  51. try:
  52. @cache_results(cache_fp, _refresh=True)
  53. def demo():
  54. time.sleep(1.5)
  55. return 1
  56. res = demo()
  57. with check_time_elapse(1, op='ge'):
  58. res = demo()
  59. finally:
  60. synchronize_safe_rm(cache_fp)
  61. def test_cache_no_func_change(self):
  62. cache_fp = os.path.abspath('demo.pkl')
  63. try:
  64. @cache_results(cache_fp)
  65. def demo():
  66. time.sleep(2)
  67. return 1
  68. with check_time_elapse(1, op='gt'):
  69. res = demo()
  70. @cache_results(cache_fp)
  71. def demo():
  72. time.sleep(2)
  73. return 1
  74. with check_time_elapse(1, op='lt'):
  75. res = demo()
  76. finally:
  77. synchronize_safe_rm('demo.pkl')
  78. def test_cache_func_change(self, capsys):
  79. cache_fp = 'demo.pkl'
  80. try:
  81. @cache_results(cache_fp)
  82. def demo():
  83. time.sleep(2)
  84. return 1
  85. with check_time_elapse(1, op='gt'):
  86. res = demo()
  87. @cache_results(cache_fp)
  88. def demo():
  89. time.sleep(1)
  90. return 1
  91. with check_time_elapse(1, op='lt'):
  92. with Capturing() as output:
  93. res = demo()
  94. assert 'is different from its last cache' in output[0]
  95. # 关闭check_hash应该不warning的
  96. with check_time_elapse(1, op='lt'):
  97. with Capturing() as output:
  98. res = demo(_check_hash=0)
  99. assert 'is different from its last cache' not in output[0]
  100. finally:
  101. synchronize_safe_rm('demo.pkl')
  102. def test_cache_check_hash(self):
  103. cache_fp = 'demo.pkl'
  104. try:
  105. @cache_results(cache_fp, _check_hash=False)
  106. def demo():
  107. time.sleep(2)
  108. return 1
  109. with check_time_elapse(1, op='gt'):
  110. res = demo()
  111. @cache_results(cache_fp, _check_hash=False)
  112. def demo():
  113. time.sleep(1)
  114. return 1
  115. # 默认不会check
  116. with check_time_elapse(1, op='lt'):
  117. with Capturing() as output:
  118. res = demo()
  119. assert 'is different from its last cache' not in output[0]
  120. # check也可以
  121. with check_time_elapse(1, op='lt'):
  122. with Capturing() as output:
  123. res = demo(_check_hash=True)
  124. assert 'is different from its last cache' in output[0]
  125. finally:
  126. synchronize_safe_rm('demo.pkl')
  127. # 外部 function 改变也会 导致改变
  128. def test_refer_fun_change(self):
  129. cache_fp = 'demo.pkl'
  130. test_type = 'func_refer_fun_change'
  131. try:
  132. with check_time_elapse(3, op='gt'):
  133. cmd = f'python {__file__} --cache_fp {cache_fp} --test_type {test_type} --turn 0'
  134. res = get_subprocess_results(cmd)
  135. # 引用的function没有变化
  136. with check_time_elapse(2, op='lt'):
  137. cmd = f'python {__file__} --cache_fp {cache_fp} --test_type {test_type} --turn 0'
  138. res = get_subprocess_results(cmd)
  139. assert 'Read cache from' in res
  140. assert 'is different from its last cache' not in res
  141. # 引用的function有变化
  142. with check_time_elapse(2, op='lt'):
  143. cmd = f'python {__file__} --cache_fp {cache_fp} --test_type {test_type} --turn 1'
  144. res = get_subprocess_results(cmd)
  145. assert 'is different from its last cache' in res
  146. finally:
  147. synchronize_safe_rm(cache_fp)
  148. # 外部 method 改变也会 导致改变
  149. def test_refer_class_method_change(self):
  150. cache_fp = 'demo.pkl'
  151. test_type = 'refer_class_method_change'
  152. try:
  153. with check_time_elapse(3, op='gt'):
  154. cmd = f'python {__file__} --cache_fp {cache_fp} --test_type {test_type} --turn 0'
  155. res = get_subprocess_results(cmd)
  156. # 引用的class没有变化
  157. with check_time_elapse(2, op='lt'):
  158. cmd = f'python {__file__} --cache_fp {cache_fp} --test_type {test_type} --turn 0'
  159. res = get_subprocess_results(cmd)
  160. assert 'Read cache from' in res
  161. assert 'is different from its last cache' not in res
  162. # 引用的class有变化
  163. with check_time_elapse(2, op='lt'):
  164. cmd = f'python {__file__} --cache_fp {cache_fp} --test_type {test_type} --turn 1'
  165. res = get_subprocess_results(cmd)
  166. assert 'is different from its last cache' in res
  167. finally:
  168. synchronize_safe_rm(cache_fp)
  169. def test_duplicate_keyword(self):
  170. with pytest.raises(RuntimeError):
  171. @cache_results(None)
  172. def func_verbose(a, _verbose):
  173. pass
  174. func_verbose(0, 1)
  175. with pytest.raises(RuntimeError):
  176. @cache_results(None)
  177. def func_cache(a, _cache_fp):
  178. pass
  179. func_cache(1, 2)
  180. with pytest.raises(RuntimeError):
  181. @cache_results(None)
  182. def func_refresh(a, _refresh):
  183. pass
  184. func_refresh(1, 2)
  185. with pytest.raises(RuntimeError):
  186. @cache_results(None)
  187. def func_refresh(a, _check_hash):
  188. pass
  189. func_refresh(1, 2)
  190. def test_create_cache_dir(self):
  191. @cache_results('demo/demo.pkl')
  192. def cache():
  193. return 1, 2
  194. try:
  195. results = cache()
  196. assert (1, 2) == results
  197. finally:
  198. synchronize_safe_rm('demo/')
  199. def test_result_none_error(self):
  200. @cache_results('demo.pkl')
  201. def cache():
  202. pass
  203. try:
  204. with pytest.raises(RuntimeError):
  205. results = cache()
  206. finally:
  207. synchronize_safe_rm('demo.pkl')
  208. if __name__ == '__main__':
  209. import argparse
  210. parser = argparse.ArgumentParser()
  211. parser.add_argument('--test_type', type=str, default='refer_class_method_change')
  212. parser.add_argument('--turn', type=int, default=1)
  213. parser.add_argument('--cache_fp', type=str, default='demo.pkl')
  214. args = parser.parse_args()
  215. test_type = args.test_type
  216. cache_fp = args.cache_fp
  217. turn = args.turn
  218. if test_type == 'func_refer_fun_change':
  219. if turn == 0:
  220. def demo():
  221. b = 1
  222. return b
  223. else:
  224. def demo():
  225. b = 2
  226. return b
  227. @cache_results(cache_fp)
  228. def demo_refer_other_func():
  229. time.sleep(3)
  230. b = demo()
  231. return b
  232. res = demo_refer_other_func()
  233. if test_type == 'refer_class_method_change':
  234. print(f"Turn:{turn}")
  235. if turn == 0:
  236. from helper_for_cache_results_1 import Demo
  237. else:
  238. from helper_for_cache_results_2 import Demo
  239. demo = Demo()
  240. # import pdb
  241. # pdb.set_trace()
  242. @cache_results(cache_fp)
  243. def demo_func():
  244. time.sleep(3)
  245. b = demo.demo()
  246. return b
  247. res = demo_func()