|
|
|
@@ -86,7 +86,7 @@ def _get_optimizer(args_opt, network): |
|
|
|
group_params = [{'params': decay_params, 'weight_decay': cfg.AdamWeightDecay.weight_decay}, |
|
|
|
{'params': other_params, 'weight_decay': 0.0}, |
|
|
|
{'order_params': params}] |
|
|
|
if args_opt.enable_lossscale == "true": |
|
|
|
if args_opt.enable_lossscale == "true" and args_opt.device_target == 'GPU': |
|
|
|
optimizer = AdamWeightDecayForBert(group_params, learning_rate=lr_schedule, eps=cfg.AdamWeightDecay.eps) |
|
|
|
else: |
|
|
|
optimizer = AdamWeightDecay(group_params, learning_rate=lr_schedule, eps=cfg.AdamWeightDecay.eps) |
|
|
|
@@ -214,7 +214,7 @@ def run_pretrain(): |
|
|
|
accumulation_steps = args_opt.accumulation_steps |
|
|
|
enable_global_norm = cfg.enable_global_norm |
|
|
|
if accumulation_steps <= 1: |
|
|
|
if cfg.optimizer == 'AdamWeightDecay': |
|
|
|
if cfg.optimizer == 'AdamWeightDecay' and args_opt.device_target == 'GPU': |
|
|
|
net_with_grads = BertTrainOneStepWithLossScaleCellForAdam(net_with_loss, optimizer=optimizer, |
|
|
|
scale_update_cell=update_cell) |
|
|
|
else: |
|
|
|
|