diff --git a/abl/learning/basic_nn.py b/abl/learning/basic_nn.py index 8bd92cb..eae6d9e 100644 --- a/abl/learning/basic_nn.py +++ b/abl/learning/basic_nn.py @@ -54,6 +54,7 @@ class BasicNN: model: torch.nn.Module, loss_fn: torch.nn.Module, optimizer: torch.optim.Optimizer, + scheduler: Optional[torch.optim.lr_scheduler.LRScheduler] = None, device: torch.device = torch.device("cpu"), batch_size: int = 32, num_epochs: int = 1, @@ -71,6 +72,10 @@ class BasicNN: raise TypeError("loss_fn must be an instance of torch.nn.Module") if not isinstance(optimizer, torch.optim.Optimizer): raise TypeError("optimizer must be an instance of torch.optim.Optimizer") + if scheduler is not None and not isinstance( + scheduler, torch.optim.lr_scheduler.LRScheduler + ): + raise TypeError("scheduler must be an instance of torch.optim.lr_scheduler.LRScheduler") if not isinstance(device, torch.device): raise TypeError("device must be an instance of torch.device") if not isinstance(batch_size, int): @@ -95,6 +100,7 @@ class BasicNN: self.model = model.to(device) self.loss_fn = loss_fn self.optimizer = optimizer + self.scheduler = scheduler self.device = device self.batch_size = batch_size self.num_epochs = num_epochs @@ -144,6 +150,8 @@ class BasicNN: self.save(epoch + 1) if self.stop_loss is not None and loss_value < self.stop_loss: break + if self.scheduler is not None: + self.scheduler.step() print_log(f"model loss: {loss_value:.5f}", logger="current") return self @@ -208,7 +216,9 @@ class BasicNN: for data, target in data_loader: data, target = data.to(device), target.to(device) out = model(data) - loss = loss_fn(out, target) + proba = torch.nn.functional.softmax(out, dim=1) + entropy = -torch.sum(proba * torch.log(proba + 1e-5), dim=1).mean() + loss = loss_fn(out, target) - 0.3 * entropy optimizer.zero_grad() loss.backward()