Browse Source

ddp添加环境变量RANK的设置

tags/v1.0.0alpha
x54-729 3 years ago
parent
commit
dca3377129
1 changed files with 2 additions and 0 deletions
  1. +2
    -0
      fastNLP/core/drivers/torch_driver/ddp.py

+ 2
- 0
fastNLP/core/drivers/torch_driver/ddp.py View File

@@ -421,6 +421,7 @@ class TorchDDPDriver(TorchDriver):
os.environ['MASTER_ADDR'] = self.master_address os.environ['MASTER_ADDR'] = self.master_address
os.environ['MASTER_PORT'] = self.master_port os.environ['MASTER_PORT'] = self.master_port


os.environ["RANK"] = "0"
os.environ["LOCAL_RANK"] = str(self.local_rank) os.environ["LOCAL_RANK"] = str(self.local_rank)
os.environ["WORLD_SIZE"] = f"{self.world_size}" os.environ["WORLD_SIZE"] = f"{self.world_size}"


@@ -433,6 +434,7 @@ class TorchDDPDriver(TorchDriver):
for rank in range(1, len(self.parallel_device)): for rank in range(1, len(self.parallel_device)):
env_copy = os.environ.copy() env_copy = os.environ.copy()
env_copy["LOCAL_RANK"] = f"{rank}" env_copy["LOCAL_RANK"] = f"{rank}"
env_copy["RANK"] = f"{rank}"


# 如果是多机,一定需要用户自己拉起,因此我们自己使用 open_subprocesses 开启的进程的 FASTNLP_GLOBAL_RANK 一定是 LOCAL_RANK; # 如果是多机,一定需要用户自己拉起,因此我们自己使用 open_subprocesses 开启的进程的 FASTNLP_GLOBAL_RANK 一定是 LOCAL_RANK;
env_copy[FASTNLP_GLOBAL_RANK] = str(rank) env_copy[FASTNLP_GLOBAL_RANK] = str(rank)


Loading…
Cancel
Save