You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_logger.py 9.4 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300
  1. import os
  2. import tempfile
  3. import datetime
  4. from pathlib import Path
  5. import logging
  6. import re
  7. from fastNLP.envs.env import FASTNLP_LAUNCH_TIME
  8. from fastNLP.core import synchronize_safe_rm
  9. from fastNLP.core.log.logger import logger
  10. from tests.helpers.utils import magic_argv_env_context, recover_logger
  11. # 测试 TorchDDPDriver;
  12. @magic_argv_env_context
  13. @recover_logger
  14. def test_add_file_ddp_1_torch():
  15. """
  16. 测试 path 是一个文件的地址,但是这个文件所在的文件夹存在;
  17. 多卡时根据时间创造文件名字有一个很大的 bug,就是不同的进程启动之间是有时差的,因此会导致他们各自输出到单独的 log 文件中;
  18. """
  19. import torch
  20. import torch.distributed as dist
  21. from fastNLP.core.log.logger import logger
  22. from fastNLP.core.drivers.torch_driver.ddp import TorchDDPDriver
  23. from tests.helpers.models.torch_model import TorchNormalModel_Classification_1
  24. model = TorchNormalModel_Classification_1(num_labels=3, feature_dimension=10)
  25. driver = TorchDDPDriver(
  26. model=model,
  27. parallel_device=[torch.device("cuda:0"), torch.device("cuda:1")],
  28. output_from_new_proc="all"
  29. )
  30. driver.setup()
  31. msg = 'some test log msg'
  32. path = Path.cwd()
  33. filepath = path.joinpath('log.txt')
  34. handler = logger.add_file(filepath, mode="w")
  35. logger.info(msg)
  36. logger.warning(f"\nrank {driver.get_local_rank()} should have this message!\n")
  37. for h in logger.handlers:
  38. if isinstance(h, logging.FileHandler):
  39. h.flush()
  40. dist.barrier()
  41. with open(filepath, 'r') as f:
  42. line = ''.join([l for l in f])
  43. assert msg in line
  44. assert f"\nrank {driver.get_local_rank()} should have this message!\n" in line
  45. pattern = re.compile(msg)
  46. assert len(pattern.findall(line)) == 1
  47. synchronize_safe_rm(filepath)
  48. dist.barrier()
  49. dist.destroy_process_group()
  50. @magic_argv_env_context
  51. @recover_logger
  52. def test_add_file_ddp_2_torch():
  53. """
  54. 测试 path 是一个文件的地址,但是这个文件所在的文件夹不存在;
  55. """
  56. import torch
  57. import torch.distributed as dist
  58. from fastNLP.core.log.logger import logger
  59. from fastNLP.core.drivers.torch_driver.ddp import TorchDDPDriver
  60. from tests.helpers.models.torch_model import TorchNormalModel_Classification_1
  61. model = TorchNormalModel_Classification_1(num_labels=3, feature_dimension=10)
  62. driver = TorchDDPDriver(
  63. model=model,
  64. parallel_device=[torch.device("cuda:0"), torch.device("cuda:1")],
  65. output_from_new_proc="all"
  66. )
  67. driver.setup()
  68. msg = 'some test log msg'
  69. origin_path = Path.cwd()
  70. try:
  71. path = origin_path.joinpath("not_existed")
  72. filepath = path.joinpath('log.txt')
  73. handler = logger.add_file(filepath)
  74. logger.info(msg)
  75. logger.warning(f"\nrank {driver.get_local_rank()} should have this message!\n")
  76. for h in logger.handlers:
  77. if isinstance(h, logging.FileHandler):
  78. h.flush()
  79. dist.barrier()
  80. with open(filepath, 'r') as f:
  81. line = ''.join([l for l in f])
  82. assert msg in line
  83. assert f"\nrank {driver.get_local_rank()} should have this message!\n" in line
  84. pattern = re.compile(msg)
  85. assert len(pattern.findall(line)) == 1
  86. finally:
  87. synchronize_safe_rm(path)
  88. dist.barrier()
  89. dist.destroy_process_group()
  90. @magic_argv_env_context
  91. @recover_logger
  92. def test_add_file_ddp_3_torch():
  93. """
  94. path = None;
  95. 多卡时根据时间创造文件名字有一个很大的 bug,就是不同的进程启动之间是有时差的,因此会导致他们各自输出到单独的 log 文件中;
  96. """
  97. import torch
  98. import torch.distributed as dist
  99. from fastNLP.core.log.logger import logger
  100. from fastNLP.core.drivers.torch_driver.ddp import TorchDDPDriver
  101. from tests.helpers.models.torch_model import TorchNormalModel_Classification_1
  102. model = TorchNormalModel_Classification_1(num_labels=3, feature_dimension=10)
  103. driver = TorchDDPDriver(
  104. model=model,
  105. parallel_device=[torch.device("cuda:0"), torch.device("cuda:1")],
  106. output_from_new_proc="all"
  107. )
  108. driver.setup()
  109. msg = 'some test log msg'
  110. handler = logger.add_file()
  111. logger.info(msg)
  112. logger.warning(f"\nrank {driver.get_local_rank()} should have this message!\n")
  113. for h in logger.handlers:
  114. if isinstance(h, logging.FileHandler):
  115. h.flush()
  116. dist.barrier()
  117. file = Path.cwd().joinpath(os.environ.get(FASTNLP_LAUNCH_TIME)+".log")
  118. with open(file, 'r') as f:
  119. line = ''.join([l for l in f])
  120. # print(f"\nrank: {driver.get_local_rank()} line, {line}\n")
  121. assert msg in line
  122. assert f"\nrank {driver.get_local_rank()} should have this message!\n" in line
  123. pattern = re.compile(msg)
  124. assert len(pattern.findall(line)) == 1
  125. synchronize_safe_rm(file)
  126. dist.barrier()
  127. dist.destroy_process_group()
  128. @magic_argv_env_context
  129. @recover_logger
  130. def test_add_file_ddp_4_torch():
  131. """
  132. 测试 path 是文件夹;
  133. """
  134. import torch
  135. import torch.distributed as dist
  136. from fastNLP.core.log.logger import logger
  137. from fastNLP.core.drivers.torch_driver.ddp import TorchDDPDriver
  138. from tests.helpers.models.torch_model import TorchNormalModel_Classification_1
  139. model = TorchNormalModel_Classification_1(num_labels=3, feature_dimension=10)
  140. driver = TorchDDPDriver(
  141. model=model,
  142. parallel_device=[torch.device("cuda:0"), torch.device("cuda:1")],
  143. output_from_new_proc="all"
  144. )
  145. driver.setup()
  146. msg = 'some test log msg'
  147. path = Path.cwd().joinpath("not_existed")
  148. try:
  149. handler = logger.add_file(path)
  150. logger.info(msg)
  151. logger.warning(f"\nrank {driver.get_local_rank()} should have this message!\n")
  152. for h in logger.handlers:
  153. if isinstance(h, logging.FileHandler):
  154. h.flush()
  155. dist.barrier()
  156. file = path.joinpath(os.environ.get(FASTNLP_LAUNCH_TIME) + ".log")
  157. with open(file, 'r') as f:
  158. line = ''.join([l for l in f])
  159. assert msg in line
  160. assert f"\nrank {driver.get_local_rank()} should have this message!\n" in line
  161. pattern = re.compile(msg)
  162. assert len(pattern.findall(line)) == 1
  163. finally:
  164. synchronize_safe_rm(path)
  165. dist.barrier()
  166. dist.destroy_process_group()
  167. class TestLogger:
  168. msg = 'some test log msg'
  169. @recover_logger
  170. def test_add_file_1(self):
  171. """
  172. 测试 path 是一个文件的地址,但是这个文件所在的文件夹存在;
  173. """
  174. path = Path(tempfile.mkdtemp())
  175. try:
  176. filepath = path.joinpath('log.txt')
  177. handler = logger.add_file(filepath)
  178. logger.info(self.msg)
  179. with open(filepath, 'r') as f:
  180. line = ''.join([l for l in f])
  181. assert self.msg in line
  182. finally:
  183. synchronize_safe_rm(path)
  184. @recover_logger
  185. def test_add_file_2(self):
  186. """
  187. 测试 path 是一个文件的地址,但是这个文件所在的文件夹不存在;
  188. """
  189. origin_path = Path(tempfile.mkdtemp())
  190. try:
  191. path = origin_path.joinpath("not_existed")
  192. path = path.joinpath('log.txt')
  193. handler = logger.add_file(path)
  194. logger.info(self.msg)
  195. with open(path, 'r') as f:
  196. line = ''.join([l for l in f])
  197. assert self.msg in line
  198. finally:
  199. synchronize_safe_rm(origin_path)
  200. @recover_logger
  201. def test_add_file_3(self):
  202. """
  203. 测试 path 是 None;
  204. """
  205. handler = logger.add_file()
  206. logger.info(self.msg)
  207. path = Path.cwd()
  208. cur_datetime = str(datetime.datetime.now().strftime('%Y-%m-%d'))
  209. for file in path.iterdir():
  210. if file.name.startswith(cur_datetime):
  211. with open(file, 'r') as f:
  212. line = ''.join([l for l in f])
  213. assert self.msg in line
  214. file.unlink()
  215. @recover_logger
  216. def test_add_file_4(self):
  217. """
  218. 测试 path 是文件夹;
  219. """
  220. path = Path(tempfile.mkdtemp())
  221. try:
  222. handler = logger.add_file(path)
  223. logger.info(self.msg)
  224. cur_datetime = str(datetime.datetime.now().strftime('%Y-%m-%d'))
  225. for file in path.iterdir():
  226. if file.name.startswith(cur_datetime):
  227. with open(file, 'r') as f:
  228. line = ''.join([l for l in f])
  229. assert self.msg in line
  230. finally:
  231. synchronize_safe_rm(path)
  232. @recover_logger
  233. def test_stdout(self, capsys):
  234. handler = logger.set_stdout(stdout="raw")
  235. logger.info(self.msg)
  236. logger.debug('aabbc')
  237. captured = capsys.readouterr()
  238. assert "some test log msg\n" == captured.out
  239. @recover_logger
  240. def test_warning_once(self, capsys):
  241. logger.warning_once('#')
  242. logger.warning_once('#')
  243. logger.warning_once('@')
  244. captured = capsys.readouterr()
  245. assert captured.out.count('#') == 1
  246. assert captured.out.count('@') == 1