|
|
|
@@ -421,6 +421,7 @@ class TorchDDPDriver(TorchDriver): |
|
|
|
os.environ['MASTER_ADDR'] = self.master_address |
|
|
|
os.environ['MASTER_PORT'] = self.master_port |
|
|
|
|
|
|
|
os.environ["RANK"] = "0" |
|
|
|
os.environ["LOCAL_RANK"] = str(self.local_rank) |
|
|
|
os.environ["WORLD_SIZE"] = f"{self.world_size}" |
|
|
|
|
|
|
|
@@ -433,6 +434,7 @@ class TorchDDPDriver(TorchDriver): |
|
|
|
for rank in range(1, len(self.parallel_device)): |
|
|
|
env_copy = os.environ.copy() |
|
|
|
env_copy["LOCAL_RANK"] = f"{rank}" |
|
|
|
env_copy["RANK"] = f"{rank}" |
|
|
|
|
|
|
|
# 如果是多机,一定需要用户自己拉起,因此我们自己使用 open_subprocesses 开启的进程的 FASTNLP_GLOBAL_RANK 一定是 LOCAL_RANK; |
|
|
|
env_copy[FASTNLP_GLOBAL_RANK] = str(rank) |
|
|
|
|