You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

train.py 7.8 kB

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import random
  3. import warnings
  4. import numpy as np
  5. import torch
  6. import torch.distributed as dist
  7. from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
  8. from mmcv.runner import (HOOKS, DistSamplerSeedHook, EpochBasedRunner,
  9. Fp16OptimizerHook, OptimizerHook, build_optimizer,
  10. build_runner, get_dist_info)
  11. from mmcv.utils import build_from_cfg
  12. from mmdet.core import DistEvalHook, EvalHook
  13. from mmdet.datasets import (build_dataloader, build_dataset,
  14. replace_ImageToTensor)
  15. from mmdet.utils import get_root_logger
  16. def init_random_seed(seed=None, device='cuda'):
  17. """Initialize random seed.
  18. If the seed is not set, the seed will be automatically randomized,
  19. and then broadcast to all processes to prevent some potential bugs.
  20. Args:
  21. seed (int, Optional): The seed. Default to None.
  22. device (str): The device where the seed will be put on.
  23. Default to 'cuda'.
  24. Returns:
  25. int: Seed to be used.
  26. """
  27. if seed is not None:
  28. return seed
  29. # Make sure all ranks share the same random seed to prevent
  30. # some potential bugs. Please refer to
  31. # https://github.com/open-mmlab/mmdetection/issues/6339
  32. rank, world_size = get_dist_info()
  33. seed = np.random.randint(2**31)
  34. if world_size == 1:
  35. return seed
  36. if rank == 0:
  37. random_num = torch.tensor(seed, dtype=torch.int32, device=device)
  38. else:
  39. random_num = torch.tensor(0, dtype=torch.int32, device=device)
  40. dist.broadcast(random_num, src=0)
  41. return random_num.item()
  42. def set_random_seed(seed, deterministic=False):
  43. """Set random seed.
  44. Args:
  45. seed (int): Seed to be used.
  46. deterministic (bool): Whether to set the deterministic option for
  47. CUDNN backend, i.e., set `torch.backends.cudnn.deterministic`
  48. to True and `torch.backends.cudnn.benchmark` to False.
  49. Default: False.
  50. """
  51. random.seed(seed)
  52. np.random.seed(seed)
  53. torch.manual_seed(seed)
  54. torch.cuda.manual_seed_all(seed)
  55. if deterministic:
  56. torch.backends.cudnn.deterministic = True
  57. torch.backends.cudnn.benchmark = False
  58. def train_detector(model,
  59. dataset,
  60. cfg,
  61. distributed=False,
  62. validate=False,
  63. timestamp=None,
  64. meta=None):
  65. logger = get_root_logger(log_level=cfg.log_level)
  66. # prepare data loaders
  67. dataset = dataset if isinstance(dataset, (list, tuple)) else [dataset]
  68. if 'imgs_per_gpu' in cfg.data:
  69. logger.warning('"imgs_per_gpu" is deprecated in MMDet V2.0. '
  70. 'Please use "samples_per_gpu" instead')
  71. if 'samples_per_gpu' in cfg.data:
  72. logger.warning(
  73. f'Got "imgs_per_gpu"={cfg.data.imgs_per_gpu} and '
  74. f'"samples_per_gpu"={cfg.data.samples_per_gpu}, "imgs_per_gpu"'
  75. f'={cfg.data.imgs_per_gpu} is used in this experiments')
  76. else:
  77. logger.warning(
  78. 'Automatically set "samples_per_gpu"="imgs_per_gpu"='
  79. f'{cfg.data.imgs_per_gpu} in this experiments')
  80. cfg.data.samples_per_gpu = cfg.data.imgs_per_gpu
  81. runner_type = 'EpochBasedRunner' if 'runner' not in cfg else cfg.runner[
  82. 'type']
  83. data_loaders = [
  84. build_dataloader(
  85. ds,
  86. cfg.data.samples_per_gpu,
  87. cfg.data.workers_per_gpu,
  88. # `num_gpus` will be ignored if distributed
  89. num_gpus=len(cfg.gpu_ids),
  90. dist=distributed,
  91. seed=cfg.seed,
  92. runner_type=runner_type) for ds in dataset
  93. ]
  94. # put model on gpus
  95. if distributed:
  96. find_unused_parameters = cfg.get('find_unused_parameters', False)
  97. # Sets the `find_unused_parameters` parameter in
  98. # torch.nn.parallel.DistributedDataParallel
  99. model = MMDistributedDataParallel(
  100. model.cuda(),
  101. device_ids=[torch.cuda.current_device()],
  102. broadcast_buffers=False,
  103. find_unused_parameters=find_unused_parameters)
  104. else:
  105. model = MMDataParallel(
  106. model.cuda(cfg.gpu_ids[0]), device_ids=cfg.gpu_ids)
  107. # build runner
  108. optimizer = build_optimizer(model, cfg.optimizer)
  109. if 'runner' not in cfg:
  110. cfg.runner = {
  111. 'type': 'EpochBasedRunner',
  112. 'max_epochs': cfg.total_epochs
  113. }
  114. warnings.warn(
  115. 'config is now expected to have a `runner` section, '
  116. 'please set `runner` in your config.', UserWarning)
  117. else:
  118. if 'total_epochs' in cfg:
  119. assert cfg.total_epochs == cfg.runner.max_epochs
  120. runner = build_runner(
  121. cfg.runner,
  122. default_args=dict(
  123. model=model,
  124. optimizer=optimizer,
  125. work_dir=cfg.work_dir,
  126. logger=logger,
  127. meta=meta))
  128. # an ugly workaround to make .log and .log.json filenames the same
  129. runner.timestamp = timestamp
  130. # fp16 setting
  131. fp16_cfg = cfg.get('fp16', None)
  132. if fp16_cfg is not None:
  133. optimizer_config = Fp16OptimizerHook(
  134. **cfg.optimizer_config, **fp16_cfg, distributed=distributed)
  135. elif distributed and 'type' not in cfg.optimizer_config:
  136. optimizer_config = OptimizerHook(**cfg.optimizer_config)
  137. else:
  138. optimizer_config = cfg.optimizer_config
  139. # register hooks
  140. runner.register_training_hooks(cfg.lr_config, optimizer_config,
  141. cfg.checkpoint_config, cfg.log_config,
  142. cfg.get('momentum_config', None))
  143. if distributed:
  144. if isinstance(runner, EpochBasedRunner):
  145. runner.register_hook(DistSamplerSeedHook())
  146. # register eval hooks
  147. if validate:
  148. # Support batch_size > 1 in validation
  149. val_samples_per_gpu = cfg.data.val.pop('samples_per_gpu', 1)
  150. if val_samples_per_gpu > 1:
  151. # Replace 'ImageToTensor' to 'DefaultFormatBundle'
  152. cfg.data.val.pipeline = replace_ImageToTensor(
  153. cfg.data.val.pipeline)
  154. val_dataset = build_dataset(cfg.data.val, dict(test_mode=True))
  155. val_dataloader = build_dataloader(
  156. val_dataset,
  157. samples_per_gpu=val_samples_per_gpu,
  158. workers_per_gpu=cfg.data.workers_per_gpu,
  159. dist=distributed,
  160. shuffle=False)
  161. eval_cfg = cfg.get('evaluation', {})
  162. eval_cfg['by_epoch'] = cfg.runner['type'] != 'IterBasedRunner'
  163. eval_hook = DistEvalHook if distributed else EvalHook
  164. # In this PR (https://github.com/open-mmlab/mmcv/pull/1193), the
  165. # priority of IterTimerHook has been modified from 'NORMAL' to 'LOW'.
  166. runner.register_hook(
  167. eval_hook(val_dataloader, **eval_cfg), priority='LOW')
  168. # user-defined hooks
  169. if cfg.get('custom_hooks', None):
  170. custom_hooks = cfg.custom_hooks
  171. assert isinstance(custom_hooks, list), \
  172. f'custom_hooks expect list type, but got {type(custom_hooks)}'
  173. for hook_cfg in cfg.custom_hooks:
  174. assert isinstance(hook_cfg, dict), \
  175. 'Each item in custom_hooks expects dict type, but got ' \
  176. f'{type(hook_cfg)}'
  177. hook_cfg = hook_cfg.copy()
  178. priority = hook_cfg.pop('priority', 'NORMAL')
  179. hook = build_from_cfg(hook_cfg, HOOKS)
  180. runner.register_hook(hook, priority=priority)
  181. if cfg.resume_from:
  182. runner.resume(cfg.resume_from)
  183. elif cfg.load_from:
  184. runner.load_checkpoint(cfg.load_from)
  185. runner.run(data_loaders, cfg.workflow)

No Description

Contributors (2)