|
|
|
@@ -10,10 +10,6 @@ |
|
|
|
# |
|
|
|
# ================================================================# |
|
|
|
|
|
|
|
import sys |
|
|
|
|
|
|
|
sys.path.append("..") |
|
|
|
|
|
|
|
import torch |
|
|
|
import numpy |
|
|
|
from torch.utils.data import DataLoader |
|
|
|
@@ -21,12 +17,12 @@ from ..utils.logger import print_log |
|
|
|
from ..dataset import ClassificationDataset |
|
|
|
|
|
|
|
import os |
|
|
|
from typing import List, Any, T, Optional, Callable |
|
|
|
from typing import List, Any, T, Optional, Callable, Tuple |
|
|
|
|
|
|
|
|
|
|
|
class BasicNN: |
|
|
|
""" |
|
|
|
Wrap NN models into the form of an sklearn estimator |
|
|
|
Wrap NN models into the form of an sklearn estimator. |
|
|
|
|
|
|
|
Parameters |
|
|
|
---------- |
|
|
|
@@ -34,83 +30,35 @@ class BasicNN: |
|
|
|
The PyTorch model to be trained or used for prediction. |
|
|
|
criterion : torch.nn.Module |
|
|
|
The loss function used for training. |
|
|
|
optimizer : torch.nn.Module |
|
|
|
optimizer : torch.optim.Optimizer |
|
|
|
The optimizer used for training. |
|
|
|
device : torch.device, optional |
|
|
|
The device on which the model will be trained or used for prediction, by default torch.decive("cpu"). |
|
|
|
The device on which the model will be trained or used for prediction, by default torch.device("cpu"). |
|
|
|
batch_size : int, optional |
|
|
|
The batch size used for training, by default 1. |
|
|
|
The batch size used for training, by default 32. |
|
|
|
num_epochs : int, optional |
|
|
|
The number of epochs used for training, by default 1. |
|
|
|
stop_loss : Optional[float], optional |
|
|
|
The loss value at which to stop training, by default 0.01. |
|
|
|
num_workers : int, optional |
|
|
|
num_workers : int |
|
|
|
The number of workers used for loading data, by default 0. |
|
|
|
save_interval : Optional[int], optional |
|
|
|
The interval at which to save the model during training, by default None. |
|
|
|
save_dir : Optional[str], optional |
|
|
|
The directory in which to save the model during training, by default None. |
|
|
|
transform : Callable[..., Any], optional |
|
|
|
A function/transform that takes in an object and returns a transformed version. Defaults to None. |
|
|
|
A function/transform that takes in an object and returns a transformed version, by default None. |
|
|
|
collate_fn : Callable[[List[T]], Any], optional |
|
|
|
The function used to collate data, by default None. |
|
|
|
|
|
|
|
Attributes |
|
|
|
---------- |
|
|
|
model : torch.nn.Module |
|
|
|
The PyTorch model to be trained or used for prediction. |
|
|
|
batch_size : int |
|
|
|
The batch size used for training. |
|
|
|
num_epochs : int |
|
|
|
The number of epochs used for training. |
|
|
|
stop_loss : Optional[float] |
|
|
|
The loss value at which to stop training. |
|
|
|
num_workers : int |
|
|
|
The number of workers used for loading data. |
|
|
|
criterion : torch.nn.Module |
|
|
|
The loss function used for training. |
|
|
|
optimizer : torch.nn.Module |
|
|
|
The optimizer used for training. |
|
|
|
transform : Callable[..., Any] |
|
|
|
The transformation function used for data augmentation. |
|
|
|
device : torch.device |
|
|
|
The device on which the model will be trained or used for prediction. |
|
|
|
save_interval : Optional[int] |
|
|
|
The interval at which to save the model during training. |
|
|
|
save_dir : Optional[str] |
|
|
|
The directory in which to save the model during training. |
|
|
|
collate_fn : Callable[[List[T]], Any] |
|
|
|
The function used to collate data. |
|
|
|
|
|
|
|
Methods |
|
|
|
------- |
|
|
|
fit(data_loader=None, X=None, y=None) |
|
|
|
Train the model. |
|
|
|
train_epoch(data_loader) |
|
|
|
Train the model for one epoch. |
|
|
|
predict(data_loader=None, X=None, print_prefix="") |
|
|
|
Predict the class of the input data. |
|
|
|
predict_proba(data_loader=None, X=None, print_prefix="") |
|
|
|
Predict the probability of each class for the input data. |
|
|
|
val(data_loader=None, X=None, y=None, print_prefix="") |
|
|
|
Validate the model. |
|
|
|
score(data_loader=None, X=None, y=None, print_prefix="") |
|
|
|
Score the model. |
|
|
|
_data_loader(X, y=None) |
|
|
|
Generate the data_loader. |
|
|
|
save(epoch_id, save_dir="") |
|
|
|
Save the model. |
|
|
|
load(epoch_id, load_dir="") |
|
|
|
Load the model. |
|
|
|
""" |
|
|
|
|
|
|
|
def __init__( |
|
|
|
self, |
|
|
|
model: torch.nn.Module, |
|
|
|
criterion: torch.nn.Module, |
|
|
|
optimizer: torch.nn.Module, |
|
|
|
optimizer: torch.optim.Optimizer, |
|
|
|
device: torch.device = torch.device("cpu"), |
|
|
|
batch_size: int = 1, |
|
|
|
batch_size: int = 32, |
|
|
|
num_epochs: int = 1, |
|
|
|
stop_loss: Optional[float] = 0.01, |
|
|
|
num_workers: int = 0, |
|
|
|
@@ -118,38 +66,46 @@ class BasicNN: |
|
|
|
save_dir: Optional[str] = None, |
|
|
|
transform: Callable[..., Any] = None, |
|
|
|
collate_fn: Callable[[List[T]], Any] = None, |
|
|
|
): |
|
|
|
) -> None: |
|
|
|
self.model = model.to(device) |
|
|
|
|
|
|
|
self.criterion = criterion |
|
|
|
self.optimizer = optimizer |
|
|
|
self.device = device |
|
|
|
self.batch_size = batch_size |
|
|
|
self.num_epochs = num_epochs |
|
|
|
self.stop_loss = stop_loss |
|
|
|
self.num_workers = num_workers |
|
|
|
|
|
|
|
self.criterion = criterion |
|
|
|
self.optimizer = optimizer |
|
|
|
self.transform = transform |
|
|
|
self.device = device |
|
|
|
|
|
|
|
self.save_interval = save_interval |
|
|
|
self.save_dir = save_dir |
|
|
|
self.transform = transform |
|
|
|
self.collate_fn = collate_fn |
|
|
|
|
|
|
|
def _fit(self, data_loader, n_epoch, stop_loss): |
|
|
|
min_loss = 1e10 |
|
|
|
for epoch in range(n_epoch): |
|
|
|
def _fit(self, data_loader) -> float: |
|
|
|
""" |
|
|
|
Internal method to fit the model on data for n epochs, with early stopping. |
|
|
|
|
|
|
|
Parameters |
|
|
|
---------- |
|
|
|
data_loader : DataLoader |
|
|
|
Data loader providing training samples. |
|
|
|
|
|
|
|
Returns |
|
|
|
------- |
|
|
|
float |
|
|
|
The loss value of the trained model. |
|
|
|
""" |
|
|
|
loss_value = 1e9 |
|
|
|
for epoch in range(self.num_epochs): |
|
|
|
loss_value = self.train_epoch(data_loader) |
|
|
|
if min_loss < 0 or loss_value < min_loss: |
|
|
|
min_loss = loss_value |
|
|
|
if self.save_interval is not None and (epoch + 1) % self.save_interval == 0: |
|
|
|
if self.save_dir is None: |
|
|
|
raise ValueError( |
|
|
|
"save_dir should not be None if save_interval is not None" |
|
|
|
"save_dir should not be None if save_interval is not None." |
|
|
|
) |
|
|
|
self.save(epoch + 1, self.save_dir) |
|
|
|
if stop_loss is not None and loss_value < stop_loss: |
|
|
|
self.save(epoch + 1) |
|
|
|
if self.stop_loss is not None and loss_value < self.stop_loss: |
|
|
|
break |
|
|
|
return min_loss |
|
|
|
return loss_value |
|
|
|
|
|
|
|
def fit( |
|
|
|
self, data_loader: DataLoader = None, X: List[Any] = None, y: List[int] = None |
|
|
|
@@ -160,11 +116,11 @@ class BasicNN: |
|
|
|
Parameters |
|
|
|
---------- |
|
|
|
data_loader : DataLoader, optional |
|
|
|
The data loader used for training, by default None |
|
|
|
The data loader used for training, by default None. |
|
|
|
X : List[Any], optional |
|
|
|
The input data, by default None |
|
|
|
The input data, by default None. |
|
|
|
y : List[int], optional |
|
|
|
The target data, by default None |
|
|
|
The target data, by default None. |
|
|
|
|
|
|
|
Returns |
|
|
|
------- |
|
|
|
@@ -172,10 +128,13 @@ class BasicNN: |
|
|
|
The loss value of the trained model. |
|
|
|
""" |
|
|
|
if data_loader is None: |
|
|
|
data_loader = self._data_loader(X, y) |
|
|
|
return self._fit(data_loader, self.num_epochs, self.stop_loss) |
|
|
|
if X is None: |
|
|
|
raise ValueError("data_loader and X can not be None simultaneously.") |
|
|
|
else: |
|
|
|
data_loader = self._data_loader(X, y) |
|
|
|
return self._fit(data_loader) |
|
|
|
|
|
|
|
def train_epoch(self, data_loader: DataLoader): |
|
|
|
def train_epoch(self, data_loader: DataLoader) -> float: |
|
|
|
""" |
|
|
|
Train the model for one epoch. |
|
|
|
|
|
|
|
@@ -211,7 +170,20 @@ class BasicNN: |
|
|
|
|
|
|
|
return total_loss / total_num |
|
|
|
|
|
|
|
def _predict(self, data_loader): |
|
|
|
def _predict(self, data_loader) -> torch.Tensor: |
|
|
|
""" |
|
|
|
Internal method to predict the outputs given a DataLoader. |
|
|
|
|
|
|
|
Parameters |
|
|
|
---------- |
|
|
|
data_loader : DataLoader |
|
|
|
The DataLoader providing input samples. |
|
|
|
|
|
|
|
Returns |
|
|
|
------- |
|
|
|
torch.Tensor |
|
|
|
Raw output from the model. |
|
|
|
""" |
|
|
|
model = self.model |
|
|
|
device = self.device |
|
|
|
|
|
|
|
@@ -227,10 +199,7 @@ class BasicNN: |
|
|
|
return torch.cat(results, axis=0) |
|
|
|
|
|
|
|
def predict( |
|
|
|
self, |
|
|
|
data_loader: DataLoader = None, |
|
|
|
X: List[Any] = None, |
|
|
|
print_prefix: str = "", |
|
|
|
self, data_loader: DataLoader = None, X: List[Any] = None |
|
|
|
) -> numpy.ndarray: |
|
|
|
""" |
|
|
|
Predict the class of the input data. |
|
|
|
@@ -238,11 +207,9 @@ class BasicNN: |
|
|
|
Parameters |
|
|
|
---------- |
|
|
|
data_loader : DataLoader, optional |
|
|
|
The data loader used for prediction, by default None |
|
|
|
The data loader used for prediction, by default None. |
|
|
|
X : List[Any], optional |
|
|
|
The input data, by default None |
|
|
|
print_prefix : str, optional |
|
|
|
The prefix used for printing, by default "" |
|
|
|
The input data, by default None. |
|
|
|
|
|
|
|
Returns |
|
|
|
------- |
|
|
|
@@ -255,10 +222,7 @@ class BasicNN: |
|
|
|
return self._predict(data_loader).argmax(axis=1).cpu().numpy() |
|
|
|
|
|
|
|
def predict_proba( |
|
|
|
self, |
|
|
|
data_loader: DataLoader = None, |
|
|
|
X: List[Any] = None, |
|
|
|
print_prefix: str = "", |
|
|
|
self, data_loader: DataLoader = None, X: List[Any] = None |
|
|
|
) -> numpy.ndarray: |
|
|
|
""" |
|
|
|
Predict the probability of each class for the input data. |
|
|
|
@@ -266,11 +230,9 @@ class BasicNN: |
|
|
|
Parameters |
|
|
|
---------- |
|
|
|
data_loader : DataLoader, optional |
|
|
|
The data loader used for prediction, by default None |
|
|
|
The data loader used for prediction, by default None. |
|
|
|
X : List[Any], optional |
|
|
|
The input data, by default None |
|
|
|
print_prefix : str, optional |
|
|
|
The prefix used for printing, by default "" |
|
|
|
The input data, by default None. |
|
|
|
|
|
|
|
Returns |
|
|
|
------- |
|
|
|
@@ -282,7 +244,21 @@ class BasicNN: |
|
|
|
data_loader = self._data_loader(X) |
|
|
|
return self._predict(data_loader).softmax(axis=1).cpu().numpy() |
|
|
|
|
|
|
|
def _score(self, data_loader): |
|
|
|
def _score(self, data_loader) -> Tuple[float, float]: |
|
|
|
""" |
|
|
|
Internal method to compute loss and accuracy for the data provided through a DataLoader. |
|
|
|
|
|
|
|
Parameters |
|
|
|
---------- |
|
|
|
data_loader : DataLoader |
|
|
|
Data loader to use for evaluation. |
|
|
|
|
|
|
|
Returns |
|
|
|
------- |
|
|
|
Tuple[float, float] |
|
|
|
mean_loss: float, The mean loss of the model on the provided data. |
|
|
|
accuracy: float, The accuracy of the model on the provided data. |
|
|
|
""" |
|
|
|
model = self.model |
|
|
|
criterion = self.criterion |
|
|
|
device = self.device |
|
|
|
@@ -298,9 +274,9 @@ class BasicNN: |
|
|
|
out = model(data) |
|
|
|
|
|
|
|
if len(out.shape) > 1: |
|
|
|
correct_num = sum(target == out.argmax(axis=1)).item() |
|
|
|
correct_num = (target == out.argmax(axis=1)).sum().item() |
|
|
|
else: |
|
|
|
correct_num = sum(target == (out > 0.5)).item() |
|
|
|
correct_num = (target == (out > 0.5)).sum().item() |
|
|
|
loss = criterion(out, target) |
|
|
|
total_loss += loss.item() * data.size(0) |
|
|
|
|
|
|
|
@@ -313,11 +289,7 @@ class BasicNN: |
|
|
|
return mean_loss, accuracy |
|
|
|
|
|
|
|
def score( |
|
|
|
self, |
|
|
|
data_loader: DataLoader = None, |
|
|
|
X: List[Any] = None, |
|
|
|
y: List[int] = None, |
|
|
|
print_prefix: str = "", |
|
|
|
self, data_loader: DataLoader = None, X: List[Any] = None, y: List[int] = None |
|
|
|
) -> float: |
|
|
|
""" |
|
|
|
Validate the model. |
|
|
|
@@ -325,25 +297,25 @@ class BasicNN: |
|
|
|
Parameters |
|
|
|
---------- |
|
|
|
data_loader : DataLoader, optional |
|
|
|
The data loader used for scoring, by default None |
|
|
|
The data loader used for scoring, by default None. |
|
|
|
X : List[Any], optional |
|
|
|
The input data, by default None |
|
|
|
The input data, by default None. |
|
|
|
y : List[int], optional |
|
|
|
The target data, by default None |
|
|
|
print_prefix : str, optional |
|
|
|
The prefix used for printing, by default "" |
|
|
|
The target data, by default None. |
|
|
|
|
|
|
|
Returns |
|
|
|
------- |
|
|
|
float |
|
|
|
The accuracy of the model. |
|
|
|
""" |
|
|
|
print_log(f"Start machine learning model validation", logger="current") |
|
|
|
print_log("Start machine learning model validation", logger="current") |
|
|
|
|
|
|
|
if data_loader is None: |
|
|
|
data_loader = self._data_loader(X, y) |
|
|
|
mean_loss, accuracy = self._score(data_loader) |
|
|
|
print_log(f"{print_prefix} mean loss: {mean_loss:.3f}, accuray: {accuracy:.3f}", logger="current") |
|
|
|
print_log( |
|
|
|
f"mean loss: {mean_loss:.3f}, accuray: {accuracy:.3f}", logger="current" |
|
|
|
) |
|
|
|
return accuracy |
|
|
|
|
|
|
|
def _data_loader( |
|
|
|
@@ -352,38 +324,39 @@ class BasicNN: |
|
|
|
y: List[int] = None, |
|
|
|
) -> DataLoader: |
|
|
|
""" |
|
|
|
Generate data_loader for user provided data. |
|
|
|
Generate a DataLoader for user-provided input and target data. |
|
|
|
|
|
|
|
Parameters |
|
|
|
---------- |
|
|
|
X : List[Any] |
|
|
|
The input data. |
|
|
|
Input samples. |
|
|
|
y : List[int], optional |
|
|
|
The target data, by default None |
|
|
|
Target labels. If None, dummy labels are created, by default None. |
|
|
|
|
|
|
|
Returns |
|
|
|
------- |
|
|
|
DataLoader |
|
|
|
The data loader. |
|
|
|
A DataLoader providing batches of (X, y) pairs. |
|
|
|
""" |
|
|
|
collate_fn = self.collate_fn |
|
|
|
transform = self.transform |
|
|
|
|
|
|
|
if X is None: |
|
|
|
raise ValueError("X should not be None.") |
|
|
|
if y is None: |
|
|
|
y = [0] * len(X) |
|
|
|
dataset = ClassificationDataset(X, y, transform=transform) |
|
|
|
sampler = None |
|
|
|
if not (len(y) == len(X)): |
|
|
|
raise ValueError("X and y should have equal length.") |
|
|
|
|
|
|
|
dataset = ClassificationDataset(X, y, transform=self.transform) |
|
|
|
data_loader = DataLoader( |
|
|
|
dataset, |
|
|
|
batch_size=self.batch_size, |
|
|
|
shuffle=False, |
|
|
|
sampler=sampler, |
|
|
|
shuffle=True, |
|
|
|
num_workers=int(self.num_workers), |
|
|
|
collate_fn=collate_fn, |
|
|
|
collate_fn=self.collate_fn, |
|
|
|
) |
|
|
|
return data_loader |
|
|
|
|
|
|
|
def save(self, epoch_id: int = 0, save_dir: str = None, save_path: str = None): |
|
|
|
def save(self, epoch_id: int = 0, save_path: str = None) -> None: |
|
|
|
""" |
|
|
|
Save the model and the optimizer. |
|
|
|
|
|
|
|
@@ -391,15 +364,20 @@ class BasicNN: |
|
|
|
---------- |
|
|
|
epoch_id : int |
|
|
|
The epoch id. |
|
|
|
save_dir : str, optional |
|
|
|
The directory to save the model, by default "" |
|
|
|
save_path : str, optional |
|
|
|
The path to save the model, by default None. |
|
|
|
""" |
|
|
|
if save_dir and (not os.path.exists(save_dir)): |
|
|
|
os.makedirs(save_dir) |
|
|
|
print_log(f"Checkpoints will be saved to {save_dir}", logger="current") |
|
|
|
|
|
|
|
if self.save_dir is None and save_path is None: |
|
|
|
raise ValueError( |
|
|
|
"'save_dir' and 'save_path' should not be None simultaneously." |
|
|
|
) |
|
|
|
|
|
|
|
if save_path is None: |
|
|
|
save_path = os.path.join(save_dir, str(epoch_id) + ".pth") |
|
|
|
save_path = os.path.join( |
|
|
|
self.save_dir, f"model_checkpoint_epoch_{epoch_id}.pth" |
|
|
|
) |
|
|
|
if not os.path.exists(self.save_dir): |
|
|
|
os.makedirs(self.save_dir) |
|
|
|
|
|
|
|
print_log(f"Checkpoints will be saved to {save_path}", logger="current") |
|
|
|
|
|
|
|
@@ -410,29 +388,25 @@ class BasicNN: |
|
|
|
|
|
|
|
torch.save(save_parma_dic, save_path) |
|
|
|
|
|
|
|
def load(self, epoch_id: int = 0, load_dir: str = "", load_path: str = None): |
|
|
|
def load(self, load_path: str = "") -> None: |
|
|
|
""" |
|
|
|
Load the model and the optimizer. |
|
|
|
|
|
|
|
Parameters |
|
|
|
---------- |
|
|
|
epoch_id : int |
|
|
|
The epoch id. |
|
|
|
load_dir : str, optional |
|
|
|
The directory to load the model, by default "" |
|
|
|
load_path : str |
|
|
|
The directory to load the model, by default "". |
|
|
|
""" |
|
|
|
|
|
|
|
if load_path is not None: |
|
|
|
print_log(f"Loads checkpoint by local backend from path: {load_path}", logger="current") |
|
|
|
else: |
|
|
|
print_log(f"Loads checkpoint by local backend from dir: {load_dir}", logger="current") |
|
|
|
load_path = os.path.join(load_dir, str(epoch_id) + ".pth") |
|
|
|
|
|
|
|
if load_path is None: |
|
|
|
raise ValueError("Load path should not be None.") |
|
|
|
|
|
|
|
print_log( |
|
|
|
f"Loads checkpoint by local backend from path: {load_path}", |
|
|
|
logger="current", |
|
|
|
) |
|
|
|
|
|
|
|
param_dic = torch.load(load_path) |
|
|
|
self.model.load_state_dict(param_dic["model"]) |
|
|
|
if "optimizer" in param_dic.keys(): |
|
|
|
self.optimizer.load_state_dict(param_dic["optimizer"]) |
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
pass |