Browse Source

Fix dist judgement when torch.distributed.is_available is always False

Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10976015
master^2
yuze.zyz wenmeng.zwm 3 years ago
parent
commit
bf97dd7501
2 changed files with 10 additions and 6 deletions
  1. +3
    -3
      modelscope/trainers/trainer.py
  2. +7
    -3
      modelscope/utils/torch_utils.py

+ 3
- 3
modelscope/trainers/trainer.py View File

@@ -37,7 +37,7 @@ from modelscope.utils.file_utils import func_receive_dict_inputs
from modelscope.utils.logger import get_logger from modelscope.utils.logger import get_logger
from modelscope.utils.registry import build_from_cfg from modelscope.utils.registry import build_from_cfg
from modelscope.utils.torch_utils import (get_dist_info, get_local_rank, from modelscope.utils.torch_utils import (get_dist_info, get_local_rank,
init_dist, is_master,
init_dist, is_dist, is_master,
set_random_seed) set_random_seed)
from .base import BaseTrainer from .base import BaseTrainer
from .builder import TRAINERS from .builder import TRAINERS
@@ -236,7 +236,7 @@ class EpochBasedTrainer(BaseTrainer):
device_name: The final device name. device_name: The final device name.
""" """
device_name = device if device is not None else 'gpu' device_name = device if device is not None else 'gpu'
if dist.is_initialized():
if is_dist():
local_rank = get_local_rank() local_rank = get_local_rank()
device_name = f'cuda:{local_rank}' device_name = f'cuda:{local_rank}'


@@ -603,7 +603,7 @@ class EpochBasedTrainer(BaseTrainer):
for key in match_keys: for key in match_keys:
value = train_outputs.get(key, None) value = train_outputs.get(key, None)
if value is not None: if value is not None:
if dist.is_available() and dist.is_initialized():
if is_dist():
value = value.data.clone().to('cuda') value = value.data.clone().to('cuda')
dist.all_reduce(value.div_(dist.get_world_size())) dist.all_reduce(value.div_(dist.get_world_size()))
log_vars.update({key: value.item()}) log_vars.update({key: value.item()})


+ 7
- 3
modelscope/utils/torch_utils.py View File

@@ -106,7 +106,7 @@ def _init_dist_slurm(backend: str, port: Optional[int] = None) -> None:




def get_dist_info() -> Tuple[int, int]: def get_dist_info() -> Tuple[int, int]:
if dist.is_available() and dist.is_initialized():
if is_dist():
try: try:
from megatron import mpu from megatron import mpu
assert mpu.model_parallel_is_initialized() assert mpu.model_parallel_is_initialized()
@@ -125,8 +125,12 @@ def get_local_rank():
return int(os.environ.get('LOCAL_RANK', 0)) return int(os.environ.get('LOCAL_RANK', 0))




def is_dist():
return dist.is_available() and dist.is_initialized()


def is_master(): def is_master():
return dist.get_rank() == 0 if dist.is_initialized() else True
return dist.get_rank() == 0 if is_dist() else True




def master_only(func: Callable) -> Callable: def master_only(func: Callable) -> Callable:
@@ -142,7 +146,7 @@ def master_only(func: Callable) -> Callable:
def make_tmp_dir(): def make_tmp_dir():
"""Make sure each rank has the same temporary directory on the distributed mode. """Make sure each rank has the same temporary directory on the distributed mode.
""" """
if not dist.is_initialized():
if not is_dist():
return tempfile.mkdtemp() return tempfile.mkdtemp()


tmpdir = None tmpdir = None


Loading…
Cancel
Save