|
- # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
- import logging
- import torch
- import torch.distributed as dist
- import torch.multiprocessing as mp
-
- from detectron2.utils import comm
-
- __all__ = ["launch"]
-
-
- def _find_free_port():
- import socket
-
- sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
- # Binding to port 0 will cause the OS to find an available port for us
- sock.bind(("", 0))
- port = sock.getsockname()[1]
- sock.close()
- # NOTE: there is still a chance the port could be taken by other processes.
- return port
-
-
- def launch(main_func, num_gpus_per_machine, num_machines=1, machine_rank=0, dist_url=None, args=()):
- """
- Args:
- main_func: a function that will be called by `main_func(*args)`
- num_machines (int): the total number of machines
- machine_rank (int): the rank of this machine (one per machine)
- dist_url (str): url to connect to for distributed training, including protocol
- e.g. "tcp://127.0.0.1:8686".
- Can be set to auto to automatically select a free port on localhost
- args (tuple): arguments passed to main_func
- """
- world_size = num_machines * num_gpus_per_machine
- if world_size > 1:
- # https://github.com/pytorch/pytorch/pull/14391
- # TODO prctl in spawned processes
-
- if dist_url == "auto":
- assert num_machines == 1, "dist_url=auto cannot work with distributed training."
- port = _find_free_port()
- dist_url = f"tcp://127.0.0.1:{port}"
-
- mp.spawn(
- _distributed_worker,
- nprocs=num_gpus_per_machine,
- args=(main_func, world_size, num_gpus_per_machine, machine_rank, dist_url, args),
- daemon=False,
- )
- else:
- main_func(*args)
-
-
- def _distributed_worker(
- local_rank, main_func, world_size, num_gpus_per_machine, machine_rank, dist_url, args
- ):
- assert torch.cuda.is_available(), "cuda is not available. Please check your installation."
- global_rank = machine_rank * num_gpus_per_machine + local_rank
- try:
- dist.init_process_group(
- backend="NCCL", init_method=dist_url, world_size=world_size, rank=global_rank
- )
- except Exception as e:
- logger = logging.getLogger(__name__)
- logger.error("Process group URL: {}".format(dist_url))
- raise e
- # synchronize is needed here to prevent a possible timeout after calling init_process_group
- # See: https://github.com/facebookresearch/maskrcnn-benchmark/issues/172
- comm.synchronize()
-
- assert num_gpus_per_machine <= torch.cuda.device_count()
- torch.cuda.set_device(local_rank)
-
- # Setup the local process group (which contains ranks within the same machine)
- assert comm._LOCAL_PROCESS_GROUP is None
- num_machines = world_size // num_gpus_per_machine
- for i in range(num_machines):
- ranks_on_i = list(range(i * num_gpus_per_machine, (i + 1) * num_gpus_per_machine))
- pg = dist.new_group(ranks_on_i)
- if i == machine_rank:
- comm._LOCAL_PROCESS_GROUP = pg
-
- main_func(*args)
|