You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

eval_hooks.py 2.2 kB

2 years ago
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import os.path as osp
  3. import torch.distributed as dist
  4. from mmcv.runner import DistEvalHook as BaseDistEvalHook
  5. from mmcv.runner import EvalHook as BaseEvalHook
  6. from torch.nn.modules.batchnorm import _BatchNorm
  7. class EvalHook(BaseEvalHook):
  8. def _do_evaluate(self, runner):
  9. """perform evaluation and save ckpt."""
  10. if not self._should_evaluate(runner):
  11. return
  12. from mmdet.apis import single_gpu_test
  13. results = single_gpu_test(runner.model, self.dataloader, show=False)
  14. runner.log_buffer.output['eval_iter_num'] = len(self.dataloader)
  15. key_score = self.evaluate(runner, results)
  16. if self.save_best:
  17. self._save_ckpt(runner, key_score)
  18. class DistEvalHook(BaseDistEvalHook):
  19. def _do_evaluate(self, runner):
  20. """perform evaluation and save ckpt."""
  21. # Synchronization of BatchNorm's buffer (running_mean
  22. # and running_var) is not supported in the DDP of pytorch,
  23. # which may cause the inconsistent performance of models in
  24. # different ranks, so we broadcast BatchNorm's buffers
  25. # of rank 0 to other ranks to avoid this.
  26. if self.broadcast_bn_buffer:
  27. model = runner.model
  28. for name, module in model.named_modules():
  29. if isinstance(module,
  30. _BatchNorm) and module.track_running_stats:
  31. dist.broadcast(module.running_var, 0)
  32. dist.broadcast(module.running_mean, 0)
  33. if not self._should_evaluate(runner):
  34. return
  35. tmpdir = self.tmpdir
  36. if tmpdir is None:
  37. tmpdir = osp.join(runner.work_dir, '.eval_hook')
  38. from mmdet.apis import multi_gpu_test
  39. results = multi_gpu_test(
  40. runner.model,
  41. self.dataloader,
  42. tmpdir=tmpdir,
  43. gpu_collect=self.gpu_collect)
  44. if runner.rank == 0:
  45. print('\n')
  46. runner.log_buffer.output['eval_iter_num'] = len(self.dataloader)
  47. key_score = self.evaluate(runner, results)
  48. if self.save_best:
  49. self._save_ckpt(runner, key_score)

No Description

Contributors (2)