Browse Source

[MNT] resolve all comments in basic_nn.py

pull/3/head
Gao Enhao 2 years ago
parent
commit
bc6dbfeb99
2 changed files with 126 additions and 152 deletions
  1. +2
    -2
      abl/bridge/simple_bridge.py
  2. +124
    -150
      abl/learning/basic_nn.py

+ 2
- 2
abl/bridge/simple_bridge.py View File

@@ -74,10 +74,10 @@ class SimpleBridge(BaseBridge):
pred_prob, pred_pseudo_label, Y
)
abduced_label = self.pseudo_label_to_idx(abduced_pseudo_label)
min_loss = self.model.train(X, abduced_label)
loss = self.model.train(X, abduced_label)

print_log(
f"Epoch(train) [{epoch + 1}] [{(seg_idx + 1):3}/{len(data_loader)}] minimal_loss is {min_loss:.5f}",
f"Epoch(train) [{epoch + 1}] [{(seg_idx + 1):3}/{len(data_loader)}] model loss is {loss:.5f}",
logger="current",
)



+ 124
- 150
abl/learning/basic_nn.py View File

@@ -10,10 +10,6 @@
#
# ================================================================#

import sys

sys.path.append("..")

import torch
import numpy
from torch.utils.data import DataLoader
@@ -21,12 +17,12 @@ from ..utils.logger import print_log
from ..dataset import ClassificationDataset

import os
from typing import List, Any, T, Optional, Callable
from typing import List, Any, T, Optional, Callable, Tuple


class BasicNN:
"""
Wrap NN models into the form of an sklearn estimator
Wrap NN models into the form of an sklearn estimator.

Parameters
----------
@@ -34,83 +30,35 @@ class BasicNN:
The PyTorch model to be trained or used for prediction.
criterion : torch.nn.Module
The loss function used for training.
optimizer : torch.nn.Module
optimizer : torch.optim.Optimizer
The optimizer used for training.
device : torch.device, optional
The device on which the model will be trained or used for prediction, by default torch.decive("cpu").
The device on which the model will be trained or used for prediction, by default torch.device("cpu").
batch_size : int, optional
The batch size used for training, by default 1.
The batch size used for training, by default 32.
num_epochs : int, optional
The number of epochs used for training, by default 1.
stop_loss : Optional[float], optional
The loss value at which to stop training, by default 0.01.
num_workers : int, optional
num_workers : int
The number of workers used for loading data, by default 0.
save_interval : Optional[int], optional
The interval at which to save the model during training, by default None.
save_dir : Optional[str], optional
The directory in which to save the model during training, by default None.
transform : Callable[..., Any], optional
A function/transform that takes in an object and returns a transformed version. Defaults to None.
A function/transform that takes in an object and returns a transformed version, by default None.
collate_fn : Callable[[List[T]], Any], optional
The function used to collate data, by default None.

Attributes
----------
model : torch.nn.Module
The PyTorch model to be trained or used for prediction.
batch_size : int
The batch size used for training.
num_epochs : int
The number of epochs used for training.
stop_loss : Optional[float]
The loss value at which to stop training.
num_workers : int
The number of workers used for loading data.
criterion : torch.nn.Module
The loss function used for training.
optimizer : torch.nn.Module
The optimizer used for training.
transform : Callable[..., Any]
The transformation function used for data augmentation.
device : torch.device
The device on which the model will be trained or used for prediction.
save_interval : Optional[int]
The interval at which to save the model during training.
save_dir : Optional[str]
The directory in which to save the model during training.
collate_fn : Callable[[List[T]], Any]
The function used to collate data.

Methods
-------
fit(data_loader=None, X=None, y=None)
Train the model.
train_epoch(data_loader)
Train the model for one epoch.
predict(data_loader=None, X=None, print_prefix="")
Predict the class of the input data.
predict_proba(data_loader=None, X=None, print_prefix="")
Predict the probability of each class for the input data.
val(data_loader=None, X=None, y=None, print_prefix="")
Validate the model.
score(data_loader=None, X=None, y=None, print_prefix="")
Score the model.
_data_loader(X, y=None)
Generate the data_loader.
save(epoch_id, save_dir="")
Save the model.
load(epoch_id, load_dir="")
Load the model.
"""

def __init__(
self,
model: torch.nn.Module,
criterion: torch.nn.Module,
optimizer: torch.nn.Module,
optimizer: torch.optim.Optimizer,
device: torch.device = torch.device("cpu"),
batch_size: int = 1,
batch_size: int = 32,
num_epochs: int = 1,
stop_loss: Optional[float] = 0.01,
num_workers: int = 0,
@@ -118,38 +66,46 @@ class BasicNN:
save_dir: Optional[str] = None,
transform: Callable[..., Any] = None,
collate_fn: Callable[[List[T]], Any] = None,
):
) -> None:
self.model = model.to(device)

self.criterion = criterion
self.optimizer = optimizer
self.device = device
self.batch_size = batch_size
self.num_epochs = num_epochs
self.stop_loss = stop_loss
self.num_workers = num_workers

self.criterion = criterion
self.optimizer = optimizer
self.transform = transform
self.device = device

self.save_interval = save_interval
self.save_dir = save_dir
self.transform = transform
self.collate_fn = collate_fn

def _fit(self, data_loader, n_epoch, stop_loss):
min_loss = 1e10
for epoch in range(n_epoch):
def _fit(self, data_loader) -> float:
"""
Internal method to fit the model on data for n epochs, with early stopping.

Parameters
----------
data_loader : DataLoader
Data loader providing training samples.

Returns
-------
float
The loss value of the trained model.
"""
loss_value = 1e9
for epoch in range(self.num_epochs):
loss_value = self.train_epoch(data_loader)
if min_loss < 0 or loss_value < min_loss:
min_loss = loss_value
if self.save_interval is not None and (epoch + 1) % self.save_interval == 0:
if self.save_dir is None:
raise ValueError(
"save_dir should not be None if save_interval is not None"
"save_dir should not be None if save_interval is not None."
)
self.save(epoch + 1, self.save_dir)
if stop_loss is not None and loss_value < stop_loss:
self.save(epoch + 1)
if self.stop_loss is not None and loss_value < self.stop_loss:
break
return min_loss
return loss_value

def fit(
self, data_loader: DataLoader = None, X: List[Any] = None, y: List[int] = None
@@ -160,11 +116,11 @@ class BasicNN:
Parameters
----------
data_loader : DataLoader, optional
The data loader used for training, by default None
The data loader used for training, by default None.
X : List[Any], optional
The input data, by default None
The input data, by default None.
y : List[int], optional
The target data, by default None
The target data, by default None.

Returns
-------
@@ -172,10 +128,13 @@ class BasicNN:
The loss value of the trained model.
"""
if data_loader is None:
data_loader = self._data_loader(X, y)
return self._fit(data_loader, self.num_epochs, self.stop_loss)
if X is None:
raise ValueError("data_loader and X can not be None simultaneously.")
else:
data_loader = self._data_loader(X, y)
return self._fit(data_loader)

def train_epoch(self, data_loader: DataLoader):
def train_epoch(self, data_loader: DataLoader) -> float:
"""
Train the model for one epoch.

@@ -211,7 +170,20 @@ class BasicNN:

return total_loss / total_num

def _predict(self, data_loader):
def _predict(self, data_loader) -> torch.Tensor:
"""
Internal method to predict the outputs given a DataLoader.

Parameters
----------
data_loader : DataLoader
The DataLoader providing input samples.

Returns
-------
torch.Tensor
Raw output from the model.
"""
model = self.model
device = self.device

@@ -227,10 +199,7 @@ class BasicNN:
return torch.cat(results, axis=0)

def predict(
self,
data_loader: DataLoader = None,
X: List[Any] = None,
print_prefix: str = "",
self, data_loader: DataLoader = None, X: List[Any] = None
) -> numpy.ndarray:
"""
Predict the class of the input data.
@@ -238,11 +207,9 @@ class BasicNN:
Parameters
----------
data_loader : DataLoader, optional
The data loader used for prediction, by default None
The data loader used for prediction, by default None.
X : List[Any], optional
The input data, by default None
print_prefix : str, optional
The prefix used for printing, by default ""
The input data, by default None.

Returns
-------
@@ -255,10 +222,7 @@ class BasicNN:
return self._predict(data_loader).argmax(axis=1).cpu().numpy()

def predict_proba(
self,
data_loader: DataLoader = None,
X: List[Any] = None,
print_prefix: str = "",
self, data_loader: DataLoader = None, X: List[Any] = None
) -> numpy.ndarray:
"""
Predict the probability of each class for the input data.
@@ -266,11 +230,9 @@ class BasicNN:
Parameters
----------
data_loader : DataLoader, optional
The data loader used for prediction, by default None
The data loader used for prediction, by default None.
X : List[Any], optional
The input data, by default None
print_prefix : str, optional
The prefix used for printing, by default ""
The input data, by default None.

Returns
-------
@@ -282,7 +244,21 @@ class BasicNN:
data_loader = self._data_loader(X)
return self._predict(data_loader).softmax(axis=1).cpu().numpy()

def _score(self, data_loader):
def _score(self, data_loader) -> Tuple[float, float]:
"""
Internal method to compute loss and accuracy for the data provided through a DataLoader.

Parameters
----------
data_loader : DataLoader
Data loader to use for evaluation.

Returns
-------
Tuple[float, float]
mean_loss: float, The mean loss of the model on the provided data.
accuracy: float, The accuracy of the model on the provided data.
"""
model = self.model
criterion = self.criterion
device = self.device
@@ -298,9 +274,9 @@ class BasicNN:
out = model(data)

if len(out.shape) > 1:
correct_num = sum(target == out.argmax(axis=1)).item()
correct_num = (target == out.argmax(axis=1)).sum().item()
else:
correct_num = sum(target == (out > 0.5)).item()
correct_num = (target == (out > 0.5)).sum().item()
loss = criterion(out, target)
total_loss += loss.item() * data.size(0)

@@ -313,11 +289,7 @@ class BasicNN:
return mean_loss, accuracy

def score(
self,
data_loader: DataLoader = None,
X: List[Any] = None,
y: List[int] = None,
print_prefix: str = "",
self, data_loader: DataLoader = None, X: List[Any] = None, y: List[int] = None
) -> float:
"""
Validate the model.
@@ -325,25 +297,25 @@ class BasicNN:
Parameters
----------
data_loader : DataLoader, optional
The data loader used for scoring, by default None
The data loader used for scoring, by default None.
X : List[Any], optional
The input data, by default None
The input data, by default None.
y : List[int], optional
The target data, by default None
print_prefix : str, optional
The prefix used for printing, by default ""
The target data, by default None.

Returns
-------
float
The accuracy of the model.
"""
print_log(f"Start machine learning model validation", logger="current")
print_log("Start machine learning model validation", logger="current")

if data_loader is None:
data_loader = self._data_loader(X, y)
mean_loss, accuracy = self._score(data_loader)
print_log(f"{print_prefix} mean loss: {mean_loss:.3f}, accuray: {accuracy:.3f}", logger="current")
print_log(
f"mean loss: {mean_loss:.3f}, accuray: {accuracy:.3f}", logger="current"
)
return accuracy

def _data_loader(
@@ -352,38 +324,39 @@ class BasicNN:
y: List[int] = None,
) -> DataLoader:
"""
Generate data_loader for user provided data.
Generate a DataLoader for user-provided input and target data.

Parameters
----------
X : List[Any]
The input data.
Input samples.
y : List[int], optional
The target data, by default None
Target labels. If None, dummy labels are created, by default None.

Returns
-------
DataLoader
The data loader.
A DataLoader providing batches of (X, y) pairs.
"""
collate_fn = self.collate_fn
transform = self.transform

if X is None:
raise ValueError("X should not be None.")
if y is None:
y = [0] * len(X)
dataset = ClassificationDataset(X, y, transform=transform)
sampler = None
if not (len(y) == len(X)):
raise ValueError("X and y should have equal length.")

dataset = ClassificationDataset(X, y, transform=self.transform)
data_loader = DataLoader(
dataset,
batch_size=self.batch_size,
shuffle=False,
sampler=sampler,
shuffle=True,
num_workers=int(self.num_workers),
collate_fn=collate_fn,
collate_fn=self.collate_fn,
)
return data_loader

def save(self, epoch_id: int = 0, save_dir: str = None, save_path: str = None):
def save(self, epoch_id: int = 0, save_path: str = None) -> None:
"""
Save the model and the optimizer.

@@ -391,15 +364,20 @@ class BasicNN:
----------
epoch_id : int
The epoch id.
save_dir : str, optional
The directory to save the model, by default ""
save_path : str, optional
The path to save the model, by default None.
"""
if save_dir and (not os.path.exists(save_dir)):
os.makedirs(save_dir)
print_log(f"Checkpoints will be saved to {save_dir}", logger="current")
if self.save_dir is None and save_path is None:
raise ValueError(
"'save_dir' and 'save_path' should not be None simultaneously."
)

if save_path is None:
save_path = os.path.join(save_dir, str(epoch_id) + ".pth")
save_path = os.path.join(
self.save_dir, f"model_checkpoint_epoch_{epoch_id}.pth"
)
if not os.path.exists(self.save_dir):
os.makedirs(self.save_dir)

print_log(f"Checkpoints will be saved to {save_path}", logger="current")

@@ -410,29 +388,25 @@ class BasicNN:

torch.save(save_parma_dic, save_path)

def load(self, epoch_id: int = 0, load_dir: str = "", load_path: str = None):
def load(self, load_path: str = "") -> None:
"""
Load the model and the optimizer.

Parameters
----------
epoch_id : int
The epoch id.
load_dir : str, optional
The directory to load the model, by default ""
load_path : str
The directory to load the model, by default "".
"""

if load_path is not None:
print_log(f"Loads checkpoint by local backend from path: {load_path}", logger="current")
else:
print_log(f"Loads checkpoint by local backend from dir: {load_dir}", logger="current")
load_path = os.path.join(load_dir, str(epoch_id) + ".pth")
if load_path is None:
raise ValueError("Load path should not be None.")

print_log(
f"Loads checkpoint by local backend from path: {load_path}",
logger="current",
)

param_dic = torch.load(load_path)
self.model.load_state_dict(param_dic["model"])
if "optimizer" in param_dic.keys():
self.optimizer.load_state_dict(param_dic["optimizer"])


if __name__ == "__main__":
pass

Loading…
Cancel
Save