diff --git a/modelscope/trainers/trainer.py b/modelscope/trainers/trainer.py index fbfcf96c..e70ad2b4 100644 --- a/modelscope/trainers/trainer.py +++ b/modelscope/trainers/trainer.py @@ -37,7 +37,7 @@ from modelscope.utils.file_utils import func_receive_dict_inputs from modelscope.utils.logger import get_logger from modelscope.utils.registry import build_from_cfg from modelscope.utils.torch_utils import (get_dist_info, get_local_rank, - init_dist, is_master, + init_dist, is_dist, is_master, set_random_seed) from .base import BaseTrainer from .builder import TRAINERS @@ -236,7 +236,7 @@ class EpochBasedTrainer(BaseTrainer): device_name: The final device name. """ device_name = device if device is not None else 'gpu' - if dist.is_initialized(): + if is_dist(): local_rank = get_local_rank() device_name = f'cuda:{local_rank}' @@ -603,7 +603,7 @@ class EpochBasedTrainer(BaseTrainer): for key in match_keys: value = train_outputs.get(key, None) if value is not None: - if dist.is_available() and dist.is_initialized(): + if is_dist(): value = value.data.clone().to('cuda') dist.all_reduce(value.div_(dist.get_world_size())) log_vars.update({key: value.item()}) diff --git a/modelscope/utils/torch_utils.py b/modelscope/utils/torch_utils.py index e8c21d86..ed1f94c5 100644 --- a/modelscope/utils/torch_utils.py +++ b/modelscope/utils/torch_utils.py @@ -106,7 +106,7 @@ def _init_dist_slurm(backend: str, port: Optional[int] = None) -> None: def get_dist_info() -> Tuple[int, int]: - if dist.is_available() and dist.is_initialized(): + if is_dist(): try: from megatron import mpu assert mpu.model_parallel_is_initialized() @@ -125,8 +125,12 @@ def get_local_rank(): return int(os.environ.get('LOCAL_RANK', 0)) +def is_dist(): + return dist.is_available() and dist.is_initialized() + + def is_master(): - return dist.get_rank() == 0 if dist.is_initialized() else True + return dist.get_rank() == 0 if is_dist() else True def master_only(func: Callable) -> Callable: @@ -142,7 +146,7 @@ def master_only(func: Callable) -> Callable: def make_tmp_dir(): """Make sure each rank has the same temporary directory on the distributed mode. """ - if not dist.is_initialized(): + if not is_dist(): return tempfile.mkdtemp() tmpdir = None