You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

hooks.py 15 kB

3 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424
  1. # -*- coding: utf-8 -*-
  2. # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
  3. import datetime
  4. import logging
  5. import os
  6. import tempfile
  7. import time
  8. from collections import Counter
  9. import torch
  10. from fvcore.common.checkpoint import PeriodicCheckpointer as _PeriodicCheckpointer
  11. from fvcore.common.file_io import PathManager
  12. from fvcore.common.timer import Timer
  13. from fvcore.nn.precise_bn import get_bn_modules, update_bn_stats
  14. import detectron2.utils.comm as comm
  15. from detectron2.evaluation.testing import flatten_results_dict
  16. from detectron2.utils.events import EventStorage, EventWriter
  17. from .train_loop import HookBase
  18. __all__ = [
  19. "CallbackHook",
  20. "IterationTimer",
  21. "PeriodicWriter",
  22. "PeriodicCheckpointer",
  23. "LRScheduler",
  24. "AutogradProfiler",
  25. "EvalHook",
  26. "PreciseBN",
  27. ]
  28. """
  29. Implement some common hooks.
  30. """
  31. class CallbackHook(HookBase):
  32. """
  33. Create a hook using callback functions provided by the user.
  34. """
  35. def __init__(self, *, before_train=None, after_train=None, before_step=None, after_step=None):
  36. """
  37. Each argument is a function that takes one argument: the trainer.
  38. """
  39. self._before_train = before_train
  40. self._before_step = before_step
  41. self._after_step = after_step
  42. self._after_train = after_train
  43. def before_train(self):
  44. if self._before_train:
  45. self._before_train(self.trainer)
  46. def after_train(self):
  47. if self._after_train:
  48. self._after_train(self.trainer)
  49. # The functions may be closures that hold reference to the trainer
  50. # Therefore, delete them to avoid circular reference.
  51. del self._before_train, self._after_train
  52. del self._before_step, self._after_step
  53. def before_step(self):
  54. if self._before_step:
  55. self._before_step(self.trainer)
  56. def after_step(self):
  57. if self._after_step:
  58. self._after_step(self.trainer)
  59. class IterationTimer(HookBase):
  60. """
  61. Track the time spent for each iteration (each run_step call in the trainer).
  62. Print a summary in the end of training.
  63. This hook uses the time between the call to its :meth:`before_step`
  64. and :meth:`after_step` methods.
  65. Under the convention that :meth:`before_step` of all hooks should only
  66. take negligible amount of time, the :class:`IterationTimer` hook should be
  67. placed at the beginning of the list of hooks to obtain accurate timing.
  68. """
  69. def __init__(self, warmup_iter=3):
  70. """
  71. Args:
  72. warmup_iter (int): the number of iterations at the beginning to exclude
  73. from timing.
  74. """
  75. self._warmup_iter = warmup_iter
  76. self._step_timer = Timer()
  77. def before_train(self):
  78. self._start_time = time.perf_counter()
  79. self._total_timer = Timer()
  80. self._total_timer.pause()
  81. def after_train(self):
  82. logger = logging.getLogger(__name__)
  83. total_time = time.perf_counter() - self._start_time
  84. total_time_minus_hooks = self._total_timer.seconds()
  85. hook_time = total_time - total_time_minus_hooks
  86. num_iter = self.trainer.iter + 1 - self.trainer.start_iter - self._warmup_iter
  87. if num_iter > 0 and total_time_minus_hooks > 0:
  88. # Speed is meaningful only after warmup
  89. # NOTE this format is parsed by grep in some scripts
  90. logger.info(
  91. "Overall training speed: {} iterations in {} ({:.4f} s / it)".format(
  92. num_iter,
  93. str(datetime.timedelta(seconds=int(total_time_minus_hooks))),
  94. total_time_minus_hooks / num_iter,
  95. )
  96. )
  97. logger.info(
  98. "Total training time: {} ({} on hooks)".format(
  99. str(datetime.timedelta(seconds=int(total_time))),
  100. str(datetime.timedelta(seconds=int(hook_time))),
  101. )
  102. )
  103. def before_step(self):
  104. self._step_timer.reset()
  105. self._total_timer.resume()
  106. def after_step(self):
  107. # +1 because we're in after_step
  108. iter_done = self.trainer.iter - self.trainer.start_iter + 1
  109. if iter_done >= self._warmup_iter:
  110. sec = self._step_timer.seconds()
  111. self.trainer.storage.put_scalars(time=sec)
  112. else:
  113. self._start_time = time.perf_counter()
  114. self._total_timer.reset()
  115. self._total_timer.pause()
  116. class PeriodicWriter(HookBase):
  117. """
  118. Write events to EventStorage periodically.
  119. It is executed every ``period`` iterations and after the last iteration.
  120. """
  121. def __init__(self, writers, period=20):
  122. """
  123. Args:
  124. writers (list[EventWriter]): a list of EventWriter objects
  125. period (int):
  126. """
  127. self._writers = writers
  128. for w in writers:
  129. assert isinstance(w, EventWriter), w
  130. self._period = period
  131. def after_step(self):
  132. if (self.trainer.iter + 1) % self._period == 0 or (
  133. self.trainer.iter == self.trainer.max_iter - 1
  134. ):
  135. for writer in self._writers:
  136. writer.write()
  137. def after_train(self):
  138. for writer in self._writers:
  139. writer.close()
  140. class PeriodicCheckpointer(_PeriodicCheckpointer, HookBase):
  141. """
  142. Same as :class:`detectron2.checkpoint.PeriodicCheckpointer`, but as a hook.
  143. Note that when used as a hook,
  144. it is unable to save additional data other than what's defined
  145. by the given `checkpointer`.
  146. It is executed every ``period`` iterations and after the last iteration.
  147. """
  148. def before_train(self):
  149. self.max_iter = self.trainer.max_iter
  150. def after_step(self):
  151. # No way to use **kwargs
  152. self.step(self.trainer.iter)
  153. class LRScheduler(HookBase):
  154. """
  155. A hook which executes a torch builtin LR scheduler and summarizes the LR.
  156. It is executed after every iteration.
  157. """
  158. def __init__(self, optimizer, scheduler):
  159. """
  160. Args:
  161. optimizer (torch.optim.Optimizer):
  162. scheduler (torch.optim._LRScheduler)
  163. """
  164. self._optimizer = optimizer
  165. self._scheduler = scheduler
  166. # NOTE: some heuristics on what LR to summarize
  167. # summarize the param group with most parameters
  168. largest_group = max(len(g["params"]) for g in optimizer.param_groups)
  169. if largest_group == 1:
  170. # If all groups have one parameter,
  171. # then find the most common initial LR, and use it for summary
  172. lr_count = Counter([g["lr"] for g in optimizer.param_groups])
  173. lr = lr_count.most_common()[0][0]
  174. for i, g in enumerate(optimizer.param_groups):
  175. if g["lr"] == lr:
  176. self._best_param_group_id = i
  177. break
  178. else:
  179. for i, g in enumerate(optimizer.param_groups):
  180. if len(g["params"]) == largest_group:
  181. self._best_param_group_id = i
  182. break
  183. def after_step(self):
  184. lr = self._optimizer.param_groups[self._best_param_group_id]["lr"]
  185. self.trainer.storage.put_scalar("lr", lr, smoothing_hint=False)
  186. self._scheduler.step()
  187. class AutogradProfiler(HookBase):
  188. """
  189. A hook which runs `torch.autograd.profiler.profile`.
  190. Examples:
  191. .. code-block:: python
  192. hooks.AutogradProfiler(
  193. lambda trainer: trainer.iter > 10 and trainer.iter < 20, self.cfg.OUTPUT_DIR
  194. )
  195. The above example will run the profiler for iteration 10~20 and dump
  196. results to ``OUTPUT_DIR``. We did not profile the first few iterations
  197. because they are typically slower than the rest.
  198. The result files can be loaded in the ``chrome://tracing`` page in chrome browser.
  199. Note:
  200. When used together with NCCL on older version of GPUs,
  201. autograd profiler may cause deadlock because it unnecessarily allocates
  202. memory on every device it sees. The memory management calls, if
  203. interleaved with NCCL calls, lead to deadlock on GPUs that do not
  204. support `cudaLaunchCooperativeKernelMultiDevice`.
  205. """
  206. def __init__(self, enable_predicate, output_dir, *, use_cuda=True):
  207. """
  208. Args:
  209. enable_predicate (callable[trainer -> bool]): a function which takes a trainer,
  210. and returns whether to enable the profiler.
  211. It will be called once every step, and can be used to select which steps to profile.
  212. output_dir (str): the output directory to dump tracing files.
  213. use_cuda (bool): same as in `torch.autograd.profiler.profile`.
  214. """
  215. self._enable_predicate = enable_predicate
  216. self._use_cuda = use_cuda
  217. self._output_dir = output_dir
  218. def before_step(self):
  219. if self._enable_predicate(self.trainer):
  220. self._profiler = torch.autograd.profiler.profile(use_cuda=self._use_cuda)
  221. self._profiler.__enter__()
  222. else:
  223. self._profiler = None
  224. def after_step(self):
  225. if self._profiler is None:
  226. return
  227. self._profiler.__exit__(None, None, None)
  228. out_file = os.path.join(
  229. self._output_dir, "profiler-trace-iter{}.json".format(self.trainer.iter)
  230. )
  231. if "://" not in out_file:
  232. self._profiler.export_chrome_trace(out_file)
  233. else:
  234. # Support non-posix filesystems
  235. with tempfile.TemporaryDirectory(prefix="detectron2_profiler") as d:
  236. tmp_file = os.path.join(d, "tmp.json")
  237. self._profiler.export_chrome_trace(tmp_file)
  238. with open(tmp_file) as f:
  239. content = f.read()
  240. with PathManager.open(out_file, "w") as f:
  241. f.write(content)
  242. class EvalHook(HookBase):
  243. """
  244. Run an evaluation function periodically, and at the end of training.
  245. It is executed every ``eval_period`` iterations and after the last iteration.
  246. """
  247. def __init__(self, eval_period, eval_function):
  248. """
  249. Args:
  250. eval_period (int): the period to run `eval_function`.
  251. eval_function (callable): a function which takes no arguments, and
  252. returns a nested dict of evaluation metrics.
  253. Note:
  254. This hook must be enabled in all or none workers.
  255. If you would like only certain workers to perform evaluation,
  256. give other workers a no-op function (`eval_function=lambda: None`).
  257. """
  258. self._period = eval_period
  259. self._func = eval_function
  260. def after_step(self):
  261. next_iter = self.trainer.iter + 1
  262. is_final = next_iter == self.trainer.max_iter
  263. if is_final or (self._period > 0 and next_iter % self._period == 0):
  264. results = self._func()
  265. if results:
  266. assert isinstance(
  267. results, dict
  268. ), "Eval function must return a dict. Got {} instead.".format(results)
  269. flattened_results = flatten_results_dict(results)
  270. for k, v in flattened_results.items():
  271. try:
  272. v = float(v)
  273. except Exception:
  274. raise ValueError(
  275. "[EvalHook] eval_function should return a nested dict of float. "
  276. "Got '{}: {}' instead.".format(k, v)
  277. )
  278. self.trainer.storage.put_scalars(**flattened_results, smoothing_hint=False)
  279. # Evaluation may take different time among workers.
  280. # A barrier make them start the next iteration together.
  281. comm.synchronize()
  282. def after_train(self):
  283. # func is likely a closure that holds reference to the trainer
  284. # therefore we clean it to avoid circular reference in the end
  285. del self._func
  286. class PreciseBN(HookBase):
  287. """
  288. The standard implementation of BatchNorm uses EMA in inference, which is
  289. sometimes suboptimal.
  290. This class computes the true average of statistics rather than the moving average,
  291. and put true averages to every BN layer in the given model.
  292. It is executed every ``period`` iterations and after the last iteration.
  293. """
  294. def __init__(self, period, model, data_loader, num_iter):
  295. """
  296. Args:
  297. period (int): the period this hook is run, or 0 to not run during training.
  298. The hook will always run in the end of training.
  299. model (nn.Module): a module whose all BN layers in training mode will be
  300. updated by precise BN.
  301. Note that user is responsible for ensuring the BN layers to be
  302. updated are in training mode when this hook is triggered.
  303. data_loader (iterable): it will produce data to be run by `model(data)`.
  304. num_iter (int): number of iterations used to compute the precise
  305. statistics.
  306. """
  307. self._logger = logging.getLogger(__name__)
  308. if len(get_bn_modules(model)) == 0:
  309. self._logger.info(
  310. "PreciseBN is disabled because model does not contain BN layers in training mode."
  311. )
  312. self._disabled = True
  313. return
  314. self._model = model
  315. self._data_loader = data_loader
  316. self._num_iter = num_iter
  317. self._period = period
  318. self._disabled = False
  319. self._data_iter = None
  320. def after_step(self):
  321. next_iter = self.trainer.iter + 1
  322. is_final = next_iter == self.trainer.max_iter
  323. if is_final or (self._period > 0 and next_iter % self._period == 0):
  324. self.update_stats()
  325. def update_stats(self):
  326. """
  327. Update the model with precise statistics. Users can manually call this method.
  328. """
  329. if self._disabled:
  330. return
  331. if self._data_iter is None:
  332. self._data_iter = iter(self._data_loader)
  333. num_iter = 0
  334. def data_loader():
  335. nonlocal num_iter
  336. while True:
  337. num_iter += 1
  338. if num_iter % 100 == 0:
  339. self._logger.info(
  340. "Running precise-BN ... {}/{} iterations.".format(num_iter, self._num_iter)
  341. )
  342. # This way we can reuse the same iterator
  343. yield next(self._data_iter)
  344. with EventStorage(): # capture events in a new storage to discard them
  345. self._logger.info(
  346. "Running precise-BN for {} iterations... ".format(self._num_iter)
  347. + "Note that this could produce different statistics every time."
  348. )
  349. update_bn_stats(self._model, data_loader(), self._num_iter)

No Description