You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

utils.py 1.5 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344
  1. import os, sys
  2. import socket
  3. from typing import Union
  4. import numpy as np
  5. from fastNLP.envs.imports import _NEED_IMPORT_TORCH
  6. if _NEED_IMPORT_TORCH:
  7. import torch
  8. from torch import distributed
  9. def setup_ddp(rank: int, world_size: int, master_port: int) -> None:
  10. """Setup ddp environment."""
  11. os.environ["MASTER_ADDR"] = "localhost"
  12. os.environ["MASTER_PORT"] = str(master_port)
  13. if torch.distributed.is_available() and sys.platform not in ("win32", "cygwin"):
  14. torch.distributed.init_process_group("gloo", rank=rank, world_size=world_size)
  15. def find_free_network_port() -> int:
  16. """Finds a free port on localhost.
  17. It is useful in single-node training when we don't want to connect to a real master node but have to set the
  18. `MASTER_PORT` environment variable.
  19. """
  20. s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
  21. s.bind(("", 0))
  22. s.listen(1)
  23. port = s.getsockname()[1]
  24. s.close()
  25. return port
  26. def _assert_allclose(my_result: Union[float, np.ndarray], sklearn_result: Union[float, np.ndarray],
  27. atol: float = 1e-8) -> None:
  28. """
  29. 测试对比结果,这里不用非得是必须数组且维度对应,一些其他情况例如 np.allclose(np.array([[1e10, ], ]), 1e10+1) 也是 True
  30. :param my_result: 可以不限设备等
  31. :param sklearn_result:
  32. :param atol:
  33. :return:
  34. """
  35. assert np.allclose(a=my_result, b=sklearn_result, atol=atol)