diff --git a/train_model.py b/train_model.py index 631fc84..e37ccd0 100644 --- a/train_model.py +++ b/train_model.py @@ -56,8 +56,13 @@ class TrainModule(pl.LightningModule): return input, label, pred def configure_optimizers(self): - optimizer = torch.optim.SGD(self.parameters(), lr=0.1, momentum=0.9, weight_decay=1e-4) - return optimizer + lr = 0.1 + optimizer = torch.optim.SGD(self.parameters(), lr=lr, momentum=0.9, weight_decay=1e-4) + lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[91, 137]) + # 仅在第一个epoch使用0.01的学习率 + for param_group in optimizer.param_groups: + param_group['lr'] = lr * 0.1 + return [optimizer], [lr_scheduler] def load_pretrain_parameters(self): """