| @@ -411,7 +411,7 @@ class BasicModel: | |||||
| data_loader = self._data_loader(X) | data_loader = self._data_loader(X) | ||||
| return self._predict(data_loader).softmax(axis=1).cpu().numpy() | return self._predict(data_loader).softmax(axis=1).cpu().numpy() | ||||
| def _val(self, data_loader): | |||||
| def _score(self, data_loader): | |||||
| model = self.model | model = self.model | ||||
| criterion = self.criterion | criterion = self.criterion | ||||
| device = self.device | device = self.device | ||||
| @@ -441,7 +441,7 @@ class BasicModel: | |||||
| return mean_loss, accuracy | return mean_loss, accuracy | ||||
| def val( | |||||
| def score( | |||||
| self, | self, | ||||
| data_loader: DataLoader = None, | data_loader: DataLoader = None, | ||||
| X: List[Any] = None, | X: List[Any] = None, | ||||
| @@ -454,7 +454,7 @@ class BasicModel: | |||||
| Parameters | Parameters | ||||
| ---------- | ---------- | ||||
| data_loader : DataLoader, optional | data_loader : DataLoader, optional | ||||
| The data loader used for validation, by default None | |||||
| The data loader used for scoring, by default None | |||||
| X : List[Any], optional | X : List[Any], optional | ||||
| The input data, by default None | The input data, by default None | ||||
| y : List[int], optional | y : List[int], optional | ||||
| @@ -468,44 +468,16 @@ class BasicModel: | |||||
| The accuracy of the model. | The accuracy of the model. | ||||
| """ | """ | ||||
| recorder = self.recorder | recorder = self.recorder | ||||
| recorder.print("Start val ", print_prefix) | |||||
| recorder.print("Start validation ", print_prefix) | |||||
| if data_loader is None: | if data_loader is None: | ||||
| data_loader = self._data_loader(X, y) | data_loader = self._data_loader(X, y) | ||||
| mean_loss, accuracy = self._val(data_loader) | |||||
| mean_loss, accuracy = self._score(data_loader) | |||||
| recorder.print( | recorder.print( | ||||
| "[%s] Val loss: %f, accuray: %f" % (print_prefix, mean_loss, accuracy) | |||||
| "[%s] mean loss: %f, accuray: %f" % (print_prefix, mean_loss, accuracy) | |||||
| ) | ) | ||||
| return accuracy | return accuracy | ||||
| def score( | |||||
| self, | |||||
| data_loader: DataLoader = None, | |||||
| X: List[Any] = None, | |||||
| y: List[int] = None, | |||||
| print_prefix: str = "", | |||||
| ) -> float: | |||||
| """ | |||||
| Score the model. | |||||
| Parameters | |||||
| ---------- | |||||
| data_loader : DataLoader, optional | |||||
| The data loader used for scoring, by default None | |||||
| X : List[Any], optional | |||||
| The input data, by default None | |||||
| y : List[int], optional | |||||
| The target data, by default None | |||||
| print_prefix : str, optional | |||||
| The prefix used for printing, by default "" | |||||
| Returns | |||||
| ------- | |||||
| float | |||||
| The accuracy of the model. | |||||
| """ | |||||
| return self.val(data_loader, X, y, print_prefix) | |||||
| def _data_loader( | def _data_loader( | ||||
| self, | self, | ||||
| X: List[Any], | X: List[Any], | ||||