|
|
|
@@ -127,7 +127,7 @@ def parse_args(cloud_args=None): |
|
|
|
# logging and checkpoint related |
|
|
|
parser.add_argument('--log_interval', type=int, default=100, help='logging interval') |
|
|
|
parser.add_argument('--ckpt_path', type=str, default='outputs/', help='checkpoint save location') |
|
|
|
parser.add_argument('--ckpt_interval', type=int, default=2, help='ckpt_interval') |
|
|
|
parser.add_argument('--ckpt_interval', type=int, default=5, help='ckpt_interval') |
|
|
|
parser.add_argument('--is_save_on_master', type=int, default=1, help='save ckpt on master or all rank') |
|
|
|
|
|
|
|
# distributed related |
|
|
|
@@ -200,12 +200,12 @@ if __name__ == '__main__': |
|
|
|
device_num = args.group_size |
|
|
|
context.reset_auto_parallel_context() |
|
|
|
context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL, |
|
|
|
mirror_mean=True) |
|
|
|
parameter_broadcast=True, mirror_mean=True) |
|
|
|
else: |
|
|
|
context.set_context(device_id=args.device_id) |
|
|
|
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target) |
|
|
|
|
|
|
|
# select for master rank save ckpt or all rank save, compatiable for model parallel |
|
|
|
# select for master rank save ckpt or all rank save, compatible for model parallel |
|
|
|
args.rank_save_ckpt_flag = 0 |
|
|
|
if args.is_save_on_master: |
|
|
|
if args.rank == 0: |
|
|
|
|