From 1a12a8c27be2b0afecfb36548738ddc767183359 Mon Sep 17 00:00:00 2001 From: shenyan <23357320@qq.com> Date: Thu, 21 Oct 2021 16:46:45 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BD=BF=E7=94=A8warm=20up=E7=9A=84=E5=AD=A6?= =?UTF-8?q?=E4=B9=A0=E7=8E=87=E8=B0=83=E6=95=B4=E7=AD=96=E7=95=A5?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- train_model.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/train_model.py b/train_model.py index 631fc84..e37ccd0 100644 --- a/train_model.py +++ b/train_model.py @@ -56,8 +56,13 @@ class TrainModule(pl.LightningModule): return input, label, pred def configure_optimizers(self): - optimizer = torch.optim.SGD(self.parameters(), lr=0.1, momentum=0.9, weight_decay=1e-4) - return optimizer + lr = 0.1 + optimizer = torch.optim.SGD(self.parameters(), lr=lr, momentum=0.9, weight_decay=1e-4) + lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[91, 137]) + # 仅在第一个epoch使用0.01的学习率 + for param_group in optimizer.param_groups: + param_group['lr'] = lr * 0.1 + return [optimizer], [lr_scheduler] def load_pretrain_parameters(self): """