You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_hook.py 10 kB

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import logging
  3. import shutil
  4. import sys
  5. import tempfile
  6. from unittest.mock import MagicMock, call
  7. import numpy as np
  8. import pytest
  9. import torch
  10. import torch.nn as nn
  11. from mmcv.runner import (CheckpointHook, IterTimerHook, PaviLoggerHook,
  12. build_runner)
  13. from torch.nn.init import constant_
  14. from torch.utils.data import DataLoader, Dataset
  15. from mmdet.core.hook import ExpMomentumEMAHook, YOLOXLrUpdaterHook
  16. from mmdet.core.hook.sync_norm_hook import SyncNormHook
  17. from mmdet.core.hook.sync_random_size_hook import SyncRandomSizeHook
  18. def _build_demo_runner_without_hook(runner_type='EpochBasedRunner',
  19. max_epochs=1,
  20. max_iters=None,
  21. multi_optimziers=False):
  22. class Model(nn.Module):
  23. def __init__(self):
  24. super().__init__()
  25. self.linear = nn.Linear(2, 1)
  26. self.conv = nn.Conv2d(3, 3, 3)
  27. def forward(self, x):
  28. return self.linear(x)
  29. def train_step(self, x, optimizer, **kwargs):
  30. return dict(loss=self(x))
  31. def val_step(self, x, optimizer, **kwargs):
  32. return dict(loss=self(x))
  33. model = Model()
  34. if multi_optimziers:
  35. optimizer = {
  36. 'model1':
  37. torch.optim.SGD(model.linear.parameters(), lr=0.02, momentum=0.95),
  38. 'model2':
  39. torch.optim.SGD(model.conv.parameters(), lr=0.01, momentum=0.9),
  40. }
  41. else:
  42. optimizer = torch.optim.SGD(model.parameters(), lr=0.02, momentum=0.95)
  43. tmp_dir = tempfile.mkdtemp()
  44. runner = build_runner(
  45. dict(type=runner_type),
  46. default_args=dict(
  47. model=model,
  48. work_dir=tmp_dir,
  49. optimizer=optimizer,
  50. logger=logging.getLogger(),
  51. max_epochs=max_epochs,
  52. max_iters=max_iters))
  53. return runner
  54. def _build_demo_runner(runner_type='EpochBasedRunner',
  55. max_epochs=1,
  56. max_iters=None,
  57. multi_optimziers=False):
  58. log_config = dict(
  59. interval=1, hooks=[
  60. dict(type='TextLoggerHook'),
  61. ])
  62. runner = _build_demo_runner_without_hook(runner_type, max_epochs,
  63. max_iters, multi_optimziers)
  64. runner.register_checkpoint_hook(dict(interval=1))
  65. runner.register_logger_hooks(log_config)
  66. return runner
  67. @pytest.mark.parametrize('multi_optimziers', (True, False))
  68. def test_yolox_lrupdater_hook(multi_optimziers):
  69. """xdoctest -m tests/test_hooks.py test_cosine_runner_hook."""
  70. # Only used to prevent program errors
  71. YOLOXLrUpdaterHook(0, min_lr_ratio=0.05)
  72. sys.modules['pavi'] = MagicMock()
  73. loader = DataLoader(torch.ones((10, 2)))
  74. runner = _build_demo_runner(multi_optimziers=multi_optimziers)
  75. hook_cfg = dict(
  76. type='YOLOXLrUpdaterHook',
  77. warmup='exp',
  78. by_epoch=False,
  79. warmup_by_epoch=True,
  80. warmup_ratio=1,
  81. warmup_iters=5, # 5 epoch
  82. num_last_epochs=15,
  83. min_lr_ratio=0.05)
  84. runner.register_hook_from_cfg(hook_cfg)
  85. runner.register_hook_from_cfg(dict(type='IterTimerHook'))
  86. runner.register_hook(IterTimerHook())
  87. # add pavi hook
  88. hook = PaviLoggerHook(interval=1, add_graph=False, add_last_ckpt=True)
  89. runner.register_hook(hook)
  90. runner.run([loader], [('train', 1)])
  91. shutil.rmtree(runner.work_dir)
  92. # TODO: use a more elegant way to check values
  93. assert hasattr(hook, 'writer')
  94. if multi_optimziers:
  95. calls = [
  96. call(
  97. 'train', {
  98. 'learning_rate/model1': 8.000000000000001e-06,
  99. 'learning_rate/model2': 4.000000000000001e-06,
  100. 'momentum/model1': 0.95,
  101. 'momentum/model2': 0.9
  102. }, 1),
  103. call(
  104. 'train', {
  105. 'learning_rate/model1': 0.00039200000000000004,
  106. 'learning_rate/model2': 0.00019600000000000002,
  107. 'momentum/model1': 0.95,
  108. 'momentum/model2': 0.9
  109. }, 7),
  110. call(
  111. 'train', {
  112. 'learning_rate/model1': 0.0008000000000000001,
  113. 'learning_rate/model2': 0.0004000000000000001,
  114. 'momentum/model1': 0.95,
  115. 'momentum/model2': 0.9
  116. }, 10)
  117. ]
  118. else:
  119. calls = [
  120. call('train', {
  121. 'learning_rate': 8.000000000000001e-06,
  122. 'momentum': 0.95
  123. }, 1),
  124. call('train', {
  125. 'learning_rate': 0.00039200000000000004,
  126. 'momentum': 0.95
  127. }, 7),
  128. call('train', {
  129. 'learning_rate': 0.0008000000000000001,
  130. 'momentum': 0.95
  131. }, 10)
  132. ]
  133. hook.writer.add_scalars.assert_has_calls(calls, any_order=True)
  134. def test_ema_hook():
  135. """xdoctest -m tests/test_hooks.py test_ema_hook."""
  136. class DemoModel(nn.Module):
  137. def __init__(self):
  138. super().__init__()
  139. self.conv = nn.Conv2d(
  140. in_channels=1,
  141. out_channels=2,
  142. kernel_size=1,
  143. padding=1,
  144. bias=True)
  145. self.bn = nn.BatchNorm2d(2)
  146. self._init_weight()
  147. def _init_weight(self):
  148. constant_(self.conv.weight, 0)
  149. constant_(self.conv.bias, 0)
  150. constant_(self.bn.weight, 0)
  151. constant_(self.bn.bias, 0)
  152. def forward(self, x):
  153. return self.bn(self.conv(x)).sum()
  154. def train_step(self, x, optimizer, **kwargs):
  155. return dict(loss=self(x))
  156. def val_step(self, x, optimizer, **kwargs):
  157. return dict(loss=self(x))
  158. loader = DataLoader(torch.ones((1, 1, 1, 1)))
  159. runner = _build_demo_runner()
  160. demo_model = DemoModel()
  161. runner.model = demo_model
  162. ema_hook = ExpMomentumEMAHook(
  163. momentum=0.0002,
  164. total_iter=1,
  165. skip_buffers=True,
  166. interval=2,
  167. resume_from=None)
  168. checkpointhook = CheckpointHook(interval=1, by_epoch=True)
  169. runner.register_hook(ema_hook, priority='HIGHEST')
  170. runner.register_hook(checkpointhook)
  171. runner.run([loader, loader], [('train', 1), ('val', 1)])
  172. checkpoint = torch.load(f'{runner.work_dir}/epoch_1.pth')
  173. num_eam_params = 0
  174. for name, value in checkpoint['state_dict'].items():
  175. if 'ema' in name:
  176. num_eam_params += 1
  177. value.fill_(1)
  178. assert num_eam_params == 4
  179. torch.save(checkpoint, f'{runner.work_dir}/epoch_1.pth')
  180. work_dir = runner.work_dir
  181. resume_ema_hook = ExpMomentumEMAHook(
  182. momentum=0.5,
  183. total_iter=10,
  184. skip_buffers=True,
  185. interval=1,
  186. resume_from=f'{work_dir}/epoch_1.pth')
  187. runner = _build_demo_runner(max_epochs=2)
  188. runner.model = demo_model
  189. runner.register_hook(resume_ema_hook, priority='HIGHEST')
  190. checkpointhook = CheckpointHook(interval=1, by_epoch=True)
  191. runner.register_hook(checkpointhook)
  192. runner.run([loader, loader], [('train', 1), ('val', 1)])
  193. checkpoint = torch.load(f'{runner.work_dir}/epoch_2.pth')
  194. num_eam_params = 0
  195. desired_output = [0.9094, 0.9094]
  196. for name, value in checkpoint['state_dict'].items():
  197. if 'ema' in name:
  198. num_eam_params += 1
  199. assert value.sum() == 2
  200. else:
  201. if ('weight' in name) or ('bias' in name):
  202. np.allclose(value.data.cpu().numpy().reshape(-1),
  203. desired_output, 1e-4)
  204. assert num_eam_params == 4
  205. shutil.rmtree(runner.work_dir)
  206. shutil.rmtree(work_dir)
  207. def test_sync_norm_hook():
  208. # Only used to prevent program errors
  209. SyncNormHook()
  210. loader = DataLoader(torch.ones((5, 2)))
  211. runner = _build_demo_runner()
  212. runner.register_hook_from_cfg(dict(type='SyncNormHook'))
  213. runner.run([loader, loader], [('train', 1), ('val', 1)])
  214. shutil.rmtree(runner.work_dir)
  215. def test_sync_random_size_hook():
  216. # Only used to prevent program errors
  217. SyncRandomSizeHook()
  218. class DemoDataset(Dataset):
  219. def __getitem__(self, item):
  220. return torch.ones(2)
  221. def __len__(self):
  222. return 5
  223. def update_dynamic_scale(self, dynamic_scale):
  224. pass
  225. loader = DataLoader(DemoDataset())
  226. runner = _build_demo_runner()
  227. runner.register_hook_from_cfg(
  228. dict(type='SyncRandomSizeHook', device='cpu'))
  229. runner.run([loader, loader], [('train', 1), ('val', 1)])
  230. shutil.rmtree(runner.work_dir)
  231. if torch.cuda.is_available():
  232. runner = _build_demo_runner()
  233. runner.register_hook_from_cfg(
  234. dict(type='SyncRandomSizeHook', device='cuda'))
  235. runner.run([loader, loader], [('train', 1), ('val', 1)])
  236. shutil.rmtree(runner.work_dir)
  237. @pytest.mark.parametrize('set_loss', [
  238. dict(set_loss_nan=False, set_loss_inf=False),
  239. dict(set_loss_nan=True, set_loss_inf=False),
  240. dict(set_loss_nan=False, set_loss_inf=True)
  241. ])
  242. def test_check_invalid_loss_hook(set_loss):
  243. # Check whether loss is valid during training.
  244. class DemoModel(nn.Module):
  245. def __init__(self, set_loss_nan=False, set_loss_inf=False):
  246. super().__init__()
  247. self.set_loss_nan = set_loss_nan
  248. self.set_loss_inf = set_loss_inf
  249. self.linear = nn.Linear(2, 1)
  250. def forward(self, x):
  251. return self.linear(x)
  252. def train_step(self, x, optimizer, **kwargs):
  253. if self.set_loss_nan:
  254. return dict(loss=torch.tensor(float('nan')))
  255. elif self.set_loss_inf:
  256. return dict(loss=torch.tensor(float('inf')))
  257. else:
  258. return dict(loss=self(x))
  259. loader = DataLoader(torch.ones((5, 2)))
  260. runner = _build_demo_runner()
  261. demo_model = DemoModel(**set_loss)
  262. runner.model = demo_model
  263. runner.register_hook_from_cfg(
  264. dict(type='CheckInvalidLossHook', interval=1))
  265. if not set_loss['set_loss_nan'] \
  266. and not set_loss['set_loss_inf']:
  267. # check loss is valid
  268. runner.run([loader], [('train', 1)])
  269. else:
  270. # check loss is nan or inf
  271. with pytest.raises(AssertionError):
  272. runner.run([loader], [('train', 1)])
  273. shutil.rmtree(runner.work_dir)

No Description

Contributors (3)