|
|
|
@@ -25,12 +25,11 @@ from src.config import cfg |
|
|
|
from src.dataset import create_bert_dataset |
|
|
|
from src.lr_generator import get_bert_lr, get_bert_damping |
|
|
|
from src.model_thor import Model |
|
|
|
from src.utils import LossCallBack, BertLearningRate |
|
|
|
from src.utils import LossCallBack |
|
|
|
import mindspore.common.dtype as mstype |
|
|
|
import mindspore.communication.management as D |
|
|
|
from mindspore import context |
|
|
|
from mindspore import log as logger |
|
|
|
from mindspore.nn.optim import Lamb, Momentum, AdamWeightDecay |
|
|
|
from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell |
|
|
|
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, TimeMonitor |
|
|
|
from mindspore.context import ParallelMode |
|
|
|
@@ -68,38 +67,8 @@ def _set_bert_all_reduce_split(): |
|
|
|
|
|
|
|
|
|
|
|
def _get_optimizer(args_opt, network): |
|
|
|
"""get bert optimizer, support Lamb, Momentum, AdamWeightDecay and Thor.""" |
|
|
|
if cfg.optimizer == 'Lamb': |
|
|
|
lr_schedule = BertLearningRate(learning_rate=cfg.Lamb.learning_rate, |
|
|
|
end_learning_rate=cfg.Lamb.end_learning_rate, |
|
|
|
warmup_steps=cfg.Lamb.warmup_steps, |
|
|
|
decay_steps=args_opt.train_steps, |
|
|
|
power=cfg.Lamb.power) |
|
|
|
params = network.trainable_params() |
|
|
|
decay_params = list(filter(cfg.Lamb.decay_filter, params)) |
|
|
|
other_params = list(filter(lambda x: not cfg.Lamb.decay_filter(x), params)) |
|
|
|
group_params = [{'params': decay_params, 'weight_decay': cfg.Lamb.weight_decay}, |
|
|
|
{'params': other_params}, |
|
|
|
{'order_params': params}] |
|
|
|
optimizer = Lamb(group_params, learning_rate=lr_schedule, eps=cfg.Lamb.eps) |
|
|
|
elif cfg.optimizer == 'Momentum': |
|
|
|
optimizer = Momentum(network.trainable_params(), learning_rate=cfg.Momentum.learning_rate, |
|
|
|
momentum=cfg.Momentum.momentum) |
|
|
|
elif cfg.optimizer == 'AdamWeightDecay': |
|
|
|
lr_schedule = BertLearningRate(learning_rate=cfg.AdamWeightDecay.learning_rate, |
|
|
|
end_learning_rate=cfg.AdamWeightDecay.end_learning_rate, |
|
|
|
warmup_steps=cfg.AdamWeightDecay.warmup_steps, |
|
|
|
decay_steps=args_opt.train_steps, |
|
|
|
power=cfg.AdamWeightDecay.power) |
|
|
|
params = network.trainable_params() |
|
|
|
decay_params = list(filter(cfg.AdamWeightDecay.decay_filter, params)) |
|
|
|
other_params = list(filter(lambda x: not cfg.AdamWeightDecay.decay_filter(x), params)) |
|
|
|
group_params = [{'params': decay_params, 'weight_decay': cfg.AdamWeightDecay.weight_decay}, |
|
|
|
{'params': other_params, 'weight_decay': 0.0}, |
|
|
|
{'order_params': params}] |
|
|
|
|
|
|
|
optimizer = AdamWeightDecay(group_params, learning_rate=lr_schedule, eps=cfg.AdamWeightDecay.eps) |
|
|
|
elif cfg.optimizer == "Thor": |
|
|
|
"""get thor optimizer.""" |
|
|
|
if cfg.optimizer == "Thor": |
|
|
|
if args_opt.distribute == "true": |
|
|
|
from src.thor_for_bert_arg import THOR |
|
|
|
else: |
|
|
|
@@ -112,8 +81,7 @@ def _get_optimizer(args_opt, network): |
|
|
|
cfg.Thor.weight_decay, cfg.Thor.loss_scale, bert_net_cfg.num_hidden_layers, |
|
|
|
bert_net_cfg.batch_size, damping) |
|
|
|
else: |
|
|
|
raise ValueError("Don't support optimizer {}, only support [Lamb, Momentum, AdamWeightDecay, Thor]". |
|
|
|
format(cfg.optimizer)) |
|
|
|
raise ValueError("Don't support optimizer {}, only support [Thor]".format(cfg.optimizer)) |
|
|
|
return optimizer |
|
|
|
|
|
|
|
|
|
|
|
|