|
|
|
@@ -106,6 +106,9 @@ class BasicNN: |
|
|
|
self.test_transform = test_transform |
|
|
|
self.collate_fn = collate_fn |
|
|
|
|
|
|
|
if self.save_interval is not None and self.save_dir is None: |
|
|
|
raise ValueError("save_dir should not be None if save_interval is not None.") |
|
|
|
|
|
|
|
if self.train_transform is not None and self.test_transform is None: |
|
|
|
print_log( |
|
|
|
"Transform used in the training phase will be used in prediction.", |
|
|
|
@@ -138,8 +141,6 @@ class BasicNN: |
|
|
|
for epoch in range(self.num_epochs): |
|
|
|
loss_value = self.train_epoch(data_loader) |
|
|
|
if self.save_interval is not None and (epoch + 1) % self.save_interval == 0: |
|
|
|
if self.save_dir is None: |
|
|
|
raise ValueError("save_dir should not be None if save_interval is not None.") |
|
|
|
self.save(epoch + 1) |
|
|
|
if self.stop_loss is not None and loss_value < self.stop_loss: |
|
|
|
break |
|
|
|
|