You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

launch.py 3.1 kB

3 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384
  1. # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
  2. import logging
  3. import torch
  4. import torch.distributed as dist
  5. import torch.multiprocessing as mp
  6. from detectron2.utils import comm
  7. __all__ = ["launch"]
  8. def _find_free_port():
  9. import socket
  10. sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
  11. # Binding to port 0 will cause the OS to find an available port for us
  12. sock.bind(("", 0))
  13. port = sock.getsockname()[1]
  14. sock.close()
  15. # NOTE: there is still a chance the port could be taken by other processes.
  16. return port
  17. def launch(main_func, num_gpus_per_machine, num_machines=1, machine_rank=0, dist_url=None, args=()):
  18. """
  19. Args:
  20. main_func: a function that will be called by `main_func(*args)`
  21. num_machines (int): the total number of machines
  22. machine_rank (int): the rank of this machine (one per machine)
  23. dist_url (str): url to connect to for distributed training, including protocol
  24. e.g. "tcp://127.0.0.1:8686".
  25. Can be set to auto to automatically select a free port on localhost
  26. args (tuple): arguments passed to main_func
  27. """
  28. world_size = num_machines * num_gpus_per_machine
  29. if world_size > 1:
  30. # https://github.com/pytorch/pytorch/pull/14391
  31. # TODO prctl in spawned processes
  32. if dist_url == "auto":
  33. assert num_machines == 1, "dist_url=auto cannot work with distributed training."
  34. port = _find_free_port()
  35. dist_url = f"tcp://127.0.0.1:{port}"
  36. mp.spawn(
  37. _distributed_worker,
  38. nprocs=num_gpus_per_machine,
  39. args=(main_func, world_size, num_gpus_per_machine, machine_rank, dist_url, args),
  40. daemon=False,
  41. )
  42. else:
  43. main_func(*args)
  44. def _distributed_worker(
  45. local_rank, main_func, world_size, num_gpus_per_machine, machine_rank, dist_url, args
  46. ):
  47. assert torch.cuda.is_available(), "cuda is not available. Please check your installation."
  48. global_rank = machine_rank * num_gpus_per_machine + local_rank
  49. try:
  50. dist.init_process_group(
  51. backend="NCCL", init_method=dist_url, world_size=world_size, rank=global_rank
  52. )
  53. except Exception as e:
  54. logger = logging.getLogger(__name__)
  55. logger.error("Process group URL: {}".format(dist_url))
  56. raise e
  57. # synchronize is needed here to prevent a possible timeout after calling init_process_group
  58. # See: https://github.com/facebookresearch/maskrcnn-benchmark/issues/172
  59. comm.synchronize()
  60. assert num_gpus_per_machine <= torch.cuda.device_count()
  61. torch.cuda.set_device(local_rank)
  62. # Setup the local process group (which contains ranks within the same machine)
  63. assert comm._LOCAL_PROCESS_GROUP is None
  64. num_machines = world_size // num_gpus_per_machine
  65. for i in range(num_machines):
  66. ranks_on_i = list(range(i * num_gpus_per_machine, (i + 1) * num_gpus_per_machine))
  67. pg = dist.new_group(ranks_on_i)
  68. if i == machine_rank:
  69. comm._LOCAL_PROCESS_GROUP = pg
  70. main_func(*args)

No Description