|
|
@@ -44,6 +44,9 @@ def parse_args(): |
|
|
help="Hccl config path, it is better to use absolute path") |
|
|
help="Hccl config path, it is better to use absolute path") |
|
|
parser.add_argument("--cmd_file", type=str, default="distributed_cmd.sh", |
|
|
parser.add_argument("--cmd_file", type=str, default="distributed_cmd.sh", |
|
|
help="Path of the generated cmd file.") |
|
|
help="Path of the generated cmd file.") |
|
|
|
|
|
parser.add_argument("--hccl_time_out", type=int, default=120, |
|
|
|
|
|
help="Seconds to determine the hccl time out," |
|
|
|
|
|
"default: 120, which is the same as hccl default config") |
|
|
|
|
|
|
|
|
args = parser.parse_args() |
|
|
args = parser.parse_args() |
|
|
return args |
|
|
return args |
|
|
@@ -73,6 +76,8 @@ def distribute_pretrain(): |
|
|
cfg = dict(cf.items("config")) |
|
|
cfg = dict(cf.items("config")) |
|
|
|
|
|
|
|
|
print("hccl_config_dir:", args.hccl_config_dir) |
|
|
print("hccl_config_dir:", args.hccl_config_dir) |
|
|
|
|
|
print("hccl_time_out:", args.hccl_time_out) |
|
|
|
|
|
cmd = append_cmd_env(cmd, 'HCCL_CONNECTION_TIMEOUT', args.hccl_time_out) |
|
|
cmd = append_cmd_env(cmd, 'RANK_TABLE_FILE', args.hccl_config_dir) |
|
|
cmd = append_cmd_env(cmd, 'RANK_TABLE_FILE', args.hccl_config_dir) |
|
|
|
|
|
|
|
|
cores = multiprocessing.cpu_count() |
|
|
cores = multiprocessing.cpu_count() |
|
|
|