You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

utils.py 1.4 kB

123456789101112131415161718192021222324252627282930313233343536373839404142
  1. import os, sys
  2. import socket
  3. from typing import Union
  4. import torch
  5. from torch import distributed
  6. import numpy as np
  7. def setup_ddp(rank: int, world_size: int, master_port: int) -> None:
  8. """Setup ddp environment."""
  9. os.environ["MASTER_ADDR"] = "localhost"
  10. os.environ["MASTER_PORT"] = str(master_port)
  11. if torch.distributed.is_available() and sys.platform not in ("win32", "cygwin"):
  12. torch.distributed.init_process_group("gloo", rank=rank, world_size=world_size)
  13. def find_free_network_port() -> int:
  14. """Finds a free port on localhost.
  15. It is useful in single-node training when we don't want to connect to a real master node but have to set the
  16. `MASTER_PORT` environment variable.
  17. """
  18. s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
  19. s.bind(("", 0))
  20. s.listen(1)
  21. port = s.getsockname()[1]
  22. s.close()
  23. return port
  24. def _assert_allclose(my_result: Union[float, np.ndarray], sklearn_result: Union[float, np.ndarray],
  25. atol: float = 1e-8) -> None:
  26. """
  27. 测试对比结果,这里不用非得是必须数组且维度对应,一些其他情况例如 np.allclose(np.array([[1e10, ], ]), 1e10+1) 也是 True
  28. :param my_result: 可以不限设备等
  29. :param sklearn_result:
  30. :param atol:
  31. :return:
  32. """
  33. assert np.allclose(a=my_result, b=sklearn_result, atol=atol)