You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

dist_utils.py 5.0 kB

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import functools
  3. import pickle
  4. import warnings
  5. from collections import OrderedDict
  6. import torch
  7. import torch.distributed as dist
  8. from mmcv.runner import OptimizerHook, get_dist_info
  9. from torch._utils import (_flatten_dense_tensors, _take_tensors,
  10. _unflatten_dense_tensors)
  11. def _allreduce_coalesced(tensors, world_size, bucket_size_mb=-1):
  12. if bucket_size_mb > 0:
  13. bucket_size_bytes = bucket_size_mb * 1024 * 1024
  14. buckets = _take_tensors(tensors, bucket_size_bytes)
  15. else:
  16. buckets = OrderedDict()
  17. for tensor in tensors:
  18. tp = tensor.type()
  19. if tp not in buckets:
  20. buckets[tp] = []
  21. buckets[tp].append(tensor)
  22. buckets = buckets.values()
  23. for bucket in buckets:
  24. flat_tensors = _flatten_dense_tensors(bucket)
  25. dist.all_reduce(flat_tensors)
  26. flat_tensors.div_(world_size)
  27. for tensor, synced in zip(
  28. bucket, _unflatten_dense_tensors(flat_tensors, bucket)):
  29. tensor.copy_(synced)
  30. def allreduce_grads(params, coalesce=True, bucket_size_mb=-1):
  31. """Allreduce gradients.
  32. Args:
  33. params (list[torch.Parameters]): List of parameters of a model
  34. coalesce (bool, optional): Whether allreduce parameters as a whole.
  35. Defaults to True.
  36. bucket_size_mb (int, optional): Size of bucket, the unit is MB.
  37. Defaults to -1.
  38. """
  39. grads = [
  40. param.grad.data for param in params
  41. if param.requires_grad and param.grad is not None
  42. ]
  43. world_size = dist.get_world_size()
  44. if coalesce:
  45. _allreduce_coalesced(grads, world_size, bucket_size_mb)
  46. else:
  47. for tensor in grads:
  48. dist.all_reduce(tensor.div_(world_size))
  49. class DistOptimizerHook(OptimizerHook):
  50. """Deprecated optimizer hook for distributed training."""
  51. def __init__(self, *args, **kwargs):
  52. warnings.warn('"DistOptimizerHook" is deprecated, please switch to'
  53. '"mmcv.runner.OptimizerHook".')
  54. super().__init__(*args, **kwargs)
  55. def reduce_mean(tensor):
  56. """"Obtain the mean of tensor on different GPUs."""
  57. if not (dist.is_available() and dist.is_initialized()):
  58. return tensor
  59. tensor = tensor.clone()
  60. dist.all_reduce(tensor.div_(dist.get_world_size()), op=dist.ReduceOp.SUM)
  61. return tensor
  62. def obj2tensor(pyobj, device='cuda'):
  63. """Serialize picklable python object to tensor."""
  64. storage = torch.ByteStorage.from_buffer(pickle.dumps(pyobj))
  65. return torch.ByteTensor(storage).to(device=device)
  66. def tensor2obj(tensor):
  67. """Deserialize tensor to picklable python object."""
  68. return pickle.loads(tensor.cpu().numpy().tobytes())
  69. @functools.lru_cache()
  70. def _get_global_gloo_group():
  71. """Return a process group based on gloo backend, containing all the ranks
  72. The result is cached."""
  73. if dist.get_backend() == 'nccl':
  74. return dist.new_group(backend='gloo')
  75. else:
  76. return dist.group.WORLD
  77. def all_reduce_dict(py_dict, op='sum', group=None, to_float=True):
  78. """Apply all reduce function for python dict object.
  79. The code is modified from https://github.com/Megvii-
  80. BaseDetection/YOLOX/blob/main/yolox/utils/allreduce_norm.py.
  81. NOTE: make sure that py_dict in different ranks has the same keys and
  82. the values should be in the same shape.
  83. Args:
  84. py_dict (dict): Dict to be applied all reduce op.
  85. op (str): Operator, could be 'sum' or 'mean'. Default: 'sum'
  86. group (:obj:`torch.distributed.group`, optional): Distributed group,
  87. Default: None.
  88. to_float (bool): Whether to convert all values of dict to float.
  89. Default: True.
  90. Returns:
  91. OrderedDict: reduced python dict object.
  92. """
  93. _, world_size = get_dist_info()
  94. if world_size == 1:
  95. return py_dict
  96. if group is None:
  97. # TODO: May try not to use gloo in the future
  98. group = _get_global_gloo_group()
  99. if dist.get_world_size(group) == 1:
  100. return py_dict
  101. # all reduce logic across different devices.
  102. py_key = list(py_dict.keys())
  103. py_key_tensor = obj2tensor(py_key)
  104. dist.broadcast(py_key_tensor, src=0)
  105. py_key = tensor2obj(py_key_tensor)
  106. tensor_shapes = [py_dict[k].shape for k in py_key]
  107. tensor_numels = [py_dict[k].numel() for k in py_key]
  108. if to_float:
  109. flatten_tensor = torch.cat(
  110. [py_dict[k].flatten().float() for k in py_key])
  111. else:
  112. flatten_tensor = torch.cat([py_dict[k].flatten() for k in py_key])
  113. dist.all_reduce(flatten_tensor, op=dist.ReduceOp.SUM)
  114. if op == 'mean':
  115. flatten_tensor /= world_size
  116. split_tensors = [
  117. x.reshape(shape) for x, shape in zip(
  118. torch.split(flatten_tensor, tensor_numels), tensor_shapes)
  119. ]
  120. return OrderedDict({k: v for k, v in zip(py_key, split_tensors)})

No Description

Contributors (1)