|
|
@@ -106,6 +106,7 @@ def run_pretrain(): |
|
|
new_repeat_count = min(new_repeat_count, args_opt.train_steps // args_opt.data_sink_steps) |
|
|
new_repeat_count = min(new_repeat_count, args_opt.train_steps // args_opt.data_sink_steps) |
|
|
else: |
|
|
else: |
|
|
args_opt.train_steps = args_opt.epoch_size * ds.get_dataset_size() |
|
|
args_opt.train_steps = args_opt.epoch_size * ds.get_dataset_size() |
|
|
|
|
|
logger.info("train steps: {}".format(args_opt.train_steps)) |
|
|
|
|
|
|
|
|
if cfg.optimizer == 'Lamb': |
|
|
if cfg.optimizer == 'Lamb': |
|
|
lr_schedule = BertLearningRate(learning_rate=cfg.Lamb.learning_rate, |
|
|
lr_schedule = BertLearningRate(learning_rate=cfg.Lamb.learning_rate, |
|
|
@@ -117,7 +118,8 @@ def run_pretrain(): |
|
|
decay_params = list(filter(cfg.Lamb.decay_filter, params)) |
|
|
decay_params = list(filter(cfg.Lamb.decay_filter, params)) |
|
|
other_params = list(filter(lambda x: x not in decay_params, params)) |
|
|
other_params = list(filter(lambda x: x not in decay_params, params)) |
|
|
group_params = [{'params': decay_params, 'weight_decay': cfg.Lamb.weight_decay}, |
|
|
group_params = [{'params': decay_params, 'weight_decay': cfg.Lamb.weight_decay}, |
|
|
{'params': other_params}] |
|
|
|
|
|
|
|
|
{'params': other_params}, |
|
|
|
|
|
{'order_params': params}] |
|
|
optimizer = Lamb(group_params, learning_rate=lr_schedule, eps=cfg.Lamb.eps) |
|
|
optimizer = Lamb(group_params, learning_rate=lr_schedule, eps=cfg.Lamb.eps) |
|
|
elif cfg.optimizer == 'Momentum': |
|
|
elif cfg.optimizer == 'Momentum': |
|
|
optimizer = Momentum(net_with_loss.trainable_params(), learning_rate=cfg.Momentum.learning_rate, |
|
|
optimizer = Momentum(net_with_loss.trainable_params(), learning_rate=cfg.Momentum.learning_rate, |
|
|
@@ -132,7 +134,8 @@ def run_pretrain(): |
|
|
decay_params = list(filter(cfg.AdamWeightDecay.decay_filter, params)) |
|
|
decay_params = list(filter(cfg.AdamWeightDecay.decay_filter, params)) |
|
|
other_params = list(filter(lambda x: x not in decay_params, params)) |
|
|
other_params = list(filter(lambda x: x not in decay_params, params)) |
|
|
group_params = [{'params': decay_params, 'weight_decay': cfg.AdamWeightDecay.weight_decay}, |
|
|
group_params = [{'params': decay_params, 'weight_decay': cfg.AdamWeightDecay.weight_decay}, |
|
|
{'params': other_params, 'weight_decay': 0.0}] |
|
|
|
|
|
|
|
|
{'params': other_params, 'weight_decay': 0.0}, |
|
|
|
|
|
{'order_params': params}] |
|
|
|
|
|
|
|
|
optimizer = AdamWeightDecay(group_params, learning_rate=lr_schedule, eps=cfg.AdamWeightDecay.eps) |
|
|
optimizer = AdamWeightDecay(group_params, learning_rate=lr_schedule, eps=cfg.AdamWeightDecay.eps) |
|
|
else: |
|
|
else: |
|
|
|