diff --git a/abl/learning/basic_nn.py b/abl/learning/basic_nn.py index 7233b51..83a6499 100644 --- a/abl/learning/basic_nn.py +++ b/abl/learning/basic_nn.py @@ -106,6 +106,9 @@ class BasicNN: self.test_transform = test_transform self.collate_fn = collate_fn + if self.save_interval is not None and self.save_dir is None: + raise ValueError("save_dir should not be None if save_interval is not None.") + if self.train_transform is not None and self.test_transform is None: print_log( "Transform used in the training phase will be used in prediction.", @@ -138,8 +141,6 @@ class BasicNN: for epoch in range(self.num_epochs): loss_value = self.train_epoch(data_loader) if self.save_interval is not None and (epoch + 1) % self.save_interval == 0: - if self.save_dir is None: - raise ValueError("save_dir should not be None if save_interval is not None.") self.save(epoch + 1) if self.stop_loss is not None and loss_value < self.stop_loss: break