Browse Source

[ENH] add scheduler to BasicNN

pull/1/head
Gao Enhao 2 years ago
parent
commit
3a3bf801e8
1 changed files with 11 additions and 1 deletions
  1. +11
    -1
      abl/learning/basic_nn.py

+ 11
- 1
abl/learning/basic_nn.py View File

@@ -54,6 +54,7 @@ class BasicNN:
model: torch.nn.Module,
loss_fn: torch.nn.Module,
optimizer: torch.optim.Optimizer,
scheduler: Optional[torch.optim.lr_scheduler.LRScheduler] = None,
device: torch.device = torch.device("cpu"),
batch_size: int = 32,
num_epochs: int = 1,
@@ -71,6 +72,10 @@ class BasicNN:
raise TypeError("loss_fn must be an instance of torch.nn.Module")
if not isinstance(optimizer, torch.optim.Optimizer):
raise TypeError("optimizer must be an instance of torch.optim.Optimizer")
if scheduler is not None and not isinstance(
scheduler, torch.optim.lr_scheduler.LRScheduler
):
raise TypeError("scheduler must be an instance of torch.optim.lr_scheduler.LRScheduler")
if not isinstance(device, torch.device):
raise TypeError("device must be an instance of torch.device")
if not isinstance(batch_size, int):
@@ -95,6 +100,7 @@ class BasicNN:
self.model = model.to(device)
self.loss_fn = loss_fn
self.optimizer = optimizer
self.scheduler = scheduler
self.device = device
self.batch_size = batch_size
self.num_epochs = num_epochs
@@ -144,6 +150,8 @@ class BasicNN:
self.save(epoch + 1)
if self.stop_loss is not None and loss_value < self.stop_loss:
break
if self.scheduler is not None:
self.scheduler.step()
print_log(f"model loss: {loss_value:.5f}", logger="current")
return self

@@ -208,7 +216,9 @@ class BasicNN:
for data, target in data_loader:
data, target = data.to(device), target.to(device)
out = model(data)
loss = loss_fn(out, target)
proba = torch.nn.functional.softmax(out, dim=1)
entropy = -torch.sum(proba * torch.log(proba + 1e-5), dim=1).mean()
loss = loss_fn(out, target) - 0.3 * entropy

optimizer.zero_grad()
loss.backward()


Loading…
Cancel
Save