You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

train_loop.py 8.9 kB

3 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270
  1. # -*- coding: utf-8 -*-
  2. # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
  3. import logging
  4. import numpy as np
  5. import time
  6. import weakref
  7. import torch
  8. import detectron2.utils.comm as comm
  9. from detectron2.utils.events import EventStorage
  10. __all__ = ["HookBase", "TrainerBase", "SimpleTrainer"]
  11. class HookBase:
  12. """
  13. Base class for hooks that can be registered with :class:`TrainerBase`.
  14. Each hook can implement 4 methods. The way they are called is demonstrated
  15. in the following snippet:
  16. .. code-block:: python
  17. hook.before_train()
  18. for iter in range(start_iter, max_iter):
  19. hook.before_step()
  20. trainer.run_step()
  21. hook.after_step()
  22. hook.after_train()
  23. Notes:
  24. 1. In the hook method, users can access `self.trainer` to access more
  25. properties about the context (e.g., current iteration).
  26. 2. A hook that does something in :meth:`before_step` can often be
  27. implemented equivalently in :meth:`after_step`.
  28. If the hook takes non-trivial time, it is strongly recommended to
  29. implement the hook in :meth:`after_step` instead of :meth:`before_step`.
  30. The convention is that :meth:`before_step` should only take negligible time.
  31. Following this convention will allow hooks that do care about the difference
  32. between :meth:`before_step` and :meth:`after_step` (e.g., timer) to
  33. function properly.
  34. Attributes:
  35. trainer: A weak reference to the trainer object. Set by the trainer when the hook is
  36. registered.
  37. """
  38. def before_train(self):
  39. """
  40. Called before the first iteration.
  41. """
  42. pass
  43. def after_train(self):
  44. """
  45. Called after the last iteration.
  46. """
  47. pass
  48. def before_step(self):
  49. """
  50. Called before each iteration.
  51. """
  52. pass
  53. def after_step(self):
  54. """
  55. Called after each iteration.
  56. """
  57. pass
  58. class TrainerBase:
  59. """
  60. Base class for iterative trainer with hooks.
  61. The only assumption we made here is: the training runs in a loop.
  62. A subclass can implement what the loop is.
  63. We made no assumptions about the existence of dataloader, optimizer, model, etc.
  64. Attributes:
  65. iter(int): the current iteration.
  66. start_iter(int): The iteration to start with.
  67. By convention the minimum possible value is 0.
  68. max_iter(int): The iteration to end training.
  69. storage(EventStorage): An EventStorage that's opened during the course of training.
  70. """
  71. def __init__(self):
  72. self._hooks = []
  73. def register_hooks(self, hooks):
  74. """
  75. Register hooks to the trainer. The hooks are executed in the order
  76. they are registered.
  77. Args:
  78. hooks (list[Optional[HookBase]]): list of hooks
  79. """
  80. hooks = [h for h in hooks if h is not None]
  81. for h in hooks:
  82. assert isinstance(h, HookBase)
  83. # To avoid circular reference, hooks and trainer cannot own each other.
  84. # This normally does not matter, but will cause memory leak if the
  85. # involved objects contain __del__:
  86. # See http://engineering.hearsaysocial.com/2013/06/16/circular-references-in-python/
  87. h.trainer = weakref.proxy(self)
  88. self._hooks.extend(hooks)
  89. def train(self, start_iter: int, max_iter: int):
  90. """
  91. Args:
  92. start_iter, max_iter (int): See docs above
  93. """
  94. logger = logging.getLogger(__name__)
  95. logger.info("Starting training from iteration {}".format(start_iter))
  96. self.iter = self.start_iter = start_iter
  97. self.max_iter = max_iter
  98. with EventStorage(start_iter) as self.storage:
  99. try:
  100. self.before_train()
  101. for self.iter in range(start_iter, max_iter):
  102. self.before_step()
  103. self.run_step()
  104. self.after_step()
  105. finally:
  106. self.after_train()
  107. def before_train(self):
  108. for h in self._hooks:
  109. h.before_train()
  110. def after_train(self):
  111. for h in self._hooks:
  112. h.after_train()
  113. def before_step(self):
  114. for h in self._hooks:
  115. h.before_step()
  116. def after_step(self):
  117. for h in self._hooks:
  118. h.after_step()
  119. # this guarantees, that in each hook's after_step, storage.iter == trainer.iter
  120. self.storage.step()
  121. def run_step(self):
  122. raise NotImplementedError
  123. class SimpleTrainer(TrainerBase):
  124. """
  125. A simple trainer for the most common type of task:
  126. single-cost single-optimizer single-data-source iterative optimization.
  127. It assumes that every step, you:
  128. 1. Compute the loss with a data from the data_loader.
  129. 2. Compute the gradients with the above loss.
  130. 3. Update the model with the optimizer.
  131. If you want to do anything fancier than this,
  132. either subclass TrainerBase and implement your own `run_step`,
  133. or write your own training loop.
  134. """
  135. def __init__(self, model, data_loader, optimizer):
  136. """
  137. Args:
  138. model: a torch Module. Takes a data from data_loader and returns a
  139. dict of losses.
  140. data_loader: an iterable. Contains data to be used to call model.
  141. optimizer: a torch optimizer.
  142. """
  143. super().__init__()
  144. """
  145. We set the model to training mode in the trainer.
  146. However it's valid to train a model that's in eval mode.
  147. If you want your model (or a submodule of it) to behave
  148. like evaluation during training, you can overwrite its train() method.
  149. """
  150. model.train()
  151. self.model = model
  152. self.data_loader = data_loader
  153. self._data_loader_iter = iter(data_loader)
  154. self.optimizer = optimizer
  155. def run_step(self):
  156. """
  157. Implement the standard training logic described above.
  158. """
  159. assert self.model.training, "[SimpleTrainer] model was changed to eval mode!"
  160. start = time.perf_counter()
  161. """
  162. If your want to do something with the data, you can wrap the dataloader.
  163. """
  164. data = next(self._data_loader_iter)
  165. data_time = time.perf_counter() - start
  166. """
  167. If your want to do something with the losses, you can wrap the model.
  168. """
  169. loss_dict = self.model(data)
  170. losses = sum(loss for loss in loss_dict.values())
  171. self._detect_anomaly(losses, loss_dict)
  172. metrics_dict = loss_dict
  173. metrics_dict["data_time"] = data_time
  174. self._write_metrics(metrics_dict)
  175. """
  176. If you need accumulate gradients or something similar, you can
  177. wrap the optimizer with your custom `zero_grad()` method.
  178. """
  179. self.optimizer.zero_grad()
  180. losses.backward()
  181. """
  182. If you need gradient clipping/scaling or other processing, you can
  183. wrap the optimizer with your custom `step()` method.
  184. """
  185. self.optimizer.step()
  186. def _detect_anomaly(self, losses, loss_dict):
  187. if not torch.isfinite(losses).all():
  188. raise FloatingPointError(
  189. "Loss became infinite or NaN at iteration={}!\nloss_dict = {}".format(
  190. self.iter, loss_dict
  191. )
  192. )
  193. def _write_metrics(self, metrics_dict: dict):
  194. """
  195. Args:
  196. metrics_dict (dict): dict of scalar metrics
  197. """
  198. metrics_dict = {
  199. k: v.detach().cpu().item() if isinstance(v, torch.Tensor) else float(v)
  200. for k, v in metrics_dict.items()
  201. }
  202. # gather metrics among all workers for logging
  203. # This assumes we do DDP-style training, which is currently the only
  204. # supported method in detectron2.
  205. all_metrics_dict = comm.gather(metrics_dict)
  206. if comm.is_main_process():
  207. if "data_time" in all_metrics_dict[0]:
  208. # data_time among workers can have high variance. The actual latency
  209. # caused by data_time is the maximum among workers.
  210. data_time = np.max([x.pop("data_time") for x in all_metrics_dict])
  211. self.storage.put_scalar("data_time", data_time)
  212. # average the rest metrics
  213. metrics_dict = {
  214. k: np.mean([x[k] for x in all_metrics_dict]) for k in all_metrics_dict[0].keys()
  215. }
  216. total_losses_reduced = sum(loss for loss in metrics_dict.values())
  217. self.storage.put_scalar("total_loss", total_losses_reduced)
  218. if len(metrics_dict) > 1:
  219. self.storage.put_scalars(**metrics_dict)

No Description