Browse Source

add order params for bert to improve performance

tags/v0.7.0-beta
shibeiji 5 years ago
parent
commit
29e35a31c0
2 changed files with 7 additions and 4 deletions
  1. +5
    -2
      model_zoo/official/nlp/bert/run_pretrain.py
  2. +2
    -2
      model_zoo/official/nlp/bert/src/config.py

+ 5
- 2
model_zoo/official/nlp/bert/run_pretrain.py View File

@@ -106,6 +106,7 @@ def run_pretrain():
new_repeat_count = min(new_repeat_count, args_opt.train_steps // args_opt.data_sink_steps) new_repeat_count = min(new_repeat_count, args_opt.train_steps // args_opt.data_sink_steps)
else: else:
args_opt.train_steps = args_opt.epoch_size * ds.get_dataset_size() args_opt.train_steps = args_opt.epoch_size * ds.get_dataset_size()
logger.info("train steps: {}".format(args_opt.train_steps))


if cfg.optimizer == 'Lamb': if cfg.optimizer == 'Lamb':
lr_schedule = BertLearningRate(learning_rate=cfg.Lamb.learning_rate, lr_schedule = BertLearningRate(learning_rate=cfg.Lamb.learning_rate,
@@ -117,7 +118,8 @@ def run_pretrain():
decay_params = list(filter(cfg.Lamb.decay_filter, params)) decay_params = list(filter(cfg.Lamb.decay_filter, params))
other_params = list(filter(lambda x: x not in decay_params, params)) other_params = list(filter(lambda x: x not in decay_params, params))
group_params = [{'params': decay_params, 'weight_decay': cfg.Lamb.weight_decay}, group_params = [{'params': decay_params, 'weight_decay': cfg.Lamb.weight_decay},
{'params': other_params}]
{'params': other_params},
{'order_params': params}]
optimizer = Lamb(group_params, learning_rate=lr_schedule, eps=cfg.Lamb.eps) optimizer = Lamb(group_params, learning_rate=lr_schedule, eps=cfg.Lamb.eps)
elif cfg.optimizer == 'Momentum': elif cfg.optimizer == 'Momentum':
optimizer = Momentum(net_with_loss.trainable_params(), learning_rate=cfg.Momentum.learning_rate, optimizer = Momentum(net_with_loss.trainable_params(), learning_rate=cfg.Momentum.learning_rate,
@@ -132,7 +134,8 @@ def run_pretrain():
decay_params = list(filter(cfg.AdamWeightDecay.decay_filter, params)) decay_params = list(filter(cfg.AdamWeightDecay.decay_filter, params))
other_params = list(filter(lambda x: x not in decay_params, params)) other_params = list(filter(lambda x: x not in decay_params, params))
group_params = [{'params': decay_params, 'weight_decay': cfg.AdamWeightDecay.weight_decay}, group_params = [{'params': decay_params, 'weight_decay': cfg.AdamWeightDecay.weight_decay},
{'params': other_params, 'weight_decay': 0.0}]
{'params': other_params, 'weight_decay': 0.0},
{'order_params': params}]


optimizer = AdamWeightDecay(group_params, learning_rate=lr_schedule, eps=cfg.AdamWeightDecay.eps) optimizer = AdamWeightDecay(group_params, learning_rate=lr_schedule, eps=cfg.AdamWeightDecay.eps)
else: else:


+ 2
- 2
model_zoo/official/nlp/bert/src/config.py View File

@@ -26,7 +26,7 @@ cfg = edict({
'optimizer': 'Lamb', 'optimizer': 'Lamb',
'AdamWeightDecay': edict({ 'AdamWeightDecay': edict({
'learning_rate': 3e-5, 'learning_rate': 3e-5,
'end_learning_rate': 1e-10,
'end_learning_rate': 0.0,
'power': 5.0, 'power': 5.0,
'weight_decay': 1e-5, 'weight_decay': 1e-5,
'decay_filter': lambda x: 'layernorm' not in x.name.lower() and 'bias' not in x.name.lower(), 'decay_filter': lambda x: 'layernorm' not in x.name.lower() and 'bias' not in x.name.lower(),
@@ -35,7 +35,7 @@ cfg = edict({
}), }),
'Lamb': edict({ 'Lamb': edict({
'learning_rate': 3e-5, 'learning_rate': 3e-5,
'end_learning_rate': 1e-10,
'end_learning_rate': 0.0,
'power': 10.0, 'power': 10.0,
'warmup_steps': 10000, 'warmup_steps': 10000,
'weight_decay': 0.01, 'weight_decay': 0.01,


Loading…
Cancel
Save