|
|
|
@@ -0,0 +1,560 @@ |
|
|
|
# coding: utf-8 |
|
|
|
# ================================================================# |
|
|
|
# Copyright (C) 2020 Freecss All rights reserved. |
|
|
|
# |
|
|
|
# File Name :basic_model.py |
|
|
|
# Author :freecss |
|
|
|
# Email :karlfreecss@gmail.com |
|
|
|
# Created Date :2020/11/21 |
|
|
|
# Description : |
|
|
|
# |
|
|
|
# ================================================================# |
|
|
|
|
|
|
|
import sys |
|
|
|
|
|
|
|
sys.path.append("..") |
|
|
|
|
|
|
|
import torch |
|
|
|
import numpy |
|
|
|
from torch.utils.data import Dataset, DataLoader |
|
|
|
|
|
|
|
import os |
|
|
|
from multiprocessing import Pool |
|
|
|
from typing import List, Any, T, Tuple, Optional, Callable |
|
|
|
|
|
|
|
|
|
|
|
class BasicDataset(Dataset): |
|
|
|
def __init__(self, X: List[Any], Y: List[Any]): |
|
|
|
"""Initialize a basic dataset. |
|
|
|
|
|
|
|
Parameters |
|
|
|
---------- |
|
|
|
X : List[Any] |
|
|
|
A list of objects representing the input data. |
|
|
|
Y : List[Any] |
|
|
|
A list of objects representing the output data. |
|
|
|
""" |
|
|
|
self.X = X |
|
|
|
self.Y = Y |
|
|
|
|
|
|
|
def __len__(self): |
|
|
|
"""Return the length of the dataset. |
|
|
|
|
|
|
|
Returns |
|
|
|
------- |
|
|
|
int |
|
|
|
The length of the dataset. |
|
|
|
""" |
|
|
|
return len(self.X) |
|
|
|
|
|
|
|
def __getitem__(self, index: int) -> Tuple[Any, Any]: |
|
|
|
"""Get an item from the dataset. |
|
|
|
|
|
|
|
Parameters |
|
|
|
---------- |
|
|
|
index : int |
|
|
|
The index of the item to retrieve. |
|
|
|
|
|
|
|
Returns |
|
|
|
------- |
|
|
|
Tuple[Any, Any] |
|
|
|
A tuple containing the input and output data at the specified index. |
|
|
|
""" |
|
|
|
if index >= len(self): |
|
|
|
raise ValueError("index range error") |
|
|
|
|
|
|
|
img = self.X[index] |
|
|
|
label = self.Y[index] |
|
|
|
|
|
|
|
return (img, label) |
|
|
|
|
|
|
|
|
|
|
|
class XYDataset(Dataset): |
|
|
|
def __init__(self, X: List[Any], Y: List[int], transform: Callable[..., Any] = None): |
|
|
|
""" |
|
|
|
Initialize the dataset used for classification task. |
|
|
|
|
|
|
|
Parameters |
|
|
|
---------- |
|
|
|
X : List[Any] |
|
|
|
The input data. |
|
|
|
Y : List[int] |
|
|
|
The target data. |
|
|
|
transform : Callable[..., Any], optional |
|
|
|
A function/transform that takes in an object and returns a transformed version. Defaults to None. |
|
|
|
""" |
|
|
|
self.X = X |
|
|
|
self.Y = torch.LongTensor(Y) |
|
|
|
|
|
|
|
self.n_sample = len(X) |
|
|
|
self.transform = transform |
|
|
|
|
|
|
|
def __len__(self) -> int: |
|
|
|
""" |
|
|
|
Return the length of the dataset. |
|
|
|
|
|
|
|
Returns |
|
|
|
------- |
|
|
|
int |
|
|
|
The length of the dataset. |
|
|
|
""" |
|
|
|
return len(self.X) |
|
|
|
|
|
|
|
def __getitem__(self, index: int) -> Tuple[Any, torch.Tensor]: |
|
|
|
""" |
|
|
|
Get the item at the given index. |
|
|
|
|
|
|
|
Parameters |
|
|
|
---------- |
|
|
|
index : int |
|
|
|
The index of the item to get. |
|
|
|
|
|
|
|
Returns |
|
|
|
------- |
|
|
|
Tuple[Any, torch.Tensor] |
|
|
|
A tuple containing the object and its label. |
|
|
|
""" |
|
|
|
if index >= len(self): |
|
|
|
raise ValueError("index range error") |
|
|
|
|
|
|
|
img = self.X[index] |
|
|
|
if self.transform is not None: |
|
|
|
img = self.transform(img) |
|
|
|
|
|
|
|
label = self.Y[index] |
|
|
|
|
|
|
|
return (img, label) |
|
|
|
|
|
|
|
|
|
|
|
class FakeRecorder: |
|
|
|
def __init__(self): |
|
|
|
pass |
|
|
|
|
|
|
|
def print(self, *x): |
|
|
|
pass |
|
|
|
|
|
|
|
|
|
|
|
class BasicNN: |
|
|
|
""" |
|
|
|
Wrap NN models into the form of an sklearn estimator |
|
|
|
|
|
|
|
Parameters |
|
|
|
---------- |
|
|
|
model : torch.nn.Module |
|
|
|
The PyTorch model to be trained or used for prediction. |
|
|
|
criterion : torch.nn.Module |
|
|
|
The loss function used for training. |
|
|
|
optimizer : torch.nn.Module |
|
|
|
The optimizer used for training. |
|
|
|
device : torch.device, optional |
|
|
|
The device on which the model will be trained or used for prediction, by default torch.decive("cpu"). |
|
|
|
batch_size : int, optional |
|
|
|
The batch size used for training, by default 1. |
|
|
|
num_epochs : int, optional |
|
|
|
The number of epochs used for training, by default 1. |
|
|
|
stop_loss : Optional[float], optional |
|
|
|
The loss value at which to stop training, by default 0.01. |
|
|
|
num_workers : int, optional |
|
|
|
The number of workers used for loading data, by default 0. |
|
|
|
save_interval : Optional[int], optional |
|
|
|
The interval at which to save the model during training, by default None. |
|
|
|
save_dir : Optional[str], optional |
|
|
|
The directory in which to save the model during training, by default None. |
|
|
|
transform : Callable[..., Any], optional |
|
|
|
A function/transform that takes in an object and returns a transformed version. Defaults to None. |
|
|
|
collate_fn : Callable[[List[T]], Any], optional |
|
|
|
The function used to collate data, by default None. |
|
|
|
recorder : Any, optional |
|
|
|
The recorder used to record training progress, by default None. |
|
|
|
|
|
|
|
Attributes |
|
|
|
---------- |
|
|
|
model : torch.nn.Module |
|
|
|
The PyTorch model to be trained or used for prediction. |
|
|
|
batch_size : int |
|
|
|
The batch size used for training. |
|
|
|
num_epochs : int |
|
|
|
The number of epochs used for training. |
|
|
|
stop_loss : Optional[float] |
|
|
|
The loss value at which to stop training. |
|
|
|
num_workers : int |
|
|
|
The number of workers used for loading data. |
|
|
|
criterion : torch.nn.Module |
|
|
|
The loss function used for training. |
|
|
|
optimizer : torch.nn.Module |
|
|
|
The optimizer used for training. |
|
|
|
transform : Callable[..., Any] |
|
|
|
The transformation function used for data augmentation. |
|
|
|
device : torch.device |
|
|
|
The device on which the model will be trained or used for prediction. |
|
|
|
recorder : Any |
|
|
|
The recorder used to record training progress. |
|
|
|
save_interval : Optional[int] |
|
|
|
The interval at which to save the model during training. |
|
|
|
save_dir : Optional[str] |
|
|
|
The directory in which to save the model during training. |
|
|
|
collate_fn : Callable[[List[T]], Any] |
|
|
|
The function used to collate data. |
|
|
|
|
|
|
|
Methods |
|
|
|
------- |
|
|
|
fit(data_loader=None, X=None, y=None) |
|
|
|
Train the model. |
|
|
|
train_epoch(data_loader) |
|
|
|
Train the model for one epoch. |
|
|
|
predict(data_loader=None, X=None, print_prefix="") |
|
|
|
Predict the class of the input data. |
|
|
|
predict_proba(data_loader=None, X=None, print_prefix="") |
|
|
|
Predict the probability of each class for the input data. |
|
|
|
val(data_loader=None, X=None, y=None, print_prefix="") |
|
|
|
Validate the model. |
|
|
|
score(data_loader=None, X=None, y=None, print_prefix="") |
|
|
|
Score the model. |
|
|
|
_data_loader(X, y=None) |
|
|
|
Generate the data_loader. |
|
|
|
save(epoch_id, save_dir="") |
|
|
|
Save the model. |
|
|
|
load(epoch_id, load_dir="") |
|
|
|
Load the model. |
|
|
|
""" |
|
|
|
|
|
|
|
def __init__( |
|
|
|
self, |
|
|
|
model: torch.nn.Module, |
|
|
|
criterion: torch.nn.Module, |
|
|
|
optimizer: torch.nn.Module, |
|
|
|
device: torch.device = torch.device("cpu"), |
|
|
|
batch_size: int = 1, |
|
|
|
num_epochs: int = 1, |
|
|
|
stop_loss: Optional[float] = 0.01, |
|
|
|
num_workers: int = 0, |
|
|
|
save_interval: Optional[int] = None, |
|
|
|
save_dir: Optional[str] = None, |
|
|
|
transform: Callable[..., Any] = None, |
|
|
|
collate_fn: Callable[[List[T]], Any] = None, |
|
|
|
recorder=None, |
|
|
|
): |
|
|
|
|
|
|
|
self.model = model.to(device) |
|
|
|
|
|
|
|
self.batch_size = batch_size |
|
|
|
self.num_epochs = num_epochs |
|
|
|
self.stop_loss = stop_loss |
|
|
|
self.num_workers = num_workers |
|
|
|
|
|
|
|
self.criterion = criterion |
|
|
|
self.optimizer = optimizer |
|
|
|
self.transform = transform |
|
|
|
self.device = device |
|
|
|
|
|
|
|
if recorder is None: |
|
|
|
recorder = FakeRecorder() |
|
|
|
self.recorder = recorder |
|
|
|
|
|
|
|
self.save_interval = save_interval |
|
|
|
self.save_dir = save_dir |
|
|
|
self.collate_fn = collate_fn |
|
|
|
|
|
|
|
def _fit(self, data_loader, n_epoch, stop_loss): |
|
|
|
recorder = self.recorder |
|
|
|
recorder.print("model fitting") |
|
|
|
|
|
|
|
min_loss = 1e10 |
|
|
|
for epoch in range(n_epoch): |
|
|
|
loss_value = self.train_epoch(data_loader) |
|
|
|
recorder.print(f"{epoch}/{n_epoch} model training loss is {loss_value}") |
|
|
|
if min_loss < 0 or loss_value < min_loss: |
|
|
|
min_loss = loss_value |
|
|
|
if self.save_interval is not None and (epoch + 1) % self.save_interval == 0: |
|
|
|
if self.save_dir is None: |
|
|
|
raise ValueError( |
|
|
|
"save_dir should not be None if save_interval is not None" |
|
|
|
) |
|
|
|
self.save(epoch + 1, self.save_dir) |
|
|
|
if stop_loss is not None and loss_value < stop_loss: |
|
|
|
break |
|
|
|
recorder.print("Model fitted, minimal loss is ", min_loss) |
|
|
|
return loss_value |
|
|
|
|
|
|
|
def fit( |
|
|
|
self, data_loader: DataLoader = None, X: List[Any] = None, y: List[int] = None |
|
|
|
) -> float: |
|
|
|
""" |
|
|
|
Train the model. |
|
|
|
|
|
|
|
Parameters |
|
|
|
---------- |
|
|
|
data_loader : DataLoader, optional |
|
|
|
The data loader used for training, by default None |
|
|
|
X : List[Any], optional |
|
|
|
The input data, by default None |
|
|
|
y : List[int], optional |
|
|
|
The target data, by default None |
|
|
|
|
|
|
|
Returns |
|
|
|
------- |
|
|
|
float |
|
|
|
The loss value of the trained model. |
|
|
|
""" |
|
|
|
if data_loader is None: |
|
|
|
data_loader = self._data_loader(X, y) |
|
|
|
return self._fit(data_loader, self.num_epochs, self.stop_loss) |
|
|
|
|
|
|
|
def train_epoch(self, data_loader: DataLoader): |
|
|
|
""" |
|
|
|
Train the model for one epoch. |
|
|
|
|
|
|
|
Parameters |
|
|
|
---------- |
|
|
|
data_loader : DataLoader |
|
|
|
The data loader used for training. |
|
|
|
|
|
|
|
Returns |
|
|
|
------- |
|
|
|
float |
|
|
|
The loss value of the trained model. |
|
|
|
""" |
|
|
|
model = self.model |
|
|
|
criterion = self.criterion |
|
|
|
optimizer = self.optimizer |
|
|
|
device = self.device |
|
|
|
|
|
|
|
model.train() |
|
|
|
|
|
|
|
total_loss, total_num = 0.0, 0 |
|
|
|
for data, target in data_loader: |
|
|
|
data, target = data.to(device), target.to(device) |
|
|
|
out = model(data) |
|
|
|
loss = criterion(out, target) |
|
|
|
|
|
|
|
optimizer.zero_grad() |
|
|
|
loss.backward() |
|
|
|
optimizer.step() |
|
|
|
|
|
|
|
total_loss += loss.item() * data.size(0) |
|
|
|
total_num += data.size(0) |
|
|
|
|
|
|
|
return total_loss / total_num |
|
|
|
|
|
|
|
def _predict(self, data_loader): |
|
|
|
model = self.model |
|
|
|
device = self.device |
|
|
|
|
|
|
|
model.eval() |
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
|
results = [] |
|
|
|
for data, _ in data_loader: |
|
|
|
data = data.to(device) |
|
|
|
out = model(data) |
|
|
|
results.append(out) |
|
|
|
|
|
|
|
return torch.cat(results, axis=0) |
|
|
|
|
|
|
|
def predict( |
|
|
|
self, |
|
|
|
data_loader: DataLoader = None, |
|
|
|
X: List[Any] = None, |
|
|
|
print_prefix: str = "", |
|
|
|
) -> numpy.ndarray: |
|
|
|
""" |
|
|
|
Predict the class of the input data. |
|
|
|
|
|
|
|
Parameters |
|
|
|
---------- |
|
|
|
data_loader : DataLoader, optional |
|
|
|
The data loader used for prediction, by default None |
|
|
|
X : List[Any], optional |
|
|
|
The input data, by default None |
|
|
|
print_prefix : str, optional |
|
|
|
The prefix used for printing, by default "" |
|
|
|
|
|
|
|
Returns |
|
|
|
------- |
|
|
|
numpy.ndarray |
|
|
|
The predicted class of the input data. |
|
|
|
""" |
|
|
|
recorder = self.recorder |
|
|
|
recorder.print("Start Predict Class ", print_prefix) |
|
|
|
|
|
|
|
if data_loader is None: |
|
|
|
data_loader = self._data_loader(X) |
|
|
|
return self._predict(data_loader).argmax(axis=1).cpu().numpy() |
|
|
|
|
|
|
|
def predict_proba( |
|
|
|
self, |
|
|
|
data_loader: DataLoader = None, |
|
|
|
X: List[Any] = None, |
|
|
|
print_prefix: str = "", |
|
|
|
) -> numpy.ndarray: |
|
|
|
""" |
|
|
|
Predict the probability of each class for the input data. |
|
|
|
|
|
|
|
Parameters |
|
|
|
---------- |
|
|
|
data_loader : DataLoader, optional |
|
|
|
The data loader used for prediction, by default None |
|
|
|
X : List[Any], optional |
|
|
|
The input data, by default None |
|
|
|
print_prefix : str, optional |
|
|
|
The prefix used for printing, by default "" |
|
|
|
|
|
|
|
Returns |
|
|
|
------- |
|
|
|
numpy.ndarray |
|
|
|
The predicted probability of each class for the input data. |
|
|
|
""" |
|
|
|
recorder = self.recorder |
|
|
|
recorder.print("Start Predict Probability ", print_prefix) |
|
|
|
|
|
|
|
if data_loader is None: |
|
|
|
data_loader = self._data_loader(X) |
|
|
|
return self._predict(data_loader).softmax(axis=1).cpu().numpy() |
|
|
|
|
|
|
|
def _score(self, data_loader): |
|
|
|
model = self.model |
|
|
|
criterion = self.criterion |
|
|
|
device = self.device |
|
|
|
|
|
|
|
model.eval() |
|
|
|
|
|
|
|
total_correct_num, total_num, total_loss = 0, 0, 0.0 |
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
|
for data, target in data_loader: |
|
|
|
data, target = data.to(device), target.to(device) |
|
|
|
|
|
|
|
out = model(data) |
|
|
|
|
|
|
|
if len(out.shape) > 1: |
|
|
|
correct_num = sum(target == out.argmax(axis=1)).item() |
|
|
|
else: |
|
|
|
correct_num = sum(target == (out > 0.5)).item() |
|
|
|
loss = criterion(out, target) |
|
|
|
total_loss += loss.item() * data.size(0) |
|
|
|
|
|
|
|
total_correct_num += correct_num |
|
|
|
total_num += data.size(0) |
|
|
|
|
|
|
|
mean_loss = total_loss / total_num |
|
|
|
accuracy = total_correct_num / total_num |
|
|
|
|
|
|
|
return mean_loss, accuracy |
|
|
|
|
|
|
|
def score( |
|
|
|
self, |
|
|
|
data_loader: DataLoader = None, |
|
|
|
X: List[Any] = None, |
|
|
|
y: List[int] = None, |
|
|
|
print_prefix: str = "", |
|
|
|
) -> float: |
|
|
|
""" |
|
|
|
Validate the model. |
|
|
|
|
|
|
|
Parameters |
|
|
|
---------- |
|
|
|
data_loader : DataLoader, optional |
|
|
|
The data loader used for scoring, by default None |
|
|
|
X : List[Any], optional |
|
|
|
The input data, by default None |
|
|
|
y : List[int], optional |
|
|
|
The target data, by default None |
|
|
|
print_prefix : str, optional |
|
|
|
The prefix used for printing, by default "" |
|
|
|
|
|
|
|
Returns |
|
|
|
------- |
|
|
|
float |
|
|
|
The accuracy of the model. |
|
|
|
""" |
|
|
|
recorder = self.recorder |
|
|
|
recorder.print("Start validation ", print_prefix) |
|
|
|
|
|
|
|
if data_loader is None: |
|
|
|
data_loader = self._data_loader(X, y) |
|
|
|
mean_loss, accuracy = self._score(data_loader) |
|
|
|
recorder.print( |
|
|
|
"[%s] mean loss: %f, accuray: %f" % (print_prefix, mean_loss, accuracy) |
|
|
|
) |
|
|
|
return accuracy |
|
|
|
|
|
|
|
def _data_loader( |
|
|
|
self, |
|
|
|
X: List[Any], |
|
|
|
y: List[int] = None, |
|
|
|
) -> DataLoader: |
|
|
|
""" |
|
|
|
Generate data_loader for user provided data. |
|
|
|
|
|
|
|
Parameters |
|
|
|
---------- |
|
|
|
X : List[Any] |
|
|
|
The input data. |
|
|
|
y : List[int], optional |
|
|
|
The target data, by default None |
|
|
|
|
|
|
|
Returns |
|
|
|
------- |
|
|
|
DataLoader |
|
|
|
The data loader. |
|
|
|
""" |
|
|
|
collate_fn = self.collate_fn |
|
|
|
transform = self.transform |
|
|
|
|
|
|
|
if y is None: |
|
|
|
y = [0] * len(X) |
|
|
|
dataset = XYDataset(X, y, transform=transform) |
|
|
|
sampler = None |
|
|
|
data_loader = DataLoader( |
|
|
|
dataset, |
|
|
|
batch_size=self.batch_size, |
|
|
|
shuffle=False, |
|
|
|
sampler=sampler, |
|
|
|
num_workers=int(self.num_workers), |
|
|
|
collate_fn=collate_fn, |
|
|
|
) |
|
|
|
return data_loader |
|
|
|
|
|
|
|
def save(self, epoch_id: int, save_dir: str = ""): |
|
|
|
""" |
|
|
|
Save the model and the optimizer. |
|
|
|
|
|
|
|
Parameters |
|
|
|
---------- |
|
|
|
epoch_id : int |
|
|
|
The epoch id. |
|
|
|
save_dir : str, optional |
|
|
|
The directory to save the model, by default "" |
|
|
|
""" |
|
|
|
recorder = self.recorder |
|
|
|
if not os.path.exists(save_dir): |
|
|
|
os.makedirs(save_dir) |
|
|
|
recorder.print("Saving model and opter") |
|
|
|
save_path = os.path.join(save_dir, str(epoch_id) + "_net.pth") |
|
|
|
torch.save(self.model.state_dict(), save_path) |
|
|
|
|
|
|
|
save_path = os.path.join(save_dir, str(epoch_id) + "_opt.pth") |
|
|
|
torch.save(self.optimizer.state_dict(), save_path) |
|
|
|
|
|
|
|
def load(self, epoch_id: int, load_dir: str = ""): |
|
|
|
""" |
|
|
|
Load the model and the optimizer. |
|
|
|
|
|
|
|
Parameters |
|
|
|
---------- |
|
|
|
epoch_id : int |
|
|
|
The epoch id. |
|
|
|
load_dir : str, optional |
|
|
|
The directory to load the model, by default "" |
|
|
|
""" |
|
|
|
recorder = self.recorder |
|
|
|
recorder.print("Loading model and opter") |
|
|
|
load_path = os.path.join(load_dir, str(epoch_id) + "_net.pth") |
|
|
|
self.model.load_state_dict(torch.load(load_path)) |
|
|
|
|
|
|
|
load_path = os.path.join(load_dir, str(epoch_id) + "_opt.pth") |
|
|
|
self.optimizer.load_state_dict(torch.load(load_path)) |
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
pass |