|
|
|
@@ -61,7 +61,7 @@ def parse_args(cloud_args=None): |
|
|
|
parser.add_argument('--lr_gamma', type=float, default=0.1, |
|
|
|
help='decrease lr by a factor of exponential lr_scheduler') |
|
|
|
parser.add_argument('--eta_min', type=float, default=0., help='eta_min in cosine_annealing scheduler') |
|
|
|
parser.add_argument('--T_max', type=int, default=150, help='T-max in cosine_annealing scheduler') |
|
|
|
parser.add_argument('--T_max', type=int, default=90, help='T-max in cosine_annealing scheduler') |
|
|
|
|
|
|
|
# logging and checkpoint related |
|
|
|
parser.add_argument('--log_interval', type=int, default=100, help='logging interval') |
|
|
|
@@ -140,7 +140,7 @@ if __name__ == '__main__': |
|
|
|
device_num = args.group_size |
|
|
|
context.reset_auto_parallel_context() |
|
|
|
context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL, |
|
|
|
gradients_mean=True, all_reduce_fusion_config=[3, 10, 12, 15]) |
|
|
|
gradients_mean=True, all_reduce_fusion_config=[2, 18]) |
|
|
|
else: |
|
|
|
if args.device_target == "Ascend": |
|
|
|
context.set_context(device_id=args.device_id) |
|
|
|
|