| @@ -411,7 +411,7 @@ class BasicModel: | |||
| data_loader = self._data_loader(X) | |||
| return self._predict(data_loader).softmax(axis=1).cpu().numpy() | |||
| def _val(self, data_loader): | |||
| def _score(self, data_loader): | |||
| model = self.model | |||
| criterion = self.criterion | |||
| device = self.device | |||
| @@ -441,7 +441,7 @@ class BasicModel: | |||
| return mean_loss, accuracy | |||
| def val( | |||
| def score( | |||
| self, | |||
| data_loader: DataLoader = None, | |||
| X: List[Any] = None, | |||
| @@ -454,7 +454,7 @@ class BasicModel: | |||
| Parameters | |||
| ---------- | |||
| data_loader : DataLoader, optional | |||
| The data loader used for validation, by default None | |||
| The data loader used for scoring, by default None | |||
| X : List[Any], optional | |||
| The input data, by default None | |||
| y : List[int], optional | |||
| @@ -468,44 +468,16 @@ class BasicModel: | |||
| The accuracy of the model. | |||
| """ | |||
| recorder = self.recorder | |||
| recorder.print("Start val ", print_prefix) | |||
| recorder.print("Start validation ", print_prefix) | |||
| if data_loader is None: | |||
| data_loader = self._data_loader(X, y) | |||
| mean_loss, accuracy = self._val(data_loader) | |||
| mean_loss, accuracy = self._score(data_loader) | |||
| recorder.print( | |||
| "[%s] Val loss: %f, accuray: %f" % (print_prefix, mean_loss, accuracy) | |||
| "[%s] mean loss: %f, accuray: %f" % (print_prefix, mean_loss, accuracy) | |||
| ) | |||
| return accuracy | |||
| def score( | |||
| self, | |||
| data_loader: DataLoader = None, | |||
| X: List[Any] = None, | |||
| y: List[int] = None, | |||
| print_prefix: str = "", | |||
| ) -> float: | |||
| """ | |||
| Score the model. | |||
| Parameters | |||
| ---------- | |||
| data_loader : DataLoader, optional | |||
| The data loader used for scoring, by default None | |||
| X : List[Any], optional | |||
| The input data, by default None | |||
| y : List[int], optional | |||
| The target data, by default None | |||
| print_prefix : str, optional | |||
| The prefix used for printing, by default "" | |||
| Returns | |||
| ------- | |||
| float | |||
| The accuracy of the model. | |||
| """ | |||
| return self.val(data_loader, X, y, print_prefix) | |||
| def _data_loader( | |||
| self, | |||
| X: List[Any], | |||