|
|
|
@@ -126,6 +126,7 @@ def merge_args(args_opt, cloud_args): |
|
|
|
if __name__ == '__main__': |
|
|
|
args = parse_args() |
|
|
|
|
|
|
|
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target) |
|
|
|
device_num = int(os.environ.get("DEVICE_NUM", 1)) |
|
|
|
if args.is_distributed: |
|
|
|
if args.device_target == "Ascend": |
|
|
|
@@ -143,7 +144,6 @@ if __name__ == '__main__': |
|
|
|
else: |
|
|
|
if args.device_target == "Ascend": |
|
|
|
context.set_context(device_id=args.device_id) |
|
|
|
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target) |
|
|
|
|
|
|
|
# select for master rank save ckpt or all rank save, compatible for model parallel |
|
|
|
args.rank_save_ckpt_flag = 0 |
|
|
|
|