Browse Source

[MNT] resolve all comments in basic_nn.py

pull/3/head
Gao Enhao 2 years ago
parent
commit
bc6dbfeb99
2 changed files with 126 additions and 152 deletions
  1. +2
    -2
      abl/bridge/simple_bridge.py
  2. +124
    -150
      abl/learning/basic_nn.py

+ 2
- 2
abl/bridge/simple_bridge.py View File

@@ -74,10 +74,10 @@ class SimpleBridge(BaseBridge):
pred_prob, pred_pseudo_label, Y pred_prob, pred_pseudo_label, Y
) )
abduced_label = self.pseudo_label_to_idx(abduced_pseudo_label) abduced_label = self.pseudo_label_to_idx(abduced_pseudo_label)
min_loss = self.model.train(X, abduced_label)
loss = self.model.train(X, abduced_label)


print_log( print_log(
f"Epoch(train) [{epoch + 1}] [{(seg_idx + 1):3}/{len(data_loader)}] minimal_loss is {min_loss:.5f}",
f"Epoch(train) [{epoch + 1}] [{(seg_idx + 1):3}/{len(data_loader)}] model loss is {loss:.5f}",
logger="current", logger="current",
) )




+ 124
- 150
abl/learning/basic_nn.py View File

@@ -10,10 +10,6 @@
# #
# ================================================================# # ================================================================#


import sys

sys.path.append("..")

import torch import torch
import numpy import numpy
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
@@ -21,12 +17,12 @@ from ..utils.logger import print_log
from ..dataset import ClassificationDataset from ..dataset import ClassificationDataset


import os import os
from typing import List, Any, T, Optional, Callable
from typing import List, Any, T, Optional, Callable, Tuple




class BasicNN: class BasicNN:
""" """
Wrap NN models into the form of an sklearn estimator
Wrap NN models into the form of an sklearn estimator.


Parameters Parameters
---------- ----------
@@ -34,83 +30,35 @@ class BasicNN:
The PyTorch model to be trained or used for prediction. The PyTorch model to be trained or used for prediction.
criterion : torch.nn.Module criterion : torch.nn.Module
The loss function used for training. The loss function used for training.
optimizer : torch.nn.Module
optimizer : torch.optim.Optimizer
The optimizer used for training. The optimizer used for training.
device : torch.device, optional device : torch.device, optional
The device on which the model will be trained or used for prediction, by default torch.decive("cpu").
The device on which the model will be trained or used for prediction, by default torch.device("cpu").
batch_size : int, optional batch_size : int, optional
The batch size used for training, by default 1.
The batch size used for training, by default 32.
num_epochs : int, optional num_epochs : int, optional
The number of epochs used for training, by default 1. The number of epochs used for training, by default 1.
stop_loss : Optional[float], optional stop_loss : Optional[float], optional
The loss value at which to stop training, by default 0.01. The loss value at which to stop training, by default 0.01.
num_workers : int, optional
num_workers : int
The number of workers used for loading data, by default 0. The number of workers used for loading data, by default 0.
save_interval : Optional[int], optional save_interval : Optional[int], optional
The interval at which to save the model during training, by default None. The interval at which to save the model during training, by default None.
save_dir : Optional[str], optional save_dir : Optional[str], optional
The directory in which to save the model during training, by default None. The directory in which to save the model during training, by default None.
transform : Callable[..., Any], optional transform : Callable[..., Any], optional
A function/transform that takes in an object and returns a transformed version. Defaults to None.
A function/transform that takes in an object and returns a transformed version, by default None.
collate_fn : Callable[[List[T]], Any], optional collate_fn : Callable[[List[T]], Any], optional
The function used to collate data, by default None. The function used to collate data, by default None.

Attributes
----------
model : torch.nn.Module
The PyTorch model to be trained or used for prediction.
batch_size : int
The batch size used for training.
num_epochs : int
The number of epochs used for training.
stop_loss : Optional[float]
The loss value at which to stop training.
num_workers : int
The number of workers used for loading data.
criterion : torch.nn.Module
The loss function used for training.
optimizer : torch.nn.Module
The optimizer used for training.
transform : Callable[..., Any]
The transformation function used for data augmentation.
device : torch.device
The device on which the model will be trained or used for prediction.
save_interval : Optional[int]
The interval at which to save the model during training.
save_dir : Optional[str]
The directory in which to save the model during training.
collate_fn : Callable[[List[T]], Any]
The function used to collate data.

Methods
-------
fit(data_loader=None, X=None, y=None)
Train the model.
train_epoch(data_loader)
Train the model for one epoch.
predict(data_loader=None, X=None, print_prefix="")
Predict the class of the input data.
predict_proba(data_loader=None, X=None, print_prefix="")
Predict the probability of each class for the input data.
val(data_loader=None, X=None, y=None, print_prefix="")
Validate the model.
score(data_loader=None, X=None, y=None, print_prefix="")
Score the model.
_data_loader(X, y=None)
Generate the data_loader.
save(epoch_id, save_dir="")
Save the model.
load(epoch_id, load_dir="")
Load the model.
""" """


def __init__( def __init__(
self, self,
model: torch.nn.Module, model: torch.nn.Module,
criterion: torch.nn.Module, criterion: torch.nn.Module,
optimizer: torch.nn.Module,
optimizer: torch.optim.Optimizer,
device: torch.device = torch.device("cpu"), device: torch.device = torch.device("cpu"),
batch_size: int = 1,
batch_size: int = 32,
num_epochs: int = 1, num_epochs: int = 1,
stop_loss: Optional[float] = 0.01, stop_loss: Optional[float] = 0.01,
num_workers: int = 0, num_workers: int = 0,
@@ -118,38 +66,46 @@ class BasicNN:
save_dir: Optional[str] = None, save_dir: Optional[str] = None,
transform: Callable[..., Any] = None, transform: Callable[..., Any] = None,
collate_fn: Callable[[List[T]], Any] = None, collate_fn: Callable[[List[T]], Any] = None,
):
) -> None:
self.model = model.to(device) self.model = model.to(device)

self.criterion = criterion
self.optimizer = optimizer
self.device = device
self.batch_size = batch_size self.batch_size = batch_size
self.num_epochs = num_epochs self.num_epochs = num_epochs
self.stop_loss = stop_loss self.stop_loss = stop_loss
self.num_workers = num_workers self.num_workers = num_workers

self.criterion = criterion
self.optimizer = optimizer
self.transform = transform
self.device = device

self.save_interval = save_interval self.save_interval = save_interval
self.save_dir = save_dir self.save_dir = save_dir
self.transform = transform
self.collate_fn = collate_fn self.collate_fn = collate_fn


def _fit(self, data_loader, n_epoch, stop_loss):
min_loss = 1e10
for epoch in range(n_epoch):
def _fit(self, data_loader) -> float:
"""
Internal method to fit the model on data for n epochs, with early stopping.

Parameters
----------
data_loader : DataLoader
Data loader providing training samples.

Returns
-------
float
The loss value of the trained model.
"""
loss_value = 1e9
for epoch in range(self.num_epochs):
loss_value = self.train_epoch(data_loader) loss_value = self.train_epoch(data_loader)
if min_loss < 0 or loss_value < min_loss:
min_loss = loss_value
if self.save_interval is not None and (epoch + 1) % self.save_interval == 0: if self.save_interval is not None and (epoch + 1) % self.save_interval == 0:
if self.save_dir is None: if self.save_dir is None:
raise ValueError( raise ValueError(
"save_dir should not be None if save_interval is not None"
"save_dir should not be None if save_interval is not None."
) )
self.save(epoch + 1, self.save_dir)
if stop_loss is not None and loss_value < stop_loss:
self.save(epoch + 1)
if self.stop_loss is not None and loss_value < self.stop_loss:
break break
return min_loss
return loss_value


def fit( def fit(
self, data_loader: DataLoader = None, X: List[Any] = None, y: List[int] = None self, data_loader: DataLoader = None, X: List[Any] = None, y: List[int] = None
@@ -160,11 +116,11 @@ class BasicNN:
Parameters Parameters
---------- ----------
data_loader : DataLoader, optional data_loader : DataLoader, optional
The data loader used for training, by default None
The data loader used for training, by default None.
X : List[Any], optional X : List[Any], optional
The input data, by default None
The input data, by default None.
y : List[int], optional y : List[int], optional
The target data, by default None
The target data, by default None.


Returns Returns
------- -------
@@ -172,10 +128,13 @@ class BasicNN:
The loss value of the trained model. The loss value of the trained model.
""" """
if data_loader is None: if data_loader is None:
data_loader = self._data_loader(X, y)
return self._fit(data_loader, self.num_epochs, self.stop_loss)
if X is None:
raise ValueError("data_loader and X can not be None simultaneously.")
else:
data_loader = self._data_loader(X, y)
return self._fit(data_loader)


def train_epoch(self, data_loader: DataLoader):
def train_epoch(self, data_loader: DataLoader) -> float:
""" """
Train the model for one epoch. Train the model for one epoch.


@@ -211,7 +170,20 @@ class BasicNN:


return total_loss / total_num return total_loss / total_num


def _predict(self, data_loader):
def _predict(self, data_loader) -> torch.Tensor:
"""
Internal method to predict the outputs given a DataLoader.

Parameters
----------
data_loader : DataLoader
The DataLoader providing input samples.

Returns
-------
torch.Tensor
Raw output from the model.
"""
model = self.model model = self.model
device = self.device device = self.device


@@ -227,10 +199,7 @@ class BasicNN:
return torch.cat(results, axis=0) return torch.cat(results, axis=0)


def predict( def predict(
self,
data_loader: DataLoader = None,
X: List[Any] = None,
print_prefix: str = "",
self, data_loader: DataLoader = None, X: List[Any] = None
) -> numpy.ndarray: ) -> numpy.ndarray:
""" """
Predict the class of the input data. Predict the class of the input data.
@@ -238,11 +207,9 @@ class BasicNN:
Parameters Parameters
---------- ----------
data_loader : DataLoader, optional data_loader : DataLoader, optional
The data loader used for prediction, by default None
The data loader used for prediction, by default None.
X : List[Any], optional X : List[Any], optional
The input data, by default None
print_prefix : str, optional
The prefix used for printing, by default ""
The input data, by default None.


Returns Returns
------- -------
@@ -255,10 +222,7 @@ class BasicNN:
return self._predict(data_loader).argmax(axis=1).cpu().numpy() return self._predict(data_loader).argmax(axis=1).cpu().numpy()


def predict_proba( def predict_proba(
self,
data_loader: DataLoader = None,
X: List[Any] = None,
print_prefix: str = "",
self, data_loader: DataLoader = None, X: List[Any] = None
) -> numpy.ndarray: ) -> numpy.ndarray:
""" """
Predict the probability of each class for the input data. Predict the probability of each class for the input data.
@@ -266,11 +230,9 @@ class BasicNN:
Parameters Parameters
---------- ----------
data_loader : DataLoader, optional data_loader : DataLoader, optional
The data loader used for prediction, by default None
The data loader used for prediction, by default None.
X : List[Any], optional X : List[Any], optional
The input data, by default None
print_prefix : str, optional
The prefix used for printing, by default ""
The input data, by default None.


Returns Returns
------- -------
@@ -282,7 +244,21 @@ class BasicNN:
data_loader = self._data_loader(X) data_loader = self._data_loader(X)
return self._predict(data_loader).softmax(axis=1).cpu().numpy() return self._predict(data_loader).softmax(axis=1).cpu().numpy()


def _score(self, data_loader):
def _score(self, data_loader) -> Tuple[float, float]:
"""
Internal method to compute loss and accuracy for the data provided through a DataLoader.

Parameters
----------
data_loader : DataLoader
Data loader to use for evaluation.

Returns
-------
Tuple[float, float]
mean_loss: float, The mean loss of the model on the provided data.
accuracy: float, The accuracy of the model on the provided data.
"""
model = self.model model = self.model
criterion = self.criterion criterion = self.criterion
device = self.device device = self.device
@@ -298,9 +274,9 @@ class BasicNN:
out = model(data) out = model(data)


if len(out.shape) > 1: if len(out.shape) > 1:
correct_num = sum(target == out.argmax(axis=1)).item()
correct_num = (target == out.argmax(axis=1)).sum().item()
else: else:
correct_num = sum(target == (out > 0.5)).item()
correct_num = (target == (out > 0.5)).sum().item()
loss = criterion(out, target) loss = criterion(out, target)
total_loss += loss.item() * data.size(0) total_loss += loss.item() * data.size(0)


@@ -313,11 +289,7 @@ class BasicNN:
return mean_loss, accuracy return mean_loss, accuracy


def score( def score(
self,
data_loader: DataLoader = None,
X: List[Any] = None,
y: List[int] = None,
print_prefix: str = "",
self, data_loader: DataLoader = None, X: List[Any] = None, y: List[int] = None
) -> float: ) -> float:
""" """
Validate the model. Validate the model.
@@ -325,25 +297,25 @@ class BasicNN:
Parameters Parameters
---------- ----------
data_loader : DataLoader, optional data_loader : DataLoader, optional
The data loader used for scoring, by default None
The data loader used for scoring, by default None.
X : List[Any], optional X : List[Any], optional
The input data, by default None
The input data, by default None.
y : List[int], optional y : List[int], optional
The target data, by default None
print_prefix : str, optional
The prefix used for printing, by default ""
The target data, by default None.


Returns Returns
------- -------
float float
The accuracy of the model. The accuracy of the model.
""" """
print_log(f"Start machine learning model validation", logger="current")
print_log("Start machine learning model validation", logger="current")


if data_loader is None: if data_loader is None:
data_loader = self._data_loader(X, y) data_loader = self._data_loader(X, y)
mean_loss, accuracy = self._score(data_loader) mean_loss, accuracy = self._score(data_loader)
print_log(f"{print_prefix} mean loss: {mean_loss:.3f}, accuray: {accuracy:.3f}", logger="current")
print_log(
f"mean loss: {mean_loss:.3f}, accuray: {accuracy:.3f}", logger="current"
)
return accuracy return accuracy


def _data_loader( def _data_loader(
@@ -352,38 +324,39 @@ class BasicNN:
y: List[int] = None, y: List[int] = None,
) -> DataLoader: ) -> DataLoader:
""" """
Generate data_loader for user provided data.
Generate a DataLoader for user-provided input and target data.


Parameters Parameters
---------- ----------
X : List[Any] X : List[Any]
The input data.
Input samples.
y : List[int], optional y : List[int], optional
The target data, by default None
Target labels. If None, dummy labels are created, by default None.


Returns Returns
------- -------
DataLoader DataLoader
The data loader.
A DataLoader providing batches of (X, y) pairs.
""" """
collate_fn = self.collate_fn
transform = self.transform


if X is None:
raise ValueError("X should not be None.")
if y is None: if y is None:
y = [0] * len(X) y = [0] * len(X)
dataset = ClassificationDataset(X, y, transform=transform)
sampler = None
if not (len(y) == len(X)):
raise ValueError("X and y should have equal length.")

dataset = ClassificationDataset(X, y, transform=self.transform)
data_loader = DataLoader( data_loader = DataLoader(
dataset, dataset,
batch_size=self.batch_size, batch_size=self.batch_size,
shuffle=False,
sampler=sampler,
shuffle=True,
num_workers=int(self.num_workers), num_workers=int(self.num_workers),
collate_fn=collate_fn,
collate_fn=self.collate_fn,
) )
return data_loader return data_loader


def save(self, epoch_id: int = 0, save_dir: str = None, save_path: str = None):
def save(self, epoch_id: int = 0, save_path: str = None) -> None:
""" """
Save the model and the optimizer. Save the model and the optimizer.


@@ -391,15 +364,20 @@ class BasicNN:
---------- ----------
epoch_id : int epoch_id : int
The epoch id. The epoch id.
save_dir : str, optional
The directory to save the model, by default ""
save_path : str, optional
The path to save the model, by default None.
""" """
if save_dir and (not os.path.exists(save_dir)):
os.makedirs(save_dir)
print_log(f"Checkpoints will be saved to {save_dir}", logger="current")
if self.save_dir is None and save_path is None:
raise ValueError(
"'save_dir' and 'save_path' should not be None simultaneously."
)

if save_path is None: if save_path is None:
save_path = os.path.join(save_dir, str(epoch_id) + ".pth")
save_path = os.path.join(
self.save_dir, f"model_checkpoint_epoch_{epoch_id}.pth"
)
if not os.path.exists(self.save_dir):
os.makedirs(self.save_dir)


print_log(f"Checkpoints will be saved to {save_path}", logger="current") print_log(f"Checkpoints will be saved to {save_path}", logger="current")


@@ -410,29 +388,25 @@ class BasicNN:


torch.save(save_parma_dic, save_path) torch.save(save_parma_dic, save_path)


def load(self, epoch_id: int = 0, load_dir: str = "", load_path: str = None):
def load(self, load_path: str = "") -> None:
""" """
Load the model and the optimizer. Load the model and the optimizer.


Parameters Parameters
---------- ----------
epoch_id : int
The epoch id.
load_dir : str, optional
The directory to load the model, by default ""
load_path : str
The directory to load the model, by default "".
""" """


if load_path is not None:
print_log(f"Loads checkpoint by local backend from path: {load_path}", logger="current")
else:
print_log(f"Loads checkpoint by local backend from dir: {load_dir}", logger="current")
load_path = os.path.join(load_dir, str(epoch_id) + ".pth")
if load_path is None:
raise ValueError("Load path should not be None.")

print_log(
f"Loads checkpoint by local backend from path: {load_path}",
logger="current",
)

param_dic = torch.load(load_path) param_dic = torch.load(load_path)
self.model.load_state_dict(param_dic["model"]) self.model.load_state_dict(param_dic["model"])
if "optimizer" in param_dic.keys(): if "optimizer" in param_dic.keys():
self.optimizer.load_state_dict(param_dic["optimizer"]) self.optimizer.load_state_dict(param_dic["optimizer"])


if __name__ == "__main__":
pass

Loading…
Cancel
Save