You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

ema.py 5.2 kB

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import math
  3. from mmcv.parallel import is_module_wrapper
  4. from mmcv.runner.hooks import HOOKS, Hook
  5. class BaseEMAHook(Hook):
  6. """Exponential Moving Average Hook.
  7. Use Exponential Moving Average on all parameters of model in training
  8. process. All parameters have a ema backup, which update by the formula
  9. as below. EMAHook takes priority over EvalHook and CheckpointHook. Note,
  10. the original model parameters are actually saved in ema field after train.
  11. Args:
  12. momentum (float): The momentum used for updating ema parameter.
  13. Ema's parameter are updated with the formula:
  14. `ema_param = (1-momentum) * ema_param + momentum * cur_param`.
  15. Defaults to 0.0002.
  16. skip_buffers (bool): Whether to skip the model buffers, such as
  17. batchnorm running stats (running_mean, running_var), it does not
  18. perform the ema operation. Default to False.
  19. interval (int): Update ema parameter every interval iteration.
  20. Defaults to 1.
  21. resume_from (str, optional): The checkpoint path. Defaults to None.
  22. momentum_fun (func, optional): The function to change momentum
  23. during early iteration (also warmup) to help early training.
  24. It uses `momentum` as a constant. Defaults to None.
  25. """
  26. def __init__(self,
  27. momentum=0.0002,
  28. interval=1,
  29. skip_buffers=False,
  30. resume_from=None,
  31. momentum_fun=None):
  32. assert 0 < momentum < 1
  33. self.momentum = momentum
  34. self.skip_buffers = skip_buffers
  35. self.interval = interval
  36. self.checkpoint = resume_from
  37. self.momentum_fun = momentum_fun
  38. def before_run(self, runner):
  39. """To resume model with it's ema parameters more friendly.
  40. Register ema parameter as ``named_buffer`` to model.
  41. """
  42. model = runner.model
  43. if is_module_wrapper(model):
  44. model = model.module
  45. self.param_ema_buffer = {}
  46. if self.skip_buffers:
  47. self.model_parameters = dict(model.named_parameters())
  48. else:
  49. self.model_parameters = model.state_dict()
  50. for name, value in self.model_parameters.items():
  51. # "." is not allowed in module's buffer name
  52. buffer_name = f"ema_{name.replace('.', '_')}"
  53. self.param_ema_buffer[name] = buffer_name
  54. model.register_buffer(buffer_name, value.data.clone())
  55. self.model_buffers = dict(model.named_buffers())
  56. if self.checkpoint is not None:
  57. runner.resume(self.checkpoint)
  58. def get_momentum(self, runner):
  59. return self.momentum_fun(runner.iter) if self.momentum_fun else \
  60. self.momentum
  61. def after_train_iter(self, runner):
  62. """Update ema parameter every self.interval iterations."""
  63. if (runner.iter + 1) % self.interval != 0:
  64. return
  65. momentum = self.get_momentum(runner)
  66. for name, parameter in self.model_parameters.items():
  67. # exclude num_tracking
  68. if parameter.dtype.is_floating_point:
  69. buffer_name = self.param_ema_buffer[name]
  70. buffer_parameter = self.model_buffers[buffer_name]
  71. buffer_parameter.mul_(1 - momentum).add_(
  72. parameter.data, alpha=momentum)
  73. def after_train_epoch(self, runner):
  74. """We load parameter values from ema backup to model before the
  75. EvalHook."""
  76. self._swap_ema_parameters()
  77. def before_train_epoch(self, runner):
  78. """We recover model's parameter from ema backup after last epoch's
  79. EvalHook."""
  80. self._swap_ema_parameters()
  81. def _swap_ema_parameters(self):
  82. """Swap the parameter of model with parameter in ema_buffer."""
  83. for name, value in self.model_parameters.items():
  84. temp = value.data.clone()
  85. ema_buffer = self.model_buffers[self.param_ema_buffer[name]]
  86. value.data.copy_(ema_buffer.data)
  87. ema_buffer.data.copy_(temp)
  88. @HOOKS.register_module()
  89. class ExpMomentumEMAHook(BaseEMAHook):
  90. """EMAHook using exponential momentum strategy.
  91. Args:
  92. total_iter (int): The total number of iterations of EMA momentum.
  93. Defaults to 2000.
  94. """
  95. def __init__(self, total_iter=2000, **kwargs):
  96. super(ExpMomentumEMAHook, self).__init__(**kwargs)
  97. self.momentum_fun = lambda x: (1 - self.momentum) * math.exp(-(
  98. 1 + x) / total_iter) + self.momentum
  99. @HOOKS.register_module()
  100. class LinearMomentumEMAHook(BaseEMAHook):
  101. """EMAHook using linear momentum strategy.
  102. Args:
  103. warm_up (int): During first warm_up steps, we may use smaller decay
  104. to update ema parameters more slowly. Defaults to 100.
  105. """
  106. def __init__(self, warm_up=100, **kwargs):
  107. super(LinearMomentumEMAHook, self).__init__(**kwargs)
  108. self.momentum_fun = lambda x: min(self.momentum**self.interval,
  109. (1 + x) / (warm_up + x))

No Description

Contributors (3)