Browse Source

!11066 GPU add restrict for bert script

From: @VectorSL
Reviewed-by: @gaoxiong1,@dylangeng,@anyrenwei
Signed-off-by: @gaoxiong1
tags/v1.2.0-rc1
mindspore-ci-bot Gitee 5 years ago
parent
commit
30560be800
1 changed files with 2 additions and 2 deletions
  1. +2
    -2
      model_zoo/official/nlp/bert/run_pretrain.py

+ 2
- 2
model_zoo/official/nlp/bert/run_pretrain.py View File

@@ -86,7 +86,7 @@ def _get_optimizer(args_opt, network):
group_params = [{'params': decay_params, 'weight_decay': cfg.AdamWeightDecay.weight_decay},
{'params': other_params, 'weight_decay': 0.0},
{'order_params': params}]
if args_opt.enable_lossscale == "true":
if args_opt.enable_lossscale == "true" and args_opt.device_target == 'GPU':
optimizer = AdamWeightDecayForBert(group_params, learning_rate=lr_schedule, eps=cfg.AdamWeightDecay.eps)
else:
optimizer = AdamWeightDecay(group_params, learning_rate=lr_schedule, eps=cfg.AdamWeightDecay.eps)
@@ -214,7 +214,7 @@ def run_pretrain():
accumulation_steps = args_opt.accumulation_steps
enable_global_norm = cfg.enable_global_norm
if accumulation_steps <= 1:
if cfg.optimizer == 'AdamWeightDecay':
if cfg.optimizer == 'AdamWeightDecay' and args_opt.device_target == 'GPU':
net_with_grads = BertTrainOneStepWithLossScaleCellForAdam(net_with_loss, optimizer=optimizer,
scale_update_cell=update_cell)
else:


Loading…
Cancel
Save