|
|
|
@@ -54,6 +54,7 @@ class BasicNN: |
|
|
|
model: torch.nn.Module, |
|
|
|
loss_fn: torch.nn.Module, |
|
|
|
optimizer: torch.optim.Optimizer, |
|
|
|
scheduler: Optional[torch.optim.lr_scheduler.LRScheduler] = None, |
|
|
|
device: torch.device = torch.device("cpu"), |
|
|
|
batch_size: int = 32, |
|
|
|
num_epochs: int = 1, |
|
|
|
@@ -71,6 +72,10 @@ class BasicNN: |
|
|
|
raise TypeError("loss_fn must be an instance of torch.nn.Module") |
|
|
|
if not isinstance(optimizer, torch.optim.Optimizer): |
|
|
|
raise TypeError("optimizer must be an instance of torch.optim.Optimizer") |
|
|
|
if scheduler is not None and not isinstance( |
|
|
|
scheduler, torch.optim.lr_scheduler.LRScheduler |
|
|
|
): |
|
|
|
raise TypeError("scheduler must be an instance of torch.optim.lr_scheduler.LRScheduler") |
|
|
|
if not isinstance(device, torch.device): |
|
|
|
raise TypeError("device must be an instance of torch.device") |
|
|
|
if not isinstance(batch_size, int): |
|
|
|
@@ -95,6 +100,7 @@ class BasicNN: |
|
|
|
self.model = model.to(device) |
|
|
|
self.loss_fn = loss_fn |
|
|
|
self.optimizer = optimizer |
|
|
|
self.scheduler = scheduler |
|
|
|
self.device = device |
|
|
|
self.batch_size = batch_size |
|
|
|
self.num_epochs = num_epochs |
|
|
|
@@ -144,6 +150,8 @@ class BasicNN: |
|
|
|
self.save(epoch + 1) |
|
|
|
if self.stop_loss is not None and loss_value < self.stop_loss: |
|
|
|
break |
|
|
|
if self.scheduler is not None: |
|
|
|
self.scheduler.step() |
|
|
|
print_log(f"model loss: {loss_value:.5f}", logger="current") |
|
|
|
return self |
|
|
|
|
|
|
|
@@ -208,7 +216,9 @@ class BasicNN: |
|
|
|
for data, target in data_loader: |
|
|
|
data, target = data.to(device), target.to(device) |
|
|
|
out = model(data) |
|
|
|
loss = loss_fn(out, target) |
|
|
|
proba = torch.nn.functional.softmax(out, dim=1) |
|
|
|
entropy = -torch.sum(proba * torch.log(proba + 1e-5), dim=1).mean() |
|
|
|
loss = loss_fn(out, target) - 0.3 * entropy |
|
|
|
|
|
|
|
optimizer.zero_grad() |
|
|
|
loss.backward() |
|
|
|
|