diff --git a/abl/models/basic_model.py b/abl/models/basic_model.py index 9f07e2c..f26a11a 100644 --- a/abl/models/basic_model.py +++ b/abl/models/basic_model.py @@ -411,7 +411,7 @@ class BasicModel: data_loader = self._data_loader(X) return self._predict(data_loader).softmax(axis=1).cpu().numpy() - def _val(self, data_loader): + def _score(self, data_loader): model = self.model criterion = self.criterion device = self.device @@ -441,7 +441,7 @@ class BasicModel: return mean_loss, accuracy - def val( + def score( self, data_loader: DataLoader = None, X: List[Any] = None, @@ -454,7 +454,7 @@ class BasicModel: Parameters ---------- data_loader : DataLoader, optional - The data loader used for validation, by default None + The data loader used for scoring, by default None X : List[Any], optional The input data, by default None y : List[int], optional @@ -468,44 +468,16 @@ class BasicModel: The accuracy of the model. """ recorder = self.recorder - recorder.print("Start val ", print_prefix) + recorder.print("Start validation ", print_prefix) if data_loader is None: data_loader = self._data_loader(X, y) - mean_loss, accuracy = self._val(data_loader) + mean_loss, accuracy = self._score(data_loader) recorder.print( - "[%s] Val loss: %f, accuray: %f" % (print_prefix, mean_loss, accuracy) + "[%s] mean loss: %f, accuray: %f" % (print_prefix, mean_loss, accuracy) ) return accuracy - def score( - self, - data_loader: DataLoader = None, - X: List[Any] = None, - y: List[int] = None, - print_prefix: str = "", - ) -> float: - """ - Score the model. - - Parameters - ---------- - data_loader : DataLoader, optional - The data loader used for scoring, by default None - X : List[Any], optional - The input data, by default None - y : List[int], optional - The target data, by default None - print_prefix : str, optional - The prefix used for printing, by default "" - - Returns - ------- - float - The accuracy of the model. - """ - return self.val(data_loader, X, y, print_prefix) - def _data_loader( self, X: List[Any],