Browse Source

[FIX] delete redundant method val in BasicModel

pull/3/head
Gao Enhao 2 years ago
parent
commit
e2faec4740
1 changed files with 6 additions and 34 deletions
  1. +6
    -34
      abl/models/basic_model.py

+ 6
- 34
abl/models/basic_model.py View File

@@ -411,7 +411,7 @@ class BasicModel:
data_loader = self._data_loader(X)
return self._predict(data_loader).softmax(axis=1).cpu().numpy()

def _val(self, data_loader):
def _score(self, data_loader):
model = self.model
criterion = self.criterion
device = self.device
@@ -441,7 +441,7 @@ class BasicModel:

return mean_loss, accuracy

def val(
def score(
self,
data_loader: DataLoader = None,
X: List[Any] = None,
@@ -454,7 +454,7 @@ class BasicModel:
Parameters
----------
data_loader : DataLoader, optional
The data loader used for validation, by default None
The data loader used for scoring, by default None
X : List[Any], optional
The input data, by default None
y : List[int], optional
@@ -468,44 +468,16 @@ class BasicModel:
The accuracy of the model.
"""
recorder = self.recorder
recorder.print("Start val ", print_prefix)
recorder.print("Start validation ", print_prefix)

if data_loader is None:
data_loader = self._data_loader(X, y)
mean_loss, accuracy = self._val(data_loader)
mean_loss, accuracy = self._score(data_loader)
recorder.print(
"[%s] Val loss: %f, accuray: %f" % (print_prefix, mean_loss, accuracy)
"[%s] mean loss: %f, accuray: %f" % (print_prefix, mean_loss, accuracy)
)
return accuracy

def score(
self,
data_loader: DataLoader = None,
X: List[Any] = None,
y: List[int] = None,
print_prefix: str = "",
) -> float:
"""
Score the model.

Parameters
----------
data_loader : DataLoader, optional
The data loader used for scoring, by default None
X : List[Any], optional
The input data, by default None
y : List[int], optional
The target data, by default None
print_prefix : str, optional
The prefix used for printing, by default ""

Returns
-------
float
The accuracy of the model.
"""
return self.val(data_loader, X, y, print_prefix)

def _data_loader(
self,
X: List[Any],


Loading…
Cancel
Save