You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_distributed.py 2.8 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192
  1. import os
  2. import pytest
  3. from fastNLP.envs.distributed import rank_zero_call, all_rank_call_context
  4. from tests.helpers.utils import re_run_current_cmd_for_torch, Capturing, magic_argv_env_context
  5. @rank_zero_call
  6. def write_something():
  7. print(os.environ.get('RANK', '0')*5, flush=True)
  8. def write_other_thing():
  9. print(os.environ.get('RANK', '0')*5, flush=True)
  10. class PaddleTest:
  11. # @x54-729
  12. def test_rank_zero_call(self):
  13. pass
  14. def test_all_rank_run(self):
  15. pass
  16. class JittorTest:
  17. # @x54-729
  18. def test_rank_zero_call(self):
  19. pass
  20. def test_all_rank_run(self):
  21. pass
  22. @pytest.mark.torch
  23. class TestTorch:
  24. @magic_argv_env_context
  25. def test_rank_zero_call(self):
  26. os.environ['MASTER_ADDR'] = '127.0.0.1'
  27. os.environ['MASTER_PORT'] = '29500'
  28. if 'LOCAL_RANK' not in os.environ and 'RANK' not in os.environ and 'WORLD_SIZE' not in os.environ:
  29. os.environ['LOCAL_RANK'] = '0'
  30. os.environ['RANK'] = '0'
  31. os.environ['WORLD_SIZE'] = '2'
  32. re_run_current_cmd_for_torch(1, output_from_new_proc='all')
  33. with Capturing() as output:
  34. write_something()
  35. output = output[0]
  36. if os.environ['LOCAL_RANK'] == '0':
  37. assert '00000' in output and '11111' not in output
  38. else:
  39. assert '00000' not in output and '11111' not in output
  40. with Capturing() as output:
  41. rank_zero_call(write_other_thing)()
  42. output = output[0]
  43. if os.environ['LOCAL_RANK'] == '0':
  44. assert '00000' in output and '11111' not in output
  45. else:
  46. assert '00000' not in output and '11111' not in output
  47. @magic_argv_env_context
  48. def test_all_rank_run(self):
  49. os.environ['MASTER_ADDR'] = '127.0.0.1'
  50. os.environ['MASTER_PORT'] = '29500'
  51. if 'LOCAL_RANK' not in os.environ and 'RANK' not in os.environ and 'WORLD_SIZE' not in os.environ:
  52. os.environ['LOCAL_RANK'] = '0'
  53. os.environ['RANK'] = '0'
  54. os.environ['WORLD_SIZE'] = '2'
  55. re_run_current_cmd_for_torch(1, output_from_new_proc='all')
  56. # torch.distributed.init_process_group(backend='nccl')
  57. # torch.distributed.barrier()
  58. with all_rank_call_context():
  59. with Capturing(no_del=True) as output:
  60. write_something()
  61. output = output[0]
  62. if os.environ['LOCAL_RANK'] == '0':
  63. assert '00000' in output
  64. else:
  65. assert '11111' in output
  66. with all_rank_call_context():
  67. with Capturing(no_del=True) as output:
  68. rank_zero_call(write_other_thing)()
  69. output = output[0]
  70. if os.environ['LOCAL_RANK'] == '0':
  71. assert '00000' in output
  72. else:
  73. assert '11111' in output