Browse Source

[MNT] resolve comments in basic_nn.py cont

pull/1/head
Gao Enhao 2 years ago
parent
commit
b7e2c94ffa
1 changed files with 3 additions and 2 deletions
  1. +3
    -2
      abl/learning/basic_nn.py

+ 3
- 2
abl/learning/basic_nn.py View File

@@ -106,6 +106,9 @@ class BasicNN:
self.test_transform = test_transform
self.collate_fn = collate_fn

if self.save_interval is not None and self.save_dir is None:
raise ValueError("save_dir should not be None if save_interval is not None.")

if self.train_transform is not None and self.test_transform is None:
print_log(
"Transform used in the training phase will be used in prediction.",
@@ -138,8 +141,6 @@ class BasicNN:
for epoch in range(self.num_epochs):
loss_value = self.train_epoch(data_loader)
if self.save_interval is not None and (epoch + 1) % self.save_interval == 0:
if self.save_dir is None:
raise ValueError("save_dir should not be None if save_interval is not None.")
self.save(epoch + 1)
if self.stop_loss is not None and loss_value < self.stop_loss:
break


Loading…
Cancel
Save