You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

checkloss_hook.py 681 B

2 years ago
123456789101112131415161718192021222324
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import torch
  3. from mmcv.runner.hooks import HOOKS, Hook
  4. @HOOKS.register_module()
  5. class CheckInvalidLossHook(Hook):
  6. """Check invalid loss hook.
  7. This hook will regularly check whether the loss is valid
  8. during training.
  9. Args:
  10. interval (int): Checking interval (every k iterations).
  11. Default: 50.
  12. """
  13. def __init__(self, interval=50):
  14. self.interval = interval
  15. def after_train_iter(self, runner):
  16. if self.every_n_iters(runner, self.interval):
  17. assert torch.isfinite(runner.outputs['loss']), \
  18. runner.logger.info('loss become infinite or NaN!')

No Description

Contributors (2)