# Copyright (c) OpenMMLab. All rights reserved. import random import warnings import numpy as np import torch import torch.distributed as dist from mmcv.parallel import MMDataParallel, MMDistributedDataParallel from mmcv.runner import (HOOKS, DistSamplerSeedHook, EpochBasedRunner, Fp16OptimizerHook, OptimizerHook, build_optimizer, build_runner, get_dist_info) from mmcv.utils import build_from_cfg from mmdet.core import DistEvalHook, EvalHook from mmdet.datasets import (build_dataloader, build_dataset, replace_ImageToTensor) from mmdet.utils import get_root_logger def init_random_seed(seed=None, device='cuda'): """Initialize random seed. If the seed is not set, the seed will be automatically randomized, and then broadcast to all processes to prevent some potential bugs. Args: seed (int, Optional): The seed. Default to None. device (str): The device where the seed will be put on. Default to 'cuda'. Returns: int: Seed to be used. """ if seed is not None: return seed # Make sure all ranks share the same random seed to prevent # some potential bugs. Please refer to # https://github.com/open-mmlab/mmdetection/issues/6339 rank, world_size = get_dist_info() seed = np.random.randint(2**31) if world_size == 1: return seed if rank == 0: random_num = torch.tensor(seed, dtype=torch.int32, device=device) else: random_num = torch.tensor(0, dtype=torch.int32, device=device) dist.broadcast(random_num, src=0) return random_num.item() def set_random_seed(seed, deterministic=False): """Set random seed. Args: seed (int): Seed to be used. deterministic (bool): Whether to set the deterministic option for CUDNN backend, i.e., set `torch.backends.cudnn.deterministic` to True and `torch.backends.cudnn.benchmark` to False. Default: False. """ random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) if deterministic: torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False def train_detector(model, dataset, cfg, distributed=False, validate=False, timestamp=None, meta=None): logger = get_root_logger(log_level=cfg.log_level) # prepare data loaders dataset = dataset if isinstance(dataset, (list, tuple)) else [dataset] if 'imgs_per_gpu' in cfg.data: logger.warning('"imgs_per_gpu" is deprecated in MMDet V2.0. ' 'Please use "samples_per_gpu" instead') if 'samples_per_gpu' in cfg.data: logger.warning( f'Got "imgs_per_gpu"={cfg.data.imgs_per_gpu} and ' f'"samples_per_gpu"={cfg.data.samples_per_gpu}, "imgs_per_gpu"' f'={cfg.data.imgs_per_gpu} is used in this experiments') else: logger.warning( 'Automatically set "samples_per_gpu"="imgs_per_gpu"=' f'{cfg.data.imgs_per_gpu} in this experiments') cfg.data.samples_per_gpu = cfg.data.imgs_per_gpu runner_type = 'EpochBasedRunner' if 'runner' not in cfg else cfg.runner[ 'type'] data_loaders = [ build_dataloader( ds, cfg.data.samples_per_gpu, cfg.data.workers_per_gpu, # `num_gpus` will be ignored if distributed num_gpus=len(cfg.gpu_ids), dist=distributed, seed=cfg.seed, runner_type=runner_type) for ds in dataset ] # put model on gpus if distributed: find_unused_parameters = cfg.get('find_unused_parameters', False) # Sets the `find_unused_parameters` parameter in # torch.nn.parallel.DistributedDataParallel model = MMDistributedDataParallel( model.cuda(), device_ids=[torch.cuda.current_device()], broadcast_buffers=False, find_unused_parameters=find_unused_parameters) else: model = MMDataParallel( model.cuda(cfg.gpu_ids[0]), device_ids=cfg.gpu_ids) # build runner optimizer = build_optimizer(model, cfg.optimizer) if 'runner' not in cfg: cfg.runner = { 'type': 'EpochBasedRunner', 'max_epochs': cfg.total_epochs } warnings.warn( 'config is now expected to have a `runner` section, ' 'please set `runner` in your config.', UserWarning) else: if 'total_epochs' in cfg: assert cfg.total_epochs == cfg.runner.max_epochs runner = build_runner( cfg.runner, default_args=dict( model=model, optimizer=optimizer, work_dir=cfg.work_dir, logger=logger, meta=meta)) # an ugly workaround to make .log and .log.json filenames the same runner.timestamp = timestamp # fp16 setting fp16_cfg = cfg.get('fp16', None) if fp16_cfg is not None: optimizer_config = Fp16OptimizerHook( **cfg.optimizer_config, **fp16_cfg, distributed=distributed) elif distributed and 'type' not in cfg.optimizer_config: optimizer_config = OptimizerHook(**cfg.optimizer_config) else: optimizer_config = cfg.optimizer_config # register hooks runner.register_training_hooks(cfg.lr_config, optimizer_config, cfg.checkpoint_config, cfg.log_config, cfg.get('momentum_config', None)) if distributed: if isinstance(runner, EpochBasedRunner): runner.register_hook(DistSamplerSeedHook()) # register eval hooks if validate: # Support batch_size > 1 in validation val_samples_per_gpu = cfg.data.val.pop('samples_per_gpu', 1) if val_samples_per_gpu > 1: # Replace 'ImageToTensor' to 'DefaultFormatBundle' cfg.data.val.pipeline = replace_ImageToTensor( cfg.data.val.pipeline) val_dataset = build_dataset(cfg.data.val, dict(test_mode=True)) val_dataloader = build_dataloader( val_dataset, samples_per_gpu=val_samples_per_gpu, workers_per_gpu=cfg.data.workers_per_gpu, dist=distributed, shuffle=False) eval_cfg = cfg.get('evaluation', {}) eval_cfg['by_epoch'] = cfg.runner['type'] != 'IterBasedRunner' eval_hook = DistEvalHook if distributed else EvalHook # In this PR (https://github.com/open-mmlab/mmcv/pull/1193), the # priority of IterTimerHook has been modified from 'NORMAL' to 'LOW'. runner.register_hook( eval_hook(val_dataloader, **eval_cfg), priority='LOW') # user-defined hooks if cfg.get('custom_hooks', None): custom_hooks = cfg.custom_hooks assert isinstance(custom_hooks, list), \ f'custom_hooks expect list type, but got {type(custom_hooks)}' for hook_cfg in cfg.custom_hooks: assert isinstance(hook_cfg, dict), \ 'Each item in custom_hooks expects dict type, but got ' \ f'{type(hook_cfg)}' hook_cfg = hook_cfg.copy() priority = hook_cfg.pop('priority', 'NORMAL') hook = build_from_cfg(hook_cfg, HOOKS) runner.register_hook(hook, priority=priority) if cfg.resume_from: runner.resume(cfg.resume_from) elif cfg.load_from: runner.load_checkpoint(cfg.load_from) runner.run(data_loaders, cfg.workflow)