|
|
@@ -56,8 +56,13 @@ class TrainModule(pl.LightningModule): |
|
|
return input, label, pred |
|
|
return input, label, pred |
|
|
|
|
|
|
|
|
def configure_optimizers(self): |
|
|
def configure_optimizers(self): |
|
|
optimizer = torch.optim.SGD(self.parameters(), lr=0.1, momentum=0.9, weight_decay=1e-4) |
|
|
|
|
|
return optimizer |
|
|
|
|
|
|
|
|
lr = 0.1 |
|
|
|
|
|
optimizer = torch.optim.SGD(self.parameters(), lr=lr, momentum=0.9, weight_decay=1e-4) |
|
|
|
|
|
lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[91, 137]) |
|
|
|
|
|
# 仅在第一个epoch使用0.01的学习率 |
|
|
|
|
|
for param_group in optimizer.param_groups: |
|
|
|
|
|
param_group['lr'] = lr * 0.1 |
|
|
|
|
|
return [optimizer], [lr_scheduler] |
|
|
|
|
|
|
|
|
def load_pretrain_parameters(self): |
|
|
def load_pretrain_parameters(self): |
|
|
""" |
|
|
""" |
|
|
|