You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_eval_hook.py 8.6 kB

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import os.path as osp
  3. import tempfile
  4. import unittest.mock as mock
  5. from collections import OrderedDict
  6. from unittest.mock import MagicMock, patch
  7. import pytest
  8. import torch
  9. import torch.nn as nn
  10. from mmcv.runner import EpochBasedRunner, build_optimizer
  11. from mmcv.utils import get_logger
  12. from torch.utils.data import DataLoader, Dataset
  13. from mmdet.core import DistEvalHook, EvalHook
  14. class ExampleDataset(Dataset):
  15. def __init__(self):
  16. self.index = 0
  17. self.eval_result = [0.1, 0.4, 0.3, 0.7, 0.2, 0.05, 0.4, 0.6]
  18. def __getitem__(self, idx):
  19. results = dict(imgs=torch.tensor([1]))
  20. return results
  21. def __len__(self):
  22. return 1
  23. @mock.create_autospec
  24. def evaluate(self, results, logger=None):
  25. pass
  26. class EvalDataset(ExampleDataset):
  27. def evaluate(self, results, logger=None):
  28. mean_ap = self.eval_result[self.index]
  29. output = OrderedDict(mAP=mean_ap, index=self.index, score=mean_ap)
  30. self.index += 1
  31. return output
  32. class ExampleModel(nn.Module):
  33. def __init__(self):
  34. super().__init__()
  35. self.conv = nn.Linear(1, 1)
  36. self.test_cfg = None
  37. def forward(self, imgs, rescale=False, return_loss=False):
  38. return imgs
  39. def train_step(self, data_batch, optimizer, **kwargs):
  40. outputs = {
  41. 'loss': 0.5,
  42. 'log_vars': {
  43. 'accuracy': 0.98
  44. },
  45. 'num_samples': 1
  46. }
  47. return outputs
  48. @pytest.mark.skipif(
  49. not torch.cuda.is_available(), reason='requires CUDA support')
  50. @patch('mmdet.apis.single_gpu_test', MagicMock)
  51. @patch('mmdet.apis.multi_gpu_test', MagicMock)
  52. @pytest.mark.parametrize('EvalHookCls', (EvalHook, DistEvalHook))
  53. def test_eval_hook(EvalHookCls):
  54. with pytest.raises(TypeError):
  55. # dataloader must be a pytorch DataLoader
  56. test_dataset = ExampleDataset()
  57. data_loader = [
  58. DataLoader(
  59. test_dataset,
  60. batch_size=1,
  61. sampler=None,
  62. num_worker=0,
  63. shuffle=False)
  64. ]
  65. EvalHookCls(data_loader)
  66. with pytest.raises(KeyError):
  67. # rule must be in keys of rule_map
  68. test_dataset = ExampleDataset()
  69. data_loader = DataLoader(
  70. test_dataset,
  71. batch_size=1,
  72. sampler=None,
  73. num_workers=0,
  74. shuffle=False)
  75. EvalHookCls(data_loader, save_best='auto', rule='unsupport')
  76. with pytest.raises(ValueError):
  77. # key_indicator must be valid when rule_map is None
  78. test_dataset = ExampleDataset()
  79. data_loader = DataLoader(
  80. test_dataset,
  81. batch_size=1,
  82. sampler=None,
  83. num_workers=0,
  84. shuffle=False)
  85. EvalHookCls(data_loader, save_best='unsupport')
  86. optimizer_cfg = dict(
  87. type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001)
  88. test_dataset = ExampleDataset()
  89. loader = DataLoader(test_dataset, batch_size=1)
  90. model = ExampleModel()
  91. optimizer = build_optimizer(model, optimizer_cfg)
  92. data_loader = DataLoader(test_dataset, batch_size=1)
  93. eval_hook = EvalHookCls(data_loader, save_best=None)
  94. with tempfile.TemporaryDirectory() as tmpdir:
  95. logger = get_logger('test_eval')
  96. runner = EpochBasedRunner(
  97. model=model,
  98. batch_processor=None,
  99. optimizer=optimizer,
  100. work_dir=tmpdir,
  101. logger=logger)
  102. runner.register_hook(eval_hook)
  103. runner.run([loader], [('train', 1)], 1)
  104. assert runner.meta is None or 'best_score' not in runner.meta[
  105. 'hook_msgs']
  106. assert runner.meta is None or 'best_ckpt' not in runner.meta[
  107. 'hook_msgs']
  108. # when `save_best` is set to 'auto', first metric will be used.
  109. loader = DataLoader(EvalDataset(), batch_size=1)
  110. model = ExampleModel()
  111. data_loader = DataLoader(EvalDataset(), batch_size=1)
  112. eval_hook = EvalHookCls(data_loader, interval=1, save_best='auto')
  113. with tempfile.TemporaryDirectory() as tmpdir:
  114. logger = get_logger('test_eval')
  115. runner = EpochBasedRunner(
  116. model=model,
  117. batch_processor=None,
  118. optimizer=optimizer,
  119. work_dir=tmpdir,
  120. logger=logger)
  121. runner.register_checkpoint_hook(dict(interval=1))
  122. runner.register_hook(eval_hook)
  123. runner.run([loader], [('train', 1)], 8)
  124. real_path = osp.join(tmpdir, 'best_mAP_epoch_4.pth')
  125. assert runner.meta['hook_msgs']['best_ckpt'] == osp.realpath(real_path)
  126. assert runner.meta['hook_msgs']['best_score'] == 0.7
  127. loader = DataLoader(EvalDataset(), batch_size=1)
  128. model = ExampleModel()
  129. data_loader = DataLoader(EvalDataset(), batch_size=1)
  130. eval_hook = EvalHookCls(data_loader, interval=1, save_best='mAP')
  131. with tempfile.TemporaryDirectory() as tmpdir:
  132. logger = get_logger('test_eval')
  133. runner = EpochBasedRunner(
  134. model=model,
  135. batch_processor=None,
  136. optimizer=optimizer,
  137. work_dir=tmpdir,
  138. logger=logger)
  139. runner.register_checkpoint_hook(dict(interval=1))
  140. runner.register_hook(eval_hook)
  141. runner.run([loader], [('train', 1)], 8)
  142. real_path = osp.join(tmpdir, 'best_mAP_epoch_4.pth')
  143. assert runner.meta['hook_msgs']['best_ckpt'] == osp.realpath(real_path)
  144. assert runner.meta['hook_msgs']['best_score'] == 0.7
  145. data_loader = DataLoader(EvalDataset(), batch_size=1)
  146. eval_hook = EvalHookCls(
  147. data_loader, interval=1, save_best='score', rule='greater')
  148. with tempfile.TemporaryDirectory() as tmpdir:
  149. logger = get_logger('test_eval')
  150. runner = EpochBasedRunner(
  151. model=model,
  152. batch_processor=None,
  153. optimizer=optimizer,
  154. work_dir=tmpdir,
  155. logger=logger)
  156. runner.register_checkpoint_hook(dict(interval=1))
  157. runner.register_hook(eval_hook)
  158. runner.run([loader], [('train', 1)], 8)
  159. real_path = osp.join(tmpdir, 'best_score_epoch_4.pth')
  160. assert runner.meta['hook_msgs']['best_ckpt'] == osp.realpath(real_path)
  161. assert runner.meta['hook_msgs']['best_score'] == 0.7
  162. data_loader = DataLoader(EvalDataset(), batch_size=1)
  163. eval_hook = EvalHookCls(data_loader, save_best='mAP', rule='less')
  164. with tempfile.TemporaryDirectory() as tmpdir:
  165. logger = get_logger('test_eval')
  166. runner = EpochBasedRunner(
  167. model=model,
  168. batch_processor=None,
  169. optimizer=optimizer,
  170. work_dir=tmpdir,
  171. logger=logger)
  172. runner.register_checkpoint_hook(dict(interval=1))
  173. runner.register_hook(eval_hook)
  174. runner.run([loader], [('train', 1)], 8)
  175. real_path = osp.join(tmpdir, 'best_mAP_epoch_6.pth')
  176. assert runner.meta['hook_msgs']['best_ckpt'] == osp.realpath(real_path)
  177. assert runner.meta['hook_msgs']['best_score'] == 0.05
  178. data_loader = DataLoader(EvalDataset(), batch_size=1)
  179. eval_hook = EvalHookCls(data_loader, save_best='mAP')
  180. with tempfile.TemporaryDirectory() as tmpdir:
  181. logger = get_logger('test_eval')
  182. runner = EpochBasedRunner(
  183. model=model,
  184. batch_processor=None,
  185. optimizer=optimizer,
  186. work_dir=tmpdir,
  187. logger=logger)
  188. runner.register_checkpoint_hook(dict(interval=1))
  189. runner.register_hook(eval_hook)
  190. runner.run([loader], [('train', 1)], 2)
  191. real_path = osp.join(tmpdir, 'best_mAP_epoch_2.pth')
  192. assert runner.meta['hook_msgs']['best_ckpt'] == osp.realpath(real_path)
  193. assert runner.meta['hook_msgs']['best_score'] == 0.4
  194. resume_from = osp.join(tmpdir, 'latest.pth')
  195. loader = DataLoader(ExampleDataset(), batch_size=1)
  196. eval_hook = EvalHookCls(data_loader, save_best='mAP')
  197. runner = EpochBasedRunner(
  198. model=model,
  199. batch_processor=None,
  200. optimizer=optimizer,
  201. work_dir=tmpdir,
  202. logger=logger)
  203. runner.register_checkpoint_hook(dict(interval=1))
  204. runner.register_hook(eval_hook)
  205. runner.resume(resume_from)
  206. runner.run([loader], [('train', 1)], 8)
  207. real_path = osp.join(tmpdir, 'best_mAP_epoch_4.pth')
  208. assert runner.meta['hook_msgs']['best_ckpt'] == osp.realpath(real_path)
  209. assert runner.meta['hook_msgs']['best_score'] == 0.7

No Description

Contributors (1)