|
|
|
@@ -28,7 +28,8 @@ from mindspore.context import ParallelMode |
|
|
|
from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell |
|
|
|
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, TimeMonitor |
|
|
|
from mindspore.train.serialization import load_checkpoint, load_param_into_net |
|
|
|
from mindspore.nn.optim import Lamb, Momentum, AdamWeightDecay |
|
|
|
from mindspore.train.train_thor import ConvertModelUtils |
|
|
|
from mindspore.nn.optim import Lamb, Momentum, AdamWeightDecay, THOR |
|
|
|
from mindspore import log as logger |
|
|
|
from mindspore.common import set_seed |
|
|
|
from src import BertNetworkWithLoss, BertTrainOneStepCell, BertTrainOneStepWithLossScaleCell, \ |
|
|
|
@@ -90,8 +91,27 @@ def _get_optimizer(args_opt, network): |
|
|
|
optimizer = AdamWeightDecayForBert(group_params, learning_rate=lr_schedule, eps=cfg.AdamWeightDecay.eps) |
|
|
|
else: |
|
|
|
optimizer = AdamWeightDecay(group_params, learning_rate=lr_schedule, eps=cfg.AdamWeightDecay.eps) |
|
|
|
elif cfg.optimizer == "Thor": |
|
|
|
from src.utils import get_bert_thor_lr, get_bert_thor_damping |
|
|
|
lr = get_bert_thor_lr() |
|
|
|
damping = get_bert_thor_damping() |
|
|
|
split_indices = None |
|
|
|
if bert_net_cfg.num_hidden_layers == 12: |
|
|
|
if bert_net_cfg.use_relative_positions: |
|
|
|
split_indices = [29, 58, 87, 116, 145, 174, 203, 217] |
|
|
|
else: |
|
|
|
split_indices = [28, 55, 82, 109, 136, 163, 190, 205] |
|
|
|
elif bert_net_cfg.num_hidden_layers == 24: |
|
|
|
if bert_net_cfg.use_relative_positions: |
|
|
|
split_indices = [30, 90, 150, 210, 270, 330, 390, 421] |
|
|
|
else: |
|
|
|
split_indices = [38, 93, 148, 203, 258, 313, 368, 397] |
|
|
|
optimizer = THOR(network, lr, damping, cfg.Thor.momentum, |
|
|
|
cfg.Thor.weight_decay, cfg.Thor.loss_scale, cfg.batch_size, |
|
|
|
decay_filter=lambda x: 'layernorm' not in x.name.lower() and 'bias' not in x.name.lower(), |
|
|
|
split_indices=split_indices) |
|
|
|
else: |
|
|
|
raise ValueError("Don't support optimizer {}, only support [Lamb, Momentum, AdamWeightDecay]". |
|
|
|
raise ValueError("Don't support optimizer {}, only support [Lamb, Momentum, AdamWeightDecay, Thor]". |
|
|
|
format(cfg.optimizer)) |
|
|
|
return optimizer |
|
|
|
|
|
|
|
@@ -244,6 +264,8 @@ def run_pretrain(): |
|
|
|
net_with_grads = BertTrainOneStepCell(net_with_loss, optimizer=optimizer) |
|
|
|
|
|
|
|
model = Model(net_with_grads) |
|
|
|
model = ConvertModelUtils().convert_to_thor_model(model, network=net_with_grads, optimizer=optimizer, |
|
|
|
frequency=cfg.Thor.frequency) |
|
|
|
model.train(new_repeat_count, ds, callbacks=callback, |
|
|
|
dataset_sink_mode=(args_opt.enable_data_sink == "true"), sink_size=args_opt.data_sink_steps) |
|
|
|
|
|
|
|
|