|
|
|
@@ -191,12 +191,13 @@ if __name__ == '__main__': |
|
|
|
if args.is_distributed: |
|
|
|
if args.device_target == "Ascend": |
|
|
|
init() |
|
|
|
context.set_context(device_id=args.device_id) |
|
|
|
elif args.device_target == "GPU": |
|
|
|
init("nccl") |
|
|
|
args.rank = get_rank() |
|
|
|
args.group_size = get_group_size() |
|
|
|
device_num = args.group_size |
|
|
|
|
|
|
|
args.rank = get_rank() |
|
|
|
args.group_size = get_group_size() |
|
|
|
device_num = args.group_size |
|
|
|
context.reset_auto_parallel_context() |
|
|
|
context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL, |
|
|
|
mirror_mean=True) |
|
|
|
|